1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
| import asyncio import json import logging from typing import Dict, List, Any, Optional from mcp import ClientSession from mcp.client.sse import sse_client from openai import OpenAI
logging.getLogger("httpx").setLevel(logging.WARNING)
class MCPGatewayClient: def __init__(self, config: Dict): self.server_configs = config.get("mcpServers", {}) self.sessions: Dict[str, ClientSession] = {} self.contexts = [] self.tool_to_server = {} self.resource_to_server = {} self.prompt_to_server = {} self.prompts_metadata: List[Dict] = []
async def start(self): print(f"\n{'='*70}") print(f"🚀 MCP 智能网关启动中... (正在连接 {len(self.server_configs)} 个节点)") print(f"{'='*70}") for name, cfg in self.server_configs.items(): try: ctx = sse_client(cfg["url"]) read, write = await ctx.__aenter__() self.contexts.append(ctx) session = ClientSession(read, write) await session.__aenter__() await session.initialize() self.sessions[name] = session print(f"\n📂 服务器: 【{name}】") tools_res = await session.list_tools() for t in tools_res.tools: self.tool_to_server[str(t.name)] = name print(f" ├─ 🛠️ Tools: {', '.join([str(t.name) for t in tools_res.tools]) or 'None'}") res_res = await session.list_resources() for r in res_res.resources: self.resource_to_server[str(r.uri)] = name print(f" ├─ 📖 Resources: {', '.join([str(r.uri) for r in res_res.resources]) or 'None'}") prompts_res = await session.list_prompts() for p in prompts_res.prompts: p_name = str(p.name) self.prompt_to_server[p_name] = name self.prompts_metadata.append({ "name": p_name, "description": getattr(p, 'description', '专业领域指令'), "server": name }) print(f" └─ 💡 Prompts: {', '.join([str(p.name) for p in prompts_res.prompts]) or 'None'}") except Exception as e: print(f"❌ 服务器 [{name}] 失败: {str(e)}") print(f"\n{'='*70}\n✅ 网关就绪。\n") async def get_combined_tools(self) -> List[Dict]: all_tools = [] for name, session in self.sessions.items(): mcp_tools = await session.list_tools() for t in mcp_tools.tools: all_tools.append({ "type": "function", "function": { "name": str(t.name), "description": t.description or "", "parameters": t.inputSchema } }) if self.resource_to_server: all_tools.append({ "type": "function", "function": { "name": "read_mcp_resource", "description": f"读取内部资料。路径: {list(self.resource_to_server.keys())}", "parameters": { "type": "object", "properties": {"uri": {"type": "string"}}, "required": ["uri"] } } }) return all_tools def robust_parse_args(self, args: Any) -> Dict: """解决模型可能产生的双重 JSON 编码问题""" if isinstance(args, dict): return args if isinstance(args, str): try: parsed = json.loads(args) if isinstance(parsed, str): return json.loads(parsed) return parsed except Exception: return {"raw_input": args} return {}
async def call_action(self, name: str, args: Any): parsed_args = self.robust_parse_args(args) try: if name == "read_mcp_resource": uri = parsed_args.get("uri", "").strip() srv = self.resource_to_server.get(uri) if not srv: return f"Error: 找不到资源 {uri}" res = await self.sessions[srv].read_resource(uri) return res.contents[0].text else: srv = self.tool_to_server.get(name) if not srv: return f"Error: 找不到工具 {name}" res = await self.sessions[srv].call_tool(name, parsed_args) return res.content[0].text if hasattr(res, 'content') else str(res) except Exception as e: return f"执行失败: {str(e)}"
async def get_prompt_template(self, prompt_name: str, arguments: Dict[str, str] = None) -> Optional[str]: srv_name = self.prompt_to_server.get(prompt_name) if not srv_name: return None try: session = self.sessions[srv_name] result = await session.get_prompt(prompt_name, arguments=arguments or {}) if result.messages: return result.messages[0].content.text return None except Exception as e: print(f"⚠️ 获取 Prompt 【{prompt_name}】 失败: {str(e)}") return None
async def stop(self): print("\n正在释放 MCP 资源...") for s in self.sessions.values(): try: await s.__aexit__(None, None, None) except: pass for c in reversed(self.contexts): try: await c.__aexit__(None, None, None) except: pass print("✅ 连接已安全关闭。")
async def run_task(user_input: str): config = {"mcpServers": { "math_server": {"url": "http://127.0.0.1:8029/sse"}, "ticket_server": {"url": "http://127.0.0.1:8030/sse"}, "data_server": {"url": "http://127.0.0.1:8031/sse"}, }}
client = MCPGatewayClient(config) await client.start() try: from config import OPENAI_BASE_URL, OPENAI_API_KEY, LLM_MODEL oai = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
system_instruction = "你是一个全能助手。" if client.prompts_metadata: router_prompt = f"我有以下专家角色库:{json.dumps(client.prompts_metadata, ensure_ascii=False)}。用户的问题涉及文档审阅、逻辑判断或数据分析吗?如果是,请返回对应的模板名称(name),否则返回 NONE。用户问题:{user_input}" resp = oai.chat.completions.create(model=LLM_MODEL, messages=[{"role": "user", "content": router_prompt}]) decision = resp.choices[0].message.content.strip() if "NONE" not in decision: prompt_name = decision.split()[0] if decision.split() else decision p_content = await client.get_prompt_template(prompt_name, arguments={"topic": user_input}) if p_content: system_instruction = p_content print(f"🎯 路由命中: 加载【{prompt_name}】模板")
messages = [{"role": "system", "content": system_instruction}, {"role": "user", "content": user_input}] print(f"🚀 任务开始 (system_instruction: {system_instruction})")
for i in range(5): tools = await client.get_combined_tools() response = oai.chat.completions.create(model=LLM_MODEL, messages=messages, tools=tools) msg = response.choices[0].message messages.append(msg)
if not msg.tool_calls: print(f"\n🏁 最终回答:\n{msg.content}") break
print(f"\n🔄 第 {i+1} 轮:模型请求 {len(msg.tool_calls)} 个并发操作") call_tasks = [] tool_call_info = [] for tc in msg.tool_calls: call_tasks.append(client.call_action(tc.function.name, tc.function.arguments)) tool_call_info.append(tc)
results = await asyncio.gather(*call_tasks)
for tc, res_text in zip(tool_call_info, results): print(f"{'-'*40}\n▶️ {tc.function.name}\n参数: {tc.function.arguments[:100]}...\n◀️ 结果: {str(res_text)[:120].replace(chr(10),' ')}\n{'-'*40}") messages.append({"role": "tool", "tool_call_id": tc.id, "content": res_text})
except Exception as e: print(f"\n❌ 运行中发生错误: {str(e)}") finally: await client.stop()
if __name__ == "__main__": query = "帮我审阅一下 IT 手册里的安全准则,并帮我查一下成都的天气,如果成都的天气晴朗,就帮小王订一张从今天晚上12点北京飞往成都的机票,如果北京的天气不晴朗,就算了。" asyncio.run(run_task(query))
|