update: 整理项目结构 重写中间件 提高复用性

This commit is contained in:
2025-09-26 12:29:57 +08:00
parent 5daf6df318
commit d15352a18b
9 changed files with 430 additions and 324 deletions

166
internal/proxy/proxy.go Normal file
View File

@@ -0,0 +1,166 @@
package proxy
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"strings"
"sync/atomic"
"github.com/gin-gonic/gin"
"anyproxy/internal/middleware"
)
// 转发的总请求计数器
var totalForwarded atomic.Int64
// Proxy 封装具体的转发逻辑
type Proxy struct {
Client *http.Client
Log *slog.Logger
}
func New(client *http.Client, logger *slog.Logger) *Proxy {
return &Proxy{Client: client, Log: logger}
}
// HandleProxyPath 处理 /proxy/*path 形式的请求
func (p *Proxy) HandleProxyPath(c *gin.Context) {
urlStr, err := BuildFromProxyPath(c.Param("proxyPath"), c.Request.URL.Query())
if err != nil {
p.writeError(c, http.StatusBadRequest, err)
return
}
p.forward(c, urlStr)
}
// HandleProtocol 处理 /:protocol/*remainder 形式的请求
func (p *Proxy) HandleProtocol(c *gin.Context) {
urlStr, err := BuildFromProtocol(c.Param("protocol"), c.Param("remainder"), c.Request.URL.Query())
if err != nil {
p.writeError(c, http.StatusBadRequest, err)
return
}
p.forward(c, urlStr)
}
func (p *Proxy) writeError(c *gin.Context, code int, err error) {
c.JSON(code, gin.H{"error": err.Error(), "req_id": middleware.GetReqID(c)})
}
func (p *Proxy) forward(c *gin.Context, target string) {
reqID := middleware.GetReqID(c)
current := totalForwarded.Add(1)
p.Log.Debug("开始转发请求",
"req_id", reqID,
"count", current,
"method", c.Request.Method,
"target", target,
"uri", c.Request.RequestURI,
)
// 基于原始上下文创建上游请求(支持客户端断开时取消)
upReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, target, c.Request.Body)
if err != nil {
p.Log.Error("创建上游请求失败", "req_id", reqID, "error", err)
p.writeError(c, http.StatusInternalServerError, errors.New("创建上游请求失败"))
return
}
upReq.Header = c.Request.Header.Clone()
// 仅在 SSE 时禁用压缩;稍后检测
resp, err := p.Client.Do(upReq)
if err != nil {
p.Log.Error("上游请求失败", "req_id", reqID, "error", err)
p.writeError(c, http.StatusBadGateway, errors.New("上游请求失败"))
return
}
defer resp.Body.Close()
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
isSSE := mediaType == "text/event-stream"
p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE)
// 复制上游响应头(最小化过滤)
for k, vs := range resp.Header {
for _, v := range vs { c.Header(k, v) }
}
if isSSE {
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Del("Transfer-Encoding")
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 确保禁用上游压缩避免 SSE 事件被聚合
upReq.Header.Del("Accept-Encoding")
}
c.Status(resp.StatusCode)
if f, ok := c.Writer.(http.Flusher); ok { f.Flush() }
if !isSSE {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
p.Log.Error("写入响应体失败", "req_id", reqID, "error", err)
}
return
}
reader := bufio.NewReader(resp.Body)
w := c.Writer
flusher, _ := w.(http.Flusher)
for {
line, err := reader.ReadBytes('\n')
if len(line) > 0 {
if _, werr := w.Write(line); werr != nil {
p.Log.Warn("SSE写入失败", "req_id", reqID, "error", werr)
return
}
if flusher != nil { flusher.Flush() }
}
if err != nil {
if errors.Is(err, io.EOF) {
p.Log.Debug("SSE结束(EOF)", "req_id", reqID)
} else {
p.Log.Error("SSE读取失败", "req_id", reqID, "error", err)
}
return
}
}
}
// HelloPage 返回简单状态页面
func HelloPage(c *gin.Context) {
count := totalForwarded.Load()
// 推断外部可见协议与主机(支持反向代理常见头)
scheme := "http"
if c.Request.TLS != nil { scheme = "https" }
if xf := c.GetHeader("X-Forwarded-Proto"); xf != "" {
// 取第一个
scheme = strings.TrimSpace(strings.Split(xf, ",")[0])
}
host := c.Request.Host
if xfh := c.GetHeader("X-Forwarded-Host"); xfh != "" {
host = strings.TrimSpace(strings.Split(xfh, ",")[0])
}
base := scheme + "://" + host
str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count)
str += "\n\n使用方法:\n"
str += "方式1 - 直接协议路径: \n"
str += fmt.Sprintf(" 目标URL: https://example.com/path --> 代理URL: %s/https/example.com/path\n", base)
str += fmt.Sprintf(" 目标URL: http://example.com/path --> 代理URL: %s/http/example.com/path\n\n", base)
str += "方式2 - 完整URL路径: \n"
str += fmt.Sprintf(" 目标URL: https://example.com --> 代理URL: %s/proxy/https://example.com\n", base)
str += fmt.Sprintf(" 目标URL: http://example.com --> 代理URL: %s/proxy/http://example.com\n\n", base)
str += "目标URL必须以 https:// 或 http:// 开头。\n\n"
str += fmt.Sprintf("本机访问基地址: %s\n", base)
c.String(200, str)
}

51
internal/proxy/url.go Normal file
View File

@@ -0,0 +1,51 @@
package proxy
import (
"errors"
"net/url"
"strings"
)
// normalizeURL 规范化URL格式处理缺少斜杠的情况
func normalizeURL(rawURL string) string {
if strings.HasPrefix(rawURL, "https:/") && !strings.HasPrefix(rawURL, "https://") {
return strings.Replace(rawURL, "https:/", "https://", 1)
}
if strings.HasPrefix(rawURL, "http:/") && !strings.HasPrefix(rawURL, "http://") {
return strings.Replace(rawURL, "http:/", "http://", 1)
}
return rawURL
}
// BuildFromProxyPath 构建 /proxy/*path 形式传入的 URL
func BuildFromProxyPath(pathPart string, originalQuery url.Values) (string, error) {
pathPart = strings.TrimPrefix(pathPart, "/")
if pathPart == "" { return "", errors.New("目标为空") }
pathPart = normalizeURL(pathPart)
return mergeQuery(pathPart, originalQuery)
}
// BuildFromProtocol 构建 /:protocol/*remainder 形式
func BuildFromProtocol(protocol, remainder string, originalQuery url.Values) (string, error) {
if protocol != "http" && protocol != "https" {
return "", errors.New("不支持的协议")
}
full := protocol + ":/" + remainder
full = normalizeURL(full)
return mergeQuery(full, originalQuery)
}
func mergeQuery(raw string, original url.Values) (string, error) {
parsed, err := url.Parse(raw)
if err != nil { return "", err }
// 合并 query
q := parsed.Query()
for k, vs := range original {
for _, v := range vs { q.Add(k, v) }
}
parsed.RawQuery = q.Encode()
if _, err := url.ParseRequestURI(parsed.String()); err != nil {
return "", err
}
return parsed.String(), nil
}