fix:修复了Get参数没有被正确转发的问题

This commit is contained in:
2025-08-18 16:10:37 +08:00
parent 481b523fe7
commit 45fd9e40e5

99
main.go
View File

@@ -4,8 +4,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -21,8 +23,38 @@ func main() {
port := flag.Int("port", 8080, "代理服务器监听的端口") port := flag.Int("port", 8080, "代理服务器监听的端口")
debug := flag.Bool("debug", false, "是否启用调试模式") debug := flag.Bool("debug", false, "是否启用调试模式")
logFile := flag.String("log", "", "日志文件路径,默认为标准输出")
flag.Parse() flag.Parse()
// 配置slog
var logger *slog.Logger
// 根据调试模式设置日志级别
var logLevel slog.Level
if *debug {
logLevel = slog.LevelDebug
} else {
logLevel = slog.LevelInfo
}
// 创建处理器选项
opts := &slog.HandlerOptions{
Level: logLevel,
}
if *logFile != "" {
file, err := os.OpenFile(*logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
if err != nil {
slog.Error("无法打开日志文件", "error", err)
os.Exit(1)
}
logger = slog.New(slog.NewJSONHandler(file, opts))
} else {
logger = slog.New(slog.NewJSONHandler(os.Stdout, opts))
}
slog.SetDefault(logger)
if *debug { if *debug {
gin.SetMode(gin.DebugMode) // 启用调试模式 gin.SetMode(gin.DebugMode) // 启用调试模式
} else { } else {
@@ -42,9 +74,9 @@ func main() {
// 检查是否以协议开头的路径 // 检查是否以协议开头的路径
r.Any("/:protocol/*remainder", protocolHandler) r.Any("/:protocol/*remainder", protocolHandler)
fmt.Printf("HTTP 代理服务器启动,监听端口 :%d\n", *port) slog.Info("HTTP 代理服务器启动", "port", *port)
if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil { if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil {
fmt.Printf("启动服务器失败: %v\n", err) slog.Error("启动服务器失败", "error", err)
} }
} }
@@ -69,6 +101,26 @@ func proxyHandler(c *gin.Context) {
// 规范化URL格式 // 规范化URL格式
targetURLStr = normalizeURL(targetURLStr) targetURLStr = normalizeURL(targetURLStr)
// 解析目标URL
parsedURL, err := url.Parse(targetURLStr)
if err != nil {
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
return
}
// 合并查询参数
originalQuery := c.Request.URL.Query()
targetQuery := parsedURL.Query()
for key, values := range originalQuery {
for _, value := range values {
targetQuery.Add(key, value)
}
}
parsedURL.RawQuery = targetQuery.Encode()
// 重新构建目标URL字符串
targetURLStr = parsedURL.String()
// 检查 URL 合法性 // 检查 URL 合法性
if _, err := url.ParseRequestURI(targetURLStr); err != nil { if _, err := url.ParseRequestURI(targetURLStr); err != nil {
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
@@ -96,6 +148,26 @@ func protocolHandler(c *gin.Context) {
// 规范化URL格式 // 规范化URL格式
targetURLStr = normalizeURL(targetURLStr) targetURLStr = normalizeURL(targetURLStr)
// 解析目标URL
parsedURL, err := url.Parse(targetURLStr)
if err != nil {
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
return
}
// 合并查询参数
originalQuery := c.Request.URL.Query()
targetQuery := parsedURL.Query()
for key, values := range originalQuery {
for _, value := range values {
targetQuery.Add(key, value)
}
}
parsedURL.RawQuery = targetQuery.Encode()
// 重新构建目标URL字符串
targetURLStr = parsedURL.String()
// 检查 URL 合法性 // 检查 URL 合法性
if _, err := url.ParseRequestURI(targetURLStr); err != nil { if _, err := url.ParseRequestURI(targetURLStr); err != nil {
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
@@ -109,12 +181,19 @@ func protocolHandler(c *gin.Context) {
// executeProxy 执行实际的代理请求 // executeProxy 执行实际的代理请求
func executeProxy(c *gin.Context, targetURLStr string) { func executeProxy(c *gin.Context, targetURLStr string) {
// 增加请求计数器 // 增加请求计数器
atomic.AddInt64(&requestCounter, 1) reqID := atomic.AddInt64(&requestCounter, 1)
slog.Debug("收到请求",
"reqID", reqID,
"method", c.Request.Method,
"uri", c.Request.RequestURI,
"target", targetURLStr)
// 创建到目标服务器的请求 // 创建到目标服务器的请求
// 注意:我们直接将原始请求的 Body 传递过去 // 注意:我们直接将原始请求的 Body 传递过去
proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body) proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body)
if err != nil { if err != nil {
slog.Error("创建代理请求失败", "reqID", reqID, "error", err)
c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err) c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err)
return return
} }
@@ -126,11 +205,17 @@ func executeProxy(c *gin.Context, targetURLStr string) {
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(proxyReq) resp, err := client.Do(proxyReq)
if err != nil { if err != nil {
c.String(http.StatusBadGateway, "请求目标服务器失败: %v", err) slog.Error("请求目标服务器失败", "reqID", reqID, "error", err)
c.String(http.StatusBadGateway, "请求目标服务器失败: %s", err.Error())
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
slog.Debug("收到响应",
"reqID", reqID,
"status_code", resp.StatusCode,
"status", resp.Status)
// 复制目标服务器响应的 Headers 到原始响应 // 复制目标服务器响应的 Headers 到原始响应
for key, values := range resp.Header { for key, values := range resp.Header {
for _, value := range values { for _, value := range values {
@@ -143,11 +228,11 @@ func executeProxy(c *gin.Context, targetURLStr string) {
// 将目标服务器的响应 Body 直接流式传输到客户端 // 将目标服务器的响应 Body 直接流式传输到客户端
// 使用 io.Copy 更高效,并能处理各种编码(如 chunked // 使用 io.Copy 更高效,并能处理各种编码(如 chunked
_, err = io.Copy(c.Writer, resp.Body) bytesCopied, err := io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
// 如果在写入 body 时发生错误,记录下来 slog.Error("写入响应 Body 时出错", "reqID", reqID, "error", err)
fmt.Printf("写入响应 Body 时出错: %v\n", err)
} }
slog.Debug("响应写入完成", "reqID", reqID, "bytes_copied", bytesCopied)
} }