update: 整理项目结构 重写中间件 提高复用性
This commit is contained in:
29
internal/config/config.go
Normal file
29
internal/config/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Config 保存程序配置
|
||||
type Config struct {
|
||||
Port int
|
||||
Debug bool
|
||||
LogFile string
|
||||
ShutdownGrace int // 优雅停机等待秒数
|
||||
RequestTimeout int // 上游整体请求超时时间(秒)
|
||||
}
|
||||
|
||||
// Parse 解析命令行参数返回配置
|
||||
func Parse() *Config {
|
||||
cfg := &Config{}
|
||||
flag.IntVar(&cfg.Port, "port", 8080, "代理服务器监听端口")
|
||||
flag.BoolVar(&cfg.Debug, "debug", false, "调试模式 (debug level log)")
|
||||
flag.StringVar(&cfg.LogFile, "log", "", "日志文件路径 (默认输出到 stderr)")
|
||||
flag.IntVar(&cfg.ShutdownGrace, "grace", 10, "优雅停机等待秒数")
|
||||
flag.IntVar(&cfg.RequestTimeout, "timeout", 0, "单次上游请求超时秒(0=不设置)")
|
||||
flag.Parse()
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *Config) Addr() string { return fmt.Sprintf(":%d", c.Port) }
|
||||
31
internal/middleware/logging.go
Normal file
31
internal/middleware/logging.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Logger 使用 slog 输出结构化访问日志
|
||||
func Logger(logger *slog.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
c.Next()
|
||||
if raw != "" { path = path + "?" + raw }
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
logger.Info("HTTP请求",
|
||||
"req_id", GetReqID(c),
|
||||
"method", c.Request.Method,
|
||||
"path", path,
|
||||
"status", status,
|
||||
"latency_ms", latency.Milliseconds(),
|
||||
"size", c.Writer.Size(),
|
||||
"ip", c.ClientIP(),
|
||||
"ua", c.GetHeader("User-Agent"),
|
||||
)
|
||||
}
|
||||
}
|
||||
30
internal/middleware/recovery.go
Normal file
30
internal/middleware/recovery.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery 捕获 panic 并记录堆栈信息
|
||||
func Recovery(logger *slog.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
logger.Error("发生Panic",
|
||||
"req_id", GetReqID(c),
|
||||
"error", rcv,
|
||||
"stack", string(debug.Stack()),
|
||||
"path", c.Request.URL.Path,
|
||||
)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "内部服务器错误",
|
||||
"req_id": GetReqID(c),
|
||||
})
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
32
internal/middleware/requestid.go
Normal file
32
internal/middleware/requestid.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const RequestIDKey = "reqID"
|
||||
|
||||
var globalReqID atomic.Int64
|
||||
|
||||
// RequestID 生成自增的请求 ID 并注入上下文及响应头
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
id := globalReqID.Add(1)
|
||||
c.Set(RequestIDKey, id)
|
||||
c.Writer.Header().Set("X-Request-ID", fmt.Sprintf("%d", id))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetReqID 从上下文中获取请求 ID
|
||||
func GetReqID(c *gin.Context) int64 {
|
||||
if v, ok := c.Get(RequestIDKey); ok {
|
||||
if id, ok2 := v.(int64); ok2 {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
166
internal/proxy/proxy.go
Normal file
166
internal/proxy/proxy.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"anyproxy/internal/middleware"
|
||||
)
|
||||
|
||||
// 转发的总请求计数器
|
||||
var totalForwarded atomic.Int64
|
||||
|
||||
// Proxy 封装具体的转发逻辑
|
||||
type Proxy struct {
|
||||
Client *http.Client
|
||||
Log *slog.Logger
|
||||
}
|
||||
|
||||
func New(client *http.Client, logger *slog.Logger) *Proxy {
|
||||
return &Proxy{Client: client, Log: logger}
|
||||
}
|
||||
|
||||
// HandleProxyPath 处理 /proxy/*path 形式的请求
|
||||
func (p *Proxy) HandleProxyPath(c *gin.Context) {
|
||||
urlStr, err := BuildFromProxyPath(c.Param("proxyPath"), c.Request.URL.Query())
|
||||
if err != nil {
|
||||
p.writeError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
p.forward(c, urlStr)
|
||||
}
|
||||
|
||||
// HandleProtocol 处理 /:protocol/*remainder 形式的请求
|
||||
func (p *Proxy) HandleProtocol(c *gin.Context) {
|
||||
urlStr, err := BuildFromProtocol(c.Param("protocol"), c.Param("remainder"), c.Request.URL.Query())
|
||||
if err != nil {
|
||||
p.writeError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
p.forward(c, urlStr)
|
||||
}
|
||||
|
||||
func (p *Proxy) writeError(c *gin.Context, code int, err error) {
|
||||
c.JSON(code, gin.H{"error": err.Error(), "req_id": middleware.GetReqID(c)})
|
||||
}
|
||||
|
||||
func (p *Proxy) forward(c *gin.Context, target string) {
|
||||
reqID := middleware.GetReqID(c)
|
||||
current := totalForwarded.Add(1)
|
||||
p.Log.Debug("开始转发请求",
|
||||
"req_id", reqID,
|
||||
"count", current,
|
||||
"method", c.Request.Method,
|
||||
"target", target,
|
||||
"uri", c.Request.RequestURI,
|
||||
)
|
||||
|
||||
// 基于原始上下文创建上游请求(支持客户端断开时取消)
|
||||
upReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, target, c.Request.Body)
|
||||
if err != nil {
|
||||
p.Log.Error("创建上游请求失败", "req_id", reqID, "error", err)
|
||||
p.writeError(c, http.StatusInternalServerError, errors.New("创建上游请求失败"))
|
||||
return
|
||||
}
|
||||
upReq.Header = c.Request.Header.Clone()
|
||||
|
||||
// 仅在 SSE 时禁用压缩;稍后检测
|
||||
|
||||
resp, err := p.Client.Do(upReq)
|
||||
if err != nil {
|
||||
p.Log.Error("上游请求失败", "req_id", reqID, "error", err)
|
||||
p.writeError(c, http.StatusBadGateway, errors.New("上游请求失败"))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
isSSE := mediaType == "text/event-stream"
|
||||
|
||||
p.Log.Debug("上游响应", "req_id", reqID, "status", resp.StatusCode, "sse", isSSE)
|
||||
|
||||
// 复制上游响应头(最小化过滤)
|
||||
for k, vs := range resp.Header {
|
||||
for _, v := range vs { c.Header(k, v) }
|
||||
}
|
||||
if isSSE {
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Del("Transfer-Encoding")
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
// 确保禁用上游压缩避免 SSE 事件被聚合
|
||||
upReq.Header.Del("Accept-Encoding")
|
||||
}
|
||||
c.Status(resp.StatusCode)
|
||||
if f, ok := c.Writer.(http.Flusher); ok { f.Flush() }
|
||||
|
||||
if !isSSE {
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
p.Log.Error("写入响应体失败", "req_id", reqID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
w := c.Writer
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
if _, werr := w.Write(line); werr != nil {
|
||||
p.Log.Warn("SSE写入失败", "req_id", reqID, "error", werr)
|
||||
return
|
||||
}
|
||||
if flusher != nil { flusher.Flush() }
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
p.Log.Debug("SSE结束(EOF)", "req_id", reqID)
|
||||
} else {
|
||||
p.Log.Error("SSE读取失败", "req_id", reqID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HelloPage 返回简单状态页面
|
||||
func HelloPage(c *gin.Context) {
|
||||
count := totalForwarded.Load()
|
||||
|
||||
// 推断外部可见协议与主机(支持反向代理常见头)
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil { scheme = "https" }
|
||||
if xf := c.GetHeader("X-Forwarded-Proto"); xf != "" {
|
||||
// 取第一个
|
||||
scheme = strings.TrimSpace(strings.Split(xf, ",")[0])
|
||||
}
|
||||
host := c.Request.Host
|
||||
if xfh := c.GetHeader("X-Forwarded-Host"); xfh != "" {
|
||||
host = strings.TrimSpace(strings.Split(xfh, ",")[0])
|
||||
}
|
||||
base := scheme + "://" + host
|
||||
|
||||
str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count)
|
||||
str += "\n\n使用方法:\n"
|
||||
str += "方式1 - 直接协议路径: \n"
|
||||
str += fmt.Sprintf(" 目标URL: https://example.com/path --> 代理URL: %s/https/example.com/path\n", base)
|
||||
str += fmt.Sprintf(" 目标URL: http://example.com/path --> 代理URL: %s/http/example.com/path\n\n", base)
|
||||
str += "方式2 - 完整URL路径: \n"
|
||||
str += fmt.Sprintf(" 目标URL: https://example.com --> 代理URL: %s/proxy/https://example.com\n", base)
|
||||
str += fmt.Sprintf(" 目标URL: http://example.com --> 代理URL: %s/proxy/http://example.com\n\n", base)
|
||||
str += "目标URL必须以 https:// 或 http:// 开头。\n\n"
|
||||
str += fmt.Sprintf("本机访问基地址: %s\n", base)
|
||||
c.String(200, str)
|
||||
}
|
||||
51
internal/proxy/url.go
Normal file
51
internal/proxy/url.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// normalizeURL 规范化URL格式,处理缺少斜杠的情况
|
||||
func normalizeURL(rawURL string) string {
|
||||
if strings.HasPrefix(rawURL, "https:/") && !strings.HasPrefix(rawURL, "https://") {
|
||||
return strings.Replace(rawURL, "https:/", "https://", 1)
|
||||
}
|
||||
if strings.HasPrefix(rawURL, "http:/") && !strings.HasPrefix(rawURL, "http://") {
|
||||
return strings.Replace(rawURL, "http:/", "http://", 1)
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
|
||||
// BuildFromProxyPath 构建 /proxy/*path 形式传入的 URL
|
||||
func BuildFromProxyPath(pathPart string, originalQuery url.Values) (string, error) {
|
||||
pathPart = strings.TrimPrefix(pathPart, "/")
|
||||
if pathPart == "" { return "", errors.New("目标为空") }
|
||||
pathPart = normalizeURL(pathPart)
|
||||
return mergeQuery(pathPart, originalQuery)
|
||||
}
|
||||
|
||||
// BuildFromProtocol 构建 /:protocol/*remainder 形式
|
||||
func BuildFromProtocol(protocol, remainder string, originalQuery url.Values) (string, error) {
|
||||
if protocol != "http" && protocol != "https" {
|
||||
return "", errors.New("不支持的协议")
|
||||
}
|
||||
full := protocol + ":/" + remainder
|
||||
full = normalizeURL(full)
|
||||
return mergeQuery(full, originalQuery)
|
||||
}
|
||||
|
||||
func mergeQuery(raw string, original url.Values) (string, error) {
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil { return "", err }
|
||||
// 合并 query
|
||||
q := parsed.Query()
|
||||
for k, vs := range original {
|
||||
for _, v := range vs { q.Add(k, v) }
|
||||
}
|
||||
parsed.RawQuery = q.Encode()
|
||||
if _, err := url.ParseRequestURI(parsed.String()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return parsed.String(), nil
|
||||
}
|
||||
15
internal/version/version.go
Normal file
15
internal/version/version.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package version
|
||||
|
||||
import "runtime/debug"
|
||||
|
||||
var (
|
||||
Version = "1.1.0-rc"
|
||||
GitCommit = ""
|
||||
BuildInfo = ""
|
||||
)
|
||||
|
||||
func init() {
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
BuildInfo = info.Main.Version
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user