You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
7.9 KiB

package com.example.demo.Util;
import com.example.demo.domain.vo.ExecutionContext;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.util.*;
import java.util.stream.Collectors;
public class ExecutionContextUtil {
/**
* 获取当前执行环境信息
* @param request 如果是Web请求,传入HttpServletRequest
* @return 执行环境信息对象
*/
/**
* 从Spring上下文获取当前HttpServletRequest
*/
public static ExecutionContext getExecutionContext() {
ExecutionContext context = new ExecutionContext();
context.setExecutionTime(new Date());
HttpServletRequest request = getCurrentHttpRequest();
if (isWebEnvironment(request)) {
// Web API 环境
context.setExecutionType("API");
context.setApiUrl(getRealRequestUrl(request));
context.setRequestParams(getRequestParams(request));
context.setToken(getRequestToken(request));
context.setMethod(request.getMethod());
} else {
// 脚本环境
context.setExecutionType("SCRIPT");
context.setScriptFile(getMainClassFile());
}
return context;
}
private static HttpServletRequest getCurrentHttpRequest() {
try {
return ((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest();
} catch (IllegalStateException e) {
// 不在Web请求上下文中
return null;
}
}
private static boolean isWebEnvironment(HttpServletRequest request) {
return request != null;
}
private static String getRealRequestUrl(HttpServletRequest request) {
// 1. 获取协议(优先从代理头获取)
String protocol = getHeaderWithFallback(request,
Arrays.asList("X-Forwarded-Proto", "X-Forwarded-Protocol"),
request.getScheme()
);
// 2. 获取真实域名(优先从代理头获取原始域名)
String domain = getHeaderWithFallback(request,
Arrays.asList(
"X-Original-Host", // 一些代理服务器设置的原始终端
"X-Real-Host", // 另一个可能的原始主机头
"X-Forwarded-Host", // 转发的主机头
"Host" // 最后回退到常规主机头
),
request.getServerName()
);
// 3. 获取端口(智能处理默认端口)
Integer port = getRealPort(request, protocol);
// 4. 获取原始路径(包括QueryString)
String path = getOriginalUri(request);
// 组装完整URL
return String.format("%s://%s:%s%s",
protocol,
domain,
port,
path
);
}
// 辅助方法:带fallback的header获取
// 方法1:保持强类型(推荐)
private static String getHeaderWithFallback(
HttpServletRequest request,
List<String> headerNames, // 明确要求String列表
String defaultValue
) {
return headerNames.stream()
.map(request::getHeader)
.filter(Objects::nonNull)
.findFirst()
.orElse(defaultValue);
}
// 获取真实端口(处理代理情况)
private static int getRealPort(HttpServletRequest request, String protocol) {
// 优先从代理头获取
String forwardedPort = request.getHeader("X-Forwarded-Port");
if (forwardedPort != null) {
return Integer.parseInt(forwardedPort);
}
// 其次从请求获取
int port = request.getServerPort();
// 处理反向代理场景
if (port == 80 && "https".equals(protocol)) {
return 443;
}
if (port == 443 && "http".equals(protocol)) {
return 80;
}
return port;
}
// 获取原始URI(包含QueryString)
private static String getOriginalUri(HttpServletRequest request) {
// 优先从代理头获取原始URI
String originalUri = request.getHeader("X-Original-URI");
if (originalUri != null) {
return originalUri;
}
// 默认从request获取
String queryString = request.getQueryString();
return request.getRequestURI() +
(queryString != null ? "?" + queryString : "");
}
private static String getRequestParams(HttpServletRequest request) {
try {
// 1. 优先读取Query String(无需缓存)
String queryString = request.getQueryString();
if (queryString != null) return queryString;
// 2. 检查表单参数(GET/POST都适用)
Map<String, String[]> params = request.getParameterMap();
if (!params.isEmpty()) return formatParams(params);
// 3. 只有明确是JSON请求时才尝试读取body
if (isJsonRequest(request)) {
return readJsonBodyOnDemand(request);
}
return "{}";
} catch (Exception e) {
return "{\"error\":\"failed to read params\"}";
}
}
private static String readJsonBodyOnDemand(HttpServletRequest request) throws IOException {
// 关键点:直接读取原始InputStream(不缓存)
try (BufferedReader reader = request.getReader()) {
String body = reader.lines().collect(Collectors.joining());
return body.isEmpty() ? "{}" : body;
}
}
private static boolean isJsonRequest(HttpServletRequest request) {
String contentType = request.getContentType();
return contentType != null && contentType.contains("application/json");
}
private static String formatParams(Map<String, String[]> params) {
// 优化后的参数格式化方法
return params.entrySet().stream()
.map(entry -> {
String key = escapeJson(entry.getKey());
String[] values = entry.getValue();
if (values.length == 1) {
return "\"" + key + "\":\"" + escapeJson(values[0]) + "\"";
}
return "\"" + key + "\":[" +
Arrays.stream(values)
.map(v -> "\"" + escapeJson(v) + "\"")
.collect(Collectors.joining(",")) +
"]";
})
.collect(Collectors.joining(",", "{", "}"));
}
private static String escapeJson(String raw) {
return raw.replace("\\", "\\\\")
.replace("\"", "\\\"")
.replace("\n", "\\n");
}
private static String getRequestToken(HttpServletRequest request) {
String token = request.getHeader("Authorization");
if (token == null) {
token = request.getHeader("token");
}
return token;
}
private static String getMainClassFile() {
try {
// 获取主类名
String mainClass = ManagementFactory.getRuntimeMXBean().getSystemProperties().get("sun.java.command");
if (mainClass != null) {
// 简单处理,提取主类名
String className = mainClass.split(" ")[0];
// 转换为文件路径
return className.replace('.', File.separatorChar) + ".java";
}
} catch (Exception e) {
e.printStackTrace();
}
return "UnknownScript";
}
}