Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 45fd9e40e5 |
99
main.go
99
main.go
@@ -4,8 +4,10 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -21,8 +23,38 @@ func main() {
|
||||
|
||||
port := flag.Int("port", 8080, "代理服务器监听的端口")
|
||||
debug := flag.Bool("debug", false, "是否启用调试模式")
|
||||
logFile := flag.String("log", "", "日志文件路径,默认为标准输出")
|
||||
flag.Parse()
|
||||
|
||||
// 配置slog
|
||||
var logger *slog.Logger
|
||||
|
||||
// 根据调试模式设置日志级别
|
||||
var logLevel slog.Level
|
||||
if *debug {
|
||||
logLevel = slog.LevelDebug
|
||||
} else {
|
||||
logLevel = slog.LevelInfo
|
||||
}
|
||||
|
||||
// 创建处理器选项
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: logLevel,
|
||||
}
|
||||
|
||||
if *logFile != "" {
|
||||
file, err := os.OpenFile(*logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
slog.Error("无法打开日志文件", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger = slog.New(slog.NewJSONHandler(file, opts))
|
||||
} else {
|
||||
logger = slog.New(slog.NewJSONHandler(os.Stdout, opts))
|
||||
}
|
||||
slog.SetDefault(logger)
|
||||
|
||||
|
||||
if *debug {
|
||||
gin.SetMode(gin.DebugMode) // 启用调试模式
|
||||
} else {
|
||||
@@ -42,9 +74,9 @@ func main() {
|
||||
// 检查是否以协议开头的路径
|
||||
r.Any("/:protocol/*remainder", protocolHandler)
|
||||
|
||||
fmt.Printf("HTTP 代理服务器启动,监听端口 :%d\n", *port)
|
||||
slog.Info("HTTP 代理服务器启动", "port", *port)
|
||||
if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil {
|
||||
fmt.Printf("启动服务器失败: %v\n", err)
|
||||
slog.Error("启动服务器失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +101,26 @@ func proxyHandler(c *gin.Context) {
|
||||
// 规范化URL格式
|
||||
targetURLStr = normalizeURL(targetURLStr)
|
||||
|
||||
// 解析目标URL
|
||||
parsedURL, err := url.Parse(targetURLStr)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并查询参数
|
||||
originalQuery := c.Request.URL.Query()
|
||||
targetQuery := parsedURL.Query()
|
||||
for key, values := range originalQuery {
|
||||
for _, value := range values {
|
||||
targetQuery.Add(key, value)
|
||||
}
|
||||
}
|
||||
parsedURL.RawQuery = targetQuery.Encode()
|
||||
|
||||
// 重新构建目标URL字符串
|
||||
targetURLStr = parsedURL.String()
|
||||
|
||||
// 检查 URL 合法性
|
||||
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
@@ -96,6 +148,26 @@ func protocolHandler(c *gin.Context) {
|
||||
// 规范化URL格式
|
||||
targetURLStr = normalizeURL(targetURLStr)
|
||||
|
||||
// 解析目标URL
|
||||
parsedURL, err := url.Parse(targetURLStr)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 合并查询参数
|
||||
originalQuery := c.Request.URL.Query()
|
||||
targetQuery := parsedURL.Query()
|
||||
for key, values := range originalQuery {
|
||||
for _, value := range values {
|
||||
targetQuery.Add(key, value)
|
||||
}
|
||||
}
|
||||
parsedURL.RawQuery = targetQuery.Encode()
|
||||
|
||||
// 重新构建目标URL字符串
|
||||
targetURLStr = parsedURL.String()
|
||||
|
||||
// 检查 URL 合法性
|
||||
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||
@@ -109,12 +181,19 @@ func protocolHandler(c *gin.Context) {
|
||||
// executeProxy 执行实际的代理请求
|
||||
func executeProxy(c *gin.Context, targetURLStr string) {
|
||||
// 增加请求计数器
|
||||
atomic.AddInt64(&requestCounter, 1)
|
||||
reqID := atomic.AddInt64(&requestCounter, 1)
|
||||
|
||||
slog.Debug("收到请求",
|
||||
"reqID", reqID,
|
||||
"method", c.Request.Method,
|
||||
"uri", c.Request.RequestURI,
|
||||
"target", targetURLStr)
|
||||
|
||||
// 创建到目标服务器的请求
|
||||
// 注意:我们直接将原始请求的 Body 传递过去
|
||||
proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body)
|
||||
if err != nil {
|
||||
slog.Error("创建代理请求失败", "reqID", reqID, "error", err)
|
||||
c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -126,11 +205,17 @@ func executeProxy(c *gin.Context, targetURLStr string) {
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(proxyReq)
|
||||
if err != nil {
|
||||
c.String(http.StatusBadGateway, "请求目标服务器失败: %v", err)
|
||||
slog.Error("请求目标服务器失败", "reqID", reqID, "error", err)
|
||||
c.String(http.StatusBadGateway, "请求目标服务器失败: %s", err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
slog.Debug("收到响应",
|
||||
"reqID", reqID,
|
||||
"status_code", resp.StatusCode,
|
||||
"status", resp.Status)
|
||||
|
||||
// 复制目标服务器响应的 Headers 到原始响应
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
@@ -143,11 +228,11 @@ func executeProxy(c *gin.Context, targetURLStr string) {
|
||||
|
||||
// 将目标服务器的响应 Body 直接流式传输到客户端
|
||||
// 使用 io.Copy 更高效,并能处理各种编码(如 chunked)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
bytesCopied, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
// 如果在写入 body 时发生错误,记录下来
|
||||
fmt.Printf("写入响应 Body 时出错: %v\n", err)
|
||||
slog.Error("写入响应 Body 时出错", "reqID", reqID, "error", err)
|
||||
}
|
||||
slog.Debug("响应写入完成", "reqID", reqID, "bytes_copied", bytesCopied)
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user