diff --git a/README.md b/README.md index da60849..33afcc4 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,35 @@ # AnyProxy 简介 AnyProxy 是一个简单的 HTTP/HTTPS 代理服务器。它可以帮助你转发和代理请求。 +支持 GET / POST / PUT / DELETE / HEAD / OPTIONS 请求。 +兼容 SSE 流式请求。 ## 使用方法 -1. **直接协议路径** - - 目标URL: `https://example.com/path` - 代理URL: `http://AnyproxyIP/https/example.com/path` - - 目标URL: `http://example.com/path` - 代理URL: `http://AnyproxyIP/http/example.com/path` +1. **直接协议路径** -2. **完整URL路径** - - 目标URL: `https://example.com` - 代理URL: `http://AnyproxyIP/proxy/https://example.com` + - 目标 URL: `https://example.com/path` + 代理 URL: `http://AnyproxyIP/https/example.com/path` + - 目标 URL: `http://example.com/path` + 代理 URL: `http://AnyproxyIP/http/example.com/path` -> 目标URL 必须以 `https://` 或 `http://` 开头。 +2. **完整 URL 路径** + - 目标 URL: `https://example.com` + 代理 URL: `http://AnyproxyIP/proxy/https://example.com` + +> 目标 URL 必须以 `https://` 或 `http://` 开头。 + +> 访问根路径可以查看使用方式 ## 安装 -1. 下载对应平台的二进制Relase文件 +1. 下载对应平台的二进制 Relase 文件 2. 运行二进制文件 3. (可选) 配置为系统服务 -系统服务参考(Systemd) -~~~ ini +### 系统服务参考(Systemd) + +```ini # /etc/systemd/system/anyproxy.service [Unit] Description=AnyProxy Service @@ -37,4 +43,14 @@ User=root [Install] WantedBy=multi-user.target -~~~ +``` + +## 可选参数 + +| 参数 | 是否可选 | 默认值 | 数据类型 | 解释 | +| -------- | -------: | :---------------: | -------- | --------------------------------- | +| -port | 是 | 8080 | int | 代理服务器监听端口 | +| -debug | 是 | false | bool | 调试模式(debug 级别日志) | +| -log | 是 | (输出到 stderr) | string | 日志文件路径(默认输出到 stderr) | +| -grace | 是 | 10 | int | 优雅停机等待秒数 | +| -timeout | 是 | 0 | int | 单次上游请求超时秒(0 = 不设置) | diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..14b874d --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "flag" + "fmt" +) + +// Config 保存程序配置 +type Config struct { + Port int + Debug bool + LogFile string + ShutdownGrace int // 优雅停机等待秒数 + RequestTimeout int // 上游整体请求超时时间(秒) +} + +// Parse 解析命令行参数返回配置 +func Parse() *Config { + cfg := &Config{} + flag.IntVar(&cfg.Port, "port", 8080, "代理服务器监听端口") + flag.BoolVar(&cfg.Debug, "debug", false, "调试模式 (debug level log)") + flag.StringVar(&cfg.LogFile, "log", "", "日志文件路径 (默认输出到 stderr)") + flag.IntVar(&cfg.ShutdownGrace, "grace", 10, "优雅停机等待秒数") + flag.IntVar(&cfg.RequestTimeout, "timeout", 0, "单次上游请求超时秒(0=不设置)") + flag.Parse() + return cfg +} + +func (c *Config) Addr() string { return fmt.Sprintf(":%d", c.Port) } diff --git a/internal/middleware/logging.go b/internal/middleware/logging.go new file mode 100644 index 0000000..91e3bed --- /dev/null +++ b/internal/middleware/logging.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "log/slog" + "time" + + "github.com/gin-gonic/gin" +) + +// Logger 使用 slog 输出结构化访问日志 +func Logger(logger *slog.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + start := time.Now() + path := c.Request.URL.Path + raw := c.Request.URL.RawQuery + c.Next() + if raw != "" { path = path + "?" + raw } + latency := time.Since(start) + status := c.Writer.Status() + logger.Info("HTTP请求", + "req_id", GetReqID(c), + "method", c.Request.Method, + "path", path, + "status", status, + "latency_ms", latency.Milliseconds(), + "size", c.Writer.Size(), + "ip", c.ClientIP(), + "ua", c.GetHeader("User-Agent"), + ) + } +} diff --git a/internal/middleware/recovery.go b/internal/middleware/recovery.go new file mode 100644 index 0000000..b9989ec --- /dev/null +++ b/internal/middleware/recovery.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "log/slog" + "net/http" + "runtime/debug" + + "github.com/gin-gonic/gin" +) + +// Recovery 捕获 panic 并记录堆栈信息 +func Recovery(logger *slog.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if rcv := recover(); rcv != nil { + logger.Error("发生Panic", + "req_id", GetReqID(c), + "error", rcv, + "stack", string(debug.Stack()), + "path", c.Request.URL.Path, + ) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ + "error": "内部服务器错误", + "req_id": GetReqID(c), + }) + } + }() + c.Next() + } +} diff --git a/internal/middleware/requestid.go b/internal/middleware/requestid.go new file mode 100644 index 0000000..75a1654 --- /dev/null +++ b/internal/middleware/requestid.go @@ -0,0 +1,32 @@ +package middleware + +import ( + "fmt" + "sync/atomic" + + "github.com/gin-gonic/gin" +) + +const RequestIDKey = "reqID" + +var globalReqID atomic.Int64 + +// RequestID 生成自增的请求 ID 并注入上下文及响应头 +func RequestID() gin.HandlerFunc { + return func(c *gin.Context) { + id := globalReqID.Add(1) + c.Set(RequestIDKey, id) + c.Writer.Header().Set("X-Request-ID", fmt.Sprintf("%d", id)) + c.Next() + } +} + +// GetReqID 从上下文中获取请求 ID +func GetReqID(c *gin.Context) int64 { + if v, ok := c.Get(RequestIDKey); ok { + if id, ok2 := v.(int64); ok2 { + return id + } + } + return 0 +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..83c3870 --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,166 @@ +package proxy + +import ( + "bufio" + "errors" + "fmt" + "io" + "log/slog" + "mime" + "net/http" + "strings" + "sync/atomic" + + "github.com/gin-gonic/gin" + + "anyproxy/internal/middleware" +) + +// 转发的总请求计数器 +var totalForwarded atomic.Int64 + +// Proxy 封装具体的转发逻辑 +type Proxy struct { + Client *http.Client + Log *slog.Logger +} + +func New(client *http.Client, logger *slog.Logger) *Proxy { + return &Proxy{Client: client, Log: logger} +} + +// HandleProxyPath 处理 /proxy/*path 形式的请求 +func (p *Proxy) HandleProxyPath(c *gin.Context) { + urlStr, err := BuildFromProxyPath(c.Param("proxyPath"), c.Request.URL.Query()) + if err != nil { + p.writeError(c, http.StatusBadRequest, err) + return + } + p.forward(c, urlStr) +} + +// HandleProtocol 处理 /:protocol/*remainder 形式的请求 +func (p *Proxy) HandleProtocol(c *gin.Context) { + urlStr, err := BuildFromProtocol(c.Param("protocol"), c.Param("remainder"), c.Request.URL.Query()) + if err != nil { + p.writeError(c, http.StatusBadRequest, err) + return + } + p.forward(c, urlStr) +} + +func (p *Proxy) writeError(c *gin.Context, code int, err error) { + c.JSON(code, gin.H{"error": err.Error(), "req_id": middleware.GetReqID(c)}) +} + +func (p *Proxy) forward(c *gin.Context, target string) { + reqID := middleware.GetReqID(c) + current := totalForwarded.Add(1) + p.Log.Debug("开始转发请求", + "req_id", reqID, + "count", current, + "method", c.Request.Method, + "target", target, + "uri", c.Request.RequestURI, + ) + + // 基于原始上下文创建上游请求(支持客户端断开时取消) + upReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, target, c.Request.Body) + if err != nil { + p.Log.Error("创建上游请求失败", "req_id", reqID, "error", err) + p.writeError(c, http.StatusInternalServerError, errors.New("创建上游请求失败")) + return + } + upReq.Header = c.Request.Header.Clone() + + // 仅在 SSE 时禁用压缩;稍后检测 + + resp, err := p.Client.Do(upReq) + if err != nil { + p.Log.Error("上游请求失败", "req_id", reqID, "error", err) + p.writeError(c, http.StatusBadGateway, errors.New("上游请求失败")) + return + } + defer resp.Body.Close() + + mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) + isSSE := mediaType == "text/event-stream" + + p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE) + + // 复制上游响应头(最小化过滤) + for k, vs := range resp.Header { + for _, v := range vs { c.Header(k, v) } + } + if isSSE { + c.Writer.Header().Del("Content-Length") + c.Writer.Header().Del("Transfer-Encoding") + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + // 确保禁用上游压缩避免 SSE 事件被聚合 + upReq.Header.Del("Accept-Encoding") + } + c.Status(resp.StatusCode) + if f, ok := c.Writer.(http.Flusher); ok { f.Flush() } + + if !isSSE { + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + p.Log.Error("写入响应体失败", "req_id", reqID, "error", err) + } + return + } + + reader := bufio.NewReader(resp.Body) + w := c.Writer + flusher, _ := w.(http.Flusher) + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + if _, werr := w.Write(line); werr != nil { + p.Log.Warn("SSE写入失败", "req_id", reqID, "error", werr) + return + } + if flusher != nil { flusher.Flush() } + } + if err != nil { + if errors.Is(err, io.EOF) { + p.Log.Debug("SSE结束(EOF)", "req_id", reqID) + } else { + p.Log.Error("SSE读取失败", "req_id", reqID, "error", err) + } + return + } + } +} + +// HelloPage 返回简单状态页面 +func HelloPage(c *gin.Context) { + count := totalForwarded.Load() + + // 推断外部可见协议与主机(支持反向代理常见头) + scheme := "http" + if c.Request.TLS != nil { scheme = "https" } + if xf := c.GetHeader("X-Forwarded-Proto"); xf != "" { + // 取第一个 + scheme = strings.TrimSpace(strings.Split(xf, ",")[0]) + } + host := c.Request.Host + if xfh := c.GetHeader("X-Forwarded-Host"); xfh != "" { + host = strings.TrimSpace(strings.Split(xfh, ",")[0]) + } + base := scheme + "://" + host + + str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count) + str += "\n\n使用方法:\n" + str += "方式1 - 直接协议路径: \n" + str += fmt.Sprintf(" 目标URL: https://example.com/path --> 代理URL: %s/https/example.com/path\n", base) + str += fmt.Sprintf(" 目标URL: http://example.com/path --> 代理URL: %s/http/example.com/path\n\n", base) + str += "方式2 - 完整URL路径: \n" + str += fmt.Sprintf(" 目标URL: https://example.com --> 代理URL: %s/proxy/https://example.com\n", base) + str += fmt.Sprintf(" 目标URL: http://example.com --> 代理URL: %s/proxy/http://example.com\n\n", base) + str += "目标URL必须以 https:// 或 http:// 开头。\n\n" + str += fmt.Sprintf("本机访问基地址: %s\n", base) + c.String(200, str) +} diff --git a/internal/proxy/url.go b/internal/proxy/url.go new file mode 100644 index 0000000..e68c3bd --- /dev/null +++ b/internal/proxy/url.go @@ -0,0 +1,51 @@ +package proxy + +import ( + "errors" + "net/url" + "strings" +) + +// normalizeURL 规范化URL格式,处理缺少斜杠的情况 +func normalizeURL(rawURL string) string { + if strings.HasPrefix(rawURL, "https:/") && !strings.HasPrefix(rawURL, "https://") { + return strings.Replace(rawURL, "https:/", "https://", 1) + } + if strings.HasPrefix(rawURL, "http:/") && !strings.HasPrefix(rawURL, "http://") { + return strings.Replace(rawURL, "http:/", "http://", 1) + } + return rawURL +} + +// BuildFromProxyPath 构建 /proxy/*path 形式传入的 URL +func BuildFromProxyPath(pathPart string, originalQuery url.Values) (string, error) { + pathPart = strings.TrimPrefix(pathPart, "/") + if pathPart == "" { return "", errors.New("目标为空") } + pathPart = normalizeURL(pathPart) + return mergeQuery(pathPart, originalQuery) +} + +// BuildFromProtocol 构建 /:protocol/*remainder 形式 +func BuildFromProtocol(protocol, remainder string, originalQuery url.Values) (string, error) { + if protocol != "http" && protocol != "https" { + return "", errors.New("不支持的协议") + } + full := protocol + ":/" + remainder + full = normalizeURL(full) + return mergeQuery(full, originalQuery) +} + +func mergeQuery(raw string, original url.Values) (string, error) { + parsed, err := url.Parse(raw) + if err != nil { return "", err } + // 合并 query + q := parsed.Query() + for k, vs := range original { + for _, v := range vs { q.Add(k, v) } + } + parsed.RawQuery = q.Encode() + if _, err := url.ParseRequestURI(parsed.String()); err != nil { + return "", err + } + return parsed.String(), nil +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 0000000..c870823 --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,15 @@ +package version + +import "runtime/debug" + +var ( + Version = "1.1.0-rc" + GitCommit = "" + BuildInfo = "" +) + +func init() { + if info, ok := debug.ReadBuildInfo(); ok { + BuildInfo = info.Main.Version + } +} diff --git a/main.go b/main.go index e676c6b..972fb54 100644 --- a/main.go +++ b/main.go @@ -1,339 +1,75 @@ package main import ( - "bufio" - "errors" - "flag" - "fmt" + "context" "io" "log/slog" "net/http" - "net/url" "os" - "runtime/debug" - "strings" - "sync/atomic" + "os/signal" + "syscall" "time" "github.com/gin-gonic/gin" "github.com/lmittmann/tint" + + "anyproxy/internal/config" + "anyproxy/internal/middleware" + "anyproxy/internal/proxy" + "anyproxy/internal/version" ) -// 全局请求计数器,使用原子操作确保线程安全 -var requestCounter int64 - - - func main() { - port := flag.Int("port", 8080, "代理服务器监听的端口") - debug := flag.Bool("debug", false, "是否启用调试模式") - logFile := flag.String("log", "", "日志文件路径,默认为控制台彩色输出") - flag.Parse() + cfg := config.Parse() - // 使用 tint + LevelVar - var levelVar = new(slog.LevelVar) - if *debug { - levelVar.Set(slog.LevelDebug) - } else { - levelVar.Set(slog.LevelInfo) - } - - // 组合输出 writer - var writer io.Writer = os.Stderr // 默认彩色输出到 stderr - if *logFile != "" { - f, err := os.OpenFile(*logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - fmt.Fprintf(os.Stderr, "无法打开日志文件: %v\n", err) - os.Exit(1) - } - // 同时输出到彩色终端和文件(文件里不需要颜色,tint 会根据是否是终端决定) - writer = io.MultiWriter(os.Stderr, f) - } - - handler := tint.NewHandler(writer, &tint.Options{ - AddSource: true, - Level: levelVar, - TimeFormat: "2006-01-02 15:04:05", - }) - slog.SetDefault(slog.New(handler)) - - if *debug { - gin.SetMode(gin.DebugMode) - } else { - gin.SetMode(gin.ReleaseMode) - } - - r := gin.New() // 不使用默认 Logger,改为自定义 slog 统一输出 - r.Use(SlogLogger(), SlogRecovery()) - r.GET("/", HelloPage) - r.Any("/proxy/*proxyPath", proxyHandler) - r.Any(":protocol/*remainder", protocolHandler) - - slog.Info("HTTP 代理服务器启动", "port", *port, "debug", *debug) - if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil { - slog.Error("启动服务器失败", "error", err) - } -} - -// normalizeURL 规范化URL格式,处理缺少斜杠的情况 -func normalizeURL(rawURL string) string { - // 处理 https:/example.com 或 http:/example.com 的情况 - if strings.HasPrefix(rawURL, "https:/") && !strings.HasPrefix(rawURL, "https://") { - return strings.Replace(rawURL, "https:/", "https://", 1) + // 日志初始化设置 + levelVar := new(slog.LevelVar) + if cfg.Debug { levelVar.Set(slog.LevelDebug) } else { levelVar.Set(slog.LevelInfo) } + var writer io.Writer = os.Stderr + if cfg.LogFile != "" { + f, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { panic(err) } + writer = io.MultiWriter(os.Stderr, f) } - if strings.HasPrefix(rawURL, "http:/") && !strings.HasPrefix(rawURL, "http://") { - return strings.Replace(rawURL, "http:/", "http://", 1) - } - return rawURL -} + h := tint.NewHandler(writer, &tint.Options{AddSource: true, Level: levelVar, TimeFormat: "2006-01-02 15:04:05"}) + logger := slog.New(h) + slog.SetDefault(logger) -func proxyHandler(c *gin.Context) { - // 从路径参数中获取目标 URL - targetURLStr := c.Param("proxyPath") - // 移除前导斜杠 - targetURLStr = strings.TrimPrefix(targetURLStr, "/") - - // 规范化URL格式 - targetURLStr = normalizeURL(targetURLStr) + if cfg.Debug { gin.SetMode(gin.DebugMode) } else { gin.SetMode(gin.ReleaseMode) } - // 解析目标URL - parsedURL, err := url.Parse(targetURLStr) - if err != nil { - c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) - return - } - - // 合并查询参数 - originalQuery := c.Request.URL.Query() - targetQuery := parsedURL.Query() - for key, values := range originalQuery { - for _, value := range values { - targetQuery.Add(key, value) - } - } - parsedURL.RawQuery = targetQuery.Encode() - - // 重新构建目标URL字符串 - targetURLStr = parsedURL.String() - - // 检查 URL 合法性 - if _, err := url.ParseRequestURI(targetURLStr); err != nil { - c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) - return - } - - // 执行代理请求 - executeProxy(c, targetURLStr) -} - -// protocolHandler 处理直接以协议开头的URL请求 (如 /https/example.com/path) -func protocolHandler(c *gin.Context) { - protocol := c.Param("protocol") - remainder := c.Param("remainder") - - // 只处理 http 和 https 协议 - if protocol != "http" && protocol != "https" { - c.String(http.StatusBadRequest, "不支持的协议: %s", protocol) - return - } - - // 构建完整的URL - targetURLStr := protocol + ":/" + remainder - - // 规范化URL格式 - targetURLStr = normalizeURL(targetURLStr) - - // 解析目标URL - parsedURL, err := url.Parse(targetURLStr) - if err != nil { - c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) - return - } - - // 合并查询参数 - originalQuery := c.Request.URL.Query() - targetQuery := parsedURL.Query() - for key, values := range originalQuery { - for _, value := range values { - targetQuery.Add(key, value) - } - } - parsedURL.RawQuery = targetQuery.Encode() - - // 重新构建目标URL字符串 - targetURLStr = parsedURL.String() - - // 检查 URL 合法性 - if _, err := url.ParseRequestURI(targetURLStr); err != nil { - c.String(http.StatusBadRequest, "无效的目标 URL: %v", err) - return - } - - // 执行代理请求 - executeProxy(c, targetURLStr) -} - -// executeProxy 执行实际的代理请求 -func executeProxy(c *gin.Context, targetURLStr string) { - // 增加请求计数器 - reqID := atomic.AddInt64(&requestCounter, 1) - - slog.Debug("收到请求", - "reqID", reqID, - "method", c.Request.Method, - "uri", c.Request.RequestURI, - "target", targetURLStr) - - // 自定义 Transport,禁止自动压缩(避免 gzip 聚合导致 SSE 延迟) - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DisableCompression: true, - } + // 可复用的 HTTP 客户端(保持连接复用) + transport := &http.Transport{Proxy: http.ProxyFromEnvironment, DisableCompression: true} client := &http.Client{Transport: transport} + if cfg.RequestTimeout > 0 { client.Timeout = time.Duration(cfg.RequestTimeout) * time.Second } - // 创建到目标服务器的请求 - proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body) - if err != nil { - slog.Error("创建代理请求失败", "reqID", reqID, "error", err) - c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err) - return - } + p := proxy.New(client, logger) - // 复制原始请求的 Headers (Clone 避免引用共享) - proxyReq.Header = c.Request.Header.Clone() - // 禁止上游压缩,保证事件粒度 - proxyReq.Header.Del("Accept-Encoding") + r := gin.New() + r.Use(middleware.Recovery(logger), middleware.RequestID(), middleware.Logger(logger)) - resp, err := client.Do(proxyReq) - if err != nil { - slog.Error("请求目标服务器失败", "reqID", reqID, "error", err) - c.String(http.StatusBadGateway, "请求目标服务器失败: %s", err.Error()) - return - } - defer resp.Body.Close() + r.GET("/", proxy.HelloPage) // 欢迎页面 + r.Any("/proxy/*proxyPath", p.HandleProxyPath) // 处理 /proxy/*path 形式的请求 + r.Any(":protocol/*remainder", p.HandleProtocol) // 处理 /:protocol/*remainder 形式的请求 - contentType := resp.Header.Get("Content-Type") - isSSE := strings.HasPrefix(contentType, "text/event-stream") + logger.Info("服务器启动", "addr", cfg.Addr(), "debug", cfg.Debug, "version", version.Version, "commit", version.GitCommit) - slog.Debug("收到响应", "reqID", reqID, "status_code", resp.StatusCode, "status", resp.Status, "isSSE", isSSE) - - // 复制响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) + // 优雅停机设置:监听系统信号,执行平滑关闭 + srv := &http.Server{Addr: cfg.Addr(), Handler: r} + go func() { + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Error("服务器监听错误", "error", err) } - } - // SSE 需要去掉不合适的头并设置必要头 - if isSSE { - c.Writer.Header().Del("Content-Length") - c.Writer.Header().Del("Transfer-Encoding") - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") // 防止某些反向代理缓冲 - } + }() - // 设置状态码 - c.Status(resp.StatusCode) - - // 立即 flush 头部,尤其是 SSE - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - - if !isSSE { - // 普通请求直接复制主体 - bytesCopied, err := io.Copy(c.Writer, resp.Body) - if err != nil { - slog.Error("写入响应 Body 时出错", "reqID", reqID, "error", err) - } - slog.Debug("响应写入完成", "reqID", reqID, "bytes_copied", bytesCopied) - return - } - - // SSE 模式:逐行读取并 flush,保持事件实时性 - reader := bufio.NewReader(resp.Body) - w := c.Writer - flusher, _ := w.(http.Flusher) - - for { - line, err := reader.ReadBytes('\n') - if len(line) > 0 { - if _, werr := w.Write(line); werr != nil { - slog.Warn("SSE 写失败", "reqID", reqID, "error", werr) - return - } - if flusher != nil { - flusher.Flush() - } - } - if err != nil { - if errors.Is(err, io.EOF) { - slog.Debug("SSE 结束(EOF)", "reqID", reqID) - } else { - slog.Error("读取 SSE 失败", "reqID", reqID, "error", err) - } - return - } + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + <-stop + logger.Info("开始关闭 (收到退出信号)") + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.ShutdownGrace)*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + logger.Error("关闭出错", "error", err) + } else { + logger.Info("关闭完成") } } - - -func HelloPage(c *gin.Context) { - // 获取当前的请求计数 - count := atomic.LoadInt64(&requestCounter) - str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count) - str += "\n\n使用方法:\n" - str += "方式1 - 直接协议路径: \n" - str += " 目标URL: https://example.com/path --> 代理URL: http://AnyproxyIP/https/example.com/path\n" - str += " 目标URL: http://example.com/path --> 代理URL: http://AnyproxyIP/http/example.com/path\n\n" - str += "方式2 - 完整URL路径: \n" - str += " 目标URL: https://example.com --> 代理URL: http://AnyproxyIP/proxy/https://example.com\n\n" - str += "目标URL必须以 https:// 或 http:// 开头。\n\n" - c.String(200, str) -} - -// SlogLogger 统一请求日志中间件 -func SlogLogger() gin.HandlerFunc { - return func(c *gin.Context) { - start := time.Now() - path := c.Request.URL.Path - rawQuery := c.Request.URL.RawQuery - c.Next() - latency := time.Since(start) - status := c.Writer.Status() - size := c.Writer.Size() - method := c.Request.Method - ip := c.ClientIP() - if rawQuery != "" { - path = path + "?" + rawQuery - } - slog.Log(c, slog.LevelInfo, "HTTP 请求", - slog.String("method", method), - slog.String("path", path), - slog.Int("status", status), - slog.Duration("latency", latency), - slog.Int("size", size), - slog.String("ip", ip), - slog.String("ua", c.GetHeader("User-Agent")), - ) - } -} - -// SlogRecovery 捕获 panic,输出堆栈 -func SlogRecovery() gin.HandlerFunc { - return func(c *gin.Context) { - defer func() { - if rcv := recover(); rcv != nil { - stack := debug.Stack() - slog.Error("发生 panic", - "error", rcv, - "stack", string(stack), - "path", c.Request.URL.Path, - ) - c.AbortWithStatus(http.StatusInternalServerError) - } - }() - c.Next() - } -} \ No newline at end of file