update: 整理项目结构 重写中间件 提高复用性
This commit is contained in:
358
main.go
358
main.go
@@ -1,339 +1,75 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lmittmann/tint"
|
||||
|
||||
"anyproxy/internal/config"
|
||||
"anyproxy/internal/middleware"
|
||||
"anyproxy/internal/proxy"
|
||||
"anyproxy/internal/version"
|
||||
)
|
||||
|
||||
// 全局请求计数器,使用原子操作确保线程安全
|
||||
var requestCounter int64
|
||||
|
||||
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8080, "代理服务器监听的端口")
|
||||
debug := flag.Bool("debug", false, "是否启用调试模式")
|
||||
logFile := flag.String("log", "", "日志文件路径,默认为控制台彩色输出")
|
||||
flag.Parse()
|
||||
cfg := config.Parse()
|
||||
|
||||
// 使用 tint + LevelVar
|
||||
var levelVar = new(slog.LevelVar)
|
||||
if *debug {
|
||||
levelVar.Set(slog.LevelDebug)
|
||||
} else {
|
||||
levelVar.Set(slog.LevelInfo)
|
||||
}
|
||||
|
||||
// 组合输出 writer
|
||||
var writer io.Writer = os.Stderr // 默认彩色输出到 stderr
|
||||
if *logFile != "" {
|
||||
f, err := os.OpenFile(*logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "无法打开日志文件: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// 同时输出到彩色终端和文件(文件里不需要颜色,tint 会根据是否是终端决定)
|
||||
writer = io.MultiWriter(os.Stderr, f)
|
||||
}
|
||||
|
||||
handler := tint.NewHandler(writer, &tint.Options{
|
||||
AddSource: true,
|
||||
Level: levelVar,
|
||||
TimeFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
if *debug {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
} else {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
r := gin.New() // 不使用默认 Logger,改为自定义 slog 统一输出
|
||||
r.Use(SlogLogger(), SlogRecovery())
|
||||
r.GET("/", HelloPage)
|
||||
r.Any("/proxy/*proxyPath", proxyHandler)
|
||||
r.Any(":protocol/*remainder", protocolHandler)
|
||||
|
||||
slog.Info("HTTP 代理服务器启动", "port", *port, "debug", *debug)
|
||||
if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil {
|
||||
slog.Error("启动服务器失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeURL 规范化URL格式,处理缺少斜杠的情况
|
||||
func normalizeURL(rawURL string) string {
|
||||
// 处理 https:/example.com 或 http:/example.com 的情况
|
||||
if strings.HasPrefix(rawURL, "https:/") && !strings.HasPrefix(rawURL, "https://") {
|
||||
return strings.Replace(rawURL, "https:/", "https://", 1)
|
||||
// 日志初始化设置
|
||||
levelVar := new(slog.LevelVar)
|
||||
if cfg.Debug { levelVar.Set(slog.LevelDebug) } else { levelVar.Set(slog.LevelInfo) }
|
||||
var writer io.Writer = os.Stderr
|
||||
if cfg.LogFile != "" {
|
||||
f, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil { panic(err) }
|
||||
writer = io.MultiWriter(os.Stderr, f)
|
||||
}
|
||||
if strings.HasPrefix(rawURL, "http:/") && !strings.HasPrefix(rawURL, "http://") {
|
||||
return strings.Replace(rawURL, "http:/", "http://", 1)
|
||||
}
|
||||
return rawURL
|
||||
}
|
||||
h := tint.NewHandler(writer, &tint.Options{AddSource: true, Level: levelVar, TimeFormat: "2006-01-02 15:04:05"})
|
||||
logger := slog.New(h)
|
||||
slog.SetDefault(logger)
|
||||
|
||||
func proxyHandler(c *gin.Context) {
|
||||
// 从路径参数中获取目标 URL
|
||||
targetURLStr := c.Param("proxyPath")
|
||||
// 移除前导斜杠
|
||||
targetURLStr = strings.TrimPrefix(targetURLStr, "/")
|
||||
|
||||
// 规范化URL格式
|
||||
targetURLStr = normalizeURL(targetURLStr)
|
||||
if cfg.Debug { gin.SetMode(gin.DebugMode) } else { gin.SetMode(gin.ReleaseMode) }
|
||||
|
||||
// 解析目标URL
|
||||
parsedURL, err := url.Parse(targetURLStr)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并查询参数
|
||||
originalQuery := c.Request.URL.Query()
|
||||
targetQuery := parsedURL.Query()
|
||||
for key, values := range originalQuery {
|
||||
for _, value := range values {
|
||||
targetQuery.Add(key, value)
|
||||
}
|
||||
}
|
||||
parsedURL.RawQuery = targetQuery.Encode()
|
||||
|
||||
// 重新构建目标URL字符串
|
||||
targetURLStr = parsedURL.String()
|
||||
|
||||
// 检查 URL 合法性
|
||||
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行代理请求
|
||||
executeProxy(c, targetURLStr)
|
||||
}
|
||||
|
||||
// protocolHandler 处理直接以协议开头的URL请求 (如 /https/example.com/path)
|
||||
func protocolHandler(c *gin.Context) {
|
||||
protocol := c.Param("protocol")
|
||||
remainder := c.Param("remainder")
|
||||
|
||||
// 只处理 http 和 https 协议
|
||||
if protocol != "http" && protocol != "https" {
|
||||
c.String(http.StatusBadRequest, "不支持的协议: %s", protocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建完整的URL
|
||||
targetURLStr := protocol + ":/" + remainder
|
||||
|
||||
// 规范化URL格式
|
||||
targetURLStr = normalizeURL(targetURLStr)
|
||||
|
||||
// 解析目标URL
|
||||
parsedURL, err := url.Parse(targetURLStr)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并查询参数
|
||||
originalQuery := c.Request.URL.Query()
|
||||
targetQuery := parsedURL.Query()
|
||||
for key, values := range originalQuery {
|
||||
for _, value := range values {
|
||||
targetQuery.Add(key, value)
|
||||
}
|
||||
}
|
||||
parsedURL.RawQuery = targetQuery.Encode()
|
||||
|
||||
// 重新构建目标URL字符串
|
||||
targetURLStr = parsedURL.String()
|
||||
|
||||
// 检查 URL 合法性
|
||||
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 执行代理请求
|
||||
executeProxy(c, targetURLStr)
|
||||
}
|
||||
|
||||
// executeProxy 执行实际的代理请求
|
||||
func executeProxy(c *gin.Context, targetURLStr string) {
|
||||
// 增加请求计数器
|
||||
reqID := atomic.AddInt64(&requestCounter, 1)
|
||||
|
||||
slog.Debug("收到请求",
|
||||
"reqID", reqID,
|
||||
"method", c.Request.Method,
|
||||
"uri", c.Request.RequestURI,
|
||||
"target", targetURLStr)
|
||||
|
||||
// 自定义 Transport,禁止自动压缩(避免 gzip 聚合导致 SSE 延迟)
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DisableCompression: true,
|
||||
}
|
||||
// 可复用的 HTTP 客户端(保持连接复用)
|
||||
transport := &http.Transport{Proxy: http.ProxyFromEnvironment, DisableCompression: true}
|
||||
client := &http.Client{Transport: transport}
|
||||
if cfg.RequestTimeout > 0 { client.Timeout = time.Duration(cfg.RequestTimeout) * time.Second }
|
||||
|
||||
// 创建到目标服务器的请求
|
||||
proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body)
|
||||
if err != nil {
|
||||
slog.Error("创建代理请求失败", "reqID", reqID, "error", err)
|
||||
c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err)
|
||||
return
|
||||
}
|
||||
p := proxy.New(client, logger)
|
||||
|
||||
// 复制原始请求的 Headers (Clone 避免引用共享)
|
||||
proxyReq.Header = c.Request.Header.Clone()
|
||||
// 禁止上游压缩,保证事件粒度
|
||||
proxyReq.Header.Del("Accept-Encoding")
|
||||
r := gin.New()
|
||||
r.Use(middleware.Recovery(logger), middleware.RequestID(), middleware.Logger(logger))
|
||||
|
||||
resp, err := client.Do(proxyReq)
|
||||
if err != nil {
|
||||
slog.Error("请求目标服务器失败", "reqID", reqID, "error", err)
|
||||
c.String(http.StatusBadGateway, "请求目标服务器失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
r.GET("/", proxy.HelloPage) // 欢迎页面
|
||||
r.Any("/proxy/*proxyPath", p.HandleProxyPath) // 处理 /proxy/*path 形式的请求
|
||||
r.Any(":protocol/*remainder", p.HandleProtocol) // 处理 /:protocol/*remainder 形式的请求
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
isSSE := strings.HasPrefix(contentType, "text/event-stream")
|
||||
logger.Info("服务器启动", "addr", cfg.Addr(), "debug", cfg.Debug, "version", version.Version, "commit", version.GitCommit)
|
||||
|
||||
slog.Debug("收到响应", "reqID", reqID, "status_code", resp.StatusCode, "status", resp.Status, "isSSE", isSSE)
|
||||
|
||||
// 复制响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
// 优雅停机设置:监听系统信号,执行平滑关闭
|
||||
srv := &http.Server{Addr: cfg.Addr(), Handler: r}
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("服务器监听错误", "error", err)
|
||||
}
|
||||
}
|
||||
// SSE 需要去掉不合适的头并设置必要头
|
||||
if isSSE {
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Del("Transfer-Encoding")
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no") // 防止某些反向代理缓冲
|
||||
}
|
||||
}()
|
||||
|
||||
// 设置状态码
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
// 立即 flush 头部,尤其是 SSE
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if !isSSE {
|
||||
// 普通请求直接复制主体
|
||||
bytesCopied, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
slog.Error("写入响应 Body 时出错", "reqID", reqID, "error", err)
|
||||
}
|
||||
slog.Debug("响应写入完成", "reqID", reqID, "bytes_copied", bytesCopied)
|
||||
return
|
||||
}
|
||||
|
||||
// SSE 模式:逐行读取并 flush,保持事件实时性
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
w := c.Writer
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
if _, werr := w.Write(line); werr != nil {
|
||||
slog.Warn("SSE 写失败", "reqID", reqID, "error", werr)
|
||||
return
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
slog.Debug("SSE 结束(EOF)", "reqID", reqID)
|
||||
} else {
|
||||
slog.Error("读取 SSE 失败", "reqID", reqID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-stop
|
||||
logger.Info("开始关闭 (收到退出信号)")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.ShutdownGrace)*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
logger.Error("关闭出错", "error", err)
|
||||
} else {
|
||||
logger.Info("关闭完成")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func HelloPage(c *gin.Context) {
|
||||
// 获取当前的请求计数
|
||||
count := atomic.LoadInt64(&requestCounter)
|
||||
str := fmt.Sprintf("AnyProxy 服务器正在运行... 已转发 %d 个请求", count)
|
||||
str += "\n\n使用方法:\n"
|
||||
str += "方式1 - 直接协议路径: \n"
|
||||
str += " 目标URL: https://example.com/path --> 代理URL: http://AnyproxyIP/https/example.com/path\n"
|
||||
str += " 目标URL: http://example.com/path --> 代理URL: http://AnyproxyIP/http/example.com/path\n\n"
|
||||
str += "方式2 - 完整URL路径: \n"
|
||||
str += " 目标URL: https://example.com --> 代理URL: http://AnyproxyIP/proxy/https://example.com\n\n"
|
||||
str += "目标URL必须以 https:// 或 http:// 开头。\n\n"
|
||||
c.String(200, str)
|
||||
}
|
||||
|
||||
// SlogLogger 统一请求日志中间件
|
||||
func SlogLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
rawQuery := c.Request.URL.RawQuery
|
||||
c.Next()
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
size := c.Writer.Size()
|
||||
method := c.Request.Method
|
||||
ip := c.ClientIP()
|
||||
if rawQuery != "" {
|
||||
path = path + "?" + rawQuery
|
||||
}
|
||||
slog.Log(c, slog.LevelInfo, "HTTP 请求",
|
||||
slog.String("method", method),
|
||||
slog.String("path", path),
|
||||
slog.Int("status", status),
|
||||
slog.Duration("latency", latency),
|
||||
slog.Int("size", size),
|
||||
slog.String("ip", ip),
|
||||
slog.String("ua", c.GetHeader("User-Agent")),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// SlogRecovery 捕获 panic,输出堆栈
|
||||
func SlogRecovery() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
stack := debug.Stack()
|
||||
slog.Error("发生 panic",
|
||||
"error", rcv,
|
||||
"stack", string(stack),
|
||||
"path", c.Request.URL.Path,
|
||||
)
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user