162 lines
7.0 KiB
Python
162 lines
7.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""流式响应测试脚本"""
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
import argparse
|
|
import logging.config
|
|
from test_config import API_BASE_URL, AUTH_TOKEN, TEST_CASES, LOG_CONFIG, HEADERS
|
|
|
|
# 配置日志
|
|
logging.config.dictConfig(LOG_CONFIG)
|
|
logger = logging.getLogger('stream_test')
|
|
|
|
class StreamResponseTester:
|
|
def __init__(self, base_url, headers):
|
|
self.base_url = base_url
|
|
self.headers = headers
|
|
|
|
def run_test(self, test_case):
|
|
"""运行单个测试用例"""
|
|
logger.info(f"\n开始测试: {test_case['name']}")
|
|
|
|
# 构建请求数据
|
|
test_data = {
|
|
"question": test_case["question"],
|
|
"conversation_id": test_case["conversation_id"],
|
|
"dataset_id_list": [test_case["dataset_id"]],
|
|
"stream": True
|
|
}
|
|
|
|
start_time = time.time()
|
|
response_count = 0
|
|
content_length = 0
|
|
full_content = ""
|
|
|
|
try:
|
|
# 发送请求
|
|
url = f"{self.base_url}/api/chat-history/"
|
|
logger.info(f"请求URL: {url}")
|
|
logger.info(f"请求数据: {json.dumps(test_data, ensure_ascii=False)}")
|
|
|
|
response = requests.post(
|
|
url=url,
|
|
json=test_data,
|
|
headers=self.headers,
|
|
stream=True
|
|
)
|
|
|
|
# 检查响应状态
|
|
if response.status_code != 200:
|
|
if test_case.get('expected_error', False):
|
|
logger.info("测试通过:预期的错误响应")
|
|
return True
|
|
logger.error(f"请求失败: {response.status_code}, {response.text}")
|
|
return False
|
|
|
|
# 处理流式响应
|
|
buffer = ""
|
|
last_time = start_time
|
|
|
|
for chunk in response.iter_content(chunk_size=1):
|
|
if chunk:
|
|
current_time = time.time()
|
|
chunk_time = current_time - last_time
|
|
|
|
try:
|
|
chunk_str = chunk.decode('utf-8')
|
|
buffer += chunk_str
|
|
|
|
if '\n\n' in buffer:
|
|
lines = buffer.split('\n\n')
|
|
for line in lines[:-1]:
|
|
if line.startswith('data: '):
|
|
response_count += 1
|
|
|
|
try:
|
|
data = json.loads(line[6:])
|
|
if 'data' in data and 'content' in data['data']:
|
|
content = data['data']['content']
|
|
prev_length = content_length
|
|
content_length += len(content)
|
|
full_content += content
|
|
|
|
# 记录响应信息
|
|
logger.debug(
|
|
f"响应 #{response_count}: "
|
|
f"+{content_length - prev_length} 字符, "
|
|
f"间隔: {chunk_time:.3f}s"
|
|
)
|
|
|
|
# 检查是否结束
|
|
if data['data'].get('is_end', False):
|
|
total_time = time.time() - start_time
|
|
logger.info(f"\n测试结果:")
|
|
logger.info(f"总响应时间: {total_time:.3f}秒")
|
|
logger.info(f"数据包数量: {response_count}")
|
|
logger.info(f"内容长度: {content_length} 字符")
|
|
logger.info(f"完整内容: {full_content}")
|
|
|
|
# 检查响应时间是否符合预期
|
|
if 'expected_response_time' in test_case:
|
|
if total_time <= test_case['expected_response_time']:
|
|
logger.info("响应时间符合预期")
|
|
else:
|
|
logger.warning(
|
|
f"响应时间超出预期: "
|
|
f"{total_time:.3f}s > {test_case['expected_response_time']}s"
|
|
)
|
|
|
|
return True
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"JSON解析错误: {e}")
|
|
if not test_case.get('expected_error', False):
|
|
return False
|
|
|
|
buffer = lines[-1]
|
|
last_time = current_time
|
|
|
|
except UnicodeDecodeError:
|
|
logger.debug("解码错误,跳过")
|
|
continue
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"测试执行错误: {str(e)}")
|
|
return False
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='流式响应测试工具')
|
|
parser.add_argument('--test-case', type=int, help='指定要运行的测试用例索引')
|
|
args = parser.parse_args()
|
|
|
|
# 创建测试器实例
|
|
tester = StreamResponseTester(API_BASE_URL, HEADERS)
|
|
|
|
if args.test_case is not None:
|
|
# 运行指定的测试用例
|
|
if 0 <= args.test_case < len(TEST_CASES):
|
|
test_case = TEST_CASES[args.test_case]
|
|
success = tester.run_test(test_case)
|
|
logger.info(f"\n测试用例 {args.test_case} {'通过' if success else '失败'}")
|
|
else:
|
|
logger.error(f"无效的测试用例索引: {args.test_case}")
|
|
else:
|
|
# 运行所有测试用例
|
|
total_cases = len(TEST_CASES)
|
|
passed_cases = 0
|
|
|
|
for i, test_case in enumerate(TEST_CASES):
|
|
logger.info(f"\n运行测试用例 {i+1}/{total_cases}")
|
|
if tester.run_test(test_case):
|
|
passed_cases += 1
|
|
|
|
logger.info(f"\n测试完成: {passed_cases}/{total_cases} 个测试用例通过")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|