Compare commits
3 Commits
1.1.0
...
a6ed3f8f4c
| Author | SHA1 | Date | |
|---|---|---|---|
| a6ed3f8f4c | |||
| 0f2d550a14 | |||
| 0f2fb51065 |
@@ -50,7 +50,8 @@ WantedBy=multi-user.target
|
|||||||
| 参数 | 是否可选 | 默认值 | 数据类型 | 解释 |
|
| 参数 | 是否可选 | 默认值 | 数据类型 | 解释 |
|
||||||
| -------- | -------: | :---------------: | -------- | --------------------------------- |
|
| -------- | -------: | :---------------: | -------- | --------------------------------- |
|
||||||
| -port | 是 | 8080 | int | 代理服务器监听端口 |
|
| -port | 是 | 8080 | int | 代理服务器监听端口 |
|
||||||
| -debug | 是 | false | bool | 调试模式(debug 级别日志) |
|
| -debug | 是 | false | bool | 调试模式(等价于未指定 -log-level 时将日志等级提升为 debug) |
|
||||||
|
| -log-level | 是 | warn | string | 日志等级: debug / info / warn / error |
|
||||||
| -log | 是 | (输出到 stderr) | string | 日志文件路径(默认输出到 stderr) |
|
| -log | 是 | (输出到 stderr) | string | 日志文件路径(默认输出到 stderr) |
|
||||||
| -grace | 是 | 10 | int | 优雅停机等待秒数 |
|
| -grace | 是 | 10 | int | 优雅停机等待秒数 |
|
||||||
| -timeout | 是 | 0 | int | 单次上游请求超时秒(0 = 不设置) |
|
| -timeout | 是 | 0 | int | 单次上游请求超时秒(0 = 不设置) |
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type Config struct {
|
|||||||
LogFile string
|
LogFile string
|
||||||
ShutdownGrace int // 优雅停机等待秒数
|
ShutdownGrace int // 优雅停机等待秒数
|
||||||
RequestTimeout int // 上游整体请求超时时间(秒)
|
RequestTimeout int // 上游整体请求超时时间(秒)
|
||||||
|
LogLevel string // 日志等级: debug|info|warn|error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse 解析命令行参数返回配置
|
// Parse 解析命令行参数返回配置
|
||||||
@@ -22,7 +23,11 @@ func Parse() *Config {
|
|||||||
flag.StringVar(&cfg.LogFile, "log", "", "日志文件路径 (默认输出到 stderr)")
|
flag.StringVar(&cfg.LogFile, "log", "", "日志文件路径 (默认输出到 stderr)")
|
||||||
flag.IntVar(&cfg.ShutdownGrace, "grace", 10, "优雅停机等待秒数")
|
flag.IntVar(&cfg.ShutdownGrace, "grace", 10, "优雅停机等待秒数")
|
||||||
flag.IntVar(&cfg.RequestTimeout, "timeout", 0, "单次上游请求超时秒(0=不设置)")
|
flag.IntVar(&cfg.RequestTimeout, "timeout", 0, "单次上游请求超时秒(0=不设置)")
|
||||||
|
flag.StringVar(&cfg.LogLevel, "log-level", "warn", "日志等级: debug|info|warn|error (默认 warn)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// 兼容旧的 -debug 参数: 当 -debug 为 true 且未显式指定其它日志等级(仍为默认 warn) 时,提升为 debug
|
||||||
|
if cfg.Debug && cfg.LogLevel == "warn" { cfg.LogLevel = "debug" }
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
63
internal/metrics/metrics.go
Normal file
63
internal/metrics/metrics.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 环形秒级窗口,用于计算 QPS / QPM。
|
||||||
|
// 只针对转发请求调用 Inc。
|
||||||
|
|
||||||
|
type bucket struct {
|
||||||
|
second atomic.Int64 // Unix 秒
|
||||||
|
count atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
buckets [60]bucket
|
||||||
|
total atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
// Inc 增加一次请求计数
|
||||||
|
func Inc() {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
idx := now % 60
|
||||||
|
b := &buckets[idx]
|
||||||
|
for {
|
||||||
|
sec := b.second.Load()
|
||||||
|
if sec == now {
|
||||||
|
b.count.Add(1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if b.second.CompareAndSwap(sec, now) {
|
||||||
|
b.count.Store(1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
total.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QPS 返回当前秒内的请求数
|
||||||
|
func QPS() int64 {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
idx := now % 60
|
||||||
|
b := &buckets[idx]
|
||||||
|
if b.second.Load() == now { return b.count.Load() }
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// QPM 返回最近 60 秒内的请求总数
|
||||||
|
func QPM() int64 {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
var sum int64
|
||||||
|
for i := range 60 {
|
||||||
|
sec := buckets[i].second.Load()
|
||||||
|
if sec <= now && now-sec < 60 { // 在窗口内
|
||||||
|
sum += buckets[i].count.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// Total 返回累计转发请求数
|
||||||
|
func Total() int64 { return total.Load() }
|
||||||
23
internal/middleware/metrics.go
Normal file
23
internal/middleware/metrics.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"anyproxy/internal/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetricsHandler 输出当前指标
|
||||||
|
func MetricsHandler(c *gin.Context) {
|
||||||
|
qps := metrics.QPS()
|
||||||
|
qpm := metrics.QPM()
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"qps_current": qps,
|
||||||
|
"qps_avg_60s": float64(qpm) / 60.0,
|
||||||
|
"qpm_current": qpm,
|
||||||
|
"qpm_avg_60m": float64(qpm),
|
||||||
|
"total": metrics.Total(),
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -16,7 +16,7 @@ func RequestID() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := globalReqID.Add(1)
|
id := globalReqID.Add(1)
|
||||||
c.Set(RequestIDKey, id)
|
c.Set(RequestIDKey, id)
|
||||||
c.Writer.Header().Set("X-Request-ID", fmt.Sprintf("%d", id))
|
c.Writer.Header().Set("X-Request-ID", strconv.FormatInt(id, 10))
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,19 +6,24 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"mime"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"anyproxy/internal/metrics"
|
||||||
"anyproxy/internal/middleware"
|
"anyproxy/internal/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 转发的总请求计数器
|
// 转发的总请求计数器
|
||||||
var totalForwarded atomic.Int64
|
var totalForwarded atomic.Int64
|
||||||
|
|
||||||
|
var copyBufPool = sync.Pool{
|
||||||
|
New: func() any { return make([]byte, 32*1024) },
|
||||||
|
}
|
||||||
|
|
||||||
// Proxy 封装具体的转发逻辑
|
// Proxy 封装具体的转发逻辑
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
Client *http.Client
|
Client *http.Client
|
||||||
@@ -72,8 +77,10 @@ func (p *Proxy) forward(c *gin.Context, target string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
upReq.Header = c.Request.Header.Clone()
|
upReq.Header = c.Request.Header.Clone()
|
||||||
|
if strings.Contains(strings.ToLower(c.GetHeader("Accept")), "text/event-stream") {
|
||||||
// 仅在 SSE 时禁用压缩;稍后检测
|
// SSE 禁用压缩
|
||||||
|
upReq.Header.Del("Accept-Encoding")
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := p.Client.Do(upReq)
|
resp, err := p.Client.Do(upReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -83,15 +90,17 @@ func (p *Proxy) forward(c *gin.Context, target string) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
// 仅在真正进行了一次上游转发并得到响应后计数
|
||||||
isSSE := mediaType == "text/event-stream"
|
metrics.Inc()
|
||||||
|
|
||||||
|
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||||
|
isSSE := strings.HasPrefix(contentType, "text/event-stream")
|
||||||
|
|
||||||
p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE)
|
p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE)
|
||||||
|
|
||||||
// 复制上游响应头(最小化过滤)
|
// 复制上游响应头
|
||||||
for k, vs := range resp.Header {
|
dstHeader := c.Writer.Header()
|
||||||
for _, v := range vs { c.Header(k, v) }
|
for k, vs := range resp.Header { dstHeader[k] = vs }
|
||||||
}
|
|
||||||
if isSSE {
|
if isSSE {
|
||||||
c.Writer.Header().Del("Content-Length")
|
c.Writer.Header().Del("Content-Length")
|
||||||
c.Writer.Header().Del("Transfer-Encoding")
|
c.Writer.Header().Del("Transfer-Encoding")
|
||||||
@@ -99,14 +108,15 @@ func (p *Proxy) forward(c *gin.Context, target string) {
|
|||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
// 确保禁用上游压缩避免 SSE 事件被聚合
|
|
||||||
upReq.Header.Del("Accept-Encoding")
|
|
||||||
}
|
}
|
||||||
c.Status(resp.StatusCode)
|
c.Status(resp.StatusCode)
|
||||||
if f, ok := c.Writer.(http.Flusher); ok { f.Flush() }
|
if f, ok := c.Writer.(http.Flusher); ok { f.Flush() }
|
||||||
|
|
||||||
if !isSSE {
|
if !isSSE {
|
||||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
buf := copyBufPool.Get().([]byte)
|
||||||
|
_, err := io.CopyBuffer(c.Writer, resp.Body, buf)
|
||||||
|
copyBufPool.Put(buf)
|
||||||
|
if err != nil {
|
||||||
p.Log.Error("写入响应体失败", "req_id", reqID, "error", err)
|
p.Log.Error("写入响应体失败", "req_id", reqID, "error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -137,7 +147,9 @@ func (p *Proxy) forward(c *gin.Context, target string) {
|
|||||||
|
|
||||||
// HelloPage 返回简单状态页面
|
// HelloPage 返回简单状态页面
|
||||||
func HelloPage(c *gin.Context) {
|
func HelloPage(c *gin.Context) {
|
||||||
count := totalForwarded.Load()
|
count := metrics.Total()
|
||||||
|
qps := metrics.QPS()
|
||||||
|
qpm := metrics.QPM()
|
||||||
|
|
||||||
// 推断外部可见协议与主机(支持反向代理常见头)
|
// 推断外部可见协议与主机(支持反向代理常见头)
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
@@ -152,7 +164,7 @@ func HelloPage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
base := scheme + "://" + host
|
base := scheme + "://" + host
|
||||||
|
|
||||||
str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count)
|
str := fmt.Sprintf("AnyProxy 服务器正在运行...\n累计转发(不含本页): %d\n当前QPS: %d\n最近1分钟QPM: %d", count, qps, qpm)
|
||||||
str += "\n\n使用方法:\n"
|
str += "\n\n使用方法:\n"
|
||||||
str += "方式1 - 直接协议路径: \n"
|
str += "方式1 - 直接协议路径: \n"
|
||||||
str += fmt.Sprintf(" 目标URL: https://example.com/path --> 代理URL: %s/https/example.com/path\n", base)
|
str += fmt.Sprintf(" 目标URL: https://example.com/path --> 代理URL: %s/https/example.com/path\n", base)
|
||||||
|
|||||||
@@ -38,14 +38,17 @@ func BuildFromProtocol(protocol, remainder string, originalQuery url.Values) (st
|
|||||||
func mergeQuery(raw string, original url.Values) (string, error) {
|
func mergeQuery(raw string, original url.Values) (string, error) {
|
||||||
parsed, err := url.Parse(raw)
|
parsed, err := url.Parse(raw)
|
||||||
if err != nil { return "", err }
|
if err != nil { return "", err }
|
||||||
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||||
|
return "", errors.New("不支持的协议")
|
||||||
|
}
|
||||||
|
if parsed.Host == "" {
|
||||||
|
return "", errors.New("目标地址无效")
|
||||||
|
}
|
||||||
// 合并 query
|
// 合并 query
|
||||||
q := parsed.Query()
|
q := parsed.Query()
|
||||||
for k, vs := range original {
|
for k, vs := range original {
|
||||||
for _, v := range vs { q.Add(k, v) }
|
for _, v := range vs { q.Add(k, v) }
|
||||||
}
|
}
|
||||||
parsed.RawQuery = q.Encode()
|
parsed.RawQuery = q.Encode()
|
||||||
if _, err := url.ParseRequestURI(parsed.String()); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return parsed.String(), nil
|
return parsed.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
34
main.go
34
main.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -22,9 +23,21 @@ import (
|
|||||||
func main() {
|
func main() {
|
||||||
cfg := config.Parse()
|
cfg := config.Parse()
|
||||||
|
|
||||||
// 日志初始化设置
|
// 日志初始化设置 (支持显式日志等级)
|
||||||
levelVar := new(slog.LevelVar)
|
levelVar := new(slog.LevelVar)
|
||||||
if cfg.Debug { levelVar.Set(slog.LevelDebug) } else { levelVar.Set(slog.LevelInfo) }
|
lvlStr := strings.ToLower(cfg.LogLevel)
|
||||||
|
switch lvlStr {
|
||||||
|
case "debug":
|
||||||
|
levelVar.Set(slog.LevelDebug)
|
||||||
|
case "info":
|
||||||
|
levelVar.Set(slog.LevelInfo)
|
||||||
|
case "warn", "warning":
|
||||||
|
levelVar.Set(slog.LevelWarn)
|
||||||
|
case "error", "err":
|
||||||
|
levelVar.Set(slog.LevelError)
|
||||||
|
default:
|
||||||
|
levelVar.Set(slog.LevelWarn) // 回退到默认 warn
|
||||||
|
}
|
||||||
var writer io.Writer = os.Stderr
|
var writer io.Writer = os.Stderr
|
||||||
if cfg.LogFile != "" {
|
if cfg.LogFile != "" {
|
||||||
f, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
f, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||||
@@ -35,10 +48,18 @@ func main() {
|
|||||||
logger := slog.New(h)
|
logger := slog.New(h)
|
||||||
slog.SetDefault(logger)
|
slog.SetDefault(logger)
|
||||||
|
|
||||||
if cfg.Debug { gin.SetMode(gin.DebugMode) } else { gin.SetMode(gin.ReleaseMode) }
|
if cfg.Debug || lvlStr == "debug" { gin.SetMode(gin.DebugMode) } else { gin.SetMode(gin.ReleaseMode) }
|
||||||
|
|
||||||
// 可复用的 HTTP 客户端(保持连接复用)
|
// 可复用的 HTTP 客户端
|
||||||
transport := &http.Transport{Proxy: http.ProxyFromEnvironment, DisableCompression: true}
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 512,
|
||||||
|
MaxIdleConnsPerHost: 128,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
client := &http.Client{Transport: transport}
|
client := &http.Client{Transport: transport}
|
||||||
if cfg.RequestTimeout > 0 { client.Timeout = time.Duration(cfg.RequestTimeout) * time.Second }
|
if cfg.RequestTimeout > 0 { client.Timeout = time.Duration(cfg.RequestTimeout) * time.Second }
|
||||||
|
|
||||||
@@ -48,10 +69,11 @@ func main() {
|
|||||||
r.Use(middleware.Recovery(logger), middleware.RequestID(), middleware.Logger(logger))
|
r.Use(middleware.Recovery(logger), middleware.RequestID(), middleware.Logger(logger))
|
||||||
|
|
||||||
r.GET("/", proxy.HelloPage) // 欢迎页面
|
r.GET("/", proxy.HelloPage) // 欢迎页面
|
||||||
|
r.GET("/metrics", middleware.MetricsHandler) // 指标接口
|
||||||
r.Any("/proxy/*proxyPath", p.HandleProxyPath) // 处理 /proxy/*path 形式的请求
|
r.Any("/proxy/*proxyPath", p.HandleProxyPath) // 处理 /proxy/*path 形式的请求
|
||||||
r.Any(":protocol/*remainder", p.HandleProtocol) // 处理 /:protocol/*remainder 形式的请求
|
r.Any(":protocol/*remainder", p.HandleProtocol) // 处理 /:protocol/*remainder 形式的请求
|
||||||
|
|
||||||
logger.Info("服务器启动", "addr", cfg.Addr(), "debug", cfg.Debug, "version", version.Version, "commit", version.GitCommit)
|
logger.Info("服务器启动", "addr", cfg.Addr(), "debug", cfg.Debug, "log_level", lvlStr, "version", version.Version, "commit", version.GitCommit)
|
||||||
|
|
||||||
// 优雅停机设置:监听系统信号,执行平滑关闭
|
// 优雅停机设置:监听系统信号,执行平滑关闭
|
||||||
srv := &http.Server{Addr: cfg.Addr(), Handler: r}
|
srv := &http.Server{Addr: cfg.Addr(), Handler: r}
|
||||||
|
|||||||
Reference in New Issue
Block a user