#!/usr/bin/env python # -*- coding: utf-8 -*- """ WebSocket流式输出测试脚本 """ import websocket import json import sys import time import uuid from datetime import datetime import threading import ssl import argparse # 测试配置 WS_URL = "ws://127.0.0.1:8000/ws/chat/stream/" # WebSocket URL TOKEN = "7831a86588bc08d025e4c9bd668de3b7940f7634" # 替换为你的实际认证令牌 # 测试数据 test_data = { "question": "什么是流式输出?", "conversation_id": str(uuid.uuid4()), # 随机生成一个会话ID "dataset_id_list": ["8390ca43-6e63-4df9-b0b9-6cb20e1b38af"] # 替换为实际的知识库ID } # 全局变量 response_count = 0 start_time = None full_content = "" is_connected = False def on_message(ws, message): """ 处理接收到的WebSocket消息 """ global response_count, full_content try: # 解析JSON响应 response_count += 1 data = json.loads(message) # 获取当前时间和距开始的时间 current_time = time.time() elapsed = current_time - start_time timestamp = datetime.now().strftime('%H:%M:%S.%f')[:-3] # 打印基本信息 print(f"\n[{timestamp}] 响应 #{response_count} (总计: {elapsed:.3f}s)") # 检查消息类型 msg_type = data.get('message', '') if msg_type == '开始流式传输': print("=== 开始接收流式内容 ===") elif msg_type == 'partial': # 显示部分内容 if 'data' in data and 'content' in data['data']: content = data['data']['content'] full_content += content # 如果内容太长,只显示前50个字符 display_content = content[:50] + "..." if len(content) > 50 else content print(f"部分内容: {display_content}") elif msg_type == '完成': # 显示完整信息 if 'data' in data: if 'title' in data['data']: print(f"标题: {data['data']['title']}") if 'content' in data['data']: print(f"完整内容长度: {len(data['data']['content'])} 字符") print("=== 流式传输完成 ===") # 如果是错误消息 if data.get('code') == 500: print(f"错误: {data.get('message')}") ws.close() except json.JSONDecodeError as e: print(f"JSON解析错误: {str(e)}") except Exception as e: print(f"处理消息时出错: {str(e)}") def on_error(ws, error): """处理WebSocket错误""" print(f"发生错误: {str(error)}") def on_close(ws, close_status_code, close_msg): """处理WebSocket连接关闭""" global is_connected is_connected = False total_time = time.time() - start_time print(f"\n===== 连接关闭 =====") print(f"状态码: {close_status_code}, 消息: {close_msg}") print(f"总响应时间: {total_time:.3f}秒, 共接收 {response_count} 个数据包") print(f"接收到的完整内容长度: {len(full_content)} 字符") def on_open(ws): """处理WebSocket连接成功""" global start_time, is_connected is_connected = True print("WebSocket连接已建立") print(f"发送测试数据: {json.dumps(test_data, ensure_ascii=False)}") # 记录开始时间 start_time = time.time() # 发送测试数据 ws.send(json.dumps(test_data)) print("数据已发送,等待响应...") def main(): """主函数""" parser = argparse.ArgumentParser(description='WebSocket流式输出测试工具') parser.add_argument('--url', type=str, default=WS_URL, help='WebSocket URL') parser.add_argument('--token', type=str, default=TOKEN, help='认证令牌') parser.add_argument('--question', type=str, default=test_data['question'], help='要发送的问题') args = parser.parse_args() # 更新测试数据 url = f"{args.url}?token={args.token}" test_data['question'] = args.question print(f"连接到: {url}") # 设置更详细的日志级别(可选) # websocket.enableTrace(True) # 创建WebSocket连接 ws = websocket.WebSocketApp( url, on_open=on_open, on_message=on_message, on_error=on_error, on_close=on_close ) # 设置运行超时(可选) # 如果需要SSL连接 # ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) # 启动WebSocket连接 ws.run_forever() # 等待一小段时间以确保所有消息都被处理 time.sleep(1) if __name__ == "__main__": main()