|
|
package com.example.demo.Util;
import com.example.demo.domain.vo.ExecutionContext; import jakarta.servlet.http.HttpServletRequest; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes;
import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.lang.management.ManagementFactory; import java.util.*; import java.util.stream.Collectors;
public class ExecutionContextUtil {
/** * 获取当前执行环境信息 * @param request 如果是Web请求,传入HttpServletRequest * @return 执行环境信息对象 */ /** * 从Spring上下文获取当前HttpServletRequest */ public static ExecutionContext getExecutionContext() { ExecutionContext context = new ExecutionContext(); context.setExecutionTime(new Date());
HttpServletRequest request = getCurrentHttpRequest();
if (isWebEnvironment(request)) { // Web API 环境
context.setExecutionType("API"); context.setApiUrl(getRealRequestUrl(request)); context.setRequestParams(getRequestParams(request)); context.setToken(getRequestToken(request)); context.setMethod(request.getMethod()); } else { // 脚本环境
context.setExecutionType("SCRIPT"); context.setScriptFile(getMainClassFile()); }
return context; }
private static HttpServletRequest getCurrentHttpRequest() { try { return ((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest(); } catch (IllegalStateException e) { // 不在Web请求上下文中
return null; } }
private static boolean isWebEnvironment(HttpServletRequest request) { return request != null; }
private static String getRealRequestUrl(HttpServletRequest request) { // 1. 获取协议(优先从代理头获取)
String protocol = getHeaderWithFallback(request, Arrays.asList("X-Forwarded-Proto", "X-Forwarded-Protocol"), request.getScheme() );
// 2. 获取真实域名(优先从代理头获取原始域名)
String domain = getHeaderWithFallback(request, Arrays.asList( "X-Original-Host", // 一些代理服务器设置的原始终端
"X-Real-Host", // 另一个可能的原始主机头
"X-Forwarded-Host", // 转发的主机头
"Host" // 最后回退到常规主机头
), request.getServerName() );
// 3. 获取端口(智能处理默认端口)
Integer port = getRealPort(request, protocol);
// 4. 获取原始路径(包括QueryString)
String path = getOriginalUri(request);
// 组装完整URL
return String.format("%s://%s:%s%s", protocol, domain, port, path ); }
// 辅助方法:带fallback的header获取
// 方法1:保持强类型(推荐)
private static String getHeaderWithFallback( HttpServletRequest request, List<String> headerNames, // 明确要求String列表
String defaultValue ) { return headerNames.stream() .map(request::getHeader) .filter(Objects::nonNull) .findFirst() .orElse(defaultValue); }
// 获取真实端口(处理代理情况)
private static int getRealPort(HttpServletRequest request, String protocol) { // 优先从代理头获取
String forwardedPort = request.getHeader("X-Forwarded-Port"); if (forwardedPort != null) { return Integer.parseInt(forwardedPort); }
// 其次从请求获取
int port = request.getServerPort();
// 处理反向代理场景
if (port == 80 && "https".equals(protocol)) { return 443; } if (port == 443 && "http".equals(protocol)) { return 80; } return port; }
// 获取原始URI(包含QueryString)
private static String getOriginalUri(HttpServletRequest request) { // 优先从代理头获取原始URI
String originalUri = request.getHeader("X-Original-URI"); if (originalUri != null) { return originalUri; }
// 默认从request获取
String queryString = request.getQueryString(); return request.getRequestURI() + (queryString != null ? "?" + queryString : ""); }
private static String getRequestParams(HttpServletRequest request) { try { // 1. 优先读取Query String(无需缓存)
String queryString = request.getQueryString(); if (queryString != null) return queryString;
// 2. 检查表单参数(GET/POST都适用)
Map<String, String[]> params = request.getParameterMap(); if (!params.isEmpty()) return formatParams(params);
// 3. 只有明确是JSON请求时才尝试读取body
if (isJsonRequest(request)) { return readJsonBodyOnDemand(request); }
return "{}"; } catch (Exception e) { return "{\"error\":\"failed to read params\"}"; } }
private static String readJsonBodyOnDemand(HttpServletRequest request) throws IOException { // 关键点:直接读取原始InputStream(不缓存)
try (BufferedReader reader = request.getReader()) { String body = reader.lines().collect(Collectors.joining()); return body.isEmpty() ? "{}" : body; } }
private static boolean isJsonRequest(HttpServletRequest request) { String contentType = request.getContentType(); return contentType != null && contentType.contains("application/json"); }
private static String formatParams(Map<String, String[]> params) { // 优化后的参数格式化方法
return params.entrySet().stream() .map(entry -> { String key = escapeJson(entry.getKey()); String[] values = entry.getValue(); if (values.length == 1) { return "\"" + key + "\":\"" + escapeJson(values[0]) + "\""; } return "\"" + key + "\":[" + Arrays.stream(values) .map(v -> "\"" + escapeJson(v) + "\"") .collect(Collectors.joining(",")) + "]"; }) .collect(Collectors.joining(",", "{", "}")); }
private static String escapeJson(String raw) { return raw.replace("\\", "\\\\") .replace("\"", "\\\"") .replace("\n", "\\n"); }
private static String getRequestToken(HttpServletRequest request) { String token = request.getHeader("Authorization"); if (token == null) { token = request.getHeader("token"); } return token; }
private static String getMainClassFile() { try { // 获取主类名
String mainClass = ManagementFactory.getRuntimeMXBean().getSystemProperties().get("sun.java.command"); if (mainClass != null) { // 简单处理,提取主类名
String className = mainClass.split(" ")[0]; // 转换为文件路径
return className.replace('.', File.separatorChar) + ".java"; } } catch (Exception e) { e.printStackTrace(); } return "UnknownScript"; } }
|