Compare commits
2 Commits
v1.0.0
...
45fd9e40e5
| Author | SHA1 | Date | |
|---|---|---|---|
| 45fd9e40e5 | |||
| 481b523fe7 |
101
main.go
101
main.go
@@ -4,8 +4,10 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -19,9 +21,39 @@ var requestCounter int64
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
port := flag.Int("port", 8080, "代理服务器监听的端口")
|
port := flag.Int("port", 8080, "代理服务器监听的端口")
|
||||||
debug := flag.Bool("debug", false, "是否启用调试模式")
|
debug := flag.Bool("debug", false, "是否启用调试模式")
|
||||||
|
logFile := flag.String("log", "", "日志文件路径,默认为标准输出")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// 配置slog
|
||||||
|
var logger *slog.Logger
|
||||||
|
|
||||||
|
// 根据调试模式设置日志级别
|
||||||
|
var logLevel slog.Level
|
||||||
|
if *debug {
|
||||||
|
logLevel = slog.LevelDebug
|
||||||
|
} else {
|
||||||
|
logLevel = slog.LevelInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建处理器选项
|
||||||
|
opts := &slog.HandlerOptions{
|
||||||
|
Level: logLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
if *logFile != "" {
|
||||||
|
file, err := os.OpenFile(*logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("无法打开日志文件", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
logger = slog.New(slog.NewJSONHandler(file, opts))
|
||||||
|
} else {
|
||||||
|
logger = slog.New(slog.NewJSONHandler(os.Stdout, opts))
|
||||||
|
}
|
||||||
|
slog.SetDefault(logger)
|
||||||
|
|
||||||
|
|
||||||
if *debug {
|
if *debug {
|
||||||
gin.SetMode(gin.DebugMode) // 启用调试模式
|
gin.SetMode(gin.DebugMode) // 启用调试模式
|
||||||
@@ -42,9 +74,9 @@ func main() {
|
|||||||
// 检查是否以协议开头的路径
|
// 检查是否以协议开头的路径
|
||||||
r.Any("/:protocol/*remainder", protocolHandler)
|
r.Any("/:protocol/*remainder", protocolHandler)
|
||||||
|
|
||||||
fmt.Printf("HTTP 代理服务器启动,监听端口 :%d\n", *port)
|
slog.Info("HTTP 代理服务器启动", "port", *port)
|
||||||
if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil {
|
if err := r.Run(fmt.Sprintf(":%d", *port)); err != nil {
|
||||||
fmt.Printf("启动服务器失败: %v\n", err)
|
slog.Error("启动服务器失败", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,6 +101,26 @@ func proxyHandler(c *gin.Context) {
|
|||||||
// 规范化URL格式
|
// 规范化URL格式
|
||||||
targetURLStr = normalizeURL(targetURLStr)
|
targetURLStr = normalizeURL(targetURLStr)
|
||||||
|
|
||||||
|
// 解析目标URL
|
||||||
|
parsedURL, err := url.Parse(targetURLStr)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并查询参数
|
||||||
|
originalQuery := c.Request.URL.Query()
|
||||||
|
targetQuery := parsedURL.Query()
|
||||||
|
for key, values := range originalQuery {
|
||||||
|
for _, value := range values {
|
||||||
|
targetQuery.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedURL.RawQuery = targetQuery.Encode()
|
||||||
|
|
||||||
|
// 重新构建目标URL字符串
|
||||||
|
targetURLStr = parsedURL.String()
|
||||||
|
|
||||||
// 检查 URL 合法性
|
// 检查 URL 合法性
|
||||||
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
||||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||||
@@ -96,6 +148,26 @@ func protocolHandler(c *gin.Context) {
|
|||||||
// 规范化URL格式
|
// 规范化URL格式
|
||||||
targetURLStr = normalizeURL(targetURLStr)
|
targetURLStr = normalizeURL(targetURLStr)
|
||||||
|
|
||||||
|
// 解析目标URL
|
||||||
|
parsedURL, err := url.Parse(targetURLStr)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并查询参数
|
||||||
|
originalQuery := c.Request.URL.Query()
|
||||||
|
targetQuery := parsedURL.Query()
|
||||||
|
for key, values := range originalQuery {
|
||||||
|
for _, value := range values {
|
||||||
|
targetQuery.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedURL.RawQuery = targetQuery.Encode()
|
||||||
|
|
||||||
|
// 重新构建目标URL字符串
|
||||||
|
targetURLStr = parsedURL.String()
|
||||||
|
|
||||||
// 检查 URL 合法性
|
// 检查 URL 合法性
|
||||||
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
if _, err := url.ParseRequestURI(targetURLStr); err != nil {
|
||||||
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
c.String(http.StatusBadRequest, "无效的目标 URL: %v", err)
|
||||||
@@ -109,12 +181,19 @@ func protocolHandler(c *gin.Context) {
|
|||||||
// executeProxy 执行实际的代理请求
|
// executeProxy 执行实际的代理请求
|
||||||
func executeProxy(c *gin.Context, targetURLStr string) {
|
func executeProxy(c *gin.Context, targetURLStr string) {
|
||||||
// 增加请求计数器
|
// 增加请求计数器
|
||||||
atomic.AddInt64(&requestCounter, 1)
|
reqID := atomic.AddInt64(&requestCounter, 1)
|
||||||
|
|
||||||
|
slog.Debug("收到请求",
|
||||||
|
"reqID", reqID,
|
||||||
|
"method", c.Request.Method,
|
||||||
|
"uri", c.Request.RequestURI,
|
||||||
|
"target", targetURLStr)
|
||||||
|
|
||||||
// 创建到目标服务器的请求
|
// 创建到目标服务器的请求
|
||||||
// 注意:我们直接将原始请求的 Body 传递过去
|
// 注意:我们直接将原始请求的 Body 传递过去
|
||||||
proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body)
|
proxyReq, err := http.NewRequest(c.Request.Method, targetURLStr, c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Error("创建代理请求失败", "reqID", reqID, "error", err)
|
||||||
c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err)
|
c.String(http.StatusInternalServerError, "创建代理请求失败: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -126,11 +205,17 @@ func executeProxy(c *gin.Context, targetURLStr string) {
|
|||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(proxyReq)
|
resp, err := client.Do(proxyReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadGateway, "请求目标服务器失败: %v", err)
|
slog.Error("请求目标服务器失败", "reqID", reqID, "error", err)
|
||||||
|
c.String(http.StatusBadGateway, "请求目标服务器失败: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
slog.Debug("收到响应",
|
||||||
|
"reqID", reqID,
|
||||||
|
"status_code", resp.StatusCode,
|
||||||
|
"status", resp.Status)
|
||||||
|
|
||||||
// 复制目标服务器响应的 Headers 到原始响应
|
// 复制目标服务器响应的 Headers 到原始响应
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
@@ -143,11 +228,11 @@ func executeProxy(c *gin.Context, targetURLStr string) {
|
|||||||
|
|
||||||
// 将目标服务器的响应 Body 直接流式传输到客户端
|
// 将目标服务器的响应 Body 直接流式传输到客户端
|
||||||
// 使用 io.Copy 更高效,并能处理各种编码(如 chunked)
|
// 使用 io.Copy 更高效,并能处理各种编码(如 chunked)
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
bytesCopied, err := io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 如果在写入 body 时发生错误,记录下来
|
slog.Error("写入响应 Body 时出错", "reqID", reqID, "error", err)
|
||||||
fmt.Printf("写入响应 Body 时出错: %v\n", err)
|
|
||||||
}
|
}
|
||||||
|
slog.Debug("响应写入完成", "reqID", reqID, "bytes_copied", bytesCopied)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user