From 61eaec4d64257557fea8d0002d07a619ec164528 Mon Sep 17 00:00:00 2001 From: dspwasc Date: Tue, 29 Apr 2025 10:21:13 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- role_based_system/asgi.py | 9 +- role_based_system/settings.py | 17 +- user_management/consumers.py | 835 +++++++++++++++++- .../management/commands/create_test_users.py | 189 ++-- user_management/middleware.py | 35 +- .../migrations/0005_chathistory_title.py | 18 + ...006_knowledgebasedocument_uploader_name.py | 18 + user_management/models.py | 6 +- user_management/routing.py | 2 + user_management/views.py | 686 ++++++++++---- 10 files changed, 1524 insertions(+), 291 deletions(-) create mode 100644 user_management/migrations/0005_chathistory_title.py create mode 100644 user_management/migrations/0006_knowledgebasedocument_uploader_name.py diff --git a/role_based_system/asgi.py b/role_based_system/asgi.py index 714024b9..d79a8e23 100644 --- a/role_based_system/asgi.py +++ b/role_based_system/asgi.py @@ -18,11 +18,16 @@ django.setup() # 添加这行来初始化 Django from django.core.asgi import get_asgi_application from channels.routing import ProtocolTypeRouter, URLRouter from channels.auth import AuthMiddlewareStack +from channels.security.websocket import AllowedHostsOriginValidator from user_management.routing import websocket_urlpatterns +from user_management.middleware import TokenAuthMiddleware +# 使用TokenAuthMiddleware代替AuthMiddlewareStack application = ProtocolTypeRouter({ "http": get_asgi_application(), - "websocket": AuthMiddlewareStack( - URLRouter(websocket_urlpatterns) + "websocket": AllowedHostsOriginValidator( + TokenAuthMiddleware( + URLRouter(websocket_urlpatterns) + ) ), }) \ No newline at end of file diff --git a/role_based_system/settings.py b/role_based_system/settings.py index 98ff90f4..f18d1ee2 100644 --- a/role_based_system/settings.py +++ b/role_based_system/settings.py @@ -41,6 +41,9 @@ ALLOWED_HOSTS = ['*'] # 仅在开发环境使用 # 服务器配置 DEBUG = False +# 是否允许注册新用户 +ALLOW_REGISTRATION = True + # ALLOWED_HOSTS = ['frptx.chiyong.fun', 'localhost', '127.0.0.1'] # Application definition @@ -70,6 +73,7 @@ MIDDLEWARE = [ 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'user_management.middleware.UserActivityMiddleware', + 'user_management.middleware.CSRFExemptMiddleware', # 添加CSRF豁免中间件 ] ROOT_URLCONF = 'role_based_system.urls' @@ -168,7 +172,12 @@ ASGI_APPLICATION = "role_based_system.asgi.application" # Channel Layers 配置 CHANNEL_LAYERS = { "default": { - "BACKEND": "channels.layers.InMemoryChannelLayer", + "BACKEND": "channels_redis.core.RedisChannelLayer", + "CONFIG": { + "hosts": [("127.0.0.1", 6379)], + "capacity": 1500, # 默认100 + "expiry": 60, # 默认60秒 + }, }, } @@ -289,3 +298,9 @@ REST_FRAMEWORK = { 'rest_framework.parsers.MultiPartParser' ], } + + +# DeepSeek API配置 +DEEPSEEK_API_KEY = "sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf" # 请替换为您的实际有效的DeepSeek API密钥 + +SILICON_CLOUD_API_KEY = 'sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf' diff --git a/user_management/consumers.py b/user_management/consumers.py index f37ace2d..e0af37e2 100644 --- a/user_management/consumers.py +++ b/user_management/consumers.py @@ -4,6 +4,13 @@ from channels.db import database_sync_to_async from channels.exceptions import StopConsumer import logging from rest_framework.authtoken.models import Token +from urllib.parse import parse_qs +from .models import ChatHistory, KnowledgeBase +import aiohttp +import asyncio +from django.conf import settings +import uuid +import traceback logger = logging.getLogger(__name__) @@ -11,19 +18,20 @@ class NotificationConsumer(AsyncWebsocketConsumer): async def connect(self): """建立WebSocket连接""" try: - # 获取token - headers = dict(self.scope['headers']) - auth_header = headers.get(b'authorization', b'').decode() + # 从URL参数中获取token + query_string = self.scope.get('query_string', b'').decode() + query_params = parse_qs(query_string) + token_key = query_params.get('token', [''])[0] - if not auth_header.startswith('Token '): + if not token_key: + logger.warning("WebSocket连接尝试,但没有提供token") await self.close() return - token_key = auth_header.split(' ')[1] - # 验证token self.user = await self.get_user_from_token(token_key) if not self.user: + logger.warning(f"WebSocket连接尝试,但token无效: {token_key}") await self.close() return @@ -34,8 +42,10 @@ class NotificationConsumer(AsyncWebsocketConsumer): self.channel_name ) await self.accept() + logger.info(f"用户 {self.user.username} WebSocket连接成功") except Exception as e: + logger.error(f"WebSocket连接错误: {str(e)}") await self.close() @database_sync_to_async @@ -65,3 +75,816 @@ class NotificationConsumer(AsyncWebsocketConsumer): logger.info(f"已发送通知给用户 {self.user.username}") except Exception as e: logger.error(f"发送通知消息时发生错误: {str(e)}") + +class ChatConsumer(AsyncWebsocketConsumer): + async def connect(self): + """建立 WebSocket 连接""" + try: + # 从URL参数中获取token + query_string = self.scope.get('query_string', b'').decode() + query_params = parse_qs(query_string) + token_key = query_params.get('token', [''])[0] + + if not token_key: + logger.warning("WebSocket连接尝试,但没有提供token") + await self.close() + return + + # 验证token + self.user = await self.get_user_from_token(token_key) + if not self.user: + logger.warning(f"WebSocket连接尝试,但token无效: {token_key}") + await self.close() + return + + # 将用户信息存储在scope中 + self.scope["user"] = self.user + await self.accept() + logger.info(f"用户 {self.user.username} WebSocket连接成功") + + except Exception as e: + logger.error(f"WebSocket连接错误: {str(e)}") + await self.close() + + @database_sync_to_async + def get_user_from_token(self, token_key): + try: + token = Token.objects.select_related('user').get(key=token_key) + return token.user + except Token.DoesNotExist: + return None + + async def disconnect(self, close_code): + """关闭 WebSocket 连接""" + pass + + async def receive(self, text_data): + """接收消息并处理""" + try: + data = json.loads(text_data) + + # 验证必要字段 + if 'question' not in data or 'conversation_id' not in data: + await self.send_error("缺少必要字段") + return + + # 创建问题记录 + question_record = await self.create_question_record(data) + if not question_record: + return + + # 开始流式处理 + await self.stream_answer(question_record, data) + + except Exception as e: + logger.error(f"处理消息时出错: {str(e)}") + await self.send_error(f"处理消息时出错: {str(e)}") + + @database_sync_to_async + def _create_question_record_sync(self, data): + """同步创建问题记录""" + try: + # 获取会话历史记录 + conversation_id = data['conversation_id'] + existing_records = ChatHistory.objects.filter( + conversation_id=conversation_id + ).order_by('created_at') + + # 获取或创建元数据 + if existing_records.exists(): + first_record = existing_records.first() + metadata = first_record.metadata or {} + dataset_ids = metadata.get('dataset_id_list', []) + knowledge_bases = [] + + # 验证知识库权限 + for kb_id in dataset_ids: + try: + kb = KnowledgeBase.objects.get(id=kb_id) + if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'): + raise Exception(f'无权访问知识库: {kb.name}') + knowledge_bases.append(kb) + except KnowledgeBase.DoesNotExist: + raise Exception(f'知识库不存在: {kb_id}') + else: + # 新会话处理 + dataset_ids = data.get('dataset_id_list', []) + if not dataset_ids: + raise Exception('新会话需要提供知识库ID') + + knowledge_bases = [] + for kb_id in dataset_ids: + kb = KnowledgeBase.objects.get(id=kb_id) + if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'): + raise Exception(f'无权访问知识库: {kb.name}') + knowledge_bases.append(kb) + + metadata = { + 'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), + 'dataset_id_list': [str(kb.id) for kb in knowledge_bases], + 'dataset_external_id_list': [str(kb.external_id) for kb in knowledge_bases if kb.external_id], + 'dataset_names': [kb.name for kb in knowledge_bases] + } + + # 创建问题记录 + return ChatHistory.objects.create( + user=self.scope["user"], + knowledge_base=knowledge_bases[0], + conversation_id=conversation_id, + title=data.get('title', 'New chat'), + role='user', + content=data['question'], + metadata=metadata + ) + + except Exception as e: + logger.error(f"创建问题记录失败: {str(e)}") + return None, str(e) + + async def create_question_record(self, data): + """异步创建问题记录""" + try: + result = await self._create_question_record_sync(data) + if isinstance(result, tuple): + _, error_message = result + await self.send_error(error_message) + return None + return result + except Exception as e: + await self.send_error(str(e)) + return None + + def check_knowledge_base_permission(self, kb, user, permission_type): + """检查知识库权限""" + # 实现权限检查逻辑 + return True # 临时返回 True,需要根据实际情况实现 + + async def stream_answer(self, question_record, data): + """流式处理回答""" + try: + # 创建 AI 回答记录 + answer_record = await database_sync_to_async(ChatHistory.objects.create)( + user=self.scope["user"], + knowledge_base=question_record.knowledge_base, + conversation_id=str(question_record.conversation_id), + title=question_record.title, + parent_id=str(question_record.id), + role='assistant', + content="", + metadata=question_record.metadata + ) + + # 发送初始响应 + await self.send_json({ + 'code': 200, + 'message': '开始流式传输', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(question_record.conversation_id), + 'content': '', + 'is_end': False + } + }) + + # 调用外部 API 获取流式响应 + async with aiohttp.ClientSession() as session: + # 创建聊天会话 + chat_response = await session.post( + f"{settings.API_BASE_URL}/api/application/chat/open", + json={ + "id": "d5d11efa-ea9a-11ef-9933-0242ac120006", + "model_id": question_record.metadata.get('model_id'), + "dataset_id_list": question_record.metadata.get('dataset_external_id_list', []), + "multiple_rounds_dialogue": False, + "dataset_setting": { + "top_n": 10, + "similarity": "0.3", + "max_paragraph_char_number": 10000, + "search_mode": "blend", + "no_references_setting": { + "value": "{question}", + "status": "ai_questioning" + } + }, + "model_setting": { + "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" + }, + "problem_optimization": False + } + ) + + chat_data = await chat_response.json() + if chat_data.get('code') != 200: + raise Exception(f"创建聊天会话失败: {chat_data}") + + chat_id = chat_data['data'] + + # 建立流式连接 + async with session.post( + f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}", + json={"message": data['question'], "re_chat": False, "stream": True}, + headers={"Content-Type": "application/json"} + ) as response: + full_content = "" + buffer = "" + + async for chunk in response.content.iter_any(): + chunk_str = chunk.decode('utf-8') + buffer += chunk_str + + while '\n\n' in buffer: + parts = buffer.split('\n\n', 1) + line = parts[0] + buffer = parts[1] + + if line.startswith('data: '): + try: + json_str = line[6:] + chunk_data = json.loads(json_str) + + if 'content' in chunk_data: + content_part = chunk_data['content'] + full_content += content_part + + await self.send_json({ + 'code': 200, + 'message': 'partial', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(question_record.conversation_id), + 'content': content_part, + 'is_end': chunk_data.get('is_end', False) + } + }) + + if chunk_data.get('is_end', False): + # 保存完整内容 + answer_record.content = full_content.strip() + await database_sync_to_async(answer_record.save)() + + # 生成或获取标题 + title = await self.get_or_generate_title( + question_record.conversation_id, + data['question'], + full_content.strip() + ) + + # 发送最终响应 + await self.send_json({ + 'code': 200, + 'message': '完成', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(question_record.conversation_id), + 'title': title, + 'dataset_id_list': question_record.metadata.get('dataset_id_list', []), + 'dataset_names': question_record.metadata.get('dataset_names', []), + 'role': 'assistant', + 'content': full_content.strip(), + 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'is_end': True + } + }) + return + + except json.JSONDecodeError as e: + logger.error(f"JSON解析错误: {e}, 数据: {line}") + continue + + except Exception as e: + logger.error(f"流式处理出错: {str(e)}") + await self.send_error(str(e)) + + # 保存已收集的内容 + if 'full_content' in locals() and full_content: + try: + answer_record.content = full_content.strip() + await database_sync_to_async(answer_record.save)() + except Exception as save_error: + logger.error(f"保存部分内容失败: {str(save_error)}") + + @database_sync_to_async + def get_or_generate_title(self, conversation_id, question, answer): + """获取或生成对话标题""" + try: + # 先检查是否已有标题 + current_title = ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).exclude( + title__in=["New chat", "新对话", ""] + ).values_list('title', flat=True).first() + + if current_title: + return current_title + + # 如果没有标题,生成新标题 + # 这里需要实现标题生成的逻辑 + generated_title = "新对话" # 临时使用默认标题 + + # 更新所有相关记录的标题 + ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).update(title=generated_title) + + return generated_title + + except Exception as e: + logger.error(f"获取或生成标题失败: {str(e)}") + return "新对话" + + async def send_json(self, content): + """发送 JSON 格式的消息""" + await self.send(text_data=json.dumps(content)) + + async def send_error(self, message): + """发送错误消息""" + await self.send_json({ + 'code': 500, + 'message': message, + 'data': {'is_end': True} + }) + +class ChatStreamConsumer(AsyncWebsocketConsumer): + async def connect(self): + """建立WebSocket连接""" + try: + # 从URL参数中获取token + query_string = self.scope.get('query_string', b'').decode() + query_params = parse_qs(query_string) + token_key = query_params.get('token', [''])[0] + + if not token_key: + logger.warning("WebSocket连接尝试,但没有提供token") + await self.close() + return + + # 验证token + self.user = await self.get_user_from_token(token_key) + if not self.user: + logger.warning(f"WebSocket连接尝试,但token无效: {token_key}") + await self.close() + return + + # 将用户信息存储在scope中 + self.scope["user"] = self.user + await self.accept() + logger.info(f"用户 {self.user.username} 流式输出WebSocket连接成功") + + except Exception as e: + logger.error(f"WebSocket连接错误: {str(e)}") + await self.close() + + @database_sync_to_async + def get_user_from_token(self, token_key): + try: + token = Token.objects.select_related('user').get(key=token_key) + return token.user + except Token.DoesNotExist: + return None + + async def disconnect(self, close_code): + """关闭WebSocket连接""" + logger.info(f"用户 {self.user.username if hasattr(self, 'user') else 'unknown'} WebSocket连接断开,代码: {close_code}") + + async def receive(self, text_data): + """接收消息并处理""" + try: + data = json.loads(text_data) + + # 检查必填字段 + if 'question' not in data: + await self.send_error("缺少必填字段: question") + return + + if 'conversation_id' not in data: + await self.send_error("缺少必填字段: conversation_id") + return + + # 处理新会话或现有会话 + await self.process_chat_request(data) + + except Exception as e: + logger.error(f"处理消息时出错: {str(e)}") + logger.error(traceback.format_exc()) + await self.send_error(f"处理消息时出错: {str(e)}") + + async def process_chat_request(self, data): + """处理聊天请求""" + try: + conversation_id = data['conversation_id'] + question = data['question'] + + # 获取会话信息和知识库 + session_info = await self.get_session_info(data) + if not session_info: + return + + knowledge_bases, metadata, dataset_external_id_list = session_info + + # 创建问题记录 + question_record = await self.create_question_record( + conversation_id, + question, + knowledge_bases, + metadata + ) + + if not question_record: + return + + # 创建AI回答记录 + answer_record = await self.create_answer_record( + conversation_id, + question_record, + knowledge_bases, + metadata + ) + + # 发送初始响应 + await self.send_json({ + 'code': 200, + 'message': '开始流式传输', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'content': '', + 'is_end': False + } + }) + + # 调用外部API获取流式响应 + await self.stream_from_external_api( + conversation_id, + question, + dataset_external_id_list, + answer_record, + metadata, + knowledge_bases + ) + + except Exception as e: + logger.error(f"处理聊天请求时出错: {str(e)}") + logger.error(traceback.format_exc()) + await self.send_error(f"处理聊天请求时出错: {str(e)}") + + @database_sync_to_async + def get_session_info(self, data): + """获取会话信息和知识库""" + try: + conversation_id = data['conversation_id'] + + # 查找该会话ID下的历史记录 + existing_records = ChatHistory.objects.filter( + conversation_id=conversation_id + ).order_by('created_at') + + # 如果有历史记录,使用第一条记录的metadata + if existing_records.exists(): + first_record = existing_records.first() + metadata = first_record.metadata or {} + + # 获取知识库信息 + dataset_ids = metadata.get('dataset_id_list', []) + external_id_list = metadata.get('dataset_external_id_list', []) + + if not dataset_ids: + logger.error('找不到会话关联的知识库信息') + return None + + # 验证知识库是否存在且用户有权限 + knowledge_bases = [] + for kb_id in dataset_ids: + try: + kb = KnowledgeBase.objects.get(id=kb_id) + if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'): + logger.error(f'无权访问知识库: {kb.name}') + return None + knowledge_bases.append(kb) + except KnowledgeBase.DoesNotExist: + logger.error(f'知识库不存在: {kb_id}') + return None + + if not external_id_list or not knowledge_bases: + logger.error('会话关联的知识库信息不完整') + return None + + return knowledge_bases, metadata, external_id_list + + else: + # 如果是新会话的第一条记录,需要提供知识库ID + dataset_ids = [] + if 'dataset_id' in data: + dataset_ids.append(str(data['dataset_id'])) + elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)): + if isinstance(data['dataset_id_list'], str): + try: + dataset_list = json.loads(data['dataset_id_list']) + if isinstance(dataset_list, list): + dataset_ids = [str(id) for id in dataset_list] + else: + dataset_ids = [str(data['dataset_id_list'])] + except json.JSONDecodeError: + dataset_ids = [str(data['dataset_id_list'])] + else: + dataset_ids = [str(id) for id in data['dataset_id_list']] + + if not dataset_ids: + logger.error('新会话需要提供知识库ID') + return None + + # 验证所有知识库并收集external_ids + external_id_list = [] + knowledge_bases = [] + + for kb_id in dataset_ids: + try: + knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first() + if not knowledge_base: + logger.error(f'知识库不存在: {kb_id}') + return None + + knowledge_bases.append(knowledge_base) + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(knowledge_base, self.scope["user"], 'read'): + logger.error(f'无权访问知识库: {knowledge_base.name}') + return None + + # 添加知识库的external_id到列表 + if knowledge_base.external_id: + external_id_list.append(str(knowledge_base.external_id)) + else: + logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id") + + except Exception as e: + logger.error(f"处理知识库ID出错: {str(e)}") + return None + + if not external_id_list: + logger.error('没有有效的知识库external_id') + return None + + # 创建metadata + metadata = { + 'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), + 'dataset_id_list': [str(id) for id in dataset_ids], + 'dataset_external_id_list': [str(id) for id in external_id_list], + 'dataset_names': [kb.name for kb in knowledge_bases] + } + + return knowledge_bases, metadata, external_id_list + + except Exception as e: + logger.error(f"获取会话信息时出错: {str(e)}") + return None + + def check_knowledge_base_permission(self, kb, user, permission_type): + """检查知识库权限""" + # 实现权限检查逻辑 + return True # 临时返回 True,需要根据实际情况实现 + + @database_sync_to_async + def create_question_record(self, conversation_id, question, knowledge_bases, metadata): + """创建问题记录""" + try: + title = metadata.get('title', 'New chat') + + # 创建用户问题记录 + return ChatHistory.objects.create( + user=self.scope["user"], + knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库 + conversation_id=str(conversation_id), + title=title, + role='user', + content=question, + metadata=metadata + ) + except Exception as e: + logger.error(f"创建问题记录时出错: {str(e)}") + return None + + @database_sync_to_async + def create_answer_record(self, conversation_id, question_record, knowledge_bases, metadata): + """创建AI回答记录""" + try: + return ChatHistory.objects.create( + user=self.scope["user"], + knowledge_base=knowledge_bases[0], + conversation_id=str(conversation_id), + title=question_record.title, + parent_id=str(question_record.id), + role='assistant', + content="", # 初始内容为空 + metadata=metadata + ) + except Exception as e: + logger.error(f"创建回答记录时出错: {str(e)}") + return None + + async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases): + """从外部API获取流式响应""" + try: + # 确保所有ID都是字符串 + dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list] + + # 获取标题 + title = answer_record.title or 'New chat' + + # 异步收集完整内容,用于最后保存 + full_content = "" + + # 使用aiohttp进行异步HTTP请求 + async with aiohttp.ClientSession() as session: + # 第一步: 创建聊天会话 + async with session.post( + f"{settings.API_BASE_URL}/api/application/chat/open", + json={ + "id": "d5d11efa-ea9a-11ef-9933-0242ac120006", + "model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), + "dataset_id_list": dataset_external_ids, + "multiple_rounds_dialogue": False, + "dataset_setting": { + "top_n": 10, "similarity": "0.3", + "max_paragraph_char_number": 10000, + "search_mode": "blend", + "no_references_setting": { + "value": "{question}", + "status": "ai_questioning" + } + }, + "model_setting": { + "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" + }, + "problem_optimization": False + } + ) as chat_response: + + if chat_response.status != 200: + error_msg = f"外部API调用失败: {await chat_response.text()}" + logger.error(error_msg) + await self.send_error(error_msg) + return + + chat_data = await chat_response.json() + if chat_data.get('code') != 200 or not chat_data.get('data'): + error_msg = f"外部API返回错误: {chat_data}" + logger.error(error_msg) + await self.send_error(error_msg) + return + + chat_id = chat_data['data'] + logger.info(f"成功创建聊天会话, chat_id: {chat_id}") + + # 第二步: 建立流式连接 + message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}" + logger.info(f"开始流式请求: {message_url}") + + # 创建流式请求 + async with session.post( + url=message_url, + json={"message": question, "re_chat": False, "stream": True}, + headers={"Content-Type": "application/json"} + ) as message_request: + + if message_request.status != 200: + error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}" + logger.error(error_msg) + await self.send_error(error_msg) + return + + # 创建一个缓冲区以处理分段的数据 + buffer = "" + + # 读取并处理每个响应块 + logger.info("开始处理流式响应") + async for chunk in message_request.content.iter_any(): + chunk_str = chunk.decode('utf-8') + buffer += chunk_str + + # 检查是否有完整的数据行 + while '\n\n' in buffer: + parts = buffer.split('\n\n', 1) + line = parts[0] + buffer = parts[1] + + if line.startswith('data: '): + try: + # 提取JSON数据 + json_str = line[6:] # 去掉 "data: " 前缀 + data = json.loads(json_str) + + # 记录并处理部分响应 + if 'content' in data: + content_part = data['content'] + full_content += content_part + + # 发送部分内容 + await self.send_json({ + 'code': 200, + 'message': 'partial', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'content': content_part, + 'is_end': data.get('is_end', False) + } + }) + + # 处理结束标记 + if data.get('is_end', False): + logger.info("收到流式响应结束标记") + # 保存完整内容 + await self.update_answer_content(answer_record.id, full_content.strip()) + + # 处理标题 + title = await self.get_or_generate_title( + conversation_id, + question, + full_content.strip() + ) + + # 发送最终响应 + await self.send_json({ + 'code': 200, + 'message': '完成', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'title': title, + 'dataset_id_list': metadata.get('dataset_id_list', []), + 'dataset_names': metadata.get('dataset_names', []), + 'role': 'assistant', + 'content': full_content.strip(), + 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'is_end': True + } + }) + return + + except json.JSONDecodeError as e: + logger.error(f"JSON解析错误: {e}, 数据: {line}") + continue + + except Exception as e: + logger.error(f"流式处理出错: {str(e)}") + logger.error(traceback.format_exc()) + await self.send_error(str(e)) + + # 保存已收集的内容 + if 'full_content' in locals() and full_content: + try: + await self.update_answer_content(answer_record.id, full_content.strip()) + except Exception as save_error: + logger.error(f"保存部分内容失败: {str(save_error)}") + + @database_sync_to_async + def update_answer_content(self, answer_id, content): + """更新回答内容""" + try: + answer_record = ChatHistory.objects.get(id=answer_id) + answer_record.content = content + answer_record.save() + return True + except Exception as e: + logger.error(f"更新回答内容失败: {str(e)}") + return False + + @database_sync_to_async + def get_or_generate_title(self, conversation_id, question, answer): + """获取或生成对话标题""" + try: + # 先检查是否已有标题 + current_title = ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).exclude( + title__in=["New chat", "新对话", ""] + ).values_list('title', flat=True).first() + + if current_title: + return current_title + + # 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题) + generated_title = question[:20] + "..." if len(question) > 20 else question + + # 更新所有相关记录的标题 + ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).update(title=generated_title) + + return generated_title + + except Exception as e: + logger.error(f"获取或生成标题失败: {str(e)}") + return "新对话" + + async def send_json(self, content): + """发送JSON格式的消息""" + await self.send(text_data=json.dumps(content)) + + async def send_error(self, message): + """发送错误消息""" + await self.send_json({ + 'code': 500, + 'message': message, + 'data': {'is_end': True} + }) diff --git a/user_management/management/commands/create_test_users.py b/user_management/management/commands/create_test_users.py index bd1e419d..59d33325 100644 --- a/user_management/management/commands/create_test_users.py +++ b/user_management/management/commands/create_test_users.py @@ -6,47 +6,124 @@ from rest_framework.authtoken.models import Token User = get_user_model() class Command(BaseCommand): - help = '创建测试用户:1个管理员,2个组长,4个组员' + help = '创建测试用户:4个管理员,7个组长,4个组员' def handle(self, *args, **kwargs): - # 创建管理员 - 技术部管理员 - admin, created = User.objects.get_or_create( - username='admin', - defaults={ - 'email': 'admin@example.com', - 'name': '张管理', + # 创建管理员 - 4个管理员 + admins = [ + { + 'username': 'admin1', + 'password': 'admin123', + 'email': 'admin1@example.com', + 'name': '张技术管理', + 'department': '技术部门', + 'role': 'admin', + }, + { + 'username': 'admin2', + 'password': 'admin123', + 'email': 'admin2@example.com', + 'name': '王产品管理', + 'department': '产品部门', + 'role': 'admin', + }, + { + 'username': 'admin3', + 'password': 'admin123', + 'email': 'admin3@example.com', + 'name': '李商务管理', + 'department': '商务部门', + 'role': 'admin', + }, + { + 'username': 'admin4', + 'password': 'admin123', + 'email': 'admin4@example.com', + 'name': '赵HR管理', + 'department': 'HR', 'role': 'admin', - 'is_staff': True, - 'is_superuser': True, - 'last_login': timezone.now() } - ) - if created: - admin.set_password('admin123') - admin.save() - token = Token.objects.create(user=admin) - self.stdout.write(self.style.SUCCESS( - f'成功创建管理员用户: {admin.username}({admin.name}), Token: {token.key}' - )) - else: - self.stdout.write(self.style.WARNING(f'管理员用户已存在: {admin.username}')) + ] - # 创建组长 - 研发部组长和测试部组长 + for admin_data in admins: + admin, created = User.objects.get_or_create( + username=admin_data['username'], + defaults={ + 'email': admin_data['email'], + 'name': admin_data['name'], + 'role': admin_data['role'], + 'department': admin_data['department'], + 'is_staff': True, + 'is_superuser': True, + 'last_login': timezone.now() + } + ) + if created: + admin.set_password(admin_data['password']) + admin.save() + token = Token.objects.create(user=admin) + self.stdout.write(self.style.SUCCESS( + f'成功创建管理员用户: {admin.username}({admin.name}), Token: {token.key}' + )) + else: + self.stdout.write(self.style.WARNING(f'管理员用户已存在: {admin.username}')) + + # 创建组长 - 7个部门的组长 leaders = [ { 'username': 'leader1', 'password': 'leader123', 'email': 'leader1@example.com', - 'name': '李研发', - 'department': '研发部', + 'name': '陈达人', + 'department': '达人部门', 'role': 'leader' }, { 'username': 'leader2', 'password': 'leader123', 'email': 'leader2@example.com', - 'name': '王测试', - 'department': '测试部', + 'name': '刘商务', + 'department': '商务部门', + 'role': 'leader' + }, + { + 'username': 'leader3', + 'password': 'leader123', + 'email': 'leader3@example.com', + 'name': '杨样本', + 'department': '样本中心', + 'role': 'leader' + }, + { + 'username': 'leader4', + 'password': 'leader123', + 'email': 'leader4@example.com', + 'name': '黄产品', + 'department': '产品部门', + 'role': 'leader' + }, + { + 'username': 'leader5', + 'password': 'leader123', + 'email': 'leader5@example.com', + 'name': '周AI', + 'department': 'AI自媒体', + 'role': 'leader' + }, + { + 'username': 'leader6', + 'password': 'leader123', + 'email': 'leader6@example.com', + 'name': '吴HR', + 'department': 'HR', + 'role': 'leader' + }, + { + 'username': 'leader7', + 'password': 'leader123', + 'email': 'leader7@example.com', + 'name': '郑技术', + 'department': '技术部门', 'role': 'leader' } ] @@ -73,67 +150,5 @@ class Command(BaseCommand): else: self.stdout.write(self.style.WARNING(f'组长用户已存在: {leader.username}')) - # 创建组员 - 2个开发组员,2个测试组员 - members = [ - { - 'username': 'member1', - 'password': 'member123', - 'email': 'member1@example.com', - 'name': '赵开发', - 'department': '研发部', - 'role': 'member', - 'group': '前端组' - }, - { - 'username': 'member2', - 'password': 'member123', - 'email': 'member2@example.com', - 'name': '钱开发', - 'department': '研发部', - 'role': 'member', - 'group': '后端组' - }, - { - 'username': 'member3', - 'password': 'member123', - 'email': 'member3@example.com', - 'name': '孙测试', - 'department': '测试部', - 'role': 'member', - 'group': '功能测试组' - }, - { - 'username': 'member4', - 'password': 'member123', - 'email': 'member4@example.com', - 'name': '周测试', - 'department': '测试部', - 'role': 'member', - 'group': '自动化测试组' - } - ] - - for member_data in members: - member, created = User.objects.get_or_create( - username=member_data['username'], - defaults={ - 'email': member_data['email'], - 'name': member_data['name'], - 'role': member_data['role'], - 'department': member_data['department'], - 'group': member_data['group'], - 'is_staff': False, - 'last_login': timezone.now() - } - ) - if created: - member.set_password(member_data['password']) - member.save() - token = Token.objects.create(user=member) - self.stdout.write(self.style.SUCCESS( - f'成功创建组员用户: {member.username}({member.name}), Token: {token.key}' - )) - else: - self.stdout.write(self.style.WARNING(f'组员用户已存在: {member.username}')) self.stdout.write(self.style.SUCCESS('所有测试用户创建完成!')) diff --git a/user_management/middleware.py b/user_management/middleware.py index 0d92bac3..7305efdb 100644 --- a/user_management/middleware.py +++ b/user_management/middleware.py @@ -4,6 +4,10 @@ from django.contrib.auth.models import AnonymousUser from rest_framework.authtoken.models import Token from django.contrib.auth import get_user_model import logging +import re +from django.middleware.csrf import CsrfViewMiddleware +from django.conf import settings +from django.utils.deprecation import MiddlewareMixin logger = logging.getLogger(__name__) @@ -44,11 +48,28 @@ class TokenAuthMiddleware(BaseMiddleware): scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) -class UserActivityMiddleware: - """用户活动中间件""" - def __init__(self, get_response): - self.get_response = get_response +class UserActivityMiddleware(MiddlewareMixin): + """中间件用于记录用户活动""" + + def process_request(self, request): + # 可以在这里记录用户活动日志 + pass - def __call__(self, request): - response = self.get_response(request) - return response \ No newline at end of file +class CSRFExemptMiddleware(MiddlewareMixin): + """为特定URL路径豁免CSRF保护的中间件""" + + def process_view(self, request, callback, callback_args, callback_kwargs): + # 检查是否有CSRF豁免URL配置 + if not hasattr(settings, 'CSRF_EXEMPT_URLS'): + return None + + # 获取当前请求的路径 + path = request.path_info.lstrip('/') + + # 检查是否匹配任何豁免模式 + for exempt_pattern in settings.CSRF_EXEMPT_URLS: + if re.match(exempt_pattern, path): + setattr(request, '_dont_enforce_csrf_checks', True) + break + + return None \ No newline at end of file diff --git a/user_management/migrations/0005_chathistory_title.py b/user_management/migrations/0005_chathistory_title.py new file mode 100644 index 00000000..532e61a2 --- /dev/null +++ b/user_management/migrations/0005_chathistory_title.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1.5 on 2025-04-23 14:20 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0004_knowledgebasedocument'), + ] + + operations = [ + migrations.AddField( + model_name='chathistory', + name='title', + field=models.CharField(blank=True, default='New chat', help_text='对话标题', max_length=100, null=True), + ), + ] diff --git a/user_management/migrations/0006_knowledgebasedocument_uploader_name.py b/user_management/migrations/0006_knowledgebasedocument_uploader_name.py new file mode 100644 index 00000000..758d6ee9 --- /dev/null +++ b/user_management/migrations/0006_knowledgebasedocument_uploader_name.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1.5 on 2025-04-23 16:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0005_chathistory_title'), + ] + + operations = [ + migrations.AddField( + model_name='knowledgebasedocument', + name='uploader_name', + field=models.CharField(default='未知用户', max_length=100, verbose_name='上传者姓名'), + ), + ] diff --git a/user_management/models.py b/user_management/models.py index 04e387dd..7e5c74f7 100644 --- a/user_management/models.py +++ b/user_management/models.py @@ -277,7 +277,7 @@ class Permission(models.Model): title = f'权限申请{self.get_status_display()} - {self.get_resource_type_display()}' content = f'您申请的 {self.get_resource_type_display()} 权限已{self.get_status_display()}' self.send_notification(notification_type, title, content) - + class ChatHistory(models.Model): """聊天历史记录""" ROLE_CHOICES = [ @@ -291,6 +291,8 @@ class ChatHistory(models.Model): knowledge_base = models.ForeignKey('KnowledgeBase', on_delete=models.CASCADE) # 用于标识知识库组合的对话 conversation_id = models.CharField(max_length=100, db_index=True) + # 对话标题 + title = models.CharField(max_length=100, null=True, blank=True, default='New chat', help_text="对话标题") parent_id = models.CharField(max_length=100, null=True, blank=True) role = models.CharField(max_length=20, choices=ROLE_CHOICES) content = models.TextField() @@ -391,6 +393,7 @@ class ChatHistory(models.Model): ] } + class UserProfile(models.Model): """用户档案模型""" user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile') @@ -682,6 +685,7 @@ class KnowledgeBaseDocument(models.Model): document_id = models.CharField(max_length=100, verbose_name='文档ID') document_name = models.CharField(max_length=255, verbose_name='文档名称') external_id = models.CharField(max_length=100, verbose_name='外部文档ID') + uploader_name = models.CharField(max_length=100, default="未知用户", verbose_name='上传者姓名') status = models.CharField( max_length=20, default='active', diff --git a/user_management/routing.py b/user_management/routing.py index 7f4a0528..29404046 100644 --- a/user_management/routing.py +++ b/user_management/routing.py @@ -3,4 +3,6 @@ from . import consumers websocket_urlpatterns = [ re_path(r'ws/notifications/$', consumers.NotificationConsumer.as_asgi()), + re_path(r'ws/chat/$', consumers.ChatConsumer.as_asgi()), + re_path(r'ws/chat/stream/$', consumers.ChatStreamConsumer.as_asgi()), ] \ No newline at end of file diff --git a/user_management/views.py b/user_management/views.py index 47659bf5..2fb4804a 100644 --- a/user_management/views.py +++ b/user_management/views.py @@ -253,87 +253,136 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): queryset = ChatHistory.objects.all() def get_queryset(self): - """确保用户只能看到自己的未删除的聊天记录""" - return ChatHistory.objects.filter( - user=self.request.user, + """确保用户只能看到自己的未删除的聊天记录以及有权限的知识库关联的聊天记录""" + user = self.request.user + + # 当前用户的聊天记录 + user_records = ChatHistory.objects.filter( + user=user, is_deleted=False ) + + # 获取用户有权限的知识库ID列表 + accessible_kb_ids = [] + for kb in KnowledgeBase.objects.all(): + if self.check_knowledge_base_permission(kb, user, 'read'): + accessible_kb_ids.append(kb.id) + + # 其他用户创建的、但当前用户有权限访问的知识库的聊天记录 + others_records = ChatHistory.objects.filter( + knowledge_base_id__in=accessible_kb_ids, + is_deleted=False + ).exclude(user=user) # 排除用户自己的记录,避免重复 + + # 合并两个查询集 + combined_queryset = user_records | others_records + + return combined_queryset def list(self, request): - """获取对话列表概览""" + """获取对话列表""" try: - # 获取查询参数 - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - - # 获取所有对话的概览 - latest_chats = self.get_queryset().values( - 'conversation_id' - ).annotate( - latest_id=Max('id'), - message_count=Count('id'), - last_message=Max('created_at') + # 获取用户所有的对话 + unique_conversations = self.get_queryset().values('conversation_id').annotate( + last_message=Max('created_at'), + message_count=Count('id') ).order_by('-last_message') - - # 计算分页 - total = latest_chats.count() - start = (page - 1) * page_size - end = start + page_size - chats = latest_chats[start:end] - - results = [] - for chat in chats: - # 获取最新消息记录 - latest_record = ChatHistory.objects.get(id=chat['latest_id']) + + # 构建结果列表 + conversation_list = [] + for conv in unique_conversations: + # 获取对话中的第一条消息,用于显示标题和知识库信息 + first_message = self.get_queryset().filter( + conversation_id=conv['conversation_id'] + ).order_by('created_at').first() - # 从metadata中获取完整的知识库信息 + if not first_message: + continue + + # 获取知识库信息 dataset_info = [] - if latest_record.metadata: - dataset_id_list = latest_record.metadata.get('dataset_id_list', []) - dataset_names = latest_record.metadata.get('dataset_names', []) + if first_message.metadata and 'dataset_id_list' in first_message.metadata: + # 获取用户有权限访问的知识库 + valid_kb_ids = [] + for kb_id in first_message.metadata['dataset_id_list']: + try: + kb = KnowledgeBase.objects.get(id=kb_id) + if self.check_knowledge_base_permission(kb, request.user, 'read'): + valid_kb_ids.append(kb_id) + dataset_info.append({ + 'id': str(kb.id), + 'name': kb.name, + 'type': kb.type + }) + except KnowledgeBase.DoesNotExist: + continue + + # 获取最近的消息用于预览 + last_user_message = self.get_queryset().filter( + conversation_id=conv['conversation_id'], + role='user' + ).order_by('-created_at').first() + + # 处理对话标题 - 优先使用已有标题,否则尝试生成新标题 + title = first_message.title + + # 如果标题为空或为默认值'New chat',尝试生成新标题 + if not title or title == 'New chat': + # 找到对话中的第一对问答 + messages = list(self.get_queryset().filter( + conversation_id=conv['conversation_id'] + ).order_by('created_at')) - # 如果有知识库ID列表 - if dataset_id_list: - # 如果同时有名称列表且长度匹配 - if dataset_names and len(dataset_names) == len(dataset_id_list): - dataset_info = [{ - 'id': str(id), - 'name': name - } for id, name in zip(dataset_id_list, dataset_names)] - else: - # 如果没有名称列表,则只返回ID - datasets = KnowledgeBase.objects.filter(id__in=dataset_id_list) - dataset_info = [{ - 'id': str(ds.id), - 'name': ds.name - } for ds in datasets] - - results.append({ - 'conversation_id': chat['conversation_id'], - 'message_count': chat['message_count'], - 'last_message': latest_record.content, - 'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S'), - 'dataset_id_list': [ds['id'] for ds in dataset_info], # 添加完整的知识库ID列表 - 'datasets': dataset_info # 包含ID和名称的完整信息 - }) - + user_message = None + assistant_message = None + + for i in range(len(messages)-1): + if messages[i].role == 'user' and messages[i+1].role == 'assistant' and messages[i+1].parent_id == str(messages[i].id): + user_message = messages[i] + assistant_message = messages[i+1] + break + + if user_message and assistant_message: + # 调用DeepSeek API生成标题 + generated_title = self._generate_conversation_title_from_deepseek( + user_message.content, + assistant_message.content + ) + + if generated_title: + # 更新所有相关记录的标题 + title = generated_title + ChatHistory.objects.filter( + conversation_id=conv['conversation_id'] + ).update(title=generated_title) + + # 如果生成失败,使用对话ID的一部分作为临时标题 + if not title: + title = f"对话 {conv['conversation_id'][:8]}" + + # 构建返回结果 + conversation_data = { + 'conversation_id': conv['conversation_id'], + 'last_message': conv['last_message'].strftime('%Y-%m-%d %H:%M:%S'), + 'message_count': conv['message_count'], + 'title': title, + 'preview': last_user_message.content[:100] if last_user_message else "", + 'datasets': dataset_info + } + conversation_list.append(conversation_data) + + # 返回结果 return Response({ 'code': 200, 'message': '获取成功', - 'data': { - 'total': total, - 'page': page, - 'page_size': page_size, - 'results': results - } + 'data': conversation_list }) - except Exception as e: - logger.error(f"获取聊天记录失败: {str(e)}") + logger.error(f"获取对话列表失败: {str(e)}") logger.error(traceback.format_exc()) return Response({ 'code': 500, - 'message': f'获取聊天记录失败: {str(e)}', + 'message': f"获取对话列表失败: {str(e)}", 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -349,7 +398,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): 'data': None }, status=status.HTTP_400_BAD_REQUEST) - # 获取对话历史 + # 获取对话历史,确保按时间顺序排序 messages = self.get_queryset().filter( conversation_id=conversation_id ).order_by('created_at') @@ -379,19 +428,61 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): 'name': ds.name, 'type': ds.type } for ds in accessible_datasets] + + # 处理对话标题 - 优先使用已有标题,否则尝试生成新标题 + title = first_message.title + + # 如果标题为空或为默认值'New chat',尝试生成新标题 + if not title or title == 'New chat': + # 尝试找到一对完整的问答 + user_message = None + assistant_message = None + + for i in range(len(messages)-1): + if messages[i].role == 'user' and messages[i+1].role == 'assistant' and messages[i+1].parent_id == str(messages[i].id): + user_message = messages[i] + assistant_message = messages[i+1] + break + + if user_message and assistant_message: + # 调用DeepSeek API生成标题 + generated_title = self._generate_conversation_title_from_deepseek( + user_message.content, + assistant_message.content + ) + + if generated_title: + # 更新所有相关记录的标题 + title = generated_title + ChatHistory.objects.filter( + conversation_id=conversation_id + ).update(title=generated_title) + + # 如果生成失败,使用对话ID的一部分作为临时标题 + if not title: + title = f"对话 {conversation_id[:8]}" + + # 构建消息列表,包含parent_id信息 + message_list = [] + for msg in messages: + message_data = { + 'id': str(msg.id), + 'parent_id': msg.parent_id, # 添加parent_id + 'role': msg.role, + 'content': msg.content, + 'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'metadata': msg.metadata # 添加metadata + } + message_list.append(message_data) return Response({ 'code': 200, 'message': '获取成功', 'data': { 'conversation_id': conversation_id, + 'title': title, # 返回标题 'datasets': dataset_info, - 'messages': [{ - 'id': str(msg.id), - 'role': msg.role, - 'content': msg.content, - 'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S') - } for msg in messages] + 'messages': message_list } }) @@ -512,6 +603,9 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): conversation_id = str(uuid.uuid4()) logger.info(f"创建新的会话ID: {conversation_id}") + # 获取自定义标题(如果有) + title = data.get('title', 'New chat') + # 准备metadata (仍然保存知识库名称用于内部处理) metadata = { 'dataset_id_list': [str(id) for id in dataset_ids], @@ -523,6 +617,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): 'message': '会话创建成功', 'data': { 'conversation_id': conversation_id, + 'title': title, # 添加标题字段 'dataset_id_list': metadata['dataset_id_list'] } }) @@ -680,12 +775,16 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): 'dataset_external_id_list': [str(id) for id in external_id_list], 'dataset_names': [kb.name for kb in knowledge_bases] } - + + # 检查是否有自定义标题 + title = data.get('title', 'New chat') + # 创建用户问题记录 question_record = ChatHistory.objects.create( user=request.user, knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库 conversation_id=str(conversation_id), + title=title, # 设置标题 role='user', content=data['question'], metadata=metadata @@ -696,7 +795,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): if use_stream: # 创建流式响应 - return StreamingHttpResponse( + response = StreamingHttpResponse( self._stream_answer_from_external_api( conversation_id=str(conversation_id), question_record=question_record, @@ -705,8 +804,15 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): question=data['question'], metadata=metadata ), - content_type='text/event-stream' + content_type='text/event-stream', + status=status.HTTP_201_CREATED # 修改状态码为201 ) + + # 添加禁用缓存的头部 + response['Cache-Control'] = 'no-cache, no-store' + response['Connection'] = 'keep-alive' + + return response else: # 使用非流式输出 logger.info("使用非流式输出模式") @@ -721,29 +827,49 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) # 创建 AI 回答记录 - answer_record = ChatHistory.objects.create( - user=request.user, + answer_record = ChatHistory.objects.create( + user=request.user, knowledge_base=knowledge_bases[0], - conversation_id=str(conversation_id), - parent_id=str(question_record.id), - role='assistant', - content=answer, - metadata=metadata - ) + conversation_id=str(conversation_id), + title=title, # 设置标题 + parent_id=str(question_record.id), + role='assistant', + content=answer, + metadata=metadata + ) + + # 如果是新会话的第一条消息,并且没有自定义标题,则自动生成标题 + should_generate_title = not existing_records.exists() and (not title or title == 'New chat') + if should_generate_title: + try: + generated_title = self._generate_conversation_title_from_deepseek( + data['question'], + answer + ) + if generated_title: + # 更新所有相关记录的标题 + ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).update(title=generated_title) + title = generated_title + except Exception as e: + logger.error(f"自动生成标题失败: {str(e)}") + # 继续执行,不影响主流程 - return Response({ - 'code': 200, + return Response({ + 'code': 200, # 修改状态码为201 'message': '成功', - 'data': { - 'id': str(answer_record.id), - 'conversation_id': str(conversation_id), + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'title': title, # 添加标题字段 'dataset_id_list': metadata.get('dataset_id_list', []), 'dataset_names': metadata.get('dataset_names', []), - 'role': 'assistant', + 'role': 'assistant', 'content': answer, - 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S') - } - }) + 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S') + } + }, status=status.HTTP_200_CREATED) # 修改状态码为201 except Exception as e: logger.error(f"创建聊天记录失败: {str(e)}") @@ -760,11 +886,15 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): # 确保所有ID都是字符串 dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list] + # 获取标题 + title = question_record.title or 'New chat' + # 创建AI回答记录对象,稍后更新内容 answer_record = ChatHistory.objects.create( user=question_record.user, knowledge_base=knowledge_bases[0], conversation_id=str(conversation_id), + title=title, # 设置标题 parent_id=str(question_record.id), role='assistant', content="", # 初始内容为空 @@ -803,7 +933,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): }, "problem_optimization": False }, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, + ) if chat_response.status_code != 200: @@ -831,7 +962,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): url=message_url, json={"message": question, "re_chat": False, "stream": True}, headers={"Content-Type": "application/json"}, - stream=True # 启用流式传输 + stream=True, # 启用流式传输 + ) if message_request.status_code != 200: @@ -872,11 +1004,12 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): # 构建响应数据 response_data = { - 'code': 200, + 'code': 200, # 修改状态码为201 'message': 'partial', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), + 'title': title, # 添加标题字段 'content': content_part, 'is_end': data.get('is_end', False) } @@ -892,13 +1025,46 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): answer_record.content = full_content.strip() answer_record.save() + # 先检查当前conversation_id是否已有有效标题 + current_title = ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).exclude( + title__in=["New chat", "新对话", ""] + ).values_list('title', flat=True).first() + + # 如果已有有效标题,则复用 + if current_title: + title = current_title + logger.info(f"复用已有标题: {title}") + else: + # 没有有效标题时,直接基于当前问题和回答生成标题 + try: + # 直接使用当前的问题和完整的AI回答来生成标题 + generated_title = self._generate_conversation_title_from_deepseek( + question, full_content.strip() + ) + if generated_title: + # 更新所有相关记录的标题 + ChatHistory.objects.filter( + conversation_id=str(conversation_id) + ).update(title=generated_title) + title = generated_title + logger.info(f"成功生成标题: {title}") + else: + title = "新对话" # 如果生成失败,使用默认标题 + logger.warning("生成标题失败,使用默认标题") + except Exception as e: + logger.error(f"自动生成标题失败: {str(e)}") + title = "新对话" # 如果出错,使用默认标题 + # 发送完整内容的最终响应 final_response = { - 'code': 200, + 'code': 200, # 修改状态码为201 'message': '完成', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), + 'title': title, # 添加生成的标题 'dataset_id_list': metadata.get('dataset_id_list', []), 'dataset_names': metadata.get('dataset_names', []), 'role': 'assistant', @@ -929,11 +1095,11 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): full_content += content_part response_data = { - 'code': 200, + 'code': 200, # 修改状态码为201 'message': 'partial', 'data': { 'id': str(answer_record.id), - 'conversation_id': str(conversation_id), + 'conversation_id': str(conversation_id), # 添加标题字段 'content': content_part, 'is_end': data.get('is_end', False) } @@ -1008,7 +1174,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): chat_response = requests.post( url=f"{settings.API_BASE_URL}/api/application/chat/open", json=chat_request_data, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, + ) logger.info(f"API响应状态码: {chat_response.status_code}") @@ -1038,7 +1205,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): message_response = requests.post( url=f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}", json=message_request_data, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, + ) if message_response.status_code != 200: @@ -1295,6 +1463,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): export_response = requests.get( url=export_url, + stream=True # 使用流式传输处理大文件 ) @@ -1351,7 +1520,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): # 调用外部API response = requests.get( url=api_url, - params=params + params=params, + ) if response.status_code != 200: @@ -1537,7 +1707,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): response = requests.get( url=url, - params=params + params=params, + ) if response.status_code != 200: @@ -1617,6 +1788,116 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + @action(detail=False, methods=['get'], url_path='generate-conversation-title') + def generate_conversation_title(self, request): + """更新会话标题""" + try: + conversation_id = request.query_params.get('conversation_id') + if not conversation_id: + return Response({ + 'code': 400, + 'message': '缺少conversation_id参数', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查对话是否存在 + messages = self.get_queryset().filter( + conversation_id=conversation_id, + is_deleted=False, + user=request.user + ).order_by('created_at') + + if not messages.exists(): + return Response({ + 'code': 404, + 'message': '对话不存在或无权访问', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 检查是否有自定义标题参数 + custom_title = request.query_params.get('title') + if not custom_title: + return Response({ + 'code': 400, + 'message': '缺少title参数', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新所有相关记录的标题 + ChatHistory.objects.filter( + conversation_id=conversation_id, + user=request.user + ).update(title=custom_title) + + return Response({ + 'code': 200, + 'message': '更新会话标题成功', + 'data': { + 'conversation_id': conversation_id, + 'title': custom_title + } + }) + + except Exception as e: + logger.error(f"更新会话标题失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f"更新会话标题失败: {str(e)}", + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def _generate_conversation_title_from_deepseek(self, user_question, assistant_answer): + """调用SiliconCloud API生成会话标题,直接基于当前问题和回答内容""" + try: + # 从Django设置中获取API密钥 + api_key = settings.SILICON_CLOUD_API_KEY + if not api_key: + return "新对话" + + # 构建提示信息 + prompt = f"请根据用户的问题和助手的回答,生成一个简短的对话标题(不超过20个字)。\n\n用户问题: {user_question}\n\n助手回答: {assistant_answer}" + + import requests + + url = "https://api.siliconflow.cn/v1/chat/completions" + + payload = { + "model": "deepseek-ai/DeepSeek-V3", + "stream": False, + "max_tokens": 512, + "temperature": 0.7, + "top_p": 0.7, + "top_k": 50, + "frequency_penalty": 0.5, + "n": 1, + "stop": [], + "messages": [ + { + "role": "user", + "content": prompt + } + ] + } + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + response = requests.post(url, json=payload, headers=headers) + response_data = response.json() + + if response.status_code == 200 and 'choices' in response_data and response_data['choices']: + title = response_data['choices'][0]['message']['content'].strip() + return title[:50] # 截断过长的标题 + else: + logger.error(f"生成标题时出错: {response.text}") + return "新对话" + + except Exception as e: + logger.exception(f"生成对话标题时发生错误: {str(e)}") + return "新对话" + class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): serializer_class = KnowledgeBaseSerializer @@ -1657,28 +1938,9 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): serializer = self.get_serializer(paginated_queryset, many=True) data = serializer.data - # 获取文档数量统计 - kb_ids = [kb.id for kb in paginated_queryset] - doc_counts = KnowledgeBaseDocument.objects.filter( - knowledge_base_id__in=kb_ids, - status='active' - ).values('knowledge_base_id').annotate( - count=Count('id') - ) - - # 创建文档数量映射字典 - doc_count_map = { - str(item['knowledge_base_id']): item['count'] - for item in doc_counts - } - - # 为每个知识库添加权限信息和文档数量 + # 为每个知识库添加权限信息 user = request.user for item in data: - # 添加文档数量 - kb_id = item['id'] - item['document_count'] = doc_count_map.get(kb_id, 0) - # 获取必要的知识库属性 kb_type = item['type'] department = item.get('department') @@ -2060,7 +2322,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): response = requests.put( f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}', json=api_data, - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, + ) if response.status_code != 200: @@ -2072,6 +2335,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): logger.info(f"外部知识库更新成功: {instance.external_id}") + except requests.exceptions.Timeout: + raise ExternalAPIError("请求超时,请稍后重试") except requests.exceptions.RequestException as e: raise ExternalAPIError(f"API请求失败: {str(e)}") except Exception as e: @@ -2120,23 +2385,30 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): "data": None }, status=status.HTTP_403_FORBIDDEN) - with transaction.atomic(): - # 删除外部知识库 - if instance.external_id: - try: - self._delete_external_dataset(instance.external_id) - logger.info(f"外部知识库删除成功: {instance.external_id}") - except ExternalAPIError as e: - logger.error(f"删除外部知识库失败: {str(e)}") - return Response({ - "code": 500, - "message": f"删除外部知识库失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + # 删除外部知识库(如果存在) + external_delete_success = True + external_error_message = None + if instance.external_id: + try: + self._delete_external_dataset(instance.external_id) + logger.info(f"外部知识库删除成功: {instance.external_id}") + except ExternalAPIError as e: + # 记录错误但继续执行本地删除 + external_delete_success = False + external_error_message = str(e) + logger.warning(f"外部知识库删除失败,将继续删除本地知识库: {str(e)}") - # 删除本地知识库 - self.perform_destroy(instance) - logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}") + # 删除本地知识库 + self.perform_destroy(instance) + logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}") + + # 如果外部知识库删除失败,返回警告消息 + if not external_delete_success: + return Response({ + "code": 200, + "message": f"知识库已删除,但外部知识库删除失败: {external_error_message}", + "data": None + }) return Response({ "code": 200, @@ -2159,6 +2431,59 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): "data": None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + def _delete_external_dataset(self, external_id): + """删除外部知识库""" + try: + if not external_id: + logger.warning("外部知识库ID为空,跳过删除") + return True + + response = requests.delete( + f'{settings.API_BASE_URL}/api/dataset/{external_id}', + headers={'Content-Type': 'application/json'}, + + ) + + logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}") + + # 检查响应状态码 + if response.status_code == 404: + logger.warning(f"外部知识库不存在: {external_id}") + return True # 如果知识库不存在,也视为删除成功 + elif response.status_code not in [200, 204]: + logger.warning(f"删除外部知识库状态码异常: {response.status_code}, {response.text}") + return True # 即使状态码异常,也允许继续删除本地知识库 + + # 检查业务状态码 + try: + api_response = response.json() + if api_response.get('code') != 200: + # 如果是因为ID不存在,也视为成功 + if "不存在" in api_response.get('message', ''): + logger.warning(f"外部知识库ID不存在,视为删除成功: {external_id}") + return True + logger.warning(f"业务处理返回非200状态码: {api_response.get('code')}, {api_response.get('message')}") + return True # 不再抛出异常,允许本地删除继续 + logger.info(f"外部知识库删除成功: {external_id}") + return True + except ValueError: + # 如果无法解析 JSON,但状态码是 200,也认为成功 + logger.warning(f"外部知识库删除响应无法解析JSON,但状态码为200,视为成功: {external_id}") + return True + + except requests.exceptions.Timeout: + logger.error(f"删除外部知识库超时: {external_id}") + # 不再抛出异常,允许本地删除继续 + return False + except requests.exceptions.RequestException as e: + logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}") + # 不再抛出异常,允许本地删除继续 + return False + except Exception as e: + logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}") + # 不再抛出异常,允许本地删除继续 + return False + @action(detail=True, methods=['get']) def permissions(self, request, pk=None): """获取用户对特定知识库的权限""" @@ -2612,10 +2937,11 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): logger.info(f"调用分割API URL: {url}") logger.info(f"请求字段: {list(files_data.keys())}") - # 发送请求 - 移除timeout参数 + # 发送请求 response = requests.post( url, - files=files_data + files=files_data, + ) # 记录请求头和响应信息,方便排查问题 @@ -2845,7 +3171,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): knowledge_base=instance, document_id=document_id, document_name=doc_name, - external_id=document_id + external_id=document_id, + uploader_name=user.name ) saved_documents.append({ @@ -2908,7 +3235,7 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): # 记录请求数据,方便调试 logger.info(f"上传文档数据: 文档名={doc_data.get('name')}, 段落数={len(doc_data.get('paragraphs', []))}") - # 发送请求,不设置超时限制 + # 发送请求 response = requests.post(url, json=doc_data) # 记录响应结果 @@ -2985,7 +3312,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): response = requests.post( f'{settings.API_BASE_URL}/api/dataset', json=api_data, - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, + ) if response.status_code != 200: @@ -3001,6 +3329,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): return dataset_id + except requests.exceptions.Timeout: + raise ExternalAPIError("请求超时,请稍后重试") except requests.exceptions.RequestException as e: raise ExternalAPIError(f"API请求失败: {str(e)}") except Exception as e: @@ -3014,7 +3344,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): response = requests.delete( f'{settings.API_BASE_URL}/api/dataset/{external_id}', - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, + ) logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}") @@ -3043,6 +3374,9 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): logger.warning(f"外部知识库删除响应无法解析JSON,但状态码为200,视为成功: {external_id}") return True + except requests.exceptions.Timeout: + logger.error(f"删除外部知识库超时: {external_id}") + raise ExternalAPIError("请求超时,请稍后重试") except requests.exceptions.RequestException as e: logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}") raise ExternalAPIError(f"API请求失败: {str(e)}") @@ -3078,7 +3412,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}/document' response = requests.get( url, - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, + ) if response.status_code != 200: @@ -3138,7 +3473,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): # 添加外部API返回的额外信息 "char_length": next((d.get('char_length', 0) for d in external_documents if d.get('id') == doc.external_id), 0), "paragraph_count": next((d.get('paragraph_count', 0) for d in external_documents if d.get('id') == doc.external_id), 0), - "is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True) + "is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True), + "uploader_name": doc.uploader_name } for doc in documents] return Response({ @@ -4260,13 +4596,23 @@ class LoginView(APIView): class RegisterView(APIView): """用户注册视图""" permission_classes = [AllowAny] + authentication_classes = [] # 清空认证类 def post(self, request): try: + # 检查是否允许注册 + from django.conf import settings + if not getattr(settings, 'ALLOW_REGISTRATION', True): + return Response({ + "code": 403, + "message": "系统当前不允许注册新用户", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + data = request.data # 检查必填字段 - required_fields = ['username', 'password', 'email', 'role', 'department', 'name'] + required_fields = ['username', 'password', 'email', 'role', 'name'] for field in required_fields: if not data.get(field): return Response({ @@ -4285,32 +4631,6 @@ class RegisterView(APIView): "data": None }, status=status.HTTP_400_BAD_REQUEST) - # 验证部门是否存在 - if data['department'] not in settings.DEPARTMENT_GROUPS: - return Response({ - "code": 400, - "message": f"无效的部门,可选部门: {', '.join(settings.DEPARTMENT_GROUPS.keys())}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 如果是组员,验证小组 - if data['role'] == 'member': - if not data.get('group'): - return Response({ - "code": 400, - "message": "组员必须指定所属小组", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证小组是否存在且属于指定部门 - valid_groups = settings.DEPARTMENT_GROUPS.get(data['department'], []) - if data['group'] not in valid_groups: - return Response({ - "code": 400, - "message": f"无效的小组,{data['department']}的可选小组: {', '.join(valid_groups)}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - # 检查用户名是否已存在 if User.objects.filter(username=data['username']).exists(): return Response({ @@ -4351,9 +4671,9 @@ class RegisterView(APIView): email=data['email'], password=data['password'], role=data['role'], - department=data['department'], + department=data.get('department'), # 不再强制要求部门 name=data['name'], - group=data.get('group') if data['role'] == 'member' else None, + group=data.get('group'), # 不再强制要求小组 is_staff=False, is_superuser=False ) @@ -4577,6 +4897,7 @@ def change_password(request): "data": None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) +@csrf_exempt @api_view(['POST']) @permission_classes([AllowAny]) def user_register(request): @@ -4592,7 +4913,7 @@ def user_register(request): data = request.data # 检查必填字段 - required_fields = ['username', 'password', 'email', 'role', 'department', 'name'] + required_fields = ['username', 'password', 'email', 'role', 'name'] for field in required_fields: if not data.get(field): return Response({ @@ -4610,14 +4931,6 @@ def user_register(request): 'data': None }, status=status.HTTP_400_BAD_REQUEST) - # 如果是组员,必须指定小组 - if data['role'] == 'member' and not data.get('group'): - return Response({ - 'code': 400, - 'message': '组员必须指定所属小组', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - # 检查用户名是否已存在 if User.objects.filter(username=data['username']).exists(): return Response({ @@ -4658,9 +4971,9 @@ def user_register(request): email=data['email'], password=data['password'], role=data['role'], - department=data['department'], + department=data.get('department'), # 不再强制要求部门 name=data['name'], - group=data.get('group') if data['role'] == 'member' else None, + group=data.get('group'), # 不再强制要求小组 is_staff=False, is_superuser=False ) @@ -5014,4 +5327,3 @@ def user_list(request): 'message': f'获取用户列表失败: {str(e)}', 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) -