2
This commit is contained in:
parent
0b067f5d87
commit
139f7bb83f
501
apps/chat/consumers.py
Normal file
501
apps/chat/consumers.py
Normal file
@ -0,0 +1,501 @@
|
|||||||
|
# apps/chat/consumers.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
from apps.chat.models import ChatHistory
|
||||||
|
from apps.knowledge_base.models import KnowledgeBase
|
||||||
|
from django.conf import settings
|
||||||
|
import aiohttp
|
||||||
|
import uuid
|
||||||
|
from apps.common.services.permission_service import PermissionService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ChatStreamConsumer(AsyncWebsocketConsumer):
|
||||||
|
async def connect(self):
|
||||||
|
"""建立WebSocket连接"""
|
||||||
|
try:
|
||||||
|
# 从URL参数中获取token
|
||||||
|
query_string = self.scope.get('query_string', b'').decode()
|
||||||
|
query_params = parse_qs(query_string)
|
||||||
|
token_key = query_params.get('token', [''])[0]
|
||||||
|
|
||||||
|
if not token_key:
|
||||||
|
logger.warning("WebSocket连接尝试,但没有提供token")
|
||||||
|
await self.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 验证token
|
||||||
|
self.user = await self.get_user_from_token(token_key)
|
||||||
|
if not self.user:
|
||||||
|
logger.warning(f"WebSocket连接尝试,但token无效: {token_key}")
|
||||||
|
await self.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 将用户信息存储在scope中
|
||||||
|
self.scope["user"] = self.user
|
||||||
|
await self.accept()
|
||||||
|
logger.info(f"用户 {self.user.username} 流式输出WebSocket连接成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket连接错误: {str(e)}")
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_user_from_token(self, token_key):
|
||||||
|
try:
|
||||||
|
token = Token.objects.select_related('user').get(key=token_key)
|
||||||
|
return token.user
|
||||||
|
except Token.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def disconnect(self, close_code):
|
||||||
|
"""关闭WebSocket连接"""
|
||||||
|
logger.info(f"用户 {self.user.username if hasattr(self, 'user') else 'unknown'} WebSocket连接断开,代码: {close_code}")
|
||||||
|
|
||||||
|
async def receive(self, text_data):
|
||||||
|
"""接收消息并处理"""
|
||||||
|
try:
|
||||||
|
data = json.loads(text_data)
|
||||||
|
|
||||||
|
# 检查必填字段
|
||||||
|
if 'question' not in data:
|
||||||
|
await self.send_error("缺少必填字段: question")
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'conversation_id' not in data:
|
||||||
|
await self.send_error("缺少必填字段: conversation_id")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理新会话或现有会话
|
||||||
|
await self.process_chat_request(data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理消息时出错: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
await self.send_error(f"处理消息时出错: {str(e)}")
|
||||||
|
|
||||||
|
async def process_chat_request(self, data):
|
||||||
|
"""处理聊天请求"""
|
||||||
|
try:
|
||||||
|
conversation_id = data['conversation_id']
|
||||||
|
question = data['question']
|
||||||
|
|
||||||
|
# 获取会话信息和知识库
|
||||||
|
session_info = await self.get_session_info(data)
|
||||||
|
if not session_info:
|
||||||
|
return
|
||||||
|
|
||||||
|
knowledge_bases, metadata, dataset_external_id_list = session_info
|
||||||
|
|
||||||
|
# 创建问题记录
|
||||||
|
question_record = await self.create_question_record(
|
||||||
|
conversation_id,
|
||||||
|
question,
|
||||||
|
knowledge_bases,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if not question_record:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建AI回答记录
|
||||||
|
answer_record = await self.create_answer_record(
|
||||||
|
conversation_id,
|
||||||
|
question_record,
|
||||||
|
knowledge_bases,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送初始响应
|
||||||
|
await self.send_json({
|
||||||
|
'code': 200,
|
||||||
|
'message': '开始流式传输',
|
||||||
|
'data': {
|
||||||
|
'id': str(answer_record.id),
|
||||||
|
'conversation_id': str(conversation_id),
|
||||||
|
'content': '',
|
||||||
|
'is_end': False
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 调用外部API获取流式响应
|
||||||
|
await self.stream_from_external_api(
|
||||||
|
conversation_id,
|
||||||
|
question,
|
||||||
|
dataset_external_id_list,
|
||||||
|
answer_record,
|
||||||
|
metadata,
|
||||||
|
knowledge_bases
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理聊天请求时出错: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
await self.send_error(f"处理聊天请求时出错: {str(e)}")
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_session_info(self, data):
|
||||||
|
"""获取会话信息和知识库"""
|
||||||
|
try:
|
||||||
|
conversation_id = data['conversation_id']
|
||||||
|
|
||||||
|
# 查找该会话ID下的历史记录
|
||||||
|
existing_records = ChatHistory.objects.filter(
|
||||||
|
conversation_id=conversation_id
|
||||||
|
).order_by('created_at')
|
||||||
|
|
||||||
|
# 如果有历史记录,使用第一条记录的metadata
|
||||||
|
if existing_records.exists():
|
||||||
|
first_record = existing_records.first()
|
||||||
|
metadata = first_record.metadata or {}
|
||||||
|
|
||||||
|
# 获取知识库信息
|
||||||
|
dataset_ids = metadata.get('dataset_id_list', [])
|
||||||
|
external_id_list = metadata.get('dataset_external_id_list', [])
|
||||||
|
|
||||||
|
if not dataset_ids:
|
||||||
|
logger.error('找不到会话关联的知识库信息')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证知识库是否存在且用户有权限
|
||||||
|
knowledge_bases = []
|
||||||
|
for kb_id in dataset_ids:
|
||||||
|
try:
|
||||||
|
kb = KnowledgeBase.objects.get(id=kb_id)
|
||||||
|
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
|
||||||
|
logger.error(f'无权访问知识库: {kb.name}')
|
||||||
|
return None
|
||||||
|
knowledge_bases.append(kb)
|
||||||
|
except KnowledgeBase.DoesNotExist:
|
||||||
|
logger.error(f'知识库不存在: {kb_id}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not external_id_list or not knowledge_bases:
|
||||||
|
logger.error('会话关联的知识库信息不完整')
|
||||||
|
return None
|
||||||
|
|
||||||
|
return knowledge_bases, metadata, external_id_list
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 如果是新会话的第一条记录,需要提供知识库ID
|
||||||
|
dataset_ids = []
|
||||||
|
if 'dataset_id' in data:
|
||||||
|
dataset_ids.append(str(data['dataset_id']))
|
||||||
|
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
|
||||||
|
if isinstance(data['dataset_id_list'], str):
|
||||||
|
try:
|
||||||
|
dataset_list = json.loads(data['dataset_id_list'])
|
||||||
|
if isinstance(dataset_list, list):
|
||||||
|
dataset_ids = [str(id) for id in dataset_list]
|
||||||
|
else:
|
||||||
|
dataset_ids = [str(data['dataset_id_list'])]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
dataset_ids = [str(data['dataset_id_list'])]
|
||||||
|
else:
|
||||||
|
dataset_ids = [str(id) for id in data['dataset_id_list']]
|
||||||
|
|
||||||
|
if not dataset_ids:
|
||||||
|
logger.error('新会话需要提供知识库ID')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证所有知识库并收集external_ids
|
||||||
|
external_id_list = []
|
||||||
|
knowledge_bases = []
|
||||||
|
|
||||||
|
for kb_id in dataset_ids:
|
||||||
|
try:
|
||||||
|
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
|
||||||
|
if not knowledge_base:
|
||||||
|
logger.error(f'知识库不存在: {kb_id}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
knowledge_bases.append(knowledge_base)
|
||||||
|
|
||||||
|
# 使用统一的权限检查方法
|
||||||
|
if not self.check_knowledge_base_permission(knowledge_base, self.scope["user"], 'read'):
|
||||||
|
logger.error(f'无权访问知识库: {knowledge_base.name}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 添加知识库的external_id到列表
|
||||||
|
if knowledge_base.external_id:
|
||||||
|
external_id_list.append(str(knowledge_base.external_id))
|
||||||
|
else:
|
||||||
|
logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理知识库ID出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not external_id_list:
|
||||||
|
logger.error('没有有效的知识库external_id')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 创建metadata
|
||||||
|
metadata = {
|
||||||
|
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
|
||||||
|
'dataset_id_list': [str(id) for id in dataset_ids],
|
||||||
|
'dataset_external_id_list': [str(id) for id in external_id_list],
|
||||||
|
'dataset_names': [kb.name for kb in knowledge_bases]
|
||||||
|
}
|
||||||
|
|
||||||
|
return knowledge_bases, metadata, external_id_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取会话信息时出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_knowledge_base_permission(self, kb, user, permission_type):
|
||||||
|
"""检查知识库权限"""
|
||||||
|
# 实现权限检查逻辑
|
||||||
|
return True # 临时返回 True,需要根据实际情况实现
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def create_question_record(self, conversation_id, question, knowledge_bases, metadata):
|
||||||
|
"""创建问题记录"""
|
||||||
|
try:
|
||||||
|
title = metadata.get('title', 'New chat')
|
||||||
|
|
||||||
|
# 创建用户问题记录
|
||||||
|
return ChatHistory.objects.create(
|
||||||
|
user=self.scope["user"],
|
||||||
|
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
|
||||||
|
conversation_id=str(conversation_id),
|
||||||
|
title=title,
|
||||||
|
role='user',
|
||||||
|
content=question,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建问题记录时出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def create_answer_record(self, conversation_id, question_record, knowledge_bases, metadata):
|
||||||
|
"""创建AI回答记录"""
|
||||||
|
try:
|
||||||
|
return ChatHistory.objects.create(
|
||||||
|
user=self.scope["user"],
|
||||||
|
knowledge_base=knowledge_bases[0],
|
||||||
|
conversation_id=str(conversation_id),
|
||||||
|
title=question_record.title,
|
||||||
|
parent_id=str(question_record.id),
|
||||||
|
role='assistant',
|
||||||
|
content="", # 初始内容为空
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建回答记录时出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases):
|
||||||
|
"""从外部API获取流式响应"""
|
||||||
|
try:
|
||||||
|
# 确保所有ID都是字符串
|
||||||
|
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
|
||||||
|
|
||||||
|
# 获取标题
|
||||||
|
title = answer_record.title or 'New chat'
|
||||||
|
|
||||||
|
# 异步收集完整内容,用于最后保存
|
||||||
|
full_content = ""
|
||||||
|
|
||||||
|
# 使用aiohttp进行异步HTTP请求
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# 第一步: 创建聊天会话
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.API_BASE_URL}/api/application/chat/open",
|
||||||
|
json={
|
||||||
|
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
|
||||||
|
"model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
|
||||||
|
"dataset_id_list": dataset_external_ids,
|
||||||
|
"multiple_rounds_dialogue": False,
|
||||||
|
"dataset_setting": {
|
||||||
|
"top_n": 10, "similarity": "0.3",
|
||||||
|
"max_paragraph_char_number": 10000,
|
||||||
|
"search_mode": "blend",
|
||||||
|
"no_references_setting": {
|
||||||
|
"value": "{question}",
|
||||||
|
"status": "ai_questioning"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"model_setting": {
|
||||||
|
"prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}"
|
||||||
|
},
|
||||||
|
"problem_optimization": False
|
||||||
|
}
|
||||||
|
) as chat_response:
|
||||||
|
|
||||||
|
if chat_response.status != 200:
|
||||||
|
error_msg = f"外部API调用失败: {await chat_response.text()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
await self.send_error(error_msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_data = await chat_response.json()
|
||||||
|
if chat_data.get('code') != 200 or not chat_data.get('data'):
|
||||||
|
error_msg = f"外部API返回错误: {chat_data}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
await self.send_error(error_msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_id = chat_data['data']
|
||||||
|
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
|
||||||
|
|
||||||
|
# 第二步: 建立流式连接
|
||||||
|
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
|
||||||
|
logger.info(f"开始流式请求: {message_url}")
|
||||||
|
|
||||||
|
# 创建流式请求
|
||||||
|
async with session.post(
|
||||||
|
url=message_url,
|
||||||
|
json={"message": question, "re_chat": False, "stream": True},
|
||||||
|
headers={"Content-Type": "application/json"}
|
||||||
|
) as message_request:
|
||||||
|
|
||||||
|
if message_request.status != 200:
|
||||||
|
error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
await self.send_error(error_msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建一个缓冲区以处理分段的数据
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# 读取并处理每个响应块
|
||||||
|
logger.info("开始处理流式响应")
|
||||||
|
async for chunk in message_request.content.iter_any():
|
||||||
|
chunk_str = chunk.decode('utf-8')
|
||||||
|
buffer += chunk_str
|
||||||
|
|
||||||
|
# 检查是否有完整的数据行
|
||||||
|
while '\n\n' in buffer:
|
||||||
|
parts = buffer.split('\n\n', 1)
|
||||||
|
line = parts[0]
|
||||||
|
buffer = parts[1]
|
||||||
|
|
||||||
|
if line.startswith('data: '):
|
||||||
|
try:
|
||||||
|
# 提取JSON数据
|
||||||
|
json_str = line[6:] # 去掉 "data: " 前缀
|
||||||
|
data = json.loads(json_str)
|
||||||
|
|
||||||
|
# 记录并处理部分响应
|
||||||
|
if 'content' in data:
|
||||||
|
content_part = data['content']
|
||||||
|
full_content += content_part
|
||||||
|
|
||||||
|
# 发送部分内容
|
||||||
|
await self.send_json({
|
||||||
|
'code': 200,
|
||||||
|
'message': 'partial',
|
||||||
|
'data': {
|
||||||
|
'id': str(answer_record.id),
|
||||||
|
'conversation_id': str(conversation_id),
|
||||||
|
'content': content_part,
|
||||||
|
'is_end': data.get('is_end', False)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 处理结束标记
|
||||||
|
if data.get('is_end', False):
|
||||||
|
logger.info("收到流式响应结束标记")
|
||||||
|
# 保存完整内容
|
||||||
|
await self.update_answer_content(answer_record.id, full_content.strip())
|
||||||
|
|
||||||
|
# 处理标题
|
||||||
|
title = await self.get_or_generate_title(
|
||||||
|
conversation_id,
|
||||||
|
question,
|
||||||
|
full_content.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送最终响应
|
||||||
|
await self.send_json({
|
||||||
|
'code': 200,
|
||||||
|
'message': '完成',
|
||||||
|
'data': {
|
||||||
|
'id': str(answer_record.id),
|
||||||
|
'conversation_id': str(conversation_id),
|
||||||
|
'title': title,
|
||||||
|
'dataset_id_list': metadata.get('dataset_id_list', []),
|
||||||
|
'dataset_names': metadata.get('dataset_names', []),
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': full_content.strip(),
|
||||||
|
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
|
'is_end': True
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON解析错误: {e}, 数据: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式处理出错: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
await self.send_error(str(e))
|
||||||
|
|
||||||
|
# 保存已收集的内容
|
||||||
|
if 'full_content' in locals() and full_content:
|
||||||
|
try:
|
||||||
|
await self.update_answer_content(answer_record.id, full_content.strip())
|
||||||
|
except Exception as save_error:
|
||||||
|
logger.error(f"保存部分内容失败: {str(save_error)}")
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def update_answer_content(self, answer_id, content):
|
||||||
|
"""更新回答内容"""
|
||||||
|
try:
|
||||||
|
answer_record = ChatHistory.objects.get(id=answer_id)
|
||||||
|
answer_record.content = content
|
||||||
|
answer_record.save()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新回答内容失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_or_generate_title(self, conversation_id, question, answer):
|
||||||
|
"""获取或生成对话标题"""
|
||||||
|
try:
|
||||||
|
# 先检查是否已有标题
|
||||||
|
current_title = ChatHistory.objects.filter(
|
||||||
|
conversation_id=str(conversation_id)
|
||||||
|
).exclude(
|
||||||
|
title__in=["New chat", "新对话", ""]
|
||||||
|
).values_list('title', flat=True).first()
|
||||||
|
|
||||||
|
if current_title:
|
||||||
|
return current_title
|
||||||
|
|
||||||
|
# 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题)
|
||||||
|
generated_title = question[:20] + "..." if len(question) > 20 else question
|
||||||
|
|
||||||
|
# 更新所有相关记录的标题
|
||||||
|
ChatHistory.objects.filter(
|
||||||
|
conversation_id=str(conversation_id)
|
||||||
|
).update(title=generated_title)
|
||||||
|
|
||||||
|
return generated_title
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取或生成标题失败: {str(e)}")
|
||||||
|
return "新对话"
|
||||||
|
|
||||||
|
async def send_json(self, content):
|
||||||
|
"""发送JSON格式的消息"""
|
||||||
|
await self.send(text_data=json.dumps(content))
|
||||||
|
|
||||||
|
async def send_error(self, message):
|
||||||
|
"""发送错误消息"""
|
||||||
|
await self.send_json({
|
||||||
|
'code': 500,
|
||||||
|
'message': message,
|
||||||
|
'data': {'is_end': True}
|
||||||
|
})
|
8
apps/chat/routing.py
Normal file
8
apps/chat/routing.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# apps/chat/routing.py
|
||||||
|
from django.urls import re_path
|
||||||
|
from apps.chat.consumers import ChatConsumer, ChatStreamConsumer
|
||||||
|
|
||||||
|
websocket_urlpatterns = [
|
||||||
|
re_path(r'ws/chat/$', ChatConsumer.as_asgi()),
|
||||||
|
re_path(r'ws/chat/stream/$', ChatStreamConsumer.as_asgi()),
|
||||||
|
]
|
46
apps/common/middlewares.py
Normal file
46
apps/common/middlewares.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from django.db import close_old_connections
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
from channels.middleware import BaseMiddleware
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_user_from_token(token_key):
|
||||||
|
try:
|
||||||
|
token = Token.objects.select_related('user').get(key=token_key)
|
||||||
|
return token.user
|
||||||
|
except Token.DoesNotExist:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户Token失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
class TokenAuthMiddleware(BaseMiddleware):
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
# 关闭之前的数据库连接
|
||||||
|
close_old_connections()
|
||||||
|
|
||||||
|
# 从查询字符串中提取token
|
||||||
|
query_string = scope.get('query_string', b'').decode()
|
||||||
|
query_params = parse_qs(query_string)
|
||||||
|
token_key = query_params.get('token', [''])[0]
|
||||||
|
|
||||||
|
if token_key:
|
||||||
|
user = await get_user_from_token(token_key)
|
||||||
|
if user:
|
||||||
|
scope['user'] = user
|
||||||
|
logger.info(f"WebSocket认证成功: 用户 {user.id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"WebSocket认证失败: 无效的Token {token_key}")
|
||||||
|
scope['user'] = None
|
||||||
|
else:
|
||||||
|
logger.warning("WebSocket连接未提供Token")
|
||||||
|
scope['user'] = None
|
||||||
|
|
||||||
|
return await super().__call__(scope, receive, send)
|
||||||
|
|
||||||
|
def TokenAuthMiddlewareStack(inner):
|
||||||
|
return TokenAuthMiddleware(inner)
|
@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from channels.layers import get_channel_layer
|
from channels.layers import get_channel_layer
|
||||||
from apps.notifications.models import Notification
|
from apps.message.models import Notification
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
72
apps/message/consumers.py
Normal file
72
apps/message/consumers.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# apps/message/consumers.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
|
from channels.db import database_sync_to_async
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class NotificationConsumer(AsyncWebsocketConsumer):
|
||||||
|
async def connect(self):
|
||||||
|
"""建立WebSocket连接"""
|
||||||
|
try:
|
||||||
|
# 从URL参数中获取token
|
||||||
|
query_string = self.scope.get('query_string', b'').decode()
|
||||||
|
query_params = parse_qs(query_string)
|
||||||
|
token_key = query_params.get('token', [''])[0]
|
||||||
|
|
||||||
|
if not token_key:
|
||||||
|
logger.warning("WebSocket连接尝试,但没有提供token")
|
||||||
|
await self.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 验证token
|
||||||
|
self.user = await self.get_user_from_token(token_key)
|
||||||
|
if not self.user:
|
||||||
|
logger.warning(f"WebSocket连接尝试,但token无效: {token_key}")
|
||||||
|
await self.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 为用户创建专属房间
|
||||||
|
self.room_name = f"notification_user_{self.user.id}"
|
||||||
|
await self.channel_layer.group_add(
|
||||||
|
self.room_name,
|
||||||
|
self.channel_name
|
||||||
|
)
|
||||||
|
await self.accept()
|
||||||
|
logger.info(f"用户 {self.user.username} WebSocket连接成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket连接错误: {str(e)}")
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_user_from_token(self, token_key):
|
||||||
|
try:
|
||||||
|
token = Token.objects.select_related('user').get(key=token_key)
|
||||||
|
return token.user
|
||||||
|
except Token.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def disconnect(self, close_code):
|
||||||
|
"""断开WebSocket连接"""
|
||||||
|
try:
|
||||||
|
if hasattr(self, 'room_name'):
|
||||||
|
await self.channel_layer.group_discard(
|
||||||
|
self.room_name,
|
||||||
|
self.channel_name
|
||||||
|
)
|
||||||
|
logger.info(f"用户 {self.user.username} 已断开连接,关闭代码: {close_code}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"断开连接时发生错误: {str(e)}")
|
||||||
|
|
||||||
|
async def notification(self, event):
|
||||||
|
"""处理并发送通知消息"""
|
||||||
|
try:
|
||||||
|
await self.send(text_data=json.dumps(event))
|
||||||
|
logger.info(f"已发送通知给用户 {self.user.username}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送通知消息时发生错误: {str(e)}")
|
||||||
|
|
@ -0,0 +1,11 @@
|
|||||||
|
# apps/message/routing.py
|
||||||
|
from django.urls import re_path
|
||||||
|
from apps.message.consumers import NotificationConsumer
|
||||||
|
from apps.chat.consumers import ChatStreamConsumer # 直接导入已有的ChatStreamConsumer
|
||||||
|
import logging
|
||||||
|
|
||||||
|
websocket_urlpatterns = [
|
||||||
|
re_path(r'^ws/notifications/$', NotificationConsumer.as_asgi()),
|
||||||
|
re_path(r'^ws/chat/stream/$', ChatStreamConsumer.as_asgi()),
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,18 @@
|
|||||||
|
# apps/message/serializers.py
|
||||||
|
from rest_framework import serializers
|
||||||
|
from apps.message.models import Notification
|
||||||
|
from apps.accounts.models import User
|
||||||
|
|
||||||
|
class NotificationSerializer(serializers.ModelSerializer):
|
||||||
|
sender = serializers.PrimaryKeyRelatedField(queryset=User.objects.all())
|
||||||
|
receiver = serializers.PrimaryKeyRelatedField(read_only=True)
|
||||||
|
created_at = serializers.DateTimeField(format='%Y-%m-%d %H:%M:%S', read_only=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Notification
|
||||||
|
fields = [
|
||||||
|
'id', 'sender', 'receiver', 'title', 'content', 'type',
|
||||||
|
'related_resource', 'is_read', 'created_at'
|
||||||
|
]
|
||||||
|
read_only_fields = ['id', 'receiver', 'created_at', 'is_read']
|
||||||
|
|
@ -0,0 +1,11 @@
|
|||||||
|
# apps/message/urls.py
|
||||||
|
from django.urls import path, include
|
||||||
|
from rest_framework.routers import DefaultRouter
|
||||||
|
from apps.message.views import NotificationViewSet
|
||||||
|
|
||||||
|
router = DefaultRouter()
|
||||||
|
router.register(r'', NotificationViewSet, basename='notification')
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path('', include(router.urls)),
|
||||||
|
]
|
@ -1,3 +1,51 @@
|
|||||||
from django.shortcuts import render
|
# apps/message/views.py
|
||||||
|
from rest_framework import viewsets, status
|
||||||
|
from rest_framework.permissions import IsAuthenticated
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.decorators import action
|
||||||
|
from apps.message.models import Notification
|
||||||
|
from apps.message.serializers import NotificationSerializer
|
||||||
|
|
||||||
|
class NotificationViewSet(viewsets.ModelViewSet):
|
||||||
|
"""通知视图集"""
|
||||||
|
queryset = Notification.objects.all()
|
||||||
|
serializer_class = NotificationSerializer
|
||||||
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
"""只返回用户自己的通知"""
|
||||||
|
return Notification.objects.filter(receiver=self.request.user)
|
||||||
|
|
||||||
|
@action(detail=True, methods=['post'])
|
||||||
|
def mark_as_read(self, request, pk=None):
|
||||||
|
"""标记通知为已读"""
|
||||||
|
notification = self.get_object()
|
||||||
|
notification.is_read = True
|
||||||
|
notification.save()
|
||||||
|
return Response({'status': 'marked as read'})
|
||||||
|
|
||||||
|
@action(detail=False, methods=['post'])
|
||||||
|
def mark_all_as_read(self, request):
|
||||||
|
"""标记所有通知为已读"""
|
||||||
|
self.get_queryset().update(is_read=True)
|
||||||
|
return Response({'status': 'all marked as read'})
|
||||||
|
|
||||||
|
@action(detail=False, methods=['get'])
|
||||||
|
def unread_count(self, request):
|
||||||
|
"""获取未读通知数量"""
|
||||||
|
count = self.get_queryset().filter(is_read=False).count()
|
||||||
|
return Response({'unread_count': count})
|
||||||
|
|
||||||
|
@action(detail=False, methods=['get'])
|
||||||
|
def latest(self, request):
|
||||||
|
"""获取最新通知"""
|
||||||
|
notifications = self.get_queryset().filter(
|
||||||
|
is_read=False
|
||||||
|
).order_by('-created_at')[:5]
|
||||||
|
serializer = self.get_serializer(notifications, many=True)
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
def perform_create(self, serializer):
|
||||||
|
"""创建通知时自动设置发送者"""
|
||||||
|
serializer.save(sender=self.request.user)
|
||||||
|
|
||||||
# Create your views here.
|
|
||||||
|
@ -11,15 +11,16 @@ import os
|
|||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||||
from channels.auth import AuthMiddlewareStack
|
from channels.auth import AuthMiddlewareStack
|
||||||
import apps.message.routing # WebSocket 路由
|
from apps.message.routing import websocket_urlpatterns # WebSocket 路由
|
||||||
|
|
||||||
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'daren_project.settings')
|
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'daren_project.settings')
|
||||||
|
|
||||||
application = ProtocolTypeRouter({
|
application = ProtocolTypeRouter({
|
||||||
'http': get_asgi_application(),
|
"http": get_asgi_application(),
|
||||||
'websocket': AuthMiddlewareStack(
|
"websocket": AuthMiddlewareStack(
|
||||||
URLRouter(
|
URLRouter(
|
||||||
apps.message.routing.websocket_urlpatterns
|
websocket_urlpatterns
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
@ -155,10 +155,7 @@ REST_FRAMEWORK = {
|
|||||||
ASGI_APPLICATION = 'daren_project.asgi.application'
|
ASGI_APPLICATION = 'daren_project.asgi.application'
|
||||||
CHANNEL_LAYERS = {
|
CHANNEL_LAYERS = {
|
||||||
'default': {
|
'default': {
|
||||||
'BACKEND': 'channels_redis.core.RedisChannelLayer',
|
'BACKEND': 'channels.layers.InMemoryChannelLayer', # 使用内存通道层,无需 Redis
|
||||||
'CONFIG': {
|
|
||||||
'hosts': [('127.0.0.1', 6379)],
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ urlpatterns = [
|
|||||||
path('api/knowledge-bases/', include('apps.knowledge_base.urls')),
|
path('api/knowledge-bases/', include('apps.knowledge_base.urls')),
|
||||||
path('api/chat-history/', include('apps.chat.urls')),
|
path('api/chat-history/', include('apps.chat.urls')),
|
||||||
path('api/permissions/', include('apps.permissions.urls')),
|
path('api/permissions/', include('apps.permissions.urls')),
|
||||||
# path('api/message/', include('apps.message.urls')),
|
path('api/message/', include('apps.message.urls')),
|
||||||
# path('api/gmail/', include('apps.gmail.urls')),
|
# path('api/gmail/', include('apps.gmail.urls')),
|
||||||
# path('api/feishu/', include('apps.feishu.urls')),
|
# path('api/feishu/', include('apps.feishu.urls')),
|
||||||
]
|
]
|
BIN
requirements.txt
Normal file
BIN
requirements.txt
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user