You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
7.9 KiB
227 lines
7.9 KiB
package com.example.demo.Util;
|
|
|
|
|
|
import com.example.demo.domain.vo.ExecutionContext;
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
import org.springframework.web.context.request.RequestContextHolder;
|
|
import org.springframework.web.context.request.ServletRequestAttributes;
|
|
|
|
|
|
import java.io.BufferedReader;
|
|
import java.io.File;
|
|
import java.io.IOException;
|
|
import java.lang.management.ManagementFactory;
|
|
import java.util.*;
|
|
import java.util.stream.Collectors;
|
|
|
|
public class ExecutionContextUtil {
|
|
|
|
/**
|
|
* 获取当前执行环境信息
|
|
* @param request 如果是Web请求,传入HttpServletRequest
|
|
* @return 执行环境信息对象
|
|
*/
|
|
/**
|
|
* 从Spring上下文获取当前HttpServletRequest
|
|
*/
|
|
public static ExecutionContext getExecutionContext() {
|
|
ExecutionContext context = new ExecutionContext();
|
|
context.setExecutionTime(new Date());
|
|
|
|
HttpServletRequest request = getCurrentHttpRequest();
|
|
|
|
if (isWebEnvironment(request)) {
|
|
// Web API 环境
|
|
context.setExecutionType("API");
|
|
context.setApiUrl(getRealRequestUrl(request));
|
|
context.setRequestParams(getRequestParams(request));
|
|
context.setToken(getRequestToken(request));
|
|
context.setMethod(request.getMethod());
|
|
} else {
|
|
// 脚本环境
|
|
context.setExecutionType("SCRIPT");
|
|
context.setScriptFile(getMainClassFile());
|
|
}
|
|
|
|
return context;
|
|
}
|
|
|
|
private static HttpServletRequest getCurrentHttpRequest() {
|
|
try {
|
|
return ((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest();
|
|
} catch (IllegalStateException e) {
|
|
// 不在Web请求上下文中
|
|
return null;
|
|
}
|
|
}
|
|
|
|
private static boolean isWebEnvironment(HttpServletRequest request) {
|
|
return request != null;
|
|
}
|
|
|
|
private static String getRealRequestUrl(HttpServletRequest request) {
|
|
// 1. 获取协议(优先从代理头获取)
|
|
String protocol = getHeaderWithFallback(request,
|
|
Arrays.asList("X-Forwarded-Proto", "X-Forwarded-Protocol"),
|
|
request.getScheme()
|
|
);
|
|
|
|
// 2. 获取真实域名(优先从代理头获取原始域名)
|
|
String domain = getHeaderWithFallback(request,
|
|
Arrays.asList(
|
|
"X-Original-Host", // 一些代理服务器设置的原始终端
|
|
"X-Real-Host", // 另一个可能的原始主机头
|
|
"X-Forwarded-Host", // 转发的主机头
|
|
"Host" // 最后回退到常规主机头
|
|
),
|
|
request.getServerName()
|
|
);
|
|
|
|
// 3. 获取端口(智能处理默认端口)
|
|
Integer port = getRealPort(request, protocol);
|
|
|
|
// 4. 获取原始路径(包括QueryString)
|
|
String path = getOriginalUri(request);
|
|
|
|
// 组装完整URL
|
|
return String.format("%s://%s:%s%s",
|
|
protocol,
|
|
domain,
|
|
port,
|
|
path
|
|
);
|
|
}
|
|
|
|
// 辅助方法:带fallback的header获取
|
|
// 方法1:保持强类型(推荐)
|
|
private static String getHeaderWithFallback(
|
|
HttpServletRequest request,
|
|
List<String> headerNames, // 明确要求String列表
|
|
String defaultValue
|
|
) {
|
|
return headerNames.stream()
|
|
.map(request::getHeader)
|
|
.filter(Objects::nonNull)
|
|
.findFirst()
|
|
.orElse(defaultValue);
|
|
}
|
|
|
|
// 获取真实端口(处理代理情况)
|
|
private static int getRealPort(HttpServletRequest request, String protocol) {
|
|
// 优先从代理头获取
|
|
String forwardedPort = request.getHeader("X-Forwarded-Port");
|
|
if (forwardedPort != null) {
|
|
return Integer.parseInt(forwardedPort);
|
|
}
|
|
|
|
// 其次从请求获取
|
|
int port = request.getServerPort();
|
|
|
|
// 处理反向代理场景
|
|
if (port == 80 && "https".equals(protocol)) {
|
|
return 443;
|
|
}
|
|
if (port == 443 && "http".equals(protocol)) {
|
|
return 80;
|
|
}
|
|
return port;
|
|
}
|
|
|
|
// 获取原始URI(包含QueryString)
|
|
private static String getOriginalUri(HttpServletRequest request) {
|
|
// 优先从代理头获取原始URI
|
|
String originalUri = request.getHeader("X-Original-URI");
|
|
if (originalUri != null) {
|
|
return originalUri;
|
|
}
|
|
|
|
// 默认从request获取
|
|
String queryString = request.getQueryString();
|
|
return request.getRequestURI() +
|
|
(queryString != null ? "?" + queryString : "");
|
|
}
|
|
|
|
private static String getRequestParams(HttpServletRequest request) {
|
|
try {
|
|
// 1. 优先读取Query String(无需缓存)
|
|
String queryString = request.getQueryString();
|
|
if (queryString != null) return queryString;
|
|
|
|
// 2. 检查表单参数(GET/POST都适用)
|
|
Map<String, String[]> params = request.getParameterMap();
|
|
if (!params.isEmpty()) return formatParams(params);
|
|
|
|
// 3. 只有明确是JSON请求时才尝试读取body
|
|
if (isJsonRequest(request)) {
|
|
return readJsonBodyOnDemand(request);
|
|
}
|
|
|
|
return "{}";
|
|
} catch (Exception e) {
|
|
return "{\"error\":\"failed to read params\"}";
|
|
}
|
|
}
|
|
|
|
private static String readJsonBodyOnDemand(HttpServletRequest request) throws IOException {
|
|
// 关键点:直接读取原始InputStream(不缓存)
|
|
try (BufferedReader reader = request.getReader()) {
|
|
String body = reader.lines().collect(Collectors.joining());
|
|
return body.isEmpty() ? "{}" : body;
|
|
}
|
|
}
|
|
|
|
|
|
private static boolean isJsonRequest(HttpServletRequest request) {
|
|
String contentType = request.getContentType();
|
|
return contentType != null && contentType.contains("application/json");
|
|
}
|
|
|
|
|
|
private static String formatParams(Map<String, String[]> params) {
|
|
// 优化后的参数格式化方法
|
|
return params.entrySet().stream()
|
|
.map(entry -> {
|
|
String key = escapeJson(entry.getKey());
|
|
String[] values = entry.getValue();
|
|
if (values.length == 1) {
|
|
return "\"" + key + "\":\"" + escapeJson(values[0]) + "\"";
|
|
}
|
|
return "\"" + key + "\":[" +
|
|
Arrays.stream(values)
|
|
.map(v -> "\"" + escapeJson(v) + "\"")
|
|
.collect(Collectors.joining(",")) +
|
|
"]";
|
|
})
|
|
.collect(Collectors.joining(",", "{", "}"));
|
|
}
|
|
|
|
private static String escapeJson(String raw) {
|
|
return raw.replace("\\", "\\\\")
|
|
.replace("\"", "\\\"")
|
|
.replace("\n", "\\n");
|
|
}
|
|
|
|
private static String getRequestToken(HttpServletRequest request) {
|
|
String token = request.getHeader("Authorization");
|
|
if (token == null) {
|
|
token = request.getHeader("token");
|
|
}
|
|
return token;
|
|
}
|
|
|
|
private static String getMainClassFile() {
|
|
try {
|
|
// 获取主类名
|
|
String mainClass = ManagementFactory.getRuntimeMXBean().getSystemProperties().get("sun.java.command");
|
|
if (mainClass != null) {
|
|
// 简单处理,提取主类名
|
|
String className = mainClass.split(" ")[0];
|
|
// 转换为文件路径
|
|
return className.replace('.', File.separatorChar) + ".java";
|
|
}
|
|
} catch (Exception e) {
|
|
e.printStackTrace();
|
|
}
|
|
return "UnknownScript";
|
|
}
|
|
}
|