You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

226 lines
7.9 KiB

3 days ago
  1. package com.example.demo.Util;
  2. import com.example.demo.domain.vo.ExecutionContext;
  3. import jakarta.servlet.http.HttpServletRequest;
  4. import org.springframework.web.context.request.RequestContextHolder;
  5. import org.springframework.web.context.request.ServletRequestAttributes;
  6. import java.io.BufferedReader;
  7. import java.io.File;
  8. import java.io.IOException;
  9. import java.lang.management.ManagementFactory;
  10. import java.util.*;
  11. import java.util.stream.Collectors;
  12. public class ExecutionContextUtil {
  13. /**
  14. * 获取当前执行环境信息
  15. * @param request 如果是Web请求传入HttpServletRequest
  16. * @return 执行环境信息对象
  17. */
  18. /**
  19. * 从Spring上下文获取当前HttpServletRequest
  20. */
  21. public static ExecutionContext getExecutionContext() {
  22. ExecutionContext context = new ExecutionContext();
  23. context.setExecutionTime(new Date());
  24. HttpServletRequest request = getCurrentHttpRequest();
  25. if (isWebEnvironment(request)) {
  26. // Web API 环境
  27. context.setExecutionType("API");
  28. context.setApiUrl(getRealRequestUrl(request));
  29. context.setRequestParams(getRequestParams(request));
  30. context.setToken(getRequestToken(request));
  31. context.setMethod(request.getMethod());
  32. } else {
  33. // 脚本环境
  34. context.setExecutionType("SCRIPT");
  35. context.setScriptFile(getMainClassFile());
  36. }
  37. return context;
  38. }
  39. private static HttpServletRequest getCurrentHttpRequest() {
  40. try {
  41. return ((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest();
  42. } catch (IllegalStateException e) {
  43. // 不在Web请求上下文中
  44. return null;
  45. }
  46. }
  47. private static boolean isWebEnvironment(HttpServletRequest request) {
  48. return request != null;
  49. }
  50. private static String getRealRequestUrl(HttpServletRequest request) {
  51. // 1. 获取协议(优先从代理头获取)
  52. String protocol = getHeaderWithFallback(request,
  53. Arrays.asList("X-Forwarded-Proto", "X-Forwarded-Protocol"),
  54. request.getScheme()
  55. );
  56. // 2. 获取真实域名(优先从代理头获取原始域名)
  57. String domain = getHeaderWithFallback(request,
  58. Arrays.asList(
  59. "X-Original-Host", // 一些代理服务器设置的原始终端
  60. "X-Real-Host", // 另一个可能的原始主机头
  61. "X-Forwarded-Host", // 转发的主机头
  62. "Host" // 最后回退到常规主机头
  63. ),
  64. request.getServerName()
  65. );
  66. // 3. 获取端口(智能处理默认端口)
  67. Integer port = getRealPort(request, protocol);
  68. // 4. 获取原始路径(包括QueryString)
  69. String path = getOriginalUri(request);
  70. // 组装完整URL
  71. return String.format("%s://%s:%s%s",
  72. protocol,
  73. domain,
  74. port,
  75. path
  76. );
  77. }
  78. // 辅助方法:带fallback的header获取
  79. // 方法1:保持强类型(推荐)
  80. private static String getHeaderWithFallback(
  81. HttpServletRequest request,
  82. List<String> headerNames, // 明确要求String列表
  83. String defaultValue
  84. ) {
  85. return headerNames.stream()
  86. .map(request::getHeader)
  87. .filter(Objects::nonNull)
  88. .findFirst()
  89. .orElse(defaultValue);
  90. }
  91. // 获取真实端口(处理代理情况)
  92. private static int getRealPort(HttpServletRequest request, String protocol) {
  93. // 优先从代理头获取
  94. String forwardedPort = request.getHeader("X-Forwarded-Port");
  95. if (forwardedPort != null) {
  96. return Integer.parseInt(forwardedPort);
  97. }
  98. // 其次从请求获取
  99. int port = request.getServerPort();
  100. // 处理反向代理场景
  101. if (port == 80 && "https".equals(protocol)) {
  102. return 443;
  103. }
  104. if (port == 443 && "http".equals(protocol)) {
  105. return 80;
  106. }
  107. return port;
  108. }
  109. // 获取原始URI(包含QueryString)
  110. private static String getOriginalUri(HttpServletRequest request) {
  111. // 优先从代理头获取原始URI
  112. String originalUri = request.getHeader("X-Original-URI");
  113. if (originalUri != null) {
  114. return originalUri;
  115. }
  116. // 默认从request获取
  117. String queryString = request.getQueryString();
  118. return request.getRequestURI() +
  119. (queryString != null ? "?" + queryString : "");
  120. }
  121. private static String getRequestParams(HttpServletRequest request) {
  122. try {
  123. // 1. 优先读取Query String(无需缓存)
  124. String queryString = request.getQueryString();
  125. if (queryString != null) return queryString;
  126. // 2. 检查表单参数(GET/POST都适用)
  127. Map<String, String[]> params = request.getParameterMap();
  128. if (!params.isEmpty()) return formatParams(params);
  129. // 3. 只有明确是JSON请求时才尝试读取body
  130. if (isJsonRequest(request)) {
  131. return readJsonBodyOnDemand(request);
  132. }
  133. return "{}";
  134. } catch (Exception e) {
  135. return "{\"error\":\"failed to read params\"}";
  136. }
  137. }
  138. private static String readJsonBodyOnDemand(HttpServletRequest request) throws IOException {
  139. // 关键点:直接读取原始InputStream(不缓存)
  140. try (BufferedReader reader = request.getReader()) {
  141. String body = reader.lines().collect(Collectors.joining());
  142. return body.isEmpty() ? "{}" : body;
  143. }
  144. }
  145. private static boolean isJsonRequest(HttpServletRequest request) {
  146. String contentType = request.getContentType();
  147. return contentType != null && contentType.contains("application/json");
  148. }
  149. private static String formatParams(Map<String, String[]> params) {
  150. // 优化后的参数格式化方法
  151. return params.entrySet().stream()
  152. .map(entry -> {
  153. String key = escapeJson(entry.getKey());
  154. String[] values = entry.getValue();
  155. if (values.length == 1) {
  156. return "\"" + key + "\":\"" + escapeJson(values[0]) + "\"";
  157. }
  158. return "\"" + key + "\":[" +
  159. Arrays.stream(values)
  160. .map(v -> "\"" + escapeJson(v) + "\"")
  161. .collect(Collectors.joining(",")) +
  162. "]";
  163. })
  164. .collect(Collectors.joining(",", "{", "}"));
  165. }
  166. private static String escapeJson(String raw) {
  167. return raw.replace("\\", "\\\\")
  168. .replace("\"", "\\\"")
  169. .replace("\n", "\\n");
  170. }
  171. private static String getRequestToken(HttpServletRequest request) {
  172. String token = request.getHeader("Authorization");
  173. if (token == null) {
  174. token = request.getHeader("token");
  175. }
  176. return token;
  177. }
  178. private static String getMainClassFile() {
  179. try {
  180. // 获取主类名
  181. String mainClass = ManagementFactory.getRuntimeMXBean().getSystemProperties().get("sun.java.command");
  182. if (mainClass != null) {
  183. // 简单处理,提取主类名
  184. String className = mainClass.split(" ")[0];
  185. // 转换为文件路径
  186. return className.replace('.', File.separatorChar) + ".java";
  187. }
  188. } catch (Exception e) {
  189. e.printStackTrace();
  190. }
  191. return "UnknownScript";
  192. }
  193. }