daren_project/test_websocket_stream.py

154 lines
4.6 KiB
Python
Raw Normal View History

2025-04-29 10:22:57 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket流式输出测试脚本
"""
import websocket
import json
import sys
import time
import uuid
from datetime import datetime
import threading
import ssl
import argparse
# 测试配置
WS_URL = "ws://127.0.0.1:8000/ws/chat/stream/" # WebSocket URL
TOKEN = "7831a86588bc08d025e4c9bd668de3b7940f7634" # 替换为你的实际认证令牌
# 测试数据
test_data = {
"question": "什么是流式输出?",
"conversation_id": str(uuid.uuid4()), # 随机生成一个会话ID
"dataset_id_list": ["8390ca43-6e63-4df9-b0b9-6cb20e1b38af"] # 替换为实际的知识库ID
}
# 全局变量
response_count = 0
start_time = None
full_content = ""
is_connected = False
def on_message(ws, message):
"""
处理接收到的WebSocket消息
"""
global response_count, full_content
try:
# 解析JSON响应
response_count += 1
data = json.loads(message)
# 获取当前时间和距开始的时间
current_time = time.time()
elapsed = current_time - start_time
timestamp = datetime.now().strftime('%H:%M:%S.%f')[:-3]
# 打印基本信息
print(f"\n[{timestamp}] 响应 #{response_count} (总计: {elapsed:.3f}s)")
# 检查消息类型
msg_type = data.get('message', '')
if msg_type == '开始流式传输':
print("=== 开始接收流式内容 ===")
elif msg_type == 'partial':
# 显示部分内容
if 'data' in data and 'content' in data['data']:
content = data['data']['content']
full_content += content
# 如果内容太长只显示前50个字符
display_content = content[:50] + "..." if len(content) > 50 else content
print(f"部分内容: {display_content}")
elif msg_type == '完成':
# 显示完整信息
if 'data' in data:
if 'title' in data['data']:
print(f"标题: {data['data']['title']}")
if 'content' in data['data']:
print(f"完整内容长度: {len(data['data']['content'])} 字符")
print("=== 流式传输完成 ===")
# 如果是错误消息
if data.get('code') == 500:
print(f"错误: {data.get('message')}")
ws.close()
except json.JSONDecodeError as e:
print(f"JSON解析错误: {str(e)}")
except Exception as e:
print(f"处理消息时出错: {str(e)}")
def on_error(ws, error):
"""处理WebSocket错误"""
print(f"发生错误: {str(error)}")
def on_close(ws, close_status_code, close_msg):
"""处理WebSocket连接关闭"""
global is_connected
is_connected = False
total_time = time.time() - start_time
print(f"\n===== 连接关闭 =====")
print(f"状态码: {close_status_code}, 消息: {close_msg}")
print(f"总响应时间: {total_time:.3f}秒, 共接收 {response_count} 个数据包")
print(f"接收到的完整内容长度: {len(full_content)} 字符")
def on_open(ws):
"""处理WebSocket连接成功"""
global start_time, is_connected
is_connected = True
print("WebSocket连接已建立")
print(f"发送测试数据: {json.dumps(test_data, ensure_ascii=False)}")
# 记录开始时间
start_time = time.time()
# 发送测试数据
ws.send(json.dumps(test_data))
print("数据已发送,等待响应...")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='WebSocket流式输出测试工具')
parser.add_argument('--url', type=str, default=WS_URL, help='WebSocket URL')
parser.add_argument('--token', type=str, default=TOKEN, help='认证令牌')
parser.add_argument('--question', type=str, default=test_data['question'], help='要发送的问题')
args = parser.parse_args()
# 更新测试数据
url = f"{args.url}?token={args.token}"
test_data['question'] = args.question
print(f"连接到: {url}")
# 设置更详细的日志级别(可选)
# websocket.enableTrace(True)
# 创建WebSocket连接
ws = websocket.WebSocketApp(
url,
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close
)
# 设置运行超时(可选)
# 如果需要SSL连接
# ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
# 启动WebSocket连接
ws.run_forever()
# 等待一小段时间以确保所有消息都被处理
time.sleep(1)
if __name__ == "__main__":
main()