diff --git a/main.go b/main.go index 5edc1c8..9e88b6f 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,10 @@ import ( "flag" "fmt" "io" + "log/slog" "net/http" "net/url" + "os" "strings" "sync/atomic" @@ -21,8 +23,38 @@ func main() { port := flag.Int("port", 8080, "代理服务器监听的端口") debug := flag.Bool("debug", false, "是否启用调试模式") + logFile := flag.String("log", "", "日志文件路径,默认为标准输出") flag.Parse() + // 配置slog + var logger *slog.Logger + + // 根据调试模式设置日志级别 + var logLevel slog.Level + if *debug { + logLevel = slog.LevelDebug + } else { + logLevel = slog.LevelInfo + } + + // 创建处理器选项 + opts := &slog.HandlerOptions{ + Level: logLevel, + } + + if *logFile != "" { + file, err := os.OpenFile(*logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666) + if err != nil { + slog.Error("无法打开日志文件", "error", err) + os.Exit(1) + } + logger = slog.New(slog.NewJSONHandler(file, opts)) + } else { + logger = slog.New(slog.NewJSONHandler(os.Stdout, opts)) + } + slog.SetDefault(logger) + + if *debug { gin.SetMode(gin.DebugMode) // 启用调试模式 } else { @@ -42,9 +74,9 @@ func main() { // 检查是否以协议开头的路径 r.Any("/:protocol/*remainder", protocolHandler) - fmt.Printf("HTTP 代理服务器启动,监听端口 :%d\n", *port) + slog.Info("HTTP 代理服务器启动", "port", *port) if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil { - fmt.Printf("启动服务器失败: %v\n", err) + slog.Error("启动服务器失败", "error", err) } } @@ -69,6 +101,26 @@ func proxyHandler(c *gin.Context) { // 规范化URL格式 targetURLStr = normalizeURL(targetURLStr) + // 解析目标URL + parsedURL, err := url.Parse(targetURLStr) + if err != nil { + c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) + return + } + + // 合并查询参数 + originalQuery := c.Request.URL.Query() + targetQuery := parsedURL.Query() + for key, values := range originalQuery { + for _, value := range values { + targetQuery.Add(key, value) + } + } + parsedURL.RawQuery = targetQuery.Encode() + + // 重新构建目标URL字符串 + targetURLStr = parsedURL.String() + // 检查 URL 合法性 if _, err := url.ParseRequestURI(targetURLStr); err != nil { c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) @@ -96,6 +148,26 @@ func protocolHandler(c *gin.Context) { // 规范化URL格式 targetURLStr = normalizeURL(targetURLStr) + // 解析目标URL + parsedURL, err := url.Parse(targetURLStr) + if err != nil { + c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) + return + } + + // 合并查询参数 + originalQuery := c.Request.URL.Query() + targetQuery := parsedURL.Query() + for key, values := range originalQuery { + for _, value := range values { + targetQuery.Add(key, value) + } + } + parsedURL.RawQuery = targetQuery.Encode() + + // 重新构建目标URL字符串 + targetURLStr = parsedURL.String() + // 检查 URL 合法性 if _, err := url.ParseRequestURI(targetURLStr); err != nil { c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) @@ -109,12 +181,19 @@ func protocolHandler(c *gin.Context) { // executeProxy 执行实际的代理请求 func executeProxy(c *gin.Context, targetURLStr string) { // 增加请求计数器 - atomic.AddInt64(&requestCounter, 1) + reqID := atomic.AddInt64(&requestCounter, 1) + + slog.Debug("收到请求", + "reqID", reqID, + "method", c.Request.Method, + "uri", c.Request.RequestURI, + "target", targetURLStr) // 创建到目标服务器的请求 // 注意:我们直接将原始请求的 Body 传递过去 proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body) if err != nil { + slog.Error("创建代理请求失败", "reqID", reqID, "error", err) c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err) return } @@ -126,11 +205,17 @@ func executeProxy(c *gin.Context, targetURLStr string) { client := &http.Client{} resp, err := client.Do(proxyReq) if err != nil { - c.String(http.StatusBadGateway, "请求目标服务器失败: %v", err) + slog.Error("请求目标服务器失败", "reqID", reqID, "error", err) + c.String(http.StatusBadGateway, "请求目标服务器失败: %s", err.Error()) return } defer resp.Body.Close() + slog.Debug("收到响应", + "reqID", reqID, + "status_code", resp.StatusCode, + "status", resp.Status) + // 复制目标服务器响应的 Headers 到原始响应 for key, values := range resp.Header { for _, value := range values { @@ -143,11 +228,11 @@ func executeProxy(c *gin.Context, targetURLStr string) { // 将目标服务器的响应 Body 直接流式传输到客户端 // 使用 io.Copy 更高效,并能处理各种编码(如 chunked) - _, err = io.Copy(c.Writer, resp.Body) + bytesCopied, err := io.Copy(c.Writer, resp.Body) if err != nil { - // 如果在写入 body 时发生错误,记录下来 - fmt.Printf("写入响应 Body 时出错: %v\n", err) + slog.Error("写入响应 Body 时出错", "reqID", reqID, "error", err) } + slog.Debug("响应写入完成", "reqID", reqID, "bytes_copied", bytesCopied) }