You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
6.6 KiB

  1. package com.deepchart;
  2. import com.alibaba.dashscope.aigc.generation.Generation;
  3. import com.alibaba.dashscope.aigc.generation.GenerationParam;
  4. import com.alibaba.dashscope.aigc.generation.GenerationResult;
  5. import com.alibaba.dashscope.common.Message;
  6. import com.alibaba.dashscope.common.Role;
  7. import com.alibaba.dashscope.exception.InputRequiredException;
  8. import com.alibaba.dashscope.exception.NoApiKeyException;
  9. import com.alibaba.dashscope.tools.FunctionDefinition;
  10. import com.alibaba.dashscope.tools.ToolCallBase;
  11. import com.alibaba.dashscope.tools.ToolCallFunction;
  12. import com.alibaba.dashscope.tools.ToolFunction;
  13. import com.alibaba.dashscope.utils.JsonUtils;
  14. import com.fasterxml.jackson.databind.JsonNode;
  15. import com.fasterxml.jackson.databind.ObjectMapper;
  16. import java.util.ArrayList;
  17. import java.util.Arrays;
  18. import java.util.List;
  19. import java.util.Random;
  20. public class Main {
  21. // 若使用新加坡地域的模型,请释放下列注释
  22. // static {Constants.baseHttpApiUrl="https://dashscope-intl.aliyuncs.com/api/v1";}
  23. /**
  24. * 第一步定义工具的本地实现
  25. * @param arguments 模型传入的包含工具所需参数的JSON字符串
  26. * @return 工具执行后的结果字符串
  27. */
  28. public static String getCurrentWeather(String arguments) {
  29. try {
  30. // 模型提供的参数是JSON格式的,需要我们手动解析。
  31. ObjectMapper objectMapper = new ObjectMapper();
  32. JsonNode argsNode = objectMapper.readTree(arguments);
  33. String location = argsNode.get("location").asText();
  34. // 用随机结果来模拟真实的API调用或业务逻辑。
  35. List<String> weatherConditions = Arrays.asList("晴天", "多云", "雨天");
  36. String randomWeather = weatherConditions.get(new Random().nextInt(weatherConditions.size()));
  37. return location + "今天是" + randomWeather + "。";
  38. } catch (Exception e) {
  39. // 异常处理,确保程序健壮性。
  40. return "无法解析地点参数。";
  41. }
  42. }
  43. public static void main(String[] args) {
  44. try {
  45. // 第二步:向模型描述(注册)我们的工具。
  46. String weatherParamsSchema =
  47. "{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"城市或县区,比如北京市、杭州市、余杭区等。\"}},\"required\":[\"location\"]}";
  48. FunctionDefinition weatherFunction = FunctionDefinition.builder()
  49. .name("get_current_weather") // 工具的唯一标识名,必须与本地实现对应。
  50. .description("当你想查询指定城市的天气时非常有用。") // 清晰的描述能帮助模型更好地决定何时使用该工具。
  51. .parameters(JsonUtils.parseString(weatherParamsSchema).getAsJsonObject())
  52. .build();
  53. Generation gen = new Generation();
  54. String userInput = "北京天气咋样";
  55. List<Message> messages = new ArrayList<>();
  56. messages.add(Message.builder().role(Role.USER.getValue()).content(userInput).build());
  57. // 第四步:首次调用模型。将用户的请求和我们定义好的工具列表一同发送给模型。
  58. GenerationParam param = GenerationParam.builder()
  59. .model("qwen-plus") // 指定需要调用的模型。
  60. .apiKey("sk-5bf09a590daf408cb1126c2c6684e982") // 从环境变量中获取API Key。
  61. .messages(messages) // 传入当前的对话历史。
  62. .tools(Arrays.asList(ToolFunction.builder().function(weatherFunction).build())) // 传入可用的工具列表。
  63. .resultFormat(GenerationParam.ResultFormat.MESSAGE)
  64. .build();
  65. GenerationResult result = gen.call(param);
  66. Message assistantOutput = result.getOutput().getChoices().get(0).getMessage();
  67. messages.add(assistantOutput); // 将模型的首次回复也加入到对话历史中。
  68. // 第五步:检查模型的回复,判断它是否请求调用工具。
  69. if (assistantOutput.getToolCalls() == null || assistantOutput.getToolCalls().isEmpty()) {
  70. // 情况A:模型没有调用工具,而是直接给出了回答。
  71. System.out.println("无需调用天气查询工具,直接回复:" + assistantOutput.getContent());
  72. } else {
  73. // 情况B:模型决定调用工具。
  74. // 使用 while 循环可以处理模型连续调用多次工具的场景。
  75. while (assistantOutput.getToolCalls() != null && !assistantOutput.getToolCalls().isEmpty()) {
  76. ToolCallBase toolCall = assistantOutput.getToolCalls().get(0);
  77. // 从模型的回复中解析出工具调用的具体信息(要调用的函数名、参数)。
  78. ToolCallFunction functionCall = (ToolCallFunction) toolCall;
  79. String funcName = functionCall.getFunction().getName();
  80. String arguments = functionCall.getFunction().getArguments();
  81. System.out.println("正在调用工具 [" + funcName + "],参数:" + arguments);
  82. // 根据工具名,在本地执行对应的Java方法。
  83. String toolResult = getCurrentWeather(arguments);
  84. // 构造一个 role 为 "tool" 的消息,其中包含工具的执行结果。
  85. Message toolMessage = Message.builder()
  86. .role("tool")
  87. .toolCallId(toolCall.getId())
  88. .content(toolResult)
  89. .build();
  90. System.out.println("工具返回:" + toolMessage.getContent());
  91. messages.add(toolMessage); // 将工具的返回结果也加入到对话历史中。
  92. // 第六步:再次调用模型。
  93. param.setMessages(messages);
  94. result = gen.call(param);
  95. assistantOutput = result.getOutput().getChoices().get(0).getMessage();
  96. messages.add(assistantOutput);
  97. }
  98. // 第七步:打印模型经过总结后,生成的最终回复。
  99. System.out.println("助手最终回复:" + assistantOutput.getContent());
  100. }
  101. } catch (NoApiKeyException | InputRequiredException e) {
  102. System.err.println("错误: " + e.getMessage());
  103. } catch (Exception e) {
  104. e.printStackTrace();
  105. }
  106. }
  107. }