154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
WebSocket流式输出测试脚本
|
||
"""
|
||
import websocket
|
||
import json
|
||
import sys
|
||
import time
|
||
import uuid
|
||
from datetime import datetime
|
||
import threading
|
||
import ssl
|
||
import argparse
|
||
|
||
# 测试配置
|
||
WS_URL = "ws://127.0.0.1:8000/ws/chat/stream/" # WebSocket URL
|
||
TOKEN = "7831a86588bc08d025e4c9bd668de3b7940f7634" # 替换为你的实际认证令牌
|
||
|
||
# 测试数据
|
||
test_data = {
|
||
"question": "什么是流式输出?",
|
||
"conversation_id": str(uuid.uuid4()), # 随机生成一个会话ID
|
||
"dataset_id_list": ["8390ca43-6e63-4df9-b0b9-6cb20e1b38af"] # 替换为实际的知识库ID
|
||
}
|
||
|
||
# 全局变量
|
||
response_count = 0
|
||
start_time = None
|
||
full_content = ""
|
||
is_connected = False
|
||
|
||
|
||
def on_message(ws, message):
|
||
"""
|
||
处理接收到的WebSocket消息
|
||
"""
|
||
global response_count, full_content
|
||
|
||
try:
|
||
# 解析JSON响应
|
||
response_count += 1
|
||
data = json.loads(message)
|
||
|
||
# 获取当前时间和距开始的时间
|
||
current_time = time.time()
|
||
elapsed = current_time - start_time
|
||
timestamp = datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
||
|
||
# 打印基本信息
|
||
print(f"\n[{timestamp}] 响应 #{response_count} (总计: {elapsed:.3f}s)")
|
||
|
||
# 检查消息类型
|
||
msg_type = data.get('message', '')
|
||
if msg_type == '开始流式传输':
|
||
print("=== 开始接收流式内容 ===")
|
||
elif msg_type == 'partial':
|
||
# 显示部分内容
|
||
if 'data' in data and 'content' in data['data']:
|
||
content = data['data']['content']
|
||
full_content += content
|
||
# 如果内容太长,只显示前50个字符
|
||
display_content = content[:50] + "..." if len(content) > 50 else content
|
||
print(f"部分内容: {display_content}")
|
||
elif msg_type == '完成':
|
||
# 显示完整信息
|
||
if 'data' in data:
|
||
if 'title' in data['data']:
|
||
print(f"标题: {data['data']['title']}")
|
||
if 'content' in data['data']:
|
||
print(f"完整内容长度: {len(data['data']['content'])} 字符")
|
||
print("=== 流式传输完成 ===")
|
||
|
||
# 如果是错误消息
|
||
if data.get('code') == 500:
|
||
print(f"错误: {data.get('message')}")
|
||
ws.close()
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {str(e)}")
|
||
except Exception as e:
|
||
print(f"处理消息时出错: {str(e)}")
|
||
|
||
|
||
def on_error(ws, error):
|
||
"""处理WebSocket错误"""
|
||
print(f"发生错误: {str(error)}")
|
||
|
||
|
||
def on_close(ws, close_status_code, close_msg):
|
||
"""处理WebSocket连接关闭"""
|
||
global is_connected
|
||
is_connected = False
|
||
total_time = time.time() - start_time
|
||
print(f"\n===== 连接关闭 =====")
|
||
print(f"状态码: {close_status_code}, 消息: {close_msg}")
|
||
print(f"总响应时间: {total_time:.3f}秒, 共接收 {response_count} 个数据包")
|
||
print(f"接收到的完整内容长度: {len(full_content)} 字符")
|
||
|
||
|
||
def on_open(ws):
|
||
"""处理WebSocket连接成功"""
|
||
global start_time, is_connected
|
||
is_connected = True
|
||
print("WebSocket连接已建立")
|
||
print(f"发送测试数据: {json.dumps(test_data, ensure_ascii=False)}")
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 发送测试数据
|
||
ws.send(json.dumps(test_data))
|
||
print("数据已发送,等待响应...")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(description='WebSocket流式输出测试工具')
|
||
parser.add_argument('--url', type=str, default=WS_URL, help='WebSocket URL')
|
||
parser.add_argument('--token', type=str, default=TOKEN, help='认证令牌')
|
||
parser.add_argument('--question', type=str, default=test_data['question'], help='要发送的问题')
|
||
args = parser.parse_args()
|
||
|
||
# 更新测试数据
|
||
url = f"{args.url}?token={args.token}"
|
||
test_data['question'] = args.question
|
||
|
||
print(f"连接到: {url}")
|
||
|
||
# 设置更详细的日志级别(可选)
|
||
# websocket.enableTrace(True)
|
||
|
||
# 创建WebSocket连接
|
||
ws = websocket.WebSocketApp(
|
||
url,
|
||
on_open=on_open,
|
||
on_message=on_message,
|
||
on_error=on_error,
|
||
on_close=on_close
|
||
)
|
||
|
||
# 设置运行超时(可选)
|
||
# 如果需要SSL连接
|
||
# ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||
|
||
# 启动WebSocket连接
|
||
ws.run_forever()
|
||
|
||
# 等待一小段时间以确保所有消息都被处理
|
||
time.sleep(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |