daren_project/test_websocket_stream.py
2025-04-29 10:22:57 +08:00

154 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket流式输出测试脚本
"""
import websocket
import json
import sys
import time
import uuid
from datetime import datetime
import threading
import ssl
import argparse
# 测试配置
WS_URL = "ws://127.0.0.1:8000/ws/chat/stream/" # WebSocket URL
TOKEN = "7831a86588bc08d025e4c9bd668de3b7940f7634" # 替换为你的实际认证令牌
# 测试数据
test_data = {
"question": "什么是流式输出?",
"conversation_id": str(uuid.uuid4()), # 随机生成一个会话ID
"dataset_id_list": ["8390ca43-6e63-4df9-b0b9-6cb20e1b38af"] # 替换为实际的知识库ID
}
# 全局变量
response_count = 0
start_time = None
full_content = ""
is_connected = False
def on_message(ws, message):
"""
处理接收到的WebSocket消息
"""
global response_count, full_content
try:
# 解析JSON响应
response_count += 1
data = json.loads(message)
# 获取当前时间和距开始的时间
current_time = time.time()
elapsed = current_time - start_time
timestamp = datetime.now().strftime('%H:%M:%S.%f')[:-3]
# 打印基本信息
print(f"\n[{timestamp}] 响应 #{response_count} (总计: {elapsed:.3f}s)")
# 检查消息类型
msg_type = data.get('message', '')
if msg_type == '开始流式传输':
print("=== 开始接收流式内容 ===")
elif msg_type == 'partial':
# 显示部分内容
if 'data' in data and 'content' in data['data']:
content = data['data']['content']
full_content += content
# 如果内容太长只显示前50个字符
display_content = content[:50] + "..." if len(content) > 50 else content
print(f"部分内容: {display_content}")
elif msg_type == '完成':
# 显示完整信息
if 'data' in data:
if 'title' in data['data']:
print(f"标题: {data['data']['title']}")
if 'content' in data['data']:
print(f"完整内容长度: {len(data['data']['content'])} 字符")
print("=== 流式传输完成 ===")
# 如果是错误消息
if data.get('code') == 500:
print(f"错误: {data.get('message')}")
ws.close()
except json.JSONDecodeError as e:
print(f"JSON解析错误: {str(e)}")
except Exception as e:
print(f"处理消息时出错: {str(e)}")
def on_error(ws, error):
"""处理WebSocket错误"""
print(f"发生错误: {str(error)}")
def on_close(ws, close_status_code, close_msg):
"""处理WebSocket连接关闭"""
global is_connected
is_connected = False
total_time = time.time() - start_time
print(f"\n===== 连接关闭 =====")
print(f"状态码: {close_status_code}, 消息: {close_msg}")
print(f"总响应时间: {total_time:.3f}秒, 共接收 {response_count} 个数据包")
print(f"接收到的完整内容长度: {len(full_content)} 字符")
def on_open(ws):
"""处理WebSocket连接成功"""
global start_time, is_connected
is_connected = True
print("WebSocket连接已建立")
print(f"发送测试数据: {json.dumps(test_data, ensure_ascii=False)}")
# 记录开始时间
start_time = time.time()
# 发送测试数据
ws.send(json.dumps(test_data))
print("数据已发送,等待响应...")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='WebSocket流式输出测试工具')
parser.add_argument('--url', type=str, default=WS_URL, help='WebSocket URL')
parser.add_argument('--token', type=str, default=TOKEN, help='认证令牌')
parser.add_argument('--question', type=str, default=test_data['question'], help='要发送的问题')
args = parser.parse_args()
# 更新测试数据
url = f"{args.url}?token={args.token}"
test_data['question'] = args.question
print(f"连接到: {url}")
# 设置更详细的日志级别(可选)
# websocket.enableTrace(True)
# 创建WebSocket连接
ws = websocket.WebSocketApp(
url,
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close
)
# 设置运行超时(可选)
# 如果需要SSL连接
# ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
# 启动WebSocket连接
ws.run_forever()
# 等待一小段时间以确保所有消息都被处理
time.sleep(1)
if __name__ == "__main__":
main()