diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 105637d..a9f6f4f 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -1,7 +1,6 @@ package metrics import ( - "sync" "sync/atomic" "time" ) @@ -10,13 +9,12 @@ import ( // 只针对转发请求调用 Inc。 type bucket struct { - second int64 // Unix 秒 - count int64 + second atomic.Int64 // Unix 秒 + count atomic.Int64 } var ( buckets [60]bucket - mu sync.Mutex total atomic.Int64 ) @@ -24,14 +22,18 @@ var ( func Inc() { now := time.Now().Unix() idx := now % 60 - mu.Lock() b := &buckets[idx] - if b.second != now { // 该槽位属于旧秒,重置 - b.second = now - b.count = 0 + for { + sec := b.second.Load() + if sec == now { + b.count.Add(1) + break + } + if b.second.CompareAndSwap(sec, now) { + b.count.Store(1) + break + } } - b.count++ - mu.Unlock() total.Add(1) } @@ -39,10 +41,8 @@ func Inc() { func QPS() int64 { now := time.Now().Unix() idx := now % 60 - mu.Lock() - b := buckets[idx] - mu.Unlock() - if b.second == now { return b.count } + b := &buckets[idx] + if b.second.Load() == now { return b.count.Load() } return 0 } @@ -50,14 +50,12 @@ func QPS() int64 { func QPM() int64 { now := time.Now().Unix() var sum int64 - mu.Lock() - for i := 0; i < 60; i++ { - b := buckets[i] - if now-b.second < 60 { // 在窗口内 - sum += b.count + for i := range 60 { + sec := buckets[i].second.Load() + if sec <= now && now-sec < 60 { // 在窗口内 + sum += buckets[i].count.Load() } } - mu.Unlock() return sum } diff --git a/internal/middleware/metrics.go b/internal/middleware/metrics.go index 1db8e69..f513feb 100644 --- a/internal/middleware/metrics.go +++ b/internal/middleware/metrics.go @@ -1,127 +1,23 @@ package middleware import ( - "sync/atomic" "time" "github.com/gin-gonic/gin" + + "anyproxy/internal/metrics" ) -// 简易 QPS / QPM 统计:使用滑动窗口环形数组按秒/按分钟聚合 -// secondBuckets: 最近 60 秒每秒的请求计数 -// minuteBuckets: 最近 60 分钟每分钟的请求计数 - -var ( - secondBuckets [60]atomic.Int64 - minuteBuckets [60]atomic.Int64 - lastSecond int64 - lastMinute int64 - // 总请求数 (复用可选) - totalRequests atomic.Int64 -) - -func init() { - now := time.Now() - lastSecond = now.Unix() - lastMinute = now.Unix() / 60 -} - -// AddRequest 在收到一个请求时调用,通常在请求完成后计数 -func AddRequest() { - now := time.Now() - sec := now.Unix() - min := sec / 60 - - // 处理秒级 bucket - oldSec := atomic.LoadInt64(&lastSecond) - if sec != oldSec { - // 跨秒:清理可能跨越多个秒的间隙 - if atomic.CompareAndSwapInt64(&lastSecond, oldSec, sec) { - steps := int(sec - oldSec) - if steps > 60 { steps = 60 } - for i := 1; i <= steps; i++ { - idx := int((oldSec+int64(i)) % 60) - secondBuckets[idx].Store(0) - } - } - } - secIdx := int(sec % 60) - secondBuckets[secIdx].Add(1) - - // 处理分钟级 bucket - oldMin := atomic.LoadInt64(&lastMinute) - if min != oldMin { - if atomic.CompareAndSwapInt64(&lastMinute, oldMin, min) { - steps := int(min - oldMin) - if steps > 60 { steps = 60 } - for i := 1; i <= steps; i++ { - idx := int((oldMin+int64(i)) % 60) - minuteBuckets[idx].Store(0) - } - } - } - minIdx := int(min % 60) - minuteBuckets[minIdx].Add(1) - - totalRequests.Add(1) -} - -// CurrentQPS 返回最近 1 秒(当前秒)的请求数 -func CurrentQPS() int64 { - sec := time.Now().Unix() - if sec != atomic.LoadInt64(&lastSecond) { return 0 } - return secondBuckets[sec%60].Load() -} - -// AvgQPSRecent60 返回最近 60 秒平均 QPS -func AvgQPSRecent60() float64 { - sec := time.Now().Unix() - total := int64(0) - last := atomic.LoadInt64(&lastSecond) - for i := 0; i < 60; i++ { - // 只统计在窗口内(未被清零)的 bucket - bucketSec := sec - int64(i) - if bucketSec <= last && last-bucketSec < 60 { - idx := bucketSec % 60 - total += secondBuckets[idx].Load() - } - } - return float64(total) / 60.0 -} - -// CurrentQPM 返回当前分钟的请求数 -func CurrentQPM() int64 { - min := time.Now().Unix() / 60 - if min != atomic.LoadInt64(&lastMinute) { return 0 } - return minuteBuckets[min%60].Load() -} - -// AvgQPMRecent60 返回最近 60 分钟的平均 QPM -func AvgQPMRecent60() float64 { - min := time.Now().Unix() / 60 - total := int64(0) - last := atomic.LoadInt64(&lastMinute) - for i := 0; i < 60; i++ { - bucketMin := min - int64(i) - if bucketMin <= last && last-bucketMin < 60 { - idx := bucketMin % 60 - total += minuteBuckets[idx].Load() - } - } - return float64(total) / 60.0 -} - -// TotalRequests 返回总请求量(从进程启动以来) -func TotalRequests() int64 { return totalRequests.Load() } - // MetricsHandler 输出当前指标 func MetricsHandler(c *gin.Context) { + qps := metrics.QPS() + qpm := metrics.QPM() c.JSON(200, gin.H{ - "qps_current": CurrentQPS(), - "qps_avg_60s": AvgQPSRecent60(), - "qpm_current": CurrentQPM(), - "qpm_avg_60m": AvgQPMRecent60(), - "total": TotalRequests(), + "qps_current": qps, + "qps_avg_60s": float64(qpm) / 60.0, + "qpm_current": qpm, + "qpm_avg_60m": float64(qpm), + "total": metrics.Total(), "timestamp": time.Now().Unix(), }) } diff --git a/internal/middleware/requestid.go b/internal/middleware/requestid.go index 75a1654..727a236 100644 --- a/internal/middleware/requestid.go +++ b/internal/middleware/requestid.go @@ -1,7 +1,7 @@ package middleware import ( - "fmt" + "strconv" "sync/atomic" "github.com/gin-gonic/gin" @@ -16,7 +16,7 @@ func RequestID() gin.HandlerFunc { return func(c *gin.Context) { id := globalReqID.Add(1) c.Set(RequestIDKey, id) - c.Writer.Header().Set("X-Request-ID", fmt.Sprintf("%d", id)) + c.Writer.Header().Set("X-Request-ID", strconv.FormatInt(id, 10)) c.Next() } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8f1bde3..39e41fa 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -6,9 +6,9 @@ import ( "fmt" "io" "log/slog" - "mime" "net/http" "strings" + "sync" "sync/atomic" "github.com/gin-gonic/gin" @@ -20,6 +20,10 @@ import ( // 转发的总请求计数器 var totalForwarded atomic.Int64 +var copyBufPool = sync.Pool{ + New: func() any { return make([]byte, 32*1024) }, +} + // Proxy 封装具体的转发逻辑 type Proxy struct { Client *http.Client @@ -73,8 +77,10 @@ func (p *Proxy) forward(c *gin.Context, target string) { return } upReq.Header = c.Request.Header.Clone() - - // 仅在 SSE 时禁用压缩;稍后检测 + if strings.Contains(strings.ToLower(c.GetHeader("Accept")), "text/event-stream") { + // SSE 禁用压缩 + upReq.Header.Del("Accept-Encoding") + } resp, err := p.Client.Do(upReq) if err != nil { @@ -87,15 +93,14 @@ func (p *Proxy) forward(c *gin.Context, target string) { // 仅在真正进行了一次上游转发并得到响应后计数 metrics.Inc() - mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type")) - isSSE := mediaType == "text/event-stream" + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + isSSE := strings.HasPrefix(contentType, "text/event-stream") p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE) - // 复制上游响应头(最小化过滤) - for k, vs := range resp.Header { - for _, v := range vs { c.Header(k, v) } - } + // 复制上游响应头 + dstHeader := c.Writer.Header() + for k, vs := range resp.Header { dstHeader[k] = vs } if isSSE { c.Writer.Header().Del("Content-Length") c.Writer.Header().Del("Transfer-Encoding") @@ -103,14 +108,15 @@ func (p *Proxy) forward(c *gin.Context, target string) { c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") - // 确保禁用上游压缩避免 SSE 事件被聚合 - upReq.Header.Del("Accept-Encoding") } c.Status(resp.StatusCode) if f, ok := c.Writer.(http.Flusher); ok { f.Flush() } if !isSSE { - if _, err := io.Copy(c.Writer, resp.Body); err != nil { + buf := copyBufPool.Get().([]byte) + _, err := io.CopyBuffer(c.Writer, resp.Body, buf) + copyBufPool.Put(buf) + if err != nil { p.Log.Error("写入响应体失败", "req_id", reqID, "error", err) } return diff --git a/internal/proxy/url.go b/internal/proxy/url.go index e68c3bd..89f35f9 100644 --- a/internal/proxy/url.go +++ b/internal/proxy/url.go @@ -38,14 +38,17 @@ func BuildFromProtocol(protocol, remainder string, originalQuery url.Values) (st func mergeQuery(raw string, original url.Values) (string, error) { parsed, err := url.Parse(raw) if err != nil { return "", err } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", errors.New("不支持的协议") + } + if parsed.Host == "" { + return "", errors.New("目标地址无效") + } // 合并 query q := parsed.Query() for k, vs := range original { for _, v := range vs { q.Add(k, v) } } parsed.RawQuery = q.Encode() - if _, err := url.ParseRequestURI(parsed.String()); err != nil { - return "", err - } return parsed.String(), nil } diff --git a/main.go b/main.go index e65cc5b..61310ab 100644 --- a/main.go +++ b/main.go @@ -50,8 +50,16 @@ func main() { if cfg.Debug || lvlStr == "debug" { gin.SetMode(gin.DebugMode) } else { gin.SetMode(gin.ReleaseMode) } - // 可复用的 HTTP 客户端(保持连接复用) - transport := &http.Transport{Proxy: http.ProxyFromEnvironment, DisableCompression: true} + // 可复用的 HTTP 客户端 + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + ForceAttemptHTTP2: true, + MaxIdleConns: 512, + MaxIdleConnsPerHost: 128, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } client := &http.Client{Transport: transport} if cfg.RequestTimeout > 0 { client.Timeout = time.Duration(cfg.RequestTimeout) * time.Second }