package proxy import ( "bufio" "errors" "fmt" "io" "log/slog" "mime" "net/http" "strings" "sync/atomic" "github.com/gin-gonic/gin" "anyproxy/internal/middleware" ) // 转发的总请求计数器 var totalForwarded atomic.Int64 // Proxy 封装具体的转发逻辑 type Proxy struct { Client *http.Client Log *slog.Logger } func New(client *http.Client, logger *slog.Logger) *Proxy { return &Proxy{Client: client, Log: logger} } // HandleProxyPath 处理 /proxy/*path 形式的请求 func (p *Proxy) HandleProxyPath(c *gin.Context) { urlStr, err := BuildFromProxyPath(c.Param("proxyPath"), c.Request.URL.Query()) if err != nil { p.writeError(c, http.StatusBadRequest, err) return } p.forward(c, urlStr) } // HandleProtocol 处理 /:protocol/*remainder 形式的请求 func (p *Proxy) HandleProtocol(c *gin.Context) { urlStr, err := BuildFromProtocol(c.Param("protocol"), c.Param("remainder"), c.Request.URL.Query()) if err != nil { p.writeError(c, http.StatusBadRequest, err) return } p.forward(c, urlStr) } func (p *Proxy) writeError(c *gin.Context, code int, err error) { c.JSON(code, gin.H{"error": err.Error(), "req_id": middleware.GetReqID(c)}) } func (p *Proxy) forward(c *gin.Context, target string) { reqID := middleware.GetReqID(c) current := totalForwarded.Add(1) p.Log.Debug("开始转发请求", "req_id", reqID, "count", current, "method", c.Request.Method, "target", target, "uri", c.Request.RequestURI, ) // 基于原始上下文创建上游请求(支持客户端断开时取消) upReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, target, c.Request.Body) if err != nil { p.Log.Error("创建上游请求失败", "req_id", reqID, "error", err) p.writeError(c, http.StatusInternalServerError, errors.New("创建上游请求失败")) return } upReq.Header = c.Request.Header.Clone() // 仅在 SSE 时禁用压缩;稍后检测 resp, err := p.Client.Do(upReq) if err != nil { p.Log.Error("上游请求失败", "req_id", reqID, "error", err) p.writeError(c, http.StatusBadGateway, errors.New("上游请求失败")) return } defer resp.Body.Close() mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) isSSE := mediaType == "text/event-stream" p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE) // 复制上游响应头(最小化过滤) for k, vs := range resp.Header { for _, v := range vs { c.Header(k, v) } } if isSSE { c.Writer.Header().Del("Content-Length") c.Writer.Header().Del("Transfer-Encoding") c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // 确保禁用上游压缩避免 SSE 事件被聚合 upReq.Header.Del("Accept-Encoding") } c.Status(resp.StatusCode) if f, ok := c.Writer.(http.Flusher); ok { f.Flush() } if !isSSE { if _, err := io.Copy(c.Writer, resp.Body); err != nil { p.Log.Error("写入响应体失败", "req_id", reqID, "error", err) } return } reader := bufio.NewReader(resp.Body) w := c.Writer flusher, _ := w.(http.Flusher) for { line, err := reader.ReadBytes('\n') if len(line) > 0 { if _, werr := w.Write(line); werr != nil { p.Log.Warn("SSE写入失败", "req_id", reqID, "error", werr) return } if flusher != nil { flusher.Flush() } } if err != nil { if errors.Is(err, io.EOF) { p.Log.Debug("SSE结束(EOF)", "req_id", reqID) } else { p.Log.Error("SSE读取失败", "req_id", reqID, "error", err) } return } } } // HelloPage 返回简单状态页面 func HelloPage(c *gin.Context) { count := totalForwarded.Load() // 推断外部可见协议与主机(支持反向代理常见头) scheme := "http" if c.Request.TLS != nil { scheme = "https" } if xf := c.GetHeader("X-Forwarded-Proto"); xf != "" { // 取第一个 scheme = strings.TrimSpace(strings.Split(xf, ",")[0]) } host := c.Request.Host if xfh := c.GetHeader("X-Forwarded-Host"); xfh != "" { host = strings.TrimSpace(strings.Split(xfh, ",")[0]) } base := scheme + "://" + host str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count) str += "\n\n使用方法:\n" str += "方式1 - 直接协议路径: \n" str += fmt.Sprintf(" 目标URL: https://example.com/path --> 代理URL: %s/https/example.com/path\n", base) str += fmt.Sprintf(" 目标URL: http://example.com/path --> 代理URL: %s/http/example.com/path\n\n", base) str += "方式2 - 完整URL路径: \n" str += fmt.Sprintf(" 目标URL: https://example.com --> 代理URL: %s/proxy/https://example.com\n", base) str += fmt.Sprintf(" 目标URL: http://example.com --> 代理URL: %s/proxy/http://example.com\n\n", base) str += "目标URL必须以 https:// 或 http:// 开头。\n\n" str += fmt.Sprintf("本机访问基地址: %s\n", base) c.String(200, str) }