daren_project/user_management/views.py
2025-04-29 10:22:57 +08:00

7447 lines
298 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from rest_framework import viewsets, status
from rest_framework.decorators import action, api_view, permission_classes
from rest_framework.permissions import IsAuthenticated, AllowAny, IsAdminUser
from rest_framework.response import Response
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError, NotFound
from rest_framework.authentication import TokenAuthentication
from django.utils import timezone
from django.db import connection
from django.db.models import Q, Max, Count, F
from datetime import timedelta, datetime
import mysql.connector
from django.contrib.auth import get_user_model, authenticate, login, logout
from channels.layers import get_channel_layer
from asgiref.sync import async_to_sync
from rest_framework.authtoken.models import Token
import requests
import json
from django.db import transaction
from django.core.exceptions import ObjectDoesNotExist
import sys
import random
import string
import time
import logging
import os
from rest_framework.test import APIRequestFactory
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.fields import GenericForeignKey
from django.http import Http404, HttpResponse, StreamingHttpResponse, FileResponse
from django.db import IntegrityError
from channels.exceptions import ChannelFull
from django.conf import settings
from django.shortcuts import get_object_or_404
from django.db import models
from rest_framework.views import APIView
from django.core.validators import validate_email
# from django.core.exceptions import ValidationError
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
import uuid
from rest_framework import serializers
import traceback
import requests
import json
import threading
import re
# 添加模型导入
from .models import (
User,
Data, # 替换原来的 AdminData, LeaderData, MemberData
Permission, # 替换原来的 DataPermission, TablePermission
ChatHistory,
KnowledgeBase,
Notification,
KnowledgeBasePermission as KBPermissionModel,
KnowledgeBaseDocument,
GmailCredential,
GmailTalentMapping,
GmailAttachment,
UserProfile
)
from .serializers import (
UserSerializer,
DataSerializer, # 需要更新
PermissionSerializer, # 需要更新
ChatHistorySerializer,
KnowledgeBaseSerializer,
KnowledgePermissionSerializer, # 添加这个导入
NotificationSerializer
)
# 导入自定义权限类
from .permissions import ResourceCRUDPermission, PermissionRequestPermission, DataPermission, KnowledgeBasePermission as KBPermissionClass
from .exceptions import ExternalAPIError
# 获取正确的用户模型
User = get_user_model()
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.StreamHandler() # 输出到控制台
]
)
class KnowledgeBasePermissionMixin:
"""知识库权限管理混入类"""
def _can_read(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None):
"""检查读取权限"""
try:
# 1. 检查显式权限表
if knowledge_base_id:
permission = KBPermissionModel.objects.filter(
knowledge_base_id=knowledge_base_id,
user=user,
can_read=True,
status='active'
).first()
if permission:
return True
# 2. 检查角色权限
# 私有知识库
if type == 'private':
return str(user.id) == str(creator_id)
# 成员级知识库
if type == 'member':
return user.department == department
# 部门级知识库
if type == 'leader':
return (user.department == department and
user.role in ['leader', 'admin'])
# 管理级知识库
if type == 'admin':
return True # 所有用户都可以读取
return False
except Exception as e:
logger.error(f"检查读取权限时出错: {str(e)}")
return False
def _can_edit(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None):
"""检查编辑权限"""
try:
# 1. 检查显式权限表
if knowledge_base_id:
permission = KBPermissionModel.objects.filter(
knowledge_base_id=knowledge_base_id,
user=user,
can_edit=True,
status='active'
).first()
if permission:
return True
# 2. 检查角色权限
# 私有知识库
if type == 'private':
return str(user.id) == str(creator_id)
# 成员级知识库
if type == 'member':
return (user.department == department and
user.role in ['leader', 'admin'])
# 部门级知识库
if type == 'leader':
return (user.department == department and
user.role in ['leader', 'admin'])
# 管理级知识库
if type == 'admin':
return True # 所有用户都可以编辑
return False
except Exception as e:
logger.error(f"检查编辑权限时出错: {str(e)}")
return False
def _can_delete(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None):
"""检查删除权限"""
try:
# 1. 检查显式权限表
if knowledge_base_id:
permission = KBPermissionModel.objects.filter(
knowledge_base_id=knowledge_base_id,
user=user,
can_delete=True,
status='active'
).first()
if permission:
return True
# 2. 检查角色权限
# 私有知识库
if type == 'private':
return str(user.id) == str(creator_id)
# 成员级知识库
if type == 'member':
return (user.department == department and
user.role == 'admin')
# 部门级知识库
if type == 'leader':
return (user.department == department and
user.role == 'admin')
# 管理级知识库
if type == 'admin':
return True # 所有用户都可以删除
return False
except Exception as e:
logger.error(f"检查删除权限时出错: {str(e)}")
return False
def check_knowledge_base_permission(self, knowledge_base, user, required_permission='read'):
"""统一的知识库权限检查方法"""
if not knowledge_base:
return False
# 1. 首先检查显式权限表
try:
# 检查是否存在显式权限记录
permission = KBPermissionModel.objects.filter(
knowledge_base_id=knowledge_base.id,
user=user,
status='active'
).first()
if permission:
# 根据请求的权限类型返回对应的权限值
if required_permission == 'read':
return permission.can_read
elif required_permission == 'edit':
return permission.can_edit
elif required_permission == 'delete':
return permission.can_delete
except Exception as e:
logger.error(f"检查显式权限时出错: {str(e)}")
# 2. 如果没有显式权限记录或出错,回退到隐式权限逻辑
permission_method = {
'read': self._can_read,
'edit': self._can_edit,
'delete': self._can_delete
}.get(required_permission)
if not permission_method:
return False
return permission_method(
type=knowledge_base.type,
user=user,
department=knowledge_base.department,
group=knowledge_base.group,
creator_id=knowledge_base.user_id,
knowledge_base_id=knowledge_base.id
)
class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
permission_classes = [IsAuthenticated]
queryset = ChatHistory.objects.all()
def get_queryset(self):
"""确保用户只能看到自己的未删除的聊天记录以及有权限的知识库关联的聊天记录"""
user = self.request.user
# 当前用户的聊天记录
user_records = ChatHistory.objects.filter(
user=user,
is_deleted=False
)
# 获取用户有权限的知识库ID列表
accessible_kb_ids = []
for kb in KnowledgeBase.objects.all():
if self.check_knowledge_base_permission(kb, user, 'read'):
accessible_kb_ids.append(kb.id)
# 其他用户创建的、但当前用户有权限访问的知识库的聊天记录
others_records = ChatHistory.objects.filter(
knowledge_base_id__in=accessible_kb_ids,
is_deleted=False
).exclude(user=user) # 排除用户自己的记录,避免重复
# 合并两个查询集
combined_queryset = user_records | others_records
return combined_queryset
def list(self, request):
"""获取对话列表概览"""
try:
# 获取查询参数
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
# 获取所有对话的概览
latest_chats = self.get_queryset().values(
'conversation_id'
).annotate(
latest_id=Max('id'),
message_count=Count('id'),
last_message=Max('created_at')
).order_by('-last_message')
# 计算分页
total = latest_chats.count()
start = (page - 1) * page_size
end = start + page_size
chats = latest_chats[start:end]
results = []
for chat in chats:
# 获取最新消息记录
latest_record = ChatHistory.objects.get(id=chat['latest_id'])
# 从metadata中获取完整的知识库信息
dataset_info = []
if latest_record.metadata:
dataset_id_list = latest_record.metadata.get('dataset_id_list', [])
dataset_names = latest_record.metadata.get('dataset_names', [])
# 如果有知识库ID列表
if dataset_id_list:
# 如果同时有名称列表且长度匹配
if dataset_names and len(dataset_names) == len(dataset_id_list):
dataset_info = [{
'id': str(id),
'name': name
} for id, name in zip(dataset_id_list, dataset_names)]
else:
# 如果没有名称列表则只返回ID
datasets = KnowledgeBase.objects.filter(id__in=dataset_id_list)
dataset_info = [{
'id': str(ds.id),
'name': ds.name
} for ds in datasets]
results.append({
'conversation_id': chat['conversation_id'],
'message_count': chat['message_count'],
'last_message': latest_record.content,
'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S'),
'dataset_id_list': [ds['id'] for ds in dataset_info], # 添加完整的知识库ID列表
'datasets': dataset_info # 包含ID和名称的完整信息
})
return Response({
'code': 200,
'message': '获取成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'results': results
}
})
except Exception as e:
logger.error(f"获取聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def conversation_detail(self, request):
"""获取特定对话的详细信息"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取对话历史,确保按时间顺序排序
messages = self.get_queryset().filter(
conversation_id=conversation_id
).order_by('created_at')
if not messages.exists():
return Response({
'code': 404,
'message': '对话不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 获取知识库信息
first_message = messages.first()
dataset_info = []
if first_message and first_message.metadata:
if 'dataset_id_list' in first_message.metadata:
datasets = KnowledgeBase.objects.filter(
id__in=first_message.metadata['dataset_id_list']
)
# 过滤出用户有权限访问的知识库
accessible_datasets = [
ds for ds in datasets
if self.check_knowledge_base_permission(ds, request.user, 'read')
]
dataset_info = [{
'id': str(ds.id),
'name': ds.name,
'type': ds.type
} for ds in accessible_datasets]
# 构建消息列表包含parent_id信息
message_list = []
for msg in messages:
message_data = {
'id': str(msg.id),
'parent_id': msg.parent_id, # 添加parent_id
'role': msg.role,
'content': msg.content,
'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'metadata': msg.metadata # 添加metadata
}
message_list.append(message_data)
return Response({
'code': 200,
'message': '获取成功',
'data': {
'conversation_id': conversation_id,
'datasets': dataset_info,
'messages': message_list
}
})
except Exception as e:
logger.error(f"获取对话详情失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取对话详情失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def available_datasets(self, request):
"""获取用户可访问的知识库列表"""
try:
user = request.user
all_datasets = KnowledgeBase.objects.all()
# 使用统一的权限检查方法
accessible_datasets = [
dataset for dataset in all_datasets
if self.check_knowledge_base_permission(dataset, user, 'read')
]
return Response({
'code': 200,
'message': '获取成功',
'data': [{
'id': str(ds.id),
'name': ds.name,
'type': ds.type,
'department': ds.department,
'description': ds.desc
} for ds in accessible_datasets]
})
except Exception as e:
logger.error(f"获取可用知识库列表失败: {str(e)}")
return Response({
'code': 500,
'message': f'获取可用知识库列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['post'])
def create_conversation(self, request):
"""创建会话 - 先选择知识库创建会话ID不发送问题"""
try:
data = request.data
# 检查知识库ID支持dataset_id或dataset_id_list格式
dataset_ids = []
if 'dataset_id' in data:
dataset_id = data['dataset_id']
# 直接使用标准UUID格式
dataset_ids.append(str(dataset_id))
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
# 处理可能的字符串格式
if isinstance(data['dataset_id_list'], str):
try:
# 尝试解析JSON字符串
dataset_list = json.loads(data['dataset_id_list'])
if isinstance(dataset_list, list):
dataset_ids = [str(id) for id in dataset_list]
except json.JSONDecodeError:
# 如果解析失败可能是单个ID
dataset_ids = [str(data['dataset_id_list'])]
else:
# 如果已经是列表直接使用标准UUID格式
dataset_ids = [str(id) for id in data['dataset_id_list']]
else:
return Response({
'code': 400,
'message': '缺少必填字段: dataset_id 或 dataset_id_list',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if not dataset_ids:
return Response({
'code': 400,
'message': '至少需要提供一个知识库ID',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证所有知识库
user = request.user
knowledge_bases = [] # 存储所有知识库对象
for kb_id in dataset_ids:
try:
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
if not knowledge_base:
return Response({
'code': 404,
'message': f'知识库不存在: {kb_id}',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
knowledge_bases.append(knowledge_base)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
return Response({
'code': 403,
'message': f'无权访问知识库: {knowledge_base.name}',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
except Exception as e:
return Response({
'code': 400,
'message': f'处理知识库ID出错: {str(e)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建一个新的会话ID
conversation_id = str(uuid.uuid4())
logger.info(f"创建新的会话ID: {conversation_id}")
# 准备metadata (仍然保存知识库名称用于内部处理)
metadata = {
'dataset_id_list': [str(id) for id in dataset_ids],
'dataset_names': [kb.name for kb in knowledge_bases]
}
return Response({
'code': 200,
'message': '会话创建成功',
'data': {
'conversation_id': conversation_id,
'dataset_id_list': metadata['dataset_id_list']
}
})
except Exception as e:
logger.error(f"创建会话失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'创建会话失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def create(self, request):
"""创建聊天记录"""
try:
data = request.data
# 检查必填字段
if 'question' not in data:
return Response({
'code': 400,
'message': '缺少必填字段: question',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if 'conversation_id' not in data:
return Response({
'code': 400,
'message': '缺少必填字段: conversation_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
conversation_id = data['conversation_id']
# 查找该会话ID下的历史记录获取知识库信息
existing_records = ChatHistory.objects.filter(
conversation_id=conversation_id
).order_by('created_at')
# 如果有历史记录使用第一条记录的metadata
if existing_records.exists():
first_record = existing_records.first()
metadata = first_record.metadata or {}
# 获取知识库信息
dataset_ids = metadata.get('dataset_id_list', [])
external_id_list = metadata.get('dataset_external_id_list', [])
# 验证知识库是否存在且用户有权限
knowledge_bases = []
if not dataset_ids:
return Response({
'code': 400,
'message': '找不到会话关联的知识库信息',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, request.user, 'read'):
return Response({
'code': 403,
'message': f'无权访问知识库: {kb.name}',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
knowledge_bases.append(kb)
except KnowledgeBase.DoesNotExist:
return Response({
'code': 404,
'message': f'知识库不存在: {kb_id}',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
if not external_id_list or not knowledge_bases:
return Response({
'code': 400,
'message': '会话关联的知识库信息不完整',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
else:
# 如果是新会话的第一条记录需要提供知识库ID
dataset_ids = []
if 'dataset_id' in data:
dataset_ids.append(str(data['dataset_id']))
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
if isinstance(data['dataset_id_list'], str):
try:
dataset_list = json.loads(data['dataset_id_list'])
if isinstance(dataset_list, list):
dataset_ids = [str(id) for id in dataset_list]
else:
dataset_ids = [str(data['dataset_id_list'])]
except json.JSONDecodeError:
dataset_ids = [str(data['dataset_id_list'])]
else:
dataset_ids = [str(id) for id in data['dataset_id_list']]
if not dataset_ids:
return Response({
'code': 400,
'message': '新会话需要提供知识库ID',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证所有知识库并收集external_ids
external_id_list = []
knowledge_bases = []
for kb_id in dataset_ids:
try:
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
if not knowledge_base:
return Response({
'code': 404,
'message': f'知识库不存在: {kb_id}',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
knowledge_bases.append(knowledge_base)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'):
return Response({
'code': 403,
'message': f'无权访问知识库: {knowledge_base.name}',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 添加知识库的external_id到列表
if knowledge_base.external_id:
external_id_list.append(str(knowledge_base.external_id))
else:
logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id")
except Exception as e:
return Response({
'code': 400,
'message': f'处理知识库ID出错: {str(e)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if not external_id_list:
return Response({
'code': 400,
'message': '没有有效的知识库external_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建metadata
metadata = {
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
'dataset_id_list': [str(id) for id in dataset_ids],
'dataset_external_id_list': [str(id) for id in external_id_list],
'dataset_names': [kb.name for kb in knowledge_bases]
}
# 检查是否有自定义标题
title = data.get('title', 'New chat')
# 创建用户问题记录
question_record = ChatHistory.objects.create(
user=request.user,
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
conversation_id=str(conversation_id),
title=title, # 设置标题
role='user',
content=data['question'],
metadata=metadata
)
# 检查是否需要流式输出
use_stream = data.get('stream', True)
if use_stream:
# 创建流式响应
response = StreamingHttpResponse(
self._stream_answer_from_external_api(
conversation_id=str(conversation_id),
question_record=question_record,
dataset_external_id_list=external_id_list,
knowledge_bases=knowledge_bases,
question=data['question'],
metadata=metadata
),
content_type='text/event-stream',
status=status.HTTP_201_CREATED # 修改状态码为201
)
# 添加禁用缓存的头部
response['Cache-Control'] = 'no-cache, no-store'
response['Connection'] = 'keep-alive'
return response
else:
# 使用非流式输出
logger.info("使用非流式输出模式")
# 调用同步 API 获取回答
answer = self._get_answer_from_external_api(external_id_list, data['question'])
if answer is None:
return Response({
'code': 500,
'message': '获取回答失败',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 创建 AI 回答记录
answer_record = ChatHistory.objects.create(
user=request.user,
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=title, # 设置标题
parent_id=str(question_record.id),
role='assistant',
content=answer,
metadata=metadata
)
# 如果是新会话的第一条消息,并且没有自定义标题,则自动生成标题
should_generate_title = not existing_records.exists() and (not title or title == 'New chat')
if should_generate_title:
try:
generated_title = self._generate_conversation_title_from_deepseek(
data['question'],
answer
)
if generated_title:
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
title = generated_title
except Exception as e:
logger.error(f"自动生成标题失败: {str(e)}")
# 继续执行,不影响主流程
return Response({
'code': 200, # 修改状态码为201
'message': '成功',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title, # 添加标题字段
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': answer,
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S')
}
}, status=status.HTTP_200_CREATED) # 修改状态码为201
except Exception as e:
logger.error(f"创建聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'创建聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _stream_answer_from_external_api(self, conversation_id, question_record, dataset_external_id_list, knowledge_bases, question, metadata):
"""流式获取AI回答并实时返回 - 优化版本"""
try:
# 确保所有ID都是字符串
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
# 获取标题
title = question_record.title or 'New chat'
# 创建AI回答记录对象稍后更新内容
answer_record = ChatHistory.objects.create(
user=question_record.user,
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=title, # 设置标题
parent_id=str(question_record.id),
role='assistant',
content="", # 初始内容为空
metadata=metadata
)
# 发送初始响应告知客户端开始流式传输
yield f"data: {json.dumps({'code': 200, 'message': '开始流式传输', 'data': {'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': '', 'is_end': False}})}\n\n"
# 异步收集完整内容,用于最后保存
full_content = ""
# 打开与外部API的连接
logger.info(f"开始调用外部API知识库ID列表: {dataset_external_ids}")
try:
# 第一步: 创建聊天会话
chat_response = requests.post(
url=f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": "7a214d0e-e65e-11ef-9f4a-0242ac120006",
"dataset_id_list": dataset_external_ids,
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10, "similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
},
headers={"Content-Type": "application/json"},
)
if chat_response.status_code != 200:
error_msg = f"外部API调用失败: {chat_response.text}"
logger.error(error_msg)
yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n"
return
chat_data = chat_response.json()
if chat_data.get('code') != 200 or not chat_data.get('data'):
error_msg = f"外部API返回错误: {chat_data}"
logger.error(error_msg)
yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n"
return
chat_id = chat_data['data']
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
# 第二步: 建立流式连接
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
logger.info(f"开始流式请求: {message_url}")
# 创建流式请求
message_request = requests.post(
url=message_url,
json={"message": question, "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"},
stream=True, # 启用流式传输
)
if message_request.status_code != 200:
error_msg = f"外部API聊天消息调用失败: {message_request.status_code}, {message_request.text}"
logger.error(error_msg)
yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n"
return
# 创建一个缓冲区以处理分段的数据
buffer = ""
# 读取并处理每个响应块
logger.info("开始处理流式响应")
for chunk in message_request.iter_content(chunk_size=1):
if not chunk:
continue
# 解码字节为字符串
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
# 检查是否有完整的数据行
if '\n\n' in buffer:
lines = buffer.split('\n\n')
# 除了最后一行,其他都是完整的
for line in lines[:-1]:
# 处理完整的数据行
if line.startswith('data: '):
try:
# 提取JSON数据
json_str = line[6:] # 去掉 "data: " 前缀
data = json.loads(json_str)
# 记录并处理部分响应
if 'content' in data:
content_part = data['content']
full_content += content_part
# 构建响应数据
response_data = {
'code': 200, # 修改状态码为201
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title, # 添加标题字段
'content': content_part,
'is_end': data.get('is_end', False)
}
}
# 立即发送每个部分到客户端
yield f"data: {json.dumps(response_data)}\n\n"
# 处理结束标记
if data.get('is_end', False):
logger.info("收到流式响应结束标记")
# 异步保存完整内容
answer_record.content = full_content.strip()
answer_record.save()
# 先检查当前conversation_id是否已有有效标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
# 如果已有有效标题,则复用
if current_title:
title = current_title
logger.info(f"复用已有标题: {title}")
else:
# 没有有效标题时,直接基于当前问题和回答生成标题
try:
# 直接使用当前的问题和完整的AI回答来生成标题
generated_title = self._generate_conversation_title_from_deepseek(
question, full_content.strip()
)
if generated_title:
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
title = generated_title
logger.info(f"成功生成标题: {title}")
else:
title = "新对话" # 如果生成失败,使用默认标题
logger.warning("生成标题失败,使用默认标题")
except Exception as e:
logger.error(f"自动生成标题失败: {str(e)}")
title = "新对话" # 如果出错,使用默认标题
# 发送完整内容的最终响应
final_response = {
'code': 200, # 修改状态码为201
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title, # 添加生成的标题
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
}
yield f"data: {json.dumps(final_response)}\n\n"
return # 结束生成器
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
# 继续处理,跳过此行
# 保留最后一个可能不完整的行
buffer = lines[-1]
# 处理最后可能剩余的缓冲数据
if buffer:
logger.info(f"处理剩余缓冲数据: {buffer}")
if buffer.startswith('data: '):
try:
json_str = buffer[6:] # 去掉 "data: " 前缀
data = json.loads(json_str)
if 'content' in data:
content_part = data['content']
full_content += content_part
response_data = {
'code': 200, # 修改状态码为201
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id), # 添加标题字段
'content': content_part,
'is_end': data.get('is_end', False)
}
}
yield f"data: {json.dumps(response_data)}\n\n"
except json.JSONDecodeError:
logger.error(f"处理剩余数据时JSON解析错误: {buffer}")
# 确保在流结束时保存内容到数据库
if full_content:
answer_record.content = full_content.strip()
answer_record.save()
logger.info(f"流结束,保存完整内容到数据库: {len(full_content)} 字符")
except requests.exceptions.RequestException as e:
logger.error(f"请求外部API时发生错误: {str(e)}")
yield f"data: {json.dumps({'code': 500, 'message': f'请求外部API时发生错误: {str(e)}', 'data': {'is_end': True}})}\n\n"
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
logger.error(traceback.format_exc())
yield f"data: {json.dumps({'code': 500, 'message': f'流式处理出错: {str(e)}', 'data': {'is_end': True}})}\n\n"
# 尝试保存已收集的内容
if 'full_content' in locals() and full_content:
try:
answer_record.content = full_content.strip()
answer_record.save()
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
def _get_answer_from_external_api(self, dataset_external_id_list, question):
"""调用外部API获取AI回答非流式版本"""
try:
# 确保所有ID都是字符串
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
logger.info(f"准备调用外部API非流式模式知识库ID列表: {dataset_external_ids}")
# 第一个API调用创建聊天
chat_request_data = {
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": "7a214d0e-e65e-11ef-9f4a-0242ac120006",
"dataset_id_list": dataset_external_ids,
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10,
"similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
logger.info(f"发送创建聊天请求:{settings.API_BASE_URL}/api/application/chat/open")
try:
# 测试JSON序列化提前捕获可能的错误
json_data = json.dumps(chat_request_data)
logger.debug(f"请求数据序列化成功,长度: {len(json_data)}")
except TypeError as e:
logger.error(f"JSON序列化失败: {str(e)}")
return None
chat_response = requests.post(
url=f"{settings.API_BASE_URL}/api/application/chat/open",
json=chat_request_data,
headers={"Content-Type": "application/json"},
)
logger.info(f"API响应状态码: {chat_response.status_code}")
if chat_response.status_code != 200:
logger.error(f"外部API调用失败: {chat_response.text}")
return None
chat_data = chat_response.json()
logger.debug(f"API响应数据: {chat_data}")
if chat_data.get('code') != 200 or not chat_data.get('data'):
logger.error(f"外部API返回错误: {chat_data}")
return None
chat_id = chat_data['data']
logger.info(f"聊天创建成功chat_id: {chat_id}")
# 第二个API调用发送消息
message_request_data = {
"message": question,
"re_chat": False,
"stream": False # 设置为非流式
}
logger.info(f"发送聊天消息请求(非流式): {settings.API_BASE_URL}/api/application/chat_message/{chat_id}")
message_response = requests.post(
url=f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}",
json=message_request_data,
headers={"Content-Type": "application/json"},
)
if message_response.status_code != 200:
logger.error(f"外部API聊天消息调用失败: {message_response.status_code}, {message_response.text}")
return None
# 处理非流式响应
try:
response_data = message_response.json()
logger.debug(f"非流式响应数据: {response_data}")
if response_data.get('code') != 200 or 'data' not in response_data:
logger.error(f"外部API返回错误: {response_data}")
return None
# 提取回答内容
answer_content = response_data.get('data', {}).get('content', '')
if not answer_content:
logger.warning("API返回的回答内容为空")
return "无法获取回答内容"
return answer_content
except json.JSONDecodeError as e:
logger.error(f"解析API响应JSON失败: {str(e)}")
return None
except Exception as e:
logger.error(f"处理API响应失败: {str(e)}")
logger.error(traceback.format_exc())
return None
except Exception as e:
logger.error(f"调用外部API获取回答失败: {str(e)}")
logger.error(traceback.format_exc())
return None
def update(self, request, pk=None):
"""更新聊天记录"""
try:
record = self.get_queryset().filter(id=pk).first()
if not record:
return Response({
'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
data = request.data
updateable_fields = ['content', 'metadata']
if 'content' in data:
record.content = data['content']
if 'metadata' in data:
current_metadata = record.metadata or {}
current_metadata.update(data['metadata'])
record.metadata = current_metadata
record.save()
return Response({
'code': 200,
'message': '更新成功',
'data': {
'id': record.id,
'conversation_id': record.conversation_id,
'role': record.role,
'content': record.content,
'metadata': record.metadata,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
})
except Exception as e:
logger.error(f"更新聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'更新聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def destroy(self, request, pk=None):
"""删除聊天记录(软删除)"""
try:
record = self.get_queryset().filter(id=pk).first()
if not record:
return Response({
'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
record.soft_delete()
return Response({
'code': 200,
'message': '删除成功',
'data': None
})
except Exception as e:
logger.error(f"删除聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'删除聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def search(self, request):
"""搜索聊天记录"""
try:
# 获取查询参数
keyword = request.query_params.get('keyword', '').strip()
dataset_id = request.query_params.get('dataset_id')
start_date = request.query_params.get('start_date')
end_date = request.query_params.get('end_date')
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
# 基础查询
query = self.get_queryset()
# 添加过滤条件
if keyword:
query = query.filter(
Q(content__icontains=keyword) |
Q(knowledge_base__name__icontains=keyword)
)
if dataset_id:
# 检查知识库权限
knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first()
if knowledge_base and not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'):
return Response({
'code': 403,
'message': '无权访问该知识库',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
query = query.filter(knowledge_base__id=dataset_id)
if start_date:
query = query.filter(created_at__gte=start_date)
if end_date:
query = query.filter(created_at__lte=end_date)
# 计算分页
total = query.count()
start = (page - 1) * page_size
end = start + page_size
# 获取分页数据
records = query.order_by('-created_at')[start:end]
# 序列化数据
results = []
for record in records:
result = {
'id': record.id,
'conversation_id': record.conversation_id,
'dataset_id': str(record.knowledge_base.id),
'dataset_name': record.knowledge_base.name,
'role': record.role,
'content': record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'metadata': record.metadata
}
if keyword:
result['highlights'] = {
'content': self._highlight_keyword(record.content, keyword)
}
results.append(result)
return Response({
'code': 200,
'message': '搜索成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'results': results
}
})
except Exception as e:
logger.error(f"搜索聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'搜索失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def export(self, request):
"""导出聊天记录为Excel文件"""
try:
# 获取查询参数
conversation_id = request.query_params.get('conversation_id')
dataset_id = request.query_params.get('dataset_id')
history_days = request.query_params.get('history_days', '7') # 默认导出最近7天
# 至少需要一个筛选条件
if not conversation_id and not dataset_id:
return Response({
'code': 400,
'message': '需要提供conversation_id或dataset_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证权限
user = request.user
if dataset_id:
knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first()
if not knowledge_base:
return Response({
'code': 404,
'message': '知识库不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
return Response({
'code': 403,
'message': '无权访问该知识库',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 查询确认有聊天记录存在
query = self.get_queryset()
if conversation_id:
records = query.filter(conversation_id=conversation_id)
elif dataset_id:
records = query.filter(knowledge_base__id=dataset_id)
if not records.exists():
return Response({
'code': 404,
'message': '未找到相关对话记录',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 调用外部API导出Excel文件 - 使用GET请求
application_id = "d5d11efa-ea9a-11ef-9933-0242ac120006" # 固定值
export_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat/export?history_day={history_days}"
logger.info(f"发送导出请求:{export_url}")
export_response = requests.get(
url=export_url,
stream=True # 使用流式传输处理大文件
)
# 检查响应状态
if export_response.status_code != 200:
logger.error(f"导出API调用失败: {export_response.status_code}, {export_response.text}")
return Response({
'code': 500,
'message': '导出失败,外部服务返回错误',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 创建响应对象并设置文件下载头
response = HttpResponse(
content_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
)
response['Content-Disposition'] = 'attachment; filename="data.xlsx"'
# 将API响应内容写入响应对象
for chunk in export_response.iter_content(chunk_size=8192):
if chunk:
response.write(chunk)
logger.info("导出成功完成")
return response
except Exception as e:
logger.error(f"导出聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'导出聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def chat_list(self, request):
"""获取对话列表"""
try:
# 获取查询参数
history_days = request.query_params.get('history_days', '7') # 默认7天
# 构建API请求
application_id = "d5d11efa-ea9a-11ef-9933-0242ac120006"
api_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat"
# 添加查询参数
params = {
'history_day': history_days
}
logger.info(f"发送获取对话列表请求:{api_url}")
# 调用外部API
response = requests.get(
url=api_url,
params=params,
)
if response.status_code != 200:
logger.error(f"获取对话列表失败: {response.status_code}, {response.text}")
return Response({
'code': 500,
'message': '获取对话列表失败,外部服务返回错误',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 解析响应数据
try:
result = response.json()
if result.get('code') != 200:
logger.error(f"外部API返回错误: {result}")
return Response({
'code': result.get('code', 500),
'message': result.get('message', '获取对话列表失败'),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 处理返回的数据
chat_list = result.get('data', [])
# 格式化返回数据
formatted_chats = []
for chat in chat_list:
formatted_chat = {
'id': chat['id'],
'chat_id': chat['chat_id'],
'abstract': chat['abstract'],
'message_count': chat['chat_record_count'],
'created_at': datetime.fromisoformat(chat['create_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'),
'updated_at': datetime.fromisoformat(chat['update_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'),
'star_count': chat['star_num'],
'trample_count': chat['trample_num'],
'mark_sum': chat['mark_sum'],
'is_deleted': chat['is_deleted']
}
formatted_chats.append(formatted_chat)
return Response({
'code': 200,
'message': '获取成功',
'data': {
'total': len(formatted_chats),
'results': formatted_chats
}
})
except json.JSONDecodeError as e:
logger.error(f"解析响应数据失败: {str(e)}")
return Response({
'code': 500,
'message': '解析响应数据失败',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e:
logger.error(f"获取对话列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取对话列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['post'])
def hit_test(self, request):
"""获取问题与知识库文档的匹配度"""
try:
data = request.data
# 检查必填字段
if 'question' not in data:
return Response({
'code': 400,
'message': '缺少必填字段: question',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if 'dataset_id_list' not in data or not data['dataset_id_list']:
return Response({
'code': 400,
'message': '缺少必填字段: dataset_id_list',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
question = data['question']
dataset_ids = data['dataset_id_list']
# 如果不是列表,转换为列表
if not isinstance(dataset_ids, list):
try:
dataset_ids = json.loads(dataset_ids)
if not isinstance(dataset_ids, list):
dataset_ids = [dataset_ids]
except (json.JSONDecodeError, TypeError):
dataset_ids = [dataset_ids]
# 检查用户是否有权限访问这些知识库
external_id_list = []
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, request.user, 'read'):
return Response({
'code': 403,
'message': f'无权访问知识库: {kb.name}',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
if kb.external_id:
external_id_list.append(str(kb.external_id))
else:
logger.warning(f"知识库 {kb.id} ({kb.name}) 没有external_id")
except KnowledgeBase.DoesNotExist:
return Response({
'code': 404,
'message': f'知识库不存在: {kb_id}',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
if not external_id_list:
return Response({
'code': 400,
'message': '没有有效的知识库external_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取所有知识库的匹配文档
all_documents = []
for dataset_id in external_id_list:
doc_info = self._call_hit_test_api(dataset_id, question)
if doc_info:
all_documents.extend(doc_info)
# 按相似度排序
all_documents = sorted(all_documents, key=lambda x: x.get('similarity', 0), reverse=True)
# 返回结果
return Response({
'code': 200,
'message': '成功',
'data': {
'question': question,
'matched_documents': all_documents,
'total_count': len(all_documents)
}
})
except Exception as e:
logger.error(f"hit_test接口调用失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'hit_test接口调用失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _highlight_keyword(self, text, keyword):
"""高亮关键词"""
if not keyword or not text:
return text
return text.replace(
keyword,
f'<em class="highlight">{keyword}</em>'
)
def _call_hit_test_api(self, dataset_id, query_text):
"""调用知识库hit_test接口获取相关文档信息"""
try:
url = f"{settings.API_BASE_URL}/api/dataset/{dataset_id}/hit_test"
params = {
"query_text": query_text,
"top_number": 10,
"similarity": 0.3,
"search_mode": "blend"
}
logger.info(f"调用hit_test接口: {url}, 参数: {params}")
response = requests.get(
url=url,
params=params,
)
if response.status_code != 200:
logger.error(f"hit_test接口调用失败: {response.status_code}, {response.text}")
return None
result = response.json()
if result.get('code') != 200:
logger.error(f"hit_test接口业务错误: {result}")
return None
# 提取文档信息
documents = result.get('data', [])
logger.info(f"hit_test接口返回 {len(documents)} 个相关文档")
# 提取文档名称和相似度等信息
doc_info = []
for doc in documents:
doc_info.append({
"document_name": doc.get("document_name", ""),
"dataset_name": doc.get("dataset_name", ""),
"similarity": doc.get("similarity", 0),
"comprehensive_score": doc.get("comprehensive_score", 0)
})
return doc_info
except Exception as e:
logger.error(f"调用hit_test接口失败: {str(e)}")
logger.error(traceback.format_exc())
return None
@action(detail=False, methods=['delete'])
def delete_conversation(self, request):
"""通过conversation_id删除一组会话"""
try:
# 获取conversation_id
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少必要参数: conversation_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 查找该会话下的所有记录
records = self.get_queryset().filter(conversation_id=conversation_id)
if not records.exists():
return Response({
'code': 404,
'message': '未找到该会话或无权限访问',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 获取记录数量
records_count = records.count()
# 批量软删除
for record in records:
record.soft_delete()
return Response({
'code': 200,
'message': '删除成功',
'data': {
'conversation_id': conversation_id,
'deleted_count': records_count
}
})
except Exception as e:
logger.error(f"删除会话失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'删除会话失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'], url_path='generate-recommended-reply')
def generate_recommended_reply(self, request):
"""获取达人消息并生成推荐回复"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取对话历史,确保按时间顺序排序
messages = self.get_queryset().filter(
conversation_id=conversation_id,
is_deleted=False
).order_by('created_at')
if not messages.exists():
return Response({
'code': 404,
'message': '对话不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 获取最新的消息,检查是否是达人发送的
latest_message = messages.last()
# 如果最后一条不是用户消息,尝试查找最新的用户消息
if latest_message.role != 'user':
user_messages = messages.filter(role='user')
if user_messages.exists():
latest_message = user_messages.order_by('-created_at').first()
else:
return Response({
'code': 400,
'message': '对话中没有达人发送的消息,无法生成推荐回复',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 提取对话历史以传递给DeepSeek API
conversation_history = []
for message in messages:
conversation_history.append({
'role': 'user' if message.role == 'user' else 'assistant',
'content': message.content
})
# 调用DeepSeek V3 API生成推荐回复
recommended_reply = self._get_recommended_reply_from_deepseek(conversation_history)
if not recommended_reply:
return Response({
'code': 400, # 改为400表示可以重试
'message': 'API暂时无法生成推荐回复请稍后再试',
'data': {
'conversation_id': conversation_id,
'latest_message': {
'id': str(latest_message.id),
'content': latest_message.content,
'role': latest_message.role,
'created_at': latest_message.created_at.strftime('%Y-%m-%d %H:%M:%S')
},
'recommended_reply': None
}
}, status=status.HTTP_400_BAD_REQUEST)
return Response({
'code': 200,
'message': '生成推荐回复成功',
'data': {
'conversation_id': conversation_id,
'latest_message': {
'id': str(latest_message.id),
'content': latest_message.content,
'role': latest_message.role,
'created_at': latest_message.created_at.strftime('%Y-%m-%d %H:%M:%S')
},
'recommended_reply': recommended_reply
}
})
except Exception as e:
logger.error(f"生成推荐回复失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"生成推荐回复失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _get_recommended_reply_from_deepseek(self, conversation_history):
"""调用DeepSeek V3 API生成推荐回复"""
try:
# 使用有效的API密钥
api_key = "sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf"
# 如果上面的密钥不正确,可以尝试从环境变量或数据库中获取
# 从Django设置中获取密钥
from django.conf import settings
if hasattr(settings, 'DEEPSEEK_API_KEY') and settings.DEEPSEEK_API_KEY:
api_key = settings.DEEPSEEK_API_KEY
url = "https://api.siliconflow.cn/v1/chat/completions"
# 直接使用默认系统消息,不进行复杂处理,尽量模仿文档示例
system_message = {
"role": "system",
"content": "你是一位专业的电商客服和达人助手。你的任务是针对用户最近的消息生成一个有帮助、礼貌且详细的回复。即使用户消息很短或不明确也必须提供有实质内容的回复。禁止返回空白内容。回复应该有至少100个字符。"
}
messages = [system_message]
# 限制对话历史长度只保留最近的5条消息避免超出token限制
recent_messages = conversation_history[-5:] if len(conversation_history) > 5 else conversation_history
messages.extend(recent_messages)
# 确保最后一条消息是用户消息,如果不是,添加一个提示
if not recent_messages or recent_messages[-1]['role'] != 'user':
# 添加一个系统消息作为用户的最后一条消息
messages.append({
"role": "user",
"content": "请针对我之前的消息提供详细的回复建议。"
})
# 完全按照文档提供的参数格式构建请求
payload = {
"model": "deepseek-ai/DeepSeek-V3",
"messages": messages,
"stream": False,
"max_tokens": 1024, # 增加token上限
"temperature": 0.7, # 提高多样性
"top_p": 0.9,
"top_k": 50,
"frequency_penalty": 0.5,
"presence_penalty": 0.2, # 添加新参数
"n": 1,
"stop": [],
"response_format": {
"type": "text"
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
logger.info(f"开始调用DeepSeek API生成推荐回复")
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
logger.error(f"DeepSeek API调用失败: {response.status_code}, {response.text}")
return None
result = response.json()
logger.debug(f"DeepSeek API返回: {result}")
# 提取回复内容
if 'choices' in result and len(result['choices']) > 0:
reply = result['choices'][0]['message']['content']
# 如果返回的内容为空直接返回None
if not reply or reply.strip() == '':
logger.warning("DeepSeek API返回的回复内容为空")
return None
return reply
logger.warning(f"DeepSeek API返回格式异常: {result}")
return None
except Exception as e:
logger.error(f"调用DeepSeek API失败: {str(e)}")
logger.error(traceback.format_exc())
return None
def _generate_fallback_reply(self, messages):
"""此功能已禁用,不再生成备用回复"""
return None
@action(detail=False, methods=['post'], url_path='auto-recommend-reply')
def auto_recommend_reply(self, request):
"""
设置自动推荐回复功能
当收到达人消息时自动生成推荐回复并通过WebSocket发送通知
"""
try:
# 获取是否启用自动推荐回复的设置
enable_auto_recommend = request.data.get('enable_auto_recommend', False)
user = request.user
# 更新用户配置
user_profile, created = UserProfile.objects.get_or_create(user=user)
user_profile.auto_recommend_reply = enable_auto_recommend
user_profile.save()
return Response({
'code': 200,
'message': '设置自动推荐回复成功',
'data': {
'enable_auto_recommend': enable_auto_recommend
}
})
except Exception as e:
logger.error(f"设置自动推荐回复失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"设置自动推荐回复失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'], url_path='get-auto-recommend-setting')
def get_auto_recommend_setting(self, request):
"""获取自动推荐回复设置"""
try:
user = request.user
user_profile, created = UserProfile.objects.get_or_create(user=user)
return Response({
'code': 200,
'message': '获取自动推荐回复设置成功',
'data': {
'enable_auto_recommend': getattr(user_profile, 'auto_recommend_reply', False)
}
})
except Exception as e:
logger.error(f"获取自动推荐回复设置失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"获取自动推荐回复设置失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'], url_path='generate-conversation-title')
def generate_conversation_title(self, request):
"""更新会话标题"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查对话是否存在
messages = self.get_queryset().filter(
conversation_id=conversation_id,
is_deleted=False,
user=request.user
).order_by('created_at')
if not messages.exists():
return Response({
'code': 404,
'message': '对话不存在或无权访问',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 检查是否有自定义标题参数
custom_title = request.query_params.get('title')
if not custom_title:
return Response({
'code': 400,
'message': '缺少title参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=conversation_id,
user=request.user
).update(title=custom_title)
return Response({
'code': 200,
'message': '更新会话标题成功',
'data': {
'conversation_id': conversation_id,
'title': custom_title
}
})
except Exception as e:
logger.error(f"更新会话标题失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"更新会话标题失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _generate_conversation_title_from_deepseek(self, user_question, assistant_answer):
"""调用SiliconCloud API生成会话标题直接基于当前问题和回答内容"""
try:
# 从Django设置中获取API密钥
api_key = settings.SILICON_CLOUD_API_KEY
if not api_key:
return "新对话"
# 构建提示信息
prompt = f"请根据用户的问题和助手的回答生成一个简短的对话标题不超过20个字\n\n用户问题: {user_question}\n\n助手回答: {assistant_answer}"
import requests
url = "https://api.siliconflow.cn/v1/chat/completions"
payload = {
"model": "deepseek-ai/DeepSeek-V3",
"stream": False,
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.7,
"top_k": 50,
"frequency_penalty": 0.5,
"n": 1,
"stop": [],
"messages": [
{
"role": "user",
"content": prompt
}
]
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
response_data = response.json()
if response.status_code == 200 and 'choices' in response_data and response_data['choices']:
title = response_data['choices'][0]['message']['content'].strip()
return title[:50] # 截断过长的标题
else:
logger.error(f"生成标题时出错: {response.text}")
return "新对话"
except Exception as e:
logger.exception(f"生成对话标题时发生错误: {str(e)}")
return "新对话"
class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
serializer_class = KnowledgeBaseSerializer
permission_classes = [IsAuthenticated]
def list(self, request, *args, **kwargs):
try:
queryset = self.get_queryset()
# 获取搜索关键字
keyword = request.query_params.get('keyword', '')
# 如果有关键字,构建搜索条件
if keyword:
query = Q(name__icontains=keyword) | \
Q(desc__icontains=keyword) | \
Q(department__icontains=keyword) | \
Q(group__icontains=keyword)
queryset = queryset.filter(query)
# 获取分页参数
try:
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
except ValueError:
page = 1
page_size = 10
# 计算总数量
total = queryset.count()
# 分页处理
start = (page - 1) * page_size
end = start + page_size
paginated_queryset = queryset[start:end]
# 序列化知识库数据
serializer = self.get_serializer(paginated_queryset, many=True)
data = serializer.data
# 为每个知识库添加权限信息
user = request.user
for item in data:
# 获取必要的知识库属性
kb_type = item['type']
department = item.get('department')
group = item.get('group')
creator_id = item.get('user_id')
kb_id = item['id']
# 首先检查权限表中的显式权限
explicit_permission = KBPermissionModel.objects.filter(
knowledge_base_id=kb_id,
user=user,
status='active'
).first()
if explicit_permission:
item['permissions'] = {
'can_read': explicit_permission.can_read,
'can_edit': explicit_permission.can_edit,
'can_delete': explicit_permission.can_delete
}
# 添加知识库的到期时间
item['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
else:
# 没有显式权限时使用统一的权限判断方法
item['permissions'] = {
'can_read': self._can_read(kb_type, user, department, group, creator_id, kb_id),
'can_edit': self._can_edit(kb_type, user, department, group, creator_id, kb_id),
'can_delete': self._can_delete(kb_type, user, department, group, creator_id, kb_id)
}
# 对于admin类型的知识库设置expires_at为None
if kb_type == 'admin':
item['expires_at'] = None
else:
# 对于其他类型,如果没有显式权限记录,则表示没有到期时间
item['expires_at'] = None
# 处理高亮
if keyword:
if 'name' in item and keyword.lower() in item['name'].lower():
item['highlighted_name'] = item['name'].replace(
keyword, f'<em class="highlight">{keyword}</em>'
)
if 'desc' in item and item.get('desc') is not None:
desc_text = str(item['desc'])
if keyword.lower() in desc_text.lower():
item['highlighted_desc'] = desc_text.replace(
keyword, f'<em class="highlight">{keyword}</em>'
)
return Response({
"code": 200,
"message": "获取知识库列表成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"keyword": keyword if keyword else None,
"items": data
}
})
except Exception as e:
logger.error(f"获取知识库列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"获取知识库列表失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def get_queryset(self):
"""获取用户有权限查看的知识库列表"""
user = self.request.user
queryset = KnowledgeBase.objects.all()
# 1. 构建基础权限条件
permission_conditions = Q()
# 2. 所有用户都可以看到 admin 类型的知识库
permission_conditions |= Q(type='admin')
# 3. 用户可以看到自己创建的所有知识库
permission_conditions |= Q(user_id=user.id)
# 4. 添加显式权限条件
# 获取所有活跃的权限记录
active_permissions = KBPermissionModel.objects.filter(
user=user,
can_read=True,
status='active',
expires_at__gt=timezone.now()
).values_list('knowledge_base_id', flat=True)
if active_permissions:
permission_conditions |= Q(id__in=active_permissions)
# 5. 根据用户角色添加隐式权限
if user.role == 'admin':
# 管理员可以看到除了其他用户 private 类型外的所有知识库
permission_conditions |= ~Q(type='private') | Q(user_id=user.id)
elif user.role == 'leader':
# 组长可以查看本部门的 leader 和 member 类型知识库
permission_conditions |= Q(
type__in=['leader', 'member'],
department=user.department
)
elif user.role in ['member', 'user']:
# 成员可以查看本部门的 leader 类型知识库
permission_conditions |= Q(
type='leader',
department=user.department
)
# 成员可以查看本部门本组的 member 类型知识库
permission_conditions |= Q(
type='member',
department=user.department,
group=user.group
)
return queryset.filter(permission_conditions).distinct()
def create(self, request, *args, **kwargs):
try:
# 1. 验证知识库名称
name = request.data.get('name')
if not name:
return Response({
'code': 400,
'message': '知识库名称不能为空',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if KnowledgeBase.objects.filter(name=name).exists():
return Response({
'code': 400,
'message': f'知识库名称 "{name}" 已存在',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 2. 验证用户权限和必填字段
user = request.user
type = request.data.get('type', 'private')
department = request.data.get('department')
group = request.data.get('group')
# 修改权限验证
if type == 'admin':
# 移除管理员权限检查,允许所有用户创建
department = None
group = None
elif type == 'secret':
if user.role != 'admin':
return Response({
'code': 403,
'message': '只有管理员可以创建保密级知识库',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
department = None
group = None
elif type == 'leader':
if user.role != 'admin':
return Response({
'code': 403,
'message': '只有管理员可以创建组长级知识库',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
if not department:
return Response({
'code': 400,
'message': '创建组长级知识库时必须指定部门',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
elif type == 'member':
if user.role not in ['admin', 'leader']:
return Response({
'code': 403,
'message': '只有管理员和组长可以创建成员级知识库',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
if user.role == 'admin' and not department:
return Response({
'code': 400,
'message': '管理员创建成员知识库时必须指定部门',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
elif user.role == 'leader':
department = user.department
if not group:
return Response({
'code': 400,
'message': '创建成员知识库时必须指定组',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
elif type == 'private':
# 对于private类型不保存department和group
department = None
group = None
# 3. 验证请求数据
data = request.data.copy()
data['department'] = department
data['group'] = group
# 不需要手动设置 user_id由序列化器自动处理
serializer = self.get_serializer(data=data)
if not serializer.is_valid():
logger.error(f"数据验证失败: {serializer.errors}")
return Response({
'code': 400,
'message': '数据验证失败',
'data': serializer.errors
}, status=status.HTTP_400_BAD_REQUEST)
with transaction.atomic():
# 4. 创建知识库
try:
knowledge_base = serializer.save()
logger.info(f"知识库创建成功: id={knowledge_base.id}, name={knowledge_base.name}, user_id={knowledge_base.user_id}")
except Exception as e:
logger.error(f"知识库创建失败: {str(e)}")
raise
# 5. 调用外部API创建知识库
try:
external_id = self._create_external_dataset(knowledge_base)
logger.info(f"外部知识库创建成功获取ID: {external_id}")
# 保存外部知识库ID
knowledge_base.external_id = external_id
knowledge_base.save()
logger.info(f"更新knowledge_base的external_id为: {external_id}")
except ExternalAPIError as e:
logger.error(f"外部知识库创建失败: {str(e)}")
raise
# 6. 创建权限记录
try:
# 创建者权限
KBPermissionModel.objects.create(
knowledge_base=knowledge_base,
user=request.user,
can_read=True,
can_edit=True,
can_delete=True,
granted_by=request.user,
status='active'
)
logger.info(f"创建者权限创建成功")
# 根据类型批量创建其他用户权限
permissions = []
if type == 'admin':
users_query = User.objects.exclude(id=request.user.id)
# 为所有用户赋予完全权限(读、写、删)
permissions = [
KBPermissionModel(
knowledge_base=knowledge_base,
user=user,
can_read=True,
can_edit=True,
can_delete=True,
granted_by=request.user,
status='active'
) for user in users_query
]
elif type == 'secret':
users_query = User.objects.filter(role='admin').exclude(id=request.user.id)
permissions = [
KBPermissionModel(
knowledge_base=knowledge_base,
user=user,
can_read=True,
can_edit=self._can_edit(type, user),
can_delete=self._can_delete(type, user),
granted_by=request.user,
status='active'
) for user in users_query
]
elif type == 'leader':
users_query = User.objects.filter(
Q(role='admin') |
Q(role='leader', department=department)
).exclude(id=request.user.id)
permissions = [
KBPermissionModel(
knowledge_base=knowledge_base,
user=user,
can_read=True,
can_edit=self._can_edit(type, user),
can_delete=self._can_delete(type, user),
granted_by=request.user,
status='active'
) for user in users_query
]
elif type == 'member':
users_query = User.objects.filter(
Q(role='admin') |
Q(department=department, role='leader') |
Q(department=department, group=group, role='member')
).exclude(id=request.user.id)
permissions = [
KBPermissionModel(
knowledge_base=knowledge_base,
user=user,
can_read=True,
can_edit=self._can_edit(type, user),
can_delete=self._can_delete(type, user),
granted_by=request.user,
status='active'
) for user in users_query
]
else: # private
users_query = User.objects.none()
if permissions:
KBPermissionModel.objects.bulk_create(permissions)
logger.info(f"{type}类型权限创建完成: {len(permissions)}条记录")
except Exception as e:
logger.error(f"权限创建失败: {str(e)}")
logger.error(traceback.format_exc())
raise
return Response({
'code': 200,
'message': '知识库创建成功',
'data': {
'knowledge_base': serializer.data,
'external_id': knowledge_base.external_id
}
})
except Exception as e:
logger.error(f"创建知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'创建知识库失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def update(self, request, *args, **kwargs):
"""更新知识库"""
try:
instance = self.get_object()
user = request.user
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(instance, user, 'edit'):
return Response({
"code": 403,
"message": "没有编辑权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
with transaction.atomic():
# 执行本地更新
serializer = self.get_serializer(instance, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
# 更新外部知识库
if instance.external_id:
try:
api_data = {
"name": serializer.validated_data.get('name', instance.name),
"desc": serializer.validated_data.get('desc', instance.desc),
"type": "0", # 保持与创建时一致
"meta": {}, # 保持与创建时一致
"documents": [] # 保持与创建时一致
}
response = requests.put(
f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}',
json=api_data,
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}")
api_response = response.json()
if not api_response.get('code') == 200:
raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}")
logger.info(f"外部知识库更新成功: {instance.external_id}")
except requests.exceptions.Timeout:
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
raise ExternalAPIError(f"更新外部知识库失败: {str(e)}")
return Response({
"code": 200,
"message": "知识库更新成功",
"data": serializer.data
})
except Http404:
return Response({
"code": 404,
"message": "知识库不存在",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
except ExternalAPIError as e:
logger.error(f"更新外部知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": str(e),
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e:
logger.error(f"更新知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"更新知识库失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def destroy(self, request, *args, **kwargs):
"""删除知识库"""
try:
instance = self.get_object()
user = request.user
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(instance, user, 'delete'):
return Response({
"code": 403,
"message": "没有删除权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 删除外部知识库(如果存在)
external_delete_success = True
external_error_message = None
if instance.external_id:
try:
self._delete_external_dataset(instance.external_id)
logger.info(f"外部知识库删除成功: {instance.external_id}")
except ExternalAPIError as e:
# 记录错误但继续执行本地删除
external_delete_success = False
external_error_message = str(e)
logger.warning(f"外部知识库删除失败,将继续删除本地知识库: {str(e)}")
# 删除本地知识库
self.perform_destroy(instance)
logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}")
# 如果外部知识库删除失败,返回警告消息
if not external_delete_success:
return Response({
"code": 200,
"message": f"知识库已删除,但外部知识库删除失败: {external_error_message}",
"data": None
})
return Response({
"code": 200,
"message": "知识库删除成功",
"data": None
})
except Http404:
return Response({
"code": 404,
"message": "知识库不存在",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
logger.error(f"删除知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"删除知识库失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _delete_external_dataset(self, external_id):
"""删除外部知识库"""
try:
if not external_id:
logger.warning("外部知识库ID为空跳过删除")
return True
response = requests.delete(
f'{settings.API_BASE_URL}/api/dataset/{external_id}',
headers={'Content-Type': 'application/json'},
)
logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}")
# 检查响应状态码
if response.status_code == 404:
logger.warning(f"外部知识库不存在: {external_id}")
return True # 如果知识库不存在,也视为删除成功
elif response.status_code not in [200, 204]:
logger.warning(f"删除外部知识库状态码异常: {response.status_code}, {response.text}")
return True # 即使状态码异常,也允许继续删除本地知识库
# 检查业务状态码
try:
api_response = response.json()
if api_response.get('code') != 200:
# 如果是因为ID不存在也视为成功
if "不存在" in api_response.get('message', ''):
logger.warning(f"外部知识库ID不存在视为删除成功: {external_id}")
return True
logger.warning(f"业务处理返回非200状态码: {api_response.get('code')}, {api_response.get('message')}")
return True # 不再抛出异常,允许本地删除继续
logger.info(f"外部知识库删除成功: {external_id}")
return True
except ValueError:
# 如果无法解析 JSON但状态码是 200也认为成功
logger.warning(f"外部知识库删除响应无法解析JSON但状态码为200视为成功: {external_id}")
return True
except requests.exceptions.Timeout:
logger.error(f"删除外部知识库超时: {external_id}")
# 不再抛出异常,允许本地删除继续
return False
except requests.exceptions.RequestException as e:
logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}")
# 不再抛出异常,允许本地删除继续
return False
except Exception as e:
logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}")
# 不再抛出异常,允许本地删除继续
return False
@action(detail=True, methods=['get'])
def permissions(self, request, pk=None):
"""获取用户对特定知识库的权限"""
try:
instance = self.get_object()
user = request.user
# 使用统一的权限检查方法
permissions_data = {
"can_read": self.check_knowledge_base_permission(instance, user, 'read'),
"can_edit": self.check_knowledge_base_permission(instance, user, 'edit'),
"can_delete": self.check_knowledge_base_permission(instance, user, 'delete')
}
return Response({
"code": 200,
"message": "获取权限信息成功",
"data": {
"knowledge_base_id": instance.id,
"knowledge_base_name": instance.name,
"permissions": permissions_data
}
})
except Http404:
return Response({
"code": 404,
"message": "知识库不存在",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
logger.error(f"获取权限信息失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"获取权限信息失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def summary(self, request):
"""获取所有可见知识库的概要信息除了secret类型"""
try:
user = request.user
# 基础查询排除secret类型的知识库
queryset = KnowledgeBase.objects.exclude(type='secret')
summaries = []
for kb in queryset:
# 使用统一的权限判断方法
permissions = {
'can_read': self.check_knowledge_base_permission(kb, user, 'read'),
'can_edit': self.check_knowledge_base_permission(kb, user, 'edit'),
'can_delete': self.check_knowledge_base_permission(kb, user, 'delete')
}
# 获取知识库到期时间
explicit_permission = KBPermissionModel.objects.filter(
knowledge_base_id=kb.id,
user=user,
status='active'
).first()
expires_at = None
if explicit_permission:
expires_at = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
elif kb.type == 'admin':
expires_at = None
# 只返回概要信息
summary = {
'id': str(kb.id),
'name': kb.name,
'desc': kb.desc,
'type': kb.type,
'department': kb.department,
'permissions': permissions,
'expires_at': expires_at
}
summaries.append(summary)
return Response({
'code': 200,
'message': '获取知识库概要信息成功',
'data': summaries
})
except Exception as e:
return Response({
'code': 500,
'message': f'获取知识库概要信息失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def retrieve(self, request, *args, **kwargs):
try:
# 获取知识库对象
instance = self.get_object()
serializer = self.get_serializer(instance)
data = serializer.data
# 获取用户
user = request.user
# 使用统一的权限判断方法
data['permissions'] = {
'can_read': self.check_knowledge_base_permission(instance, user, 'read'),
'can_edit': self.check_knowledge_base_permission(instance, user, 'edit'),
'can_delete': self.check_knowledge_base_permission(instance, user, 'delete')
}
# 添加知识库到期时间
explicit_permission = KBPermissionModel.objects.filter(
knowledge_base_id=instance.id,
user=user,
status='active'
).first()
if explicit_permission:
data['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
else:
# 对于admin类型的知识库设置expires_at为None
if instance.type == 'admin':
data['expires_at'] = None
else:
# 对于其他类型,如果没有显式权限记录,则表示没有到期时间
data['expires_at'] = None
return Response({
'code': 200,
'message': '获取知识库详情成功',
'data': data
})
except Exception as e:
logger.error(f"获取知识库详情失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取知识库详情失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def search(self, request):
"""搜索知识库功能"""
try:
# 获取搜索关键字
keyword = request.query_params.get('keyword', '')
if not keyword:
return Response({
"code": 400,
"message": "搜索关键字不能为空",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取分页参数
try:
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
except ValueError:
page = 1
page_size = 10
# 构建搜索条件
query = Q(name__icontains=keyword) | \
Q(desc__icontains=keyword) | \
Q(department__icontains=keyword) | \
Q(group__icontains=keyword)
# 排除 secret 类型的知识库
queryset = KnowledgeBase.objects.filter(query).exclude(type='secret')
# 获取用户
user = request.user
# 获取用户所有有效的知识库权限
active_permissions = KBPermissionModel.objects.filter(
user=user,
status='active',
expires_at__gt=timezone.now()
).select_related('knowledge_base')
# 创建权限映射字典
permission_map = {
str(perm.knowledge_base.id): {
'can_read': perm.can_read,
'can_edit': perm.can_edit,
'can_delete': perm.can_delete
}
for perm in active_permissions
}
# 计算总数量
total = queryset.count()
# 分页处理
start = (page - 1) * page_size
end = start + page_size
paginated_queryset = queryset[start:end]
# 序列化知识库数据
serializer = self.get_serializer(paginated_queryset, many=True)
data = serializer.data
# 处理每个知识库项的权限和返回内容
result_items = []
for item in data:
# 创建一个临时的知识库对象用于权限检查
temp_kb = KnowledgeBase(
id=item['id'],
type=item['type'],
department=item.get('department'),
group=item.get('group'),
user_id=item.get('user_id')
)
# 使用统一的权限判断方法
explicit_permission = KBPermissionModel.objects.filter(
knowledge_base_id=item['id'],
user=user,
status='active'
).first()
if explicit_permission:
kb_permissions = {
'can_read': explicit_permission.can_read,
'can_edit': explicit_permission.can_edit,
'can_delete': explicit_permission.can_delete
}
# 添加知识库的到期时间
item['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
else:
# 使用统一的权限判断方法
kb_permissions = {
'can_read': self.check_knowledge_base_permission(temp_kb, user, 'read'),
'can_edit': self.check_knowledge_base_permission(temp_kb, user, 'edit'),
'can_delete': self.check_knowledge_base_permission(temp_kb, user, 'delete')
}
# 对于admin类型的知识库设置expires_at为None
if item['type'] == 'admin':
item['expires_at'] = None
else:
# 对于其他类型,如果没有显式权限记录,则表示没有到期时间
item['expires_at'] = None
# 添加权限信息
item['permissions'] = kb_permissions
# 根据权限返回不同级别的信息
if kb_permissions['can_read']:
result_items.append(item)
else:
# 无读取权限,只返回概要信息
summary_info = {
'id': item['id'],
'name': item['name'],
'type': item['type'],
'department': item.get('department'),
'permissions': kb_permissions
}
result_items.append(summary_info)
# 高亮搜索关键字
for item in result_items:
if 'name' in item and keyword.lower() in item['name'].lower():
highlighted = item['name'].replace(
keyword, f'<em class="highlight">{keyword}</em>'
)
item['highlighted_name'] = highlighted
# 确保desc不为None并且是字符串
if 'desc' in item and item.get('desc') is not None:
desc_text = str(item['desc']) # 转换为字符串以确保安全
if keyword.lower() in desc_text.lower():
highlighted = desc_text.replace(
keyword, f'<em class="highlight">{keyword}</em>'
)
item['highlighted_desc'] = highlighted
return Response({
"code": 200,
"message": "搜索知识库成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"keyword": keyword,
"items": result_items
}
})
except Exception as e:
logger.error(f"搜索知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"搜索知识库失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=True, methods=['post'])
def change_type(self, request, pk=None):
"""修改知识库类型"""
try:
instance = self.get_object()
user = request.user
# 使用统一的权限检查方法检查编辑权限
if not self.check_knowledge_base_permission(instance, user, 'edit'):
return Response({
"code": 403,
"message": "没有修改权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 其余代码保持不变...
# 获取新类型
new_type = request.data.get('type')
if not new_type:
return Response({
"code": 400,
"message": "新类型不能为空",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证类型是否有效
valid_types = ['private', 'admin', 'secret', 'leader', 'member']
if new_type not in valid_types:
return Response({
"code": 400,
"message": f"无效的知识库类型,可选值: {', '.join(valid_types)}",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 角色特定的类型限制
if new_type == 'leader' and not user.role == 'admin': # 组长且不是管理员
# 组长只能在private和member类型之间切换
if new_type not in ['private', 'member']:
return Response({
"code": 403,
"message": "组长只能将知识库设置为private或member类型",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 处理department和group字段
department = request.data.get('department')
group = request.data.get('group')
# 组长只能设置自己部门
if new_type == 'leader' and not user.role == 'admin':
if department and department != user.department:
return Response({
"code": 403,
"message": "组长只能为本部门设置知识库",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 如果未指定部门,强制设置为组长的部门
department = user.department
# 根据类型验证必填字段
if new_type == 'leader':
if not department:
return Response({
"code": 400,
"message": "组长级知识库必须指定部门",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
if new_type == 'member':
if not department:
return Response({
"code": 400,
"message": "成员级知识库必须指定部门",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
if not group:
return Response({
"code": 400,
"message": "成员级知识库必须指定组",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 如果是admin或secret类型清除department和group
if new_type in ['admin', 'secret']:
department = None
group = None
# 如果是private类型但未指定department和group使用原值
if new_type == 'private':
if department is None:
department = instance.department
if group is None:
group = instance.group
# 更新知识库类型和相关字段
instance.type = new_type
instance.department = department
instance.group = group
instance.save()
return Response({
"code": 200,
"message": f"知识库类型已更新为{new_type}",
"data": {
"id": instance.id,
"name": instance.name,
"type": instance.type,
"department": instance.department,
"group": instance.group
}
})
except Http404:
return Response({
"code": 404,
"message": "知识库不存在",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
logger.error(f"修改知识库类型失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"修改知识库类型失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _call_split_api_multiple(self, files):
"""调用文档分割API - 直接传递多个文件"""
try:
url = f'{settings.API_BASE_URL}/api/dataset/document/split'
# 准备请求数据 - 使用单个"file"字段
file_obj = files[0] # 先只处理第一个文件,便于排查问题
# 重置文件指针位置
if hasattr(file_obj, 'seek'):
file_obj.seek(0)
logger.info(f"准备上传文件: {file_obj.name}, 大小: {file_obj.size}字节, 类型: {file_obj.content_type}")
# 读取文件内容前100个字符进行记录
if hasattr(file_obj, 'read') and hasattr(file_obj, 'seek'):
content_preview = file_obj.read(100).decode('utf-8', errors='ignore')
logger.info(f"文件内容预览: {content_preview}")
file_obj.seek(0) # 重置文件指针
# 使用正确的字段名称发送请求
files_data = {'file': file_obj}
logger.info(f"调用分割API URL: {url}")
logger.info(f"请求字段: {list(files_data.keys())}")
# 发送请求
response = requests.post(
url,
files=files_data,
)
# 记录请求头和响应信息,方便排查问题
logger.info(f"请求头: {response.request.headers}")
logger.info(f"响应状态码: {response.status_code}")
if response.status_code != 200:
logger.error(f"分割API返回错误状态码: {response.status_code}, 响应: {response.text}")
return None
# 解析响应
result = response.json()
logger.info(f"分割API响应详情: {result}")
# 如果数据为空可能是API期望的请求格式不对尝试使用不同的字段名
if len(result.get('data', [])) == 0:
logger.warning("分割API返回的数据为空尝试使用后备方案")
# 创建一个手动构建的文档结构
fallback_data = {
'code': 200,
'message': '成功',
'data': [
{
'name': file_obj.name,
'content': [
{
'title': '文档内容',
'content': '文件内容无法自动分割请检查外部API。这是一个后备内容。'
}
]
}
]
}
logger.info("使用后备数据结构")
return fallback_data
return result
except Exception as e:
logger.error(f"调用分割API失败: {str(e)}")
logger.error(traceback.format_exc())
# 创建一个后备响应
fallback_response = {
'code': 200,
'message': '成功',
'data': []
}
# 如果有文件,为每个文件创建一个基本文档结构
if files:
fallback_response['data'] = [
{
'name': file.name,
'content': [
{
'title': '文档内容',
'content': '文件内容无法自动分割请检查API连接。'
}
]
} for file in files
]
logger.info("由于异常,返回后备响应")
return fallback_response
@action(detail=True, methods=['post'])
def upload_document(self, request, pk=None):
"""上传文档到知识库 - 支持多文件上传"""
try:
instance = self.get_object()
user = request.user
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(instance, user, 'edit'):
return Response({
"code": 403,
"message": "没有编辑权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 记录请求内容,方便调试
logger.info(f"请求内容: {request.data}")
logger.info(f"请求FILES: {request.FILES}")
# 获取上传的文件,尝试多种可能的字段名
files = []
# 尝试'files'字段(多文件)
if 'files' in request.FILES:
files = request.FILES.getlist('files')
# 尝试'file'字段(多文件)
elif 'file' in request.FILES:
files = request.FILES.getlist('file')
# 尝试files[]格式(常见于前端FormData)
elif any(key.startswith('files[') for key in request.FILES):
files = [file for key, file in request.FILES.items() if key.startswith('files[')]
# 尝试file[]格式
elif any(key.startswith('file[') for key in request.FILES):
files = [file for key, file in request.FILES.items() if key.startswith('file[')]
# 单个文件上传的情况
elif len(request.FILES) > 0:
# 如果有任何文件,就全部使用
files = list(request.FILES.values())
if not files:
return Response({
"code": 400,
"message": "未找到上传文件,请确保表单字段名为'files''file'",
"data": {
"available_fields": list(request.FILES.keys())
}
}, status=status.HTTP_400_BAD_REQUEST)
logger.info(f"接收到 {len(files)} 个文件上传请求")
# 保存所有处理后的文档
saved_documents = []
failed_documents = []
# 验证knowledge_base的external_id是否有效
if not instance.external_id:
return Response({
"code": 400,
"message": "知识库没有有效的external_id请先创建知识库",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 先验证外部知识库是否存在
try:
# 简单的验证请求
verify_url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}'
verify_response = requests.get(verify_url)
if verify_response.status_code != 200:
logger.error(f"外部知识库不存在或无法访问: {instance.external_id}, 状态码: {verify_response.status_code}")
return Response({
"code": 404,
"message": f"外部知识库不存在或无法访问: {instance.external_id}",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
verify_data = verify_response.json()
if verify_data.get('code') != 200:
logger.error(f"验证外部知识库失败: {verify_data.get('message')}")
return Response({
"code": verify_data.get('code', 500),
"message": f"验证外部知识库失败: {verify_data.get('message', '未知错误')}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
logger.info(f"外部知识库验证成功: {instance.external_id}")
except Exception as e:
logger.error(f"验证外部知识库时出错: {str(e)}")
return Response({
"code": 500,
"message": f"验证外部知识库时出错: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 逐个处理每个文件 - 避免一次性传多个文件导致外部API处理失败
for i, file in enumerate(files):
logger.info(f"处理第 {i+1} 个文件: {file.name}")
# 创建只包含当前文件的列表传递给分割API
current_file = [file]
# 调用文档分割API
split_response = self._call_split_api_multiple(current_file)
if not split_response or split_response.get('code') != 200:
error_msg = f"文件 {file.name} 分割失败: {split_response.get('message', '未知错误') if split_response else '请求失败'}"
logger.error(error_msg)
failed_documents.append({
"name": file.name,
"error": error_msg
})
continue
# 处理分割后的文档
documents_data = split_response.get('data', [])
# 如果没有文档数据,使用一个基本结构
if not documents_data:
logger.warning(f"文件 {file.name} 未返回文档数据,创建基本文档结构")
documents_data = [{
'name': file.name,
'content': [{
'title': '文档内容',
'content': '文件内容无法自动分割,请检查文件格式。'
}]
}]
# 遍历所有分割后的文档
for doc in documents_data:
doc_name = doc.get('name', file.name)
doc_content = doc.get('content', [])
logger.info(f"处理文档: {doc_name}, 包含 {len(doc_content)} 个段落")
# 如果没有内容,添加一个默认段落
if not doc_content:
doc_content = [{
'title': '文档内容',
'content': '文件内容无法自动分割,请检查文件格式。'
}]
# 准备文档数据结构
doc_data = {
"name": doc_name,
"paragraphs": []
}
# 将所有段落添加到文档中
for paragraph in doc_content:
doc_data["paragraphs"].append({
"content": paragraph.get('content', ''),
"title": paragraph.get('title', ''),
"is_active": True,
"problem_list": []
})
# 调用文档上传API
upload_response = self._call_upload_api(instance.external_id, doc_data)
if upload_response and upload_response.get('code') == 200 and upload_response.get('data'):
# 上传成功,保存记录到数据库
document_id = upload_response['data']['id']
doc_record = KnowledgeBaseDocument.objects.create(
knowledge_base=instance,
document_id=document_id,
document_name=doc_name,
external_id=document_id,
uploader_name=user.name
)
saved_documents.append({
"id": str(doc_record.id),
"name": doc_record.document_name,
"external_id": doc_record.external_id
})
logger.info(f"文档 '{doc_name}' 上传成功ID: {document_id}")
else:
# 上传失败,记录错误信息
error_msg = upload_response.get('message', '未知错误') if upload_response else '上传API调用失败'
logger.error(f"文档 '{doc_name}' 上传失败: {error_msg}")
failed_documents.append({
"name": doc_name,
"error": error_msg
})
# 返回结果
if saved_documents:
return Response({
"code": 200,
"message": f"文档上传完成,成功: {len(saved_documents)},失败: {len(failed_documents)}",
"data": {
"uploaded_count": len(saved_documents),
"failed_count": len(failed_documents),
"total_files": len(files),
"documents": saved_documents,
"failed_documents": failed_documents
}
})
else:
return Response({
"code": 400,
"message": f"所有文档上传失败",
"data": {
"uploaded_count": 0,
"failed_count": len(failed_documents),
"total_files": len(files),
"documents": [],
"failed_documents": failed_documents
}
}, status=status.HTTP_400_BAD_REQUEST)
except Exception as e:
logger.error(f"文档上传失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"文档上传失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _call_upload_api(self, external_id, doc_data):
"""调用文档上传API"""
try:
url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document'
logger.info(f"调用文档上传API: {url}")
# 记录请求数据,方便调试
logger.info(f"上传文档数据: 文档名={doc_data.get('name')}, 段落数={len(doc_data.get('paragraphs', []))}")
# 发送请求
response = requests.post(url, json=doc_data)
# 记录响应结果
logger.info(f"上传API响应状态码: {response.status_code}")
# 检查响应状态码
if response.status_code != 200:
logger.error(f"上传API HTTP错误: {response.status_code}, 响应: {response.text}")
return {
'code': response.status_code,
'message': f"上传失败HTTP状态码: {response.status_code}",
'data': None
}
# 解析响应JSON
result = response.json()
logger.info(f"上传API响应内容: {result}")
# 检查业务状态码
if result.get('code') != 200:
error_msg = result.get('message', '未知错误')
logger.error(f"上传API业务错误: {error_msg}")
return {
'code': result.get('code', 500),
'message': error_msg,
'data': None
}
return result
except requests.exceptions.RequestException as e:
logger.error(f"调用上传API网络错误: {str(e)}")
return {
'code': 500,
'message': f"网络请求错误: {str(e)}",
'data': None
}
except json.JSONDecodeError as e:
logger.error(f"解析API响应JSON失败: {str(e)}")
return {
'code': 500,
'message': f"解析响应数据失败: {str(e)}",
'data': None
}
except Exception as e:
logger.error(f"调用上传API其他错误: {str(e)}")
return {
'code': 500,
'message': f"上传API调用失败: {str(e)}",
'data': None
}
def _call_delete_document_api(self, external_id, document_id):
"""调用文档删除API"""
try:
url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document/{document_id}'
response = requests.delete(url)
return response.json()
except Exception as e:
logger.error(f"调用删除API失败: {str(e)}")
return None
def _create_external_dataset(self, instance):
"""创建外部知识库"""
try:
api_data = {
"name": instance.name,
"desc": instance.desc,
"type": "0", # 添加必要的type字段
"meta": {}, # 添加必要的meta字段
"documents": [] # 初始化为空列表
}
response = requests.post(
f'{settings.API_BASE_URL}/api/dataset',
json=api_data,
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
raise ExternalAPIError(f"创建失败,状态码: {response.status_code}, 响应: {response.text}")
api_response = response.json()
if not api_response.get('code') == 200:
raise ExternalAPIError(f"业务处理失败: {api_response.get('message', '未知错误')}")
dataset_id = api_response.get('data', {}).get('id')
if not dataset_id:
raise ExternalAPIError("响应数据中缺少dataset id")
return dataset_id
except requests.exceptions.Timeout:
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
raise ExternalAPIError(f"创建外部知识库失败: {str(e)}")
def _delete_external_dataset(self, external_id):
"""删除外部知识库"""
try:
if not external_id:
raise ExternalAPIError("外部知识库ID不能为空")
response = requests.delete(
f'{settings.API_BASE_URL}/api/dataset/{external_id}',
headers={'Content-Type': 'application/json'},
)
logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}")
# 检查响应状态码
if response.status_code == 404:
logger.warning(f"外部知识库不存在: {external_id}")
return True # 如果知识库不存在,也视为删除成功
elif response.status_code not in [200, 204]:
raise ExternalAPIError(f"删除失败,状态码: {response.status_code}, 响应: {response.text}")
# 如果是 204 状态码,说明删除成功但无返回内容
if response.status_code == 204:
logger.info(f"外部知识库删除成功: {external_id}")
return True
# 如果是 200 状态码,检查响应内容
try:
api_response = response.json()
if api_response.get('code') != 200:
raise ExternalAPIError(f"业务处理失败: {api_response.get('message', '未知错误')}")
logger.info(f"外部知识库删除成功: {external_id}")
return True
except ValueError:
# 如果无法解析 JSON但状态码是 200也认为成功
logger.warning(f"外部知识库删除响应无法解析JSON但状态码为200视为成功: {external_id}")
return True
except requests.exceptions.Timeout:
logger.error(f"删除外部知识库超时: {external_id}")
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}")
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}")
raise ExternalAPIError(f"删除外部知识库失败: {str(e)}")
@action(detail=True, methods=['get'])
def documents(self, request, pk=None):
"""获取知识库的文档列表"""
try:
instance = self.get_object()
user = request.user
# 权限检查
if not self.check_knowledge_base_permission(instance, user, 'read'):
return Response({
"code": 403,
"message": "没有查看权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 检查external_id是否存在
if not instance.external_id:
return Response({
"code": 400,
"message": "知识库没有有效的external_id",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 调用外部API获取文档列表
try:
url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}/document'
response = requests.get(
url,
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
logger.error(f"获取文档列表API调用失败: {response.status_code}, {response.text}")
return Response({
"code": 500,
"message": f"获取文档列表失败: HTTP {response.status_code}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
result = response.json()
if result.get('code') != 200:
logger.error(f"获取文档列表业务失败: {result.get('message')}")
return Response({
"code": result.get('code', 500),
"message": result.get('message', '获取文档列表失败'),
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 同步外部文档到本地数据库
external_documents = result.get('data', [])
for doc in external_documents:
# 获取外部文档ID和名称
external_id = doc.get('id')
doc_name = doc.get('name')
if external_id and doc_name:
# 检查文档是否已存在
kb_doc, created = KnowledgeBaseDocument.objects.update_or_create(
knowledge_base=instance,
external_id=external_id,
defaults={
'document_id': external_id,
'document_name': doc_name,
'status': 'active' if doc.get('is_active', True) else 'deleted'
}
)
if created:
logger.info(f"同步创建文档: {doc_name}, ID: {external_id}")
else:
logger.info(f"同步更新文档: {doc_name}, ID: {external_id}")
# 获取最新的本地文档数据
documents = KnowledgeBaseDocument.objects.filter(
knowledge_base=instance,
status='active'
).order_by('-create_time')
# 构建响应数据
documents_data = [{
"id": str(doc.id),
"document_id": doc.document_id,
"name": doc.document_name,
"external_id": doc.external_id,
"created_at": doc.create_time.strftime('%Y-%m-%d %H:%M:%S'),
# 添加外部API返回的额外信息
"char_length": next((d.get('char_length', 0) for d in external_documents if d.get('id') == doc.external_id), 0),
"paragraph_count": next((d.get('paragraph_count', 0) for d in external_documents if d.get('id') == doc.external_id), 0),
"is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True),
"uploader_name": doc.uploader_name
} for doc in documents]
return Response({
"code": 200,
"message": "获取文档列表成功",
"data": documents_data
})
except requests.exceptions.RequestException as e:
logger.error(f"获取文档列表网络异常: {str(e)}")
return Response({
"code": 500,
"message": f"获取文档列表失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e:
logger.error(f"获取文档列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"获取文档列表失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=True, methods=['get'])
def document_content(self, request, pk=None):
"""获取文档内容 - 段落列表"""
try:
knowledge_base = self.get_object()
user = request.user
# 权限检查
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
return Response({
"code": 403,
"message": "没有查看权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 获取文档ID
document_id = request.query_params.get('document_id')
if not document_id:
return Response({
"code": 400,
"message": "缺少document_id参数",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证文档存在
document = KnowledgeBaseDocument.objects.filter(
knowledge_base=knowledge_base,
document_id=document_id,
status='active'
).first()
if not document:
return Response({
"code": 404,
"message": "文档不存在或已删除",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
# 调用正确的外部API获取文档段落内容
try:
url = f'{settings.API_BASE_URL}/api/dataset/{knowledge_base.external_id}/document/{document.external_id}/paragraph'
response = requests.get(url)
if response.status_code != 200:
logger.error(f"获取文档段落内容失败: {response.status_code}, {response.text}")
return Response({
"code": 500,
"message": f"获取文档段落内容失败,状态码: {response.status_code}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
api_response = response.json()
if api_response.get('code') != 200:
logger.error(f"获取文档段落内容业务失败: {api_response.get('message')}")
return Response({
"code": api_response.get('code', 500),
"message": api_response.get('message', '获取文档段落内容失败'),
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
paragraphs = api_response.get('data', [])
# 直接返回外部API的段落数据
return Response({
"code": 200,
"message": "获取文档内容成功",
"data": {
"document_id": document_id,
"name": document.document_name,
"paragraphs": paragraphs
}
})
except Exception as e:
logger.error(f"获取文档段落内容API调用失败: {str(e)}")
return Response({
"code": 500,
"message": f"获取文档内容失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e:
logger.error(f"获取文档内容失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"获取文档内容失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=True, methods=['delete'])
def delete_document(self, request, pk=None):
"""删除知识库文档"""
try:
knowledge_base = self.get_object()
user = request.user
# 权限检查
if not self.check_knowledge_base_permission(knowledge_base, user, 'edit'):
return Response({
"code": 403,
"message": "没有编辑权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 获取文档ID
document_id = request.query_params.get('document_id')
if not document_id:
return Response({
"code": 400,
"message": "缺少document_id参数",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证文档存在
document = KnowledgeBaseDocument.objects.filter(
knowledge_base=knowledge_base,
document_id=document_id,
status='active'
).first()
if not document:
return Response({
"code": 404,
"message": "文档不存在或已删除",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
# 调用外部API删除文档
try:
external_id = document.external_id
delete_result = self._call_delete_document_api(knowledge_base.external_id, external_id)
# 无论外部API结果如何都更新本地状态
document.status = 'deleted'
document.save()
if delete_result and delete_result.get('code') != 200:
logger.warning(f"外部API删除文档失败但本地标记已更新: {delete_result.get('message')}")
return Response({
"code": 200,
"message": "文档删除成功",
"data": {
"document_id": document_id,
"name": document.document_name
}
})
except Exception as e:
logger.error(f"调用删除文档API失败: {str(e)}")
# 即使外部API调用失败也更新本地状态
document.status = 'deleted'
document.save()
return Response({
"code": 200,
"message": "文档在系统中已标记为删除但外部API调用失败",
"data": {
"document_id": document_id,
"name": document.document_name,
"error": str(e)
}
})
except Exception as e:
logger.error(f"删除文档失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"删除文档失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
class PermissionViewSet(viewsets.ModelViewSet):
serializer_class = PermissionSerializer
permission_classes = [IsAuthenticated]
def can_manage_knowledge_base(self, user, knowledge_base):
"""检查用户是否是知识库的创建者"""
return str(knowledge_base.user_id) == str(user.id)
def get_queryset(self):
"""
获取权限申请列表:
1. applicant_id 是当前用户 (看到自己发起的申请)
2. approver_id 是当前用户 (看到自己需要审批的申请)
"""
user_id = str(self.request.user.id)
# 构建查询条件:申请人是自己 或 审批人是自己
query = Q(applicant_id=user_id) | Q(approver_id=user_id)
return Permission.objects.filter(query).select_related(
'knowledge_base',
'applicant',
'approver'
)
def list(self, request, *args, **kwargs):
"""获取权限申请列表,包含详细信息"""
try:
queryset = self.get_queryset()
user_id = str(request.user.id)
# 获取分页参数
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
# 计算总数
total = queryset.count()
# 手动分页
start = (page - 1) * page_size
end = start + page_size
permissions = queryset[start:end]
# 构建响应数据
data = []
for permission in permissions:
# 检查当前用户是否是申请人或审批人
if user_id not in [str(permission.applicant_id), str(permission.approver_id)]:
continue
# 构建响应数据
permission_data = {
'id': str(permission.id),
'knowledge_base': {
'id': str(permission.knowledge_base.id),
'name': permission.knowledge_base.name,
'type': permission.knowledge_base.type,
},
'applicant': {
'id': str(permission.applicant.id),
'username': permission.applicant.username,
'name': permission.applicant.name,
'department': permission.applicant.department,
},
'approver': {
'id': str(permission.approver.id) if permission.approver else '',
'username': permission.approver.username if permission.approver else '',
'name': permission.approver.name if permission.approver else '',
'department': permission.approver.department if permission.approver else '',
},
'permissions': permission.permissions,
'status': permission.status,
'created_at': permission.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None,
'response_message': permission.response_message or '',
# 添加角色标识,用于前端展示
'role': 'applicant' if str(permission.applicant_id) == user_id else 'approver'
}
data.append(permission_data)
return Response({
'code': 200,
'message': '获取权限申请列表成功',
'data': {
'total': len(data), # 使用过滤后的实际数量
'page': page,
'page_size': page_size,
'results': data
}
})
except Exception as e:
logger.error(f"获取权限申请列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取权限申请列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def perform_create(self, serializer):
"""创建权限申请并发送通知给知识库创建者"""
# 获取知识库
# 获取知识库
knowledge_base = serializer.validated_data['knowledge_base']
# 检查是否是申请访问自己的知识库
if str(knowledge_base.user_id) == str(self.request.user.id):
raise ValidationError({
"code": 400,
"message": "您是此知识库的创建者,无需申请权限",
"data": None
})
# 获取知识库创建者作为审批者
approver = User.objects.get(id=knowledge_base.user_id)
# 验证权限请求
requested_permissions = serializer.validated_data.get('permissions', {})
expires_at = serializer.validated_data.get('expires_at')
if not any([requested_permissions.get('can_read'),
requested_permissions.get('can_edit'),
requested_permissions.get('can_delete')]):
raise ValidationError("至少需要申请一种权限(读/改/删)")
if not expires_at:
raise ValidationError("请指定权限到期时间")
# 检查是否已有未过期的权限申请
existing_request = Permission.objects.filter(
knowledge_base=knowledge_base,
applicant=self.request.user,
status='pending'
).first()
if existing_request:
raise ValidationError("您已有一个待处理的权限申请")
# 检查是否已有有效的权限
existing_permission = Permission.objects.filter(
knowledge_base=knowledge_base,
applicant=self.request.user,
status='approved',
expires_at__gt=timezone.now()
).first()
if existing_permission:
raise ValidationError("您已有此知识库的访问权限")
# 保存权限申请,设置审批者
permission = serializer.save(
applicant=self.request.user,
status='pending',
approver=approver # 创建时就设置审批者
)
# 获取权限类型字符串
permission_types = []
if requested_permissions.get('can_read'):
permission_types.append('读取')
if requested_permissions.get('can_edit'):
permission_types.append('编辑')
if requested_permissions.get('can_delete'):
permission_types.append('删除')
permission_str = ''.join(permission_types)
# 发送通知给知识库创建者
owner = User.objects.get(id=knowledge_base.user_id)
self.send_notification(
user=owner,
title="新的权限申请",
content=f"用户 {self.request.user.name} 申请了知识库 '{knowledge_base.name}'{permission_str}权限",
notification_type="permission_request",
related_object_id=permission.id
)
def send_notification(self, user, title, content, notification_type, related_object_id):
"""发送通知"""
try:
notification = Notification.objects.create(
sender=self.request.user,
receiver=user,
title=title,
content=content,
type=notification_type,
related_resource=related_object_id,
)
# 通过WebSocket发送实时通知
channel_layer = get_channel_layer()
async_to_sync(channel_layer.group_send)(
f"notification_user_{user.id}",
{
"type": "notification",
"data": {
"id": str(notification.id),
"title": notification.title,
"content": notification.content,
"type": notification.type,
"created_at": notification.created_at.isoformat(),
"sender": {
"id": str(notification.sender.id),
"name": notification.sender.name
}
}
}
)
except Exception as e:
logger.error(f"发送通知时发生错误: {str(e)}")
@action(detail=True, methods=['post'])
def approve(self, request, pk=None):
try:
# 获取权限申请记录
permission = self.get_object()
# 只检查是否是知识库创建者
if not self.can_manage_knowledge_base(request.user, permission.knowledge_base):
logger.warning(f"用户 {request.user.username} 尝试审批知识库 {permission.knowledge_base.name} 的权限申请,但不是创建者")
return Response({
'code': 403,
'message': '只有知识库创建者可以审批此申请',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 获取审批意见
response_message = request.data.get('response_message', '')
with transaction.atomic():
# 更新权限申请状态
permission.status = 'approved'
permission.approver = request.user
permission.response_message = response_message
permission.save()
# 检查是否已存在权限记录
kb_permission = KBPermissionModel.objects.filter(
knowledge_base=permission.knowledge_base,
user=permission.applicant
).first()
if kb_permission:
# 更新现有权限
kb_permission.can_read = permission.permissions.get('can_read', False)
kb_permission.can_edit = permission.permissions.get('can_edit', False)
kb_permission.can_delete = permission.permissions.get('can_delete', False)
kb_permission.granted_by = request.user
kb_permission.status = 'active'
kb_permission.expires_at = permission.expires_at
kb_permission.save()
logger.info(f"更新知识库权限记录: {kb_permission.id}")
else:
# 创建新的权限记录
kb_permission = KBPermissionModel.objects.create(
knowledge_base=permission.knowledge_base,
user=permission.applicant,
can_read=permission.permissions.get('can_read', False),
can_edit=permission.permissions.get('can_edit', False),
can_delete=permission.permissions.get('can_delete', False),
granted_by=request.user,
status='active',
expires_at=permission.expires_at
)
logger.info(f"创建新的知识库权限记录: {kb_permission.id}")
# 发送通知给申请人
self.send_notification(
user=permission.applicant,
title="权限申请已通过",
content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已通过",
notification_type="permission_approved",
related_object_id=permission.id
)
return Response({
'code': 200,
'message': '权限申请已批准',
'data': None
})
except Permission.DoesNotExist:
return Response({
'code': 404,
'message': '权限申请不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
logger.error(f"处理权限申请失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'处理权限申请失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=True, methods=['post'])
def reject(self, request, pk=None):
"""拒绝权限申请"""
permission = self.get_object()
# 检查是否是知识库创建者
if str(permission.knowledge_base.user_id) != str(request.user.id):
return Response({
'code': 403,
'message': '只有知识库创建者可以审批此申请',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 检查申请是否已被处理
if permission.status != 'pending':
return Response({
'code': 400,
'message': '该申请已被处理',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证拒绝原因
response_message = request.data.get('response_message')
if not response_message:
return Response({
'code': 400,
'message': '请填写拒绝原因',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 更新权限状态
permission.status = 'rejected'
permission.approver = request.user
permission.response_message = response_message
permission.save()
# 发送通知给申请人
self.send_notification(
user=permission.applicant,
title="权限申请已拒绝",
content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已被拒绝\n"
f"拒绝原因:{response_message}",
notification_type="permission_rejected",
related_object_id=permission.id
)
return Response({
'code': 200,
'message': '权限申请已拒绝',
'data': PermissionSerializer(permission).data
})
@action(detail=True, methods=['post'])
def extend(self, request, pk=None):
"""延长权限有效期"""
instance = self.get_object()
user = request.user
# 检查是否有权限延长
if not self.check_extend_permission(instance, user):
return Response({
"code": 403,
"message": "您没有权限延长此权限",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
new_expires_at = request.data.get('expires_at')
if not new_expires_at:
return Response({
"code": 400,
"message": "请设置新的过期时间",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
try:
with transaction.atomic():
# 更新权限申请表的过期时间
instance.expires_at = new_expires_at
instance.save()
# 同步更新知识库权限表的过期时间
kb_permission = KBPermissionModel.objects.get(
knowledge_base=instance.knowledge_base,
user=instance.applicant
)
kb_permission.expires_at = new_expires_at
kb_permission.save()
return Response({
"code": 200,
"message": "权限有效期延长成功",
"data": PermissionSerializer(instance).data
})
except Exception as e:
return Response({
"code": 500,
"message": f"延长权限有效期失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def check_extend_permission(self, permission, user):
"""检查是否有权限延长权限有效期"""
knowledge_base = permission.knowledge_base
# 私人知识库只有拥有者能延长
if knowledge_base.type == 'private':
return knowledge_base.owner == user
# 组长知识库只有管理员能延长
if knowledge_base.type == 'leader':
return user.role == 'admin'
# 组员知识库可以由管理员或本部门组长延长
if knowledge_base.type == 'member':
return (
user.role == 'admin' or
(user.role == 'leader' and user.department == knowledge_base.department)
)
return False
@action(detail=False, methods=['get'])
def user_permissions(self, request):
"""获取指定用户的所有知识库权限"""
try:
# 获取用户名参数
username = request.query_params.get('username')
if not username:
return Response({
'code': 400,
'message': '请提供用户名',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取用户
try:
target_user = User.objects.get(username=username)
except User.DoesNotExist:
return Response({
'code': 404,
'message': f'用户 {username} 不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 获取该用户的所有权限记录
permissions = KBPermissionModel.objects.filter(
user=target_user,
status='active'
).select_related('knowledge_base', 'granted_by')
# 构建响应数据
permissions_data = []
for perm in permissions:
perm_data = {
'id': str(perm.id),
'knowledge_base': {
'id': str(perm.knowledge_base.id),
'name': perm.knowledge_base.name,
'type': perm.knowledge_base.type,
'department': perm.knowledge_base.department,
'group': perm.knowledge_base.group
},
'permissions': {
'can_read': perm.can_read,
'can_edit': perm.can_edit,
'can_delete': perm.can_delete
},
'granted_by': {
'id': str(perm.granted_by.id) if perm.granted_by else None,
'username': perm.granted_by.username if perm.granted_by else None,
'name': perm.granted_by.name if perm.granted_by else None
},
'created_at': perm.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None,
'status': perm.status
}
permissions_data.append(perm_data)
return Response({
'code': 200,
'message': '获取用户权限成功',
'data': {
'user': {
'id': str(target_user.id),
'username': target_user.username,
'name': target_user.name,
'department': target_user.department,
'role': target_user.role
},
'permissions': permissions_data
}
})
except Exception as e:
logger.error(f"获取用户权限失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取用户权限失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def all_permissions(self, request):
"""管理员获取所有用户的知识库权限(不包括私有知识库)"""
try:
# 检查是否是管理员
if request.user.role != 'admin':
return Response({
'code': 403,
'message': '只有管理员可以查看所有权限',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 获取查询参数
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
status_filter = request.query_params.get('status')
department = request.query_params.get('department')
kb_type = request.query_params.get('kb_type')
# 构建基础查询
queryset = KBPermissionModel.objects.filter(
~Q(knowledge_base__type='private')
).select_related(
'user',
'knowledge_base',
'granted_by'
)
# 应用过滤条件
if status_filter == 'active':
queryset = queryset.filter(
Q(expires_at__gt=timezone.now()) | Q(expires_at__isnull=True),
status='active'
)
elif status_filter == 'expired':
queryset = queryset.filter(
Q(expires_at__lte=timezone.now()) | Q(status='inactive')
)
if department:
queryset = queryset.filter(user__department=department)
if kb_type:
queryset = queryset.filter(knowledge_base__type=kb_type)
# 按用户分组处理数据
user_permissions = {}
for perm in queryset:
user_id = str(perm.user.id)
if user_id not in user_permissions:
user_permissions[user_id] = {
'user_info': {
'id': user_id,
'username': perm.user.username,
'name': getattr(perm.user, 'name', perm.user.username),
'department': getattr(perm.user, 'department', None),
'role': getattr(perm.user, 'role', None)
},
'permissions': [],
'stats': {
'total': 0,
'by_type': {
'admin': 0,
'secret': 0,
'leader': 0,
'member': 0
},
'by_permission': {
'read_only': 0,
'read_write': 0,
'full_access': 0
}
}
}
# 添加权限信息
perm_data = {
'id': str(perm.id),
'knowledge_base': {
'id': str(perm.knowledge_base.id),
'name': perm.knowledge_base.name,
'type': perm.knowledge_base.type,
'department': perm.knowledge_base.department,
'group': perm.knowledge_base.group,
'creator': {
'id': str(perm.knowledge_base.user_id),
'name': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'name', None),
'username': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'username', None)
}
},
'permissions': {
'can_read': perm.can_read,
'can_edit': perm.can_edit,
'can_delete': perm.can_delete
},
'granted_by': {
'id': str(perm.granted_by.id) if perm.granted_by else None,
'username': perm.granted_by.username if perm.granted_by else None,
'name': getattr(perm.granted_by, 'name', None) if perm.granted_by else None
},
'granted_at': perm.granted_at.strftime('%Y-%m-%d %H:%M:%S'),
'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None,
'status': perm.status
}
user_permissions[user_id]['permissions'].append(perm_data)
# 更新统计信息
stats = user_permissions[user_id]['stats']
stats['total'] += 1
stats['by_type'][perm.knowledge_base.type] += 1
# 统计权限级别
if perm.can_delete:
stats['by_permission']['full_access'] += 1
elif perm.can_edit:
stats['by_permission']['read_write'] += 1
elif perm.can_read:
stats['by_permission']['read_only'] += 1
# 转换为列表并分页
users_list = list(user_permissions.values())
total = len(users_list)
start = (page - 1) * page_size
end = start + page_size
paginated_users = users_list[start:end]
return Response({
'code': 200,
'message': '获取权限列表成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'results': paginated_users
}
})
except Exception as e:
logger.error(f"获取所有权限失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取所有权限失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['post'])
def update_permission(self, request):
"""管理员更新用户的知识库权限"""
try:
# 检查是否是管理员
if request.user.role != 'admin':
return Response({
'code': 403,
'message': '只有管理员可以直接修改权限',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 验证必要参数
user_id = request.data.get('user_id')
knowledge_base_id = request.data.get('knowledge_base_id')
permissions = request.data.get('permissions')
expires_at_str = request.data.get('expires_at')
if not all([user_id, knowledge_base_id, permissions]):
return Response({
'code': 400,
'message': '缺少必要参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证权限参数格式
required_permission_fields = ['can_read', 'can_edit', 'can_delete']
if not all(field in permissions for field in required_permission_fields):
return Response({
'code': 400,
'message': '权限参数格式错误,必须包含 can_read、can_edit、can_delete',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取用户和知识库
try:
user = User.objects.get(id=user_id)
knowledge_base = KnowledgeBase.objects.get(id=knowledge_base_id)
except User.DoesNotExist:
return Response({
'code': 404,
'message': f'用户ID {user_id} 不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
except KnowledgeBase.DoesNotExist:
return Response({
'code': 404,
'message': f'知识库ID {knowledge_base_id} 不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 检查知识库类型和用户角色的匹配
if knowledge_base.type == 'private' and str(knowledge_base.user_id) != str(user.id):
return Response({
'code': 403,
'message': '不能修改其他用户的私有知识库权限',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 处理过期时间
expires_at = None
if expires_at_str:
try:
# 将字符串转换为datetime对象
expires_at = timezone.datetime.strptime(
expires_at_str,
'%Y-%m-%dT%H:%M:%SZ'
)
# 确保时区感知
expires_at = timezone.make_aware(expires_at)
# 检查是否早于当前时间
if expires_at <= timezone.now():
return Response({
'code': 400,
'message': '过期时间不能早于或等于当前时间',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
except ValueError:
return Response({
'code': 400,
'message': '过期时间格式错误,应为 ISO 格式 (YYYY-MM-DDThh:mm:ssZ)',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 根据用户角色限制权限
if user.role == 'member' and permissions.get('can_delete'):
return Response({
'code': 400,
'message': '普通成员不能获得删除权限',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 更新或创建权限记录
try:
with transaction.atomic():
permission, created = KBPermissionModel.objects.update_or_create(
user=user,
knowledge_base=knowledge_base,
defaults={
'can_read': permissions.get('can_read', False),
'can_edit': permissions.get('can_edit', False),
'can_delete': permissions.get('can_delete', False),
'granted_by': request.user,
'status': 'active',
'expires_at': expires_at
}
)
# 发送通知给用户
self.send_notification(
user=user,
title="知识库权限更新",
content=f"管理员已{created and '授予' or '更新'}您对知识库 '{knowledge_base.name}' 的权限",
notification_type="permission_updated",
related_object_id=permission.id
)
except IntegrityError as e:
return Response({
'code': 500,
'message': f'数据库操作失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response({
'code': 200,
'message': f"{'创建' if created else '更新'}权限成功",
'data': {
'id': str(permission.id),
'user': {
'id': str(user.id),
'username': user.username,
'name': user.name,
'department': user.department,
'role': user.role
},
'knowledge_base': {
'id': str(knowledge_base.id),
'name': knowledge_base.name,
'type': knowledge_base.type,
'department': knowledge_base.department,
'group': knowledge_base.group
},
'permissions': {
'can_read': permission.can_read,
'can_edit': permission.can_edit,
'can_delete': permission.can_delete
},
'granted_by': {
'id': str(request.user.id),
'username': request.user.username,
'name': request.user.name
},
'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None,
'created': created
}
})
except Exception as e:
logger.error(f"更新权限失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'更新权限失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
class NotificationViewSet(viewsets.ModelViewSet):
"""通知视图集"""
queryset = Notification.objects.all()
serializer_class = NotificationSerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
"""只返回用户自己的通知"""
return Notification.objects.filter(receiver=self.request.user)
@action(detail=True, methods=['post'])
def mark_as_read(self, request, pk=None):
"""标记通知为已读"""
notification = self.get_object()
notification.is_read = True
notification.save()
return Response({'status': 'marked as read'})
@action(detail=False, methods=['post'])
def mark_all_as_read(self, request):
"""标记所有通知为已读"""
self.get_queryset().update(is_read=True)
return Response({'status': 'all marked as read'})
@action(detail=False, methods=['get'])
def unread_count(self, request):
"""获取未读通知数量"""
count = self.get_queryset().filter(is_read=False).count()
return Response({'unread_count': count})
@action(detail=False, methods=['get'])
def latest(self, request):
"""获取最新通知"""
notifications = self.get_queryset().filter(
is_read=False
).order_by('-created_at')[:5]
serializer = self.get_serializer(notifications, many=True)
return Response(serializer.data)
def perform_create(self, serializer):
"""创建通知时自动设置发送者"""
serializer.save(sender=self.request.user)
@method_decorator(csrf_exempt, name='dispatch')
class LoginView(APIView):
"""用户登录视图"""
authentication_classes = [] # 清空认证类
permission_classes = [AllowAny]
def post(self, request):
try:
username = request.data.get('username')
password = request.data.get('password')
# 参数验证
if not username or not password:
return Response({
"code": 400,
"message": "请提供用户名和密码",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证用户
user = authenticate(request, username=username, password=password)
if user is not None:
# 获取或创建token
token, _ = Token.objects.get_or_create(user=user)
# 登录用户(可选)
login(request, user)
return Response({
"code": 200,
"message": "登录成功",
"data": {
"id": str(user.id),
"username": user.username,
"email": user.email,
"name": user.name,
"role": user.role,
"department": user.department,
"group": user.group,
"token": token.key
}
})
else:
return Response({
"code": 401,
"message": "用户名或密码错误",
"data": None
}, status=status.HTTP_401_UNAUTHORIZED)
except Exception as e:
import traceback
logger.error(f"登录失败: {str(e)}")
logger.error(f"错误类型: {type(e)}")
logger.error(f"错误堆栈: {traceback.format_exc()}")
return Response({
"code": 500,
"message": "登录失败,请稍后重试",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@method_decorator(csrf_exempt, name='dispatch')
class RegisterView(APIView):
"""用户注册视图"""
permission_classes = [AllowAny]
def post(self, request):
try:
data = request.data
# 检查必填字段
required_fields = ['username', 'password', 'email', 'role', 'name']
for field in required_fields:
if not data.get(field):
return Response({
"code": 400,
"message": f"缺少必填字段: {field}",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证角色
valid_roles = ['admin', 'leader', 'member']
roles_str = ', '.join(valid_roles) # 先构造角色字符串
if data['role'] not in valid_roles:
return Response({
"code": 400,
"message": f"无效的角色,必须是: {roles_str}",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查用户名是否已存在
if User.objects.filter(username=data['username']).exists():
return Response({
"code": 400,
"message": "用户名已存在",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查邮箱是否已存在
if User.objects.filter(email=data['email']).exists():
return Response({
"code": 400,
"message": "邮箱已被注册",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证密码强度
if len(data['password']) < 8:
return Response({
"code": 400,
"message": "密码长度必须至少为8位",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证邮箱格式
try:
validate_email(data['email'])
except ValidationError:
return Response({
"code": 400,
"message": "邮箱格式不正确",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建用户
user = User.objects.create_user(
username=data['username'],
email=data['email'],
password=data['password'],
role=data['role'],
department=data.get('department'), # 不再强制要求部门
name=data['name'],
group=data.get('group'), # 不再强制要求小组
is_staff=False,
is_superuser=False
)
# 生成认证令牌
token, _ = Token.objects.get_or_create(user=user)
return Response({
"code": 200,
"message": "注册成功",
"data": {
"id": str(user.id),
"username": user.username,
"email": user.email,
"role": user.role,
"department": user.department,
"name": user.name,
"group": user.group,
"token": token.key,
"created_at": user.date_joined.strftime('%Y-%m-%d %H:%M:%S')
}
}, status=status.HTTP_201_CREATED)
except Exception as e:
print(f"注册失败: {str(e)}")
print(f"错误类型: {type(e)}")
print(f"错误堆栈: {traceback.format_exc()}")
return Response({
"code": 500,
"message": f"注册失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@method_decorator(csrf_exempt, name='dispatch')
class LogoutView(APIView):
"""用户登出视图"""
permission_classes = [IsAuthenticated]
def post(self, request):
try:
# 删除用户的token
request.user.auth_token.delete()
# 执行django的登出
logout(request)
return Response({
"code": 200,
"message": "登出成功",
"data": None
})
except Exception as e:
return Response({
"code": 500,
"message": f"登出失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET', 'PUT'])
@permission_classes([IsAuthenticated])
def user_profile(request):
"""
获取或更新当前登录用户信息
此接口与user_update的区别:
1. 任何已认证用户可访问
2. 仅能更新当前登录用户自己的信息
3. 不能修改角色等重要字段
4. 不需要指定用户ID自动使用当前用户
"""
try:
if request.method == 'GET':
# 检查用户是否已认证
user = request.user
if not user.is_authenticated:
return Response({
'code': 401,
'message': '用户未认证',
'data': None
}, status=status.HTTP_401_UNAUTHORIZED)
data = {
'id': str(user.id),
'username': user.username,
'email': user.email,
'name': user.name,
'role': user.role,
'department': user.department,
'group': user.group,
'date_joined': user.date_joined.strftime('%Y-%m-%d %H:%M:%S')
}
return Response({
'code': 200,
'message': '获取用户信息成功',
'data': data
})
elif request.method == 'PUT':
# 检查请求数据格式
try:
if not request.data:
return Response({
'code': 400,
'message': '请求数据为空或格式错误',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
except Exception as data_error:
return Response({
'code': 400,
'message': f'请求数据格式错误: {str(data_error)}。请确保提交的是有效的JSON格式数据属性名必须使用双引号。',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
user = request.user
# 只允许更新特定字段
allowed_fields = ['email', 'name', 'phone', 'department', 'group']
updated_fields = []
for field in allowed_fields:
if field in request.data:
setattr(user, field, request.data[field])
updated_fields.append(field)
if updated_fields:
try:
user.save()
return Response({
'code': 200,
'message': f'用户信息更新成功,已更新字段: {", ".join(updated_fields)}',
'data': {
'id': str(user.id),
'username': user.username,
'email': user.email,
'name': user.name,
'role': user.role,
'department': user.department,
'group': user.group,
}
})
except Exception as save_error:
logger.error(f"保存用户数据失败: {str(save_error)}")
return Response({
'code': 500,
'message': f'更新用户信息失败: {str(save_error)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
else:
return Response({
'code': 400,
'message': '没有提供任何可更新的字段',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
else:
return Response({
'code': 405,
'message': f'不支持的请求方法: {request.method}',
'data': None
}, status=status.HTTP_405_METHOD_NOT_ALLOWED)
except Exception as e:
logger.error(f"处理用户信息请求失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'处理请求失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@csrf_exempt
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def change_password(request):
"""修改密码"""
try:
old_password = request.data.get('old_password')
new_password = request.data.get('new_password')
# 验证参数
if not old_password or not new_password:
return Response({
"code": 400,
"message": "请提供旧密码和新密码",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证旧密码
user = request.user
if not user.check_password(old_password):
return Response({
"code": 400,
"message": "旧密码错误",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证新密码长度
if len(new_password) < 8:
return Response({
"code": 400,
"message": "新密码长度必须至少为8位",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 修改密码
user.set_password(new_password)
user.save()
# 更新token
user.auth_token.delete()
token, _ = Token.objects.get_or_create(user=user)
return Response({
"code": 200,
"message": "密码修改成功",
"data": {
"token": token.key
}
})
except Exception as e:
return Response({
"code": 500,
"message": f"密码修改失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([AllowAny])
def user_register(request):
"""
[已弃用] 用户注册 - 请使用 RegisterView 类代替
此函数仅保留用于兼容性目的,新代码应该使用 /api/auth/register/ 接口
"""
# 打印弃用警告
logger.warning("使用已弃用的user_register函数请改用RegisterView类")
try:
data = request.data
# 检查必填字段
required_fields = ['username', 'password', 'email', 'role', 'name']
for field in required_fields:
if not data.get(field):
return Response({
'code': 400,
'message': f'缺少必填字段: {field}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证角色
valid_roles = ['admin', 'leader', 'member']
if data['role'] not in valid_roles:
return Response({
'code': 400,
'message': f'无效的角色,必须是: {", ".join(valid_roles)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查用户名是否已存在
if User.objects.filter(username=data['username']).exists():
return Response({
'code': 400,
'message': '用户名已存在',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查邮箱是否已存在
if User.objects.filter(email=data['email']).exists():
return Response({
'code': 400,
'message': '邮箱已被注册',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证密码强度
if len(data['password']) < 8:
return Response({
'code': 400,
'message': '密码长度必须至少为8位',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证邮箱格式
try:
validate_email(data['email'])
except ValidationError:
return Response({
'code': 400,
'message': '邮箱格式不正确',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建用户
user = User.objects.create_user(
username=data['username'],
email=data['email'],
password=data['password'],
role=data['role'],
department=data.get('department'), # 不再强制要求部门
name=data['name'],
group=data.get('group'), # 不再强制要求小组
is_staff=False,
is_superuser=False
)
# 生成认证令牌
token, _ = Token.objects.get_or_create(user=user)
return Response({
'code': 200,
'message': '注册成功',
'data': {
'id': str(user.id),
'username': user.username,
'email': user.email,
'role': user.role,
'department': user.department,
'name': user.name,
'group': user.group,
'token': token.key,
'created_at': user.date_joined.strftime('%Y-%m-%d %H:%M:%S')
}
}, status=status.HTTP_201_CREATED)
except Exception as e:
logger.error(f"注册失败: {str(e)}")
logger.error(f"错误类型: {type(e)}")
logger.error(f"错误堆栈: {traceback.format_exc()}")
return Response({
'code': 500,
'message': f'注册失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def user_detail(request, pk):
"""获取用户详情"""
try:
# 尝试转换为 UUID处理多种可能的格式
try:
if not isinstance(pk, uuid.UUID):
# 移除所有空格,以防万一
pk = pk.strip()
# 处理带连字符和不带连字符的格式
if '-' not in pk and len(pk) == 32:
# 转换没有连字符的UUID格式
pk_with_hyphens = f"{pk[0:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}"
pk = uuid.UUID(pk_with_hyphens)
else:
# 尝试直接转换
pk = uuid.UUID(pk)
except ValueError:
# 提供更详细的错误信息
return Response({
"code": 400,
"message": f"无效的用户ID格式: {pk}。用户ID应为有效的UUID格式。",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
user = get_object_or_404(User, pk=pk)
return Response({
"code": 200,
"message": "获取用户信息成功",
"data": {
"id": str(user.id),
"username": user.username,
"email": user.email,
"name": user.name,
"role": user.role,
"department": user.department,
"group": user.group
}
})
except Exception as e:
logger.error(f"获取用户信息失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"获取用户信息失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['PUT'])
@permission_classes([IsAdminUser])
def user_update(request, pk):
"""
管理员更新用户信息
此接口与user_profile的区别:
1. 仅管理员可访问
2. 可以更新任何用户的信息
3. 可以修改角色等重要字段
4. 需要在URL中指定用户ID
"""
try:
# 检查请求数据格式
try:
if not request.data:
return Response({
'code': 400,
'message': '请求数据为空或格式错误',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
except Exception as data_error:
return Response({
'code': 400,
'message': f'请求数据格式错误: {str(data_error)}。请确保提交的是有效的JSON格式数据属性名必须使用双引号。',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 尝试转换为 UUID
try:
if not isinstance(pk, uuid.UUID):
pk = pk.strip()
# 处理带连字符和不带连字符的格式
if '-' not in pk and len(pk) == 32:
# 转换没有连字符的UUID格式
pk_with_hyphens = f"{pk[0:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}"
pk = uuid.UUID(pk_with_hyphens)
else:
# 尝试直接转换
pk = uuid.UUID(pk)
except ValueError:
return Response({
"code": 400,
"message": f"无效的用户ID格式: {pk}。用户ID应为有效的UUID格式。",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
try:
user = User.objects.get(pk=pk)
except User.DoesNotExist:
return Response({
'code': 404,
'message': '用户不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 只允许更新特定字段
allowed_fields = ['email', 'role', 'department', 'group', 'is_active', 'phone', 'name']
updated_fields = []
for field in allowed_fields:
if field in request.data:
setattr(user, field, request.data[field])
updated_fields.append(field)
if not updated_fields:
return Response({
'code': 400,
'message': '没有提供任何可更新的字段',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
user.save()
return Response({
'code': 200,
'message': '用户信息更新成功',
'data': {
'id': str(user.id),
'username': user.username,
'email': user.email,
'name': user.name,
'role': user.role,
'department': user.department,
'group': user.group,
'is_active': user.is_active
}
})
except Exception as e:
logger.error(f"更新用户信息失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'更新用户信息失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['DELETE'])
@permission_classes([IsAdminUser])
def user_delete(request, pk):
"""删除用户"""
try:
# 尝试转换为 UUID
try:
if not isinstance(pk, uuid.UUID):
pk = pk.strip()
# 处理带连字符和不带连字符的格式
if '-' not in pk and len(pk) == 32:
# 转换没有连字符的UUID格式
pk_with_hyphens = f"{pk[0:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}"
pk = uuid.UUID(pk_with_hyphens)
else:
# 尝试直接转换
pk = uuid.UUID(pk)
except ValueError:
return Response({
"code": 400,
"message": f"无效的用户ID格式: {pk}。用户ID应为有效的UUID格式。",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
try:
user = User.objects.get(pk=pk)
except User.DoesNotExist:
return Response({
'code': 404,
'message': '用户不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 检查是否试图删除管理员账户
if user.is_superuser or user.role == 'admin':
return Response({
'code': 403,
'message': '不允许删除管理员账户',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 删除用户
username = user.username
user.delete()
return Response({
'code': 200,
'message': f'用户 {username} 删除成功',
'data': None
})
except Exception as e:
logger.error(f"删除用户失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'删除用户失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@csrf_exempt
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def verify_token(request):
"""验证令牌有效性"""
try:
return Response({
"code": 200,
"message": "令牌有效",
"data": {
"is_valid": True,
"user": {
"id": str(request.user.id),
"username": request.user.username,
"email": request.user.email,
"name": request.user.name,
"role": request.user.role,
"department": request.user.department,
"group": request.user.group
}
}
})
except Exception as e:
logger.error(f"验证令牌失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": f"验证失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def user_list(request):
"""获取用户列表"""
try:
# 获取查询参数
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 20))
keyword = request.query_params.get('keyword', '')
# 根据用户角色获取不同范围的用户列表
user = request.user
base_query = User.objects.all()
if user.role == 'admin':
users_query = base_query
elif user.role == 'leader':
users_query = base_query.filter(department=user.department)
else:
users_query = base_query.filter(id=user.id)
# 添加关键字搜索
if keyword:
users_query = users_query.filter(
Q(username__icontains=keyword) |
Q(email__icontains=keyword) |
Q(name__icontains=keyword) |
Q(department__icontains=keyword)
)
# 计算总数
total = users_query.count()
# 分页
start = (page - 1) * page_size
end = start + page_size
users = users_query[start:end]
# 构造数据
user_data = []
for u in users:
user_data.append({
'id': str(u.id),
'username': u.username,
'email': u.email,
'name': u.name,
'role': u.role,
'department': u.department,
'group': u.group,
'is_active': u.is_active,
'date_joined': u.date_joined.strftime('%Y-%m-%d %H:%M:%S')
})
return Response({
'code': 200,
'message': '获取用户列表成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'users': user_data
}
})
except ValueError as e:
# 处理参数转换错误
return Response({
'code': 400,
'message': f'参数错误: {str(e)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
except Exception as e:
logger.error(f"获取用户列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取用户列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework import status
from django.http import FileResponse
import os
from .gmail_integration import GmailIntegration
import traceback
import logging
from django.utils import timezone # 导入timezone模块
from django.conf import settings # 导入settings以检查USE_TZ
logger = logging.getLogger(__name__)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def setup_gmail_integration(request):
"""
设置Gmail集成并加载邮件到知识库
"""
user = request.user
logger.info(f"用户 {user.username} 请求设置Gmail集成")
request_data = request.data
# 检查必填字段
required_fields = ['client_secret_json', 'talent_gmail']
for field in required_fields:
if field not in request_data:
logger.warning(f"缺少必填字段: {field}")
return Response({
'code': 400,
'message': f"缺少必填字段: {field}",
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
client_secret_json = request_data.get('client_secret_json')
talent_gmail = request_data.get('talent_gmail')
logger.info(f"目标Gmail: {talent_gmail}")
# 可选的授权码(如果用户已经获取了授权码)
auth_code = request_data.get('auth_code')
has_auth_code = bool(auth_code)
# 获取可选的代理设置
use_proxy = request_data.get('use_proxy', True)
proxy_url = request_data.get('proxy_url', 'http://127.0.0.1:7890')
# 初始化Gmail集成
gmail_integration = GmailIntegration(
user=user,
email=talent_gmail,
client_secret_json=client_secret_json,
use_proxy=use_proxy,
proxy_url=proxy_url
)
try:
if auth_code:
# 处理授权码完成OAuth流程
logger.info("使用授权码完成Gmail认证")
auth_success = gmail_integration.handle_auth_code(auth_code)
if not auth_success:
logger.error("授权码认证失败")
return Response({
'code': 400,
'message': "授权码认证失败,请确保授权码正确",
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
logger.info("授权码认证成功")
else:
# 尝试常规认证
try:
gmail_integration.authenticate()
logger.info("Gmail认证成功")
except Exception as e:
# 如果异常中包含授权URL返回给前端引导用户授权
error_message = str(e)
if "Please visit this URL" in error_message:
# 从错误消息中提取URL
auth_url = error_message.split("Please visit this URL to authorize: ")[1]
logger.info("需要用户授权返回授权URL")
return Response({
'code': 202,
'message': "请访问提供的URL获取授权码然后与client_secret_json和talent_gmail一起提交",
'data': {
"status": "authorization_required",
"auth_url": auth_url
}
}, status=status.HTTP_202_ACCEPTED)
else:
# 其他错误
logger.error(f"Gmail认证失败: {error_message}")
return Response({
'code': 400,
'message': f"Gmail认证失败: {error_message}",
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建或获取talent知识库
logger.info(f"创建或获取Gmail知识库: {talent_gmail}")
knowledge_base, created = gmail_integration.create_talent_knowledge_base(talent_gmail)
kb_action = "创建" if created else "获取"
logger.info(f"知识库{kb_action}成功: {knowledge_base.id}")
# 获取邮件对话
logger.info(f"获取与Gmail: {talent_gmail} 的邮件对话")
conversations = gmail_integration.get_conversations(talent_gmail)
conversation_count = len(conversations)
logger.info(f"获取到 {conversation_count} 个邮件对话")
if not conversations:
logger.warning("没有找到Gmail邮件对话")
# 生成随机会话ID
conversation_id = f"conv_{uuid.uuid4().hex[:10]}"
return Response({
'code': 200,
'message': f"没有找到与 {talent_gmail} 的邮件对话,请确保您有与该地址的邮件往来",
'data': {
"knowledge_base_id": str(knowledge_base.id),
"conversation_id": conversation_id,
"troubleshooting": {
"check_emails": f"请确认Gmail中是否有与 {talent_gmail} 的往来邮件",
"verify_address": "请确认邮箱地址拼写正确",
"check_permissions": "请确保已授权完整的邮箱访问权限"
}
}
}, status=status.HTTP_200_OK)
# 保存对话到知识库
logger.info(f"{conversation_count} 个邮件对话保存到知识库")
result = gmail_integration.save_conversations_to_knowledge_base(conversations, knowledge_base)
conversation_id = result.get('conversation_id')
logger.info(f"对话已保存到知识库ID: {conversation_id}")
return Response({
'code': 200,
'message': f"Gmail集成成功。已加载与 {talent_gmail}{conversation_count} 个邮件对话到知识库。",
'data': {
"knowledge_base_id": str(knowledge_base.id),
"conversation_id": conversation_id
}
}, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"设置Gmail集成失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"设置Gmail集成失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def send_gmail_message(request):
"""通过Gmail发送消息给达人"""
try:
data = request.data
conversation_id = data.get('conversation_id')
to_email = data.get('to_email')
subject = data.get('subject')
body = data.get('body')
# 验证必填字段
if not all([conversation_id, to_email, subject, body]):
return Response({
'code': 400,
'message': '缺少必填字段: conversation_id, to_email, subject, body',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 查找对话记录,确认用户有权限 - 确保不使用timezone-aware字段进行过滤
chat_records = ChatHistory.objects.filter(
conversation_id=conversation_id,
user=request.user,
is_deleted=False
).first()
if not chat_records:
return Response({
'code': 404,
'message': '聊天记录不存在或无权访问',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 处理附件上传
attachment_files = []
if 'attachments' in request.FILES:
uploaded_files = request.FILES.getlist('attachments')
# 创建临时目录保存上传的文件
temp_dir = os.path.join(settings.MEDIA_ROOT, 'temp_attachments')
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
# 保存上传的附件
for file in uploaded_files:
filepath = os.path.join(temp_dir, file.name)
with open(filepath, 'wb+') as destination:
for chunk in file.chunks():
destination.write(chunk)
attachment_files.append(filepath)
# 获取可选的代理设置
use_proxy = data.get('use_proxy', True)
proxy_url = data.get('proxy_url', 'http://127.0.0.1:7890')
# 创建Gmail集成实例并认证
gmail_integration = GmailIntegration(
user=request.user,
use_proxy=use_proxy,
proxy_url=proxy_url
)
if not gmail_integration.authenticate():
return Response({
'code': 400,
'message': 'Gmail认证失败请先完成认证',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 发送邮件
message_id = gmail_integration.send_email(
to_email=to_email,
subject=subject,
body=body,
conversation_id=conversation_id,
attachments=attachment_files
)
# 清理临时文件
for file in attachment_files:
if os.path.exists(file):
os.remove(file)
return Response({
'code': 200,
'message': '邮件发送成功',
'data': {
'message_id': message_id,
'conversation_id': conversation_id
}
})
except Exception as e:
logger.error(f"发送Gmail消息失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'发送Gmail消息失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([])
def gmail_webhook(request):
"""Gmail推送通知webhook"""
try:
# 添加更详细的日志
logger.info(f"接收到Gmail webhook请求: 路径={request.path}, 方法={request.method}")
logger.info(f"请求头: {dict(request.headers)}")
logger.info(f"请求数据: {request.data}")
# 验证请求来源(可以添加额外的安全校验)
data = request.data
if not data:
return Response({
'code': 400,
'message': '无效的请求数据',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 处理数据
email_address = None
history_id = None
# 处理Google Pub/Sub消息格式
if isinstance(data, dict) and 'message' in data and 'data' in data['message']:
try:
import base64
import json
logger.info("检测到Google Pub/Sub消息格式")
# Base64解码data字段
encoded_data = data['message']['data']
decoded_data = base64.b64decode(encoded_data).decode('utf-8')
logger.info(f"解码后的数据: {decoded_data}")
# 解析JSON获取email和historyId
json_data = json.loads(decoded_data)
email_address = json_data.get('emailAddress')
history_id = json_data.get('historyId')
logger.info(f"从Pub/Sub消息中提取: email={email_address}, historyId={history_id}")
except Exception as decode_error:
logger.error(f"解析Pub/Sub消息失败: {str(decode_error)}")
logger.error(traceback.format_exc())
# 处理其他格式的数据
elif isinstance(data, dict):
# 直接使用JSON格式数据
logger.info("接收到JSON格式数据")
email_address = data.get('emailAddress')
history_id = data.get('historyId')
elif hasattr(data, 'decode'):
# 尝试解析原始数据
logger.info("接收到原始数据格式,尝试解析")
try:
import json
json_data = json.loads(data.decode('utf-8'))
email_address = json_data.get('emailAddress')
history_id = json_data.get('historyId')
except Exception as parse_error:
logger.error(f"解析请求数据失败: {str(parse_error)}")
email_address = None
history_id = None
else:
# 尝试从请求参数获取
logger.info("尝试从请求参数获取数据")
email_address = request.GET.get('emailAddress') or request.POST.get('emailAddress')
history_id = request.GET.get('historyId') or request.POST.get('historyId')
logger.info(f"提取的邮箱: {email_address}, 历史ID: {history_id}")
if not email_address or not history_id:
return Response({
'code': 400,
'message': '缺少必要的参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 查找用户和认证信息
user = User.objects.filter(email=email_address).first()
if not user:
logger.info(f"没有找到email={email_address}的用户尝试使用gmail_email查找")
# 尝试使用gmail_email字段查找
credential = GmailCredential.objects.filter(gmail_email=email_address, is_active=True).first()
if credential:
user = credential.user
logger.info(f"通过gmail_email找到用户: {user.email}")
else:
logger.error(f"无法找到与{email_address}关联的用户")
return Response({
'code': 404,
'message': f'找不到与 {email_address} 关联的用户',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 查找认证信息
credential = GmailCredential.objects.filter(user=user, is_active=True).first()
if not credential:
return Response({
'code': 404,
'message': f'找不到用户 {email_address} 的Gmail认证信息',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 更新history_id
credential.last_history_id = history_id
credential.save()
# 如果请求中包含达人邮箱,直接处理特定达人的邮件
talent_email = data.get('talent_email') or request.GET.get('talent_email')
if talent_email and user:
logger.info(f"检测到特定达人邮箱: {talent_email},将直接处理其最近邮件")
try:
# 创建Gmail集成实例
integration = GmailIntegration(user=user, email=talent_email)
if integration.authenticate():
# 获取达人最近的邮件
recent_emails = integration.get_recent_emails(
from_email=talent_email,
max_results=5 # 限制获取最近5封
)
if recent_emails:
logger.info(f"找到 {len(recent_emails)} 封来自 {talent_email} 的最近邮件")
# 创建或获取知识库
knowledge_base, created = integration.create_talent_knowledge_base(talent_email)
# 保存对话
result = integration.save_conversations_to_knowledge_base(recent_emails, knowledge_base)
logger.info(f"已处理达人 {talent_email} 的最近邮件: {result}")
else:
logger.info(f"没有找到来自 {talent_email} 的最近邮件")
else:
logger.error("Gmail认证失败")
except Exception as talent_error:
logger.error(f"处理达人邮件失败: {str(talent_error)}")
logger.error(traceback.format_exc())
# 异步处理通知
# 在生产环境中应该使用Celery等异步任务队列处理
try:
integration = GmailIntegration(user=user)
if integration.authenticate():
result = integration.process_notification(data)
# 如果处理成功尝试通过WebSocket发送通知
# 注意: WebSocket通知已集成到_process_new_message方法中这里只是额外的备份通知
if result:
try:
from channels.layers import get_channel_layer
from asgiref.sync import async_to_sync
# 获取Channel Layer
channel_layer = get_channel_layer()
if channel_layer:
# 发送WebSocket消息
async_to_sync(channel_layer.group_send)(
f"notification_user_{user.id}",
{
"type": "notification",
"data": {
"message_type": "gmail_update",
"message": "您的Gmail有新消息已自动处理",
"history_id": history_id,
"timestamp": timezone.now().isoformat()
}
}
)
except Exception as ws_error:
logger.error(f"发送WebSocket通知失败: {str(ws_error)}")
logger.error(traceback.format_exc())
else:
logger.error(f"Gmail认证失败: {email_address}")
except Exception as process_error:
logger.error(f"处理Gmail通知失败: {str(process_error)}")
logger.error(traceback.format_exc())
return Response({
'code': 200,
'message': '通知已处理',
'data': {
'user_id': str(user.id),
'history_id': history_id
}
})
except Exception as e:
logger.error(f"处理Gmail webhook失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'处理通知失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def get_gmail_attachments(request):
"""获取Gmail邮件附件"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取可选的代理设置
use_proxy = request.query_params.get('use_proxy', 'true').lower() == 'true'
proxy_url = request.query_params.get('proxy_url', 'http://127.0.0.1:7890')
# 创建Gmail集成实例
gmail_integration = GmailIntegration(
user=request.user,
use_proxy=use_proxy,
proxy_url=proxy_url
)
# 获取附件列表
attachments = gmail_integration.get_attachment_by_conversation(conversation_id)
return Response({
'code': 200,
'message': '获取附件列表成功',
'data': {
'attachments': attachments,
'count': len(attachments)
}
})
except Exception as e:
logger.error(f"获取Gmail附件列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取Gmail附件列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def download_gmail_attachment(request):
"""下载Gmail附件"""
try:
filepath = request.query_params.get('filepath')
if not filepath:
return Response({
'code': 400,
'message': '缺少必填参数: filepath',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证文件存在
if not os.path.exists(filepath):
return Response({
'code': 404,
'message': '附件文件不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 验证文件是否在允许的目录中
attachments_dir = os.path.join(settings.MEDIA_ROOT, 'gmail_attachments')
if not filepath.startswith(attachments_dir):
return Response({
'code': 403,
'message': '无权访问该文件',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 返回文件
filename = os.path.basename(filepath)
response = FileResponse(open(filepath, 'rb'))
response['Content-Disposition'] = f'attachment; filename="{filename}"'
return response
except Exception as e:
logger.error(f"下载Gmail附件失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'下载Gmail附件失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def get_gmail_talents(request):
"""获取用户的所有达人Gmail映射"""
try:
# 获取用户的所有达人映射
from .models import GmailTalentMapping
mappings = GmailTalentMapping.objects.filter(
user=request.user,
is_active=True
).select_related('knowledge_base')
# 格式化返回数据
result = []
for mapping in mappings:
# 查询最新的对话记录
latest_message = ChatHistory.objects.filter(
conversation_id=mapping.conversation_id,
is_deleted=False
).order_by('-created_at').first()
# 查询附件数量
attachment_count = GmailAttachment.objects.filter(
chat_message__conversation_id=mapping.conversation_id
).count()
result.append({
'id': str(mapping.id),
'talent_email': mapping.talent_email,
'knowledge_base_id': str(mapping.knowledge_base.id),
'knowledge_base_name': mapping.knowledge_base.name,
'conversation_id': mapping.conversation_id,
'created_at': mapping.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'updated_at': mapping.updated_at.strftime('%Y-%m-%d %H:%M:%S'),
'latest_message': {
'content': latest_message.content if latest_message else '',
'time': latest_message.created_at.strftime('%Y-%m-%d %H:%M:%S') if latest_message else '',
'role': latest_message.role if latest_message else ''
},
'attachment_count': attachment_count
})
return Response({
'code': 200,
'message': '获取成功',
'data': {
'talents': result,
'count': len(result)
}
})
except Exception as e:
logger.error(f"获取Gmail达人映射失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取Gmail达人映射失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def refresh_gmail_watch(request):
"""刷新Gmail监听"""
try:
# 获取用户的Gmail认证信息
from .models import GmailCredential
credential = GmailCredential.objects.filter(user=request.user, is_active=True).first()
if not credential:
return Response({
'code': 404,
'message': '找不到Gmail认证信息请先设置Gmail集成',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 记录用户信息和Gmail邮箱信息
gmail_email = credential.gmail_email or "未知Gmail账号"
logger.info(f"刷新Gmail监听: 系统用户={request.user.email}, Gmail账号={gmail_email}")
logger.info(f"请确保Gmail账号 {gmail_email} 和 gmail-api-push@system.gserviceaccount.com 都有 Pub/Sub 发布权限")
# 检查监听是否过期
is_expired = False
if credential.watch_expiration:
# 使用timezone.now()获取带时区信息的当前时间
is_expired = credential.watch_expiration < timezone.now()
# 获取可选的代理设置
use_proxy = request.data.get('use_proxy', True)
proxy_url = request.data.get('proxy_url', 'http://127.0.0.1:7890')
# 创建Gmail集成实例
gmail_integration = GmailIntegration(
user=request.user,
use_proxy=use_proxy,
proxy_url=proxy_url
)
# 认证Gmail
if not gmail_integration.authenticate():
return Response({
'code': 400,
'message': 'Gmail认证失败',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
try:
# 重新设置监听
watch_result = gmail_integration.setup_watch()
return Response({
'code': 200,
'message': 'Gmail监听已刷新',
'data': {
'history_id': watch_result['historyId'],
'expiration': watch_result['expiration'],
'was_expired': is_expired,
'gmail_email': gmail_email
}
})
except Exception as watch_error:
logger.error(f"设置监听失败: {str(watch_error)}")
logger.error(traceback.format_exc())
# 返回项目未配置错误
if "settings" in str(watch_error).lower() or "google_cloud_project" in str(watch_error).lower():
return Response({
'code': 500,
'message': 'Gmail监听设置失败: 服务器配置错误请联系管理员配置Google Cloud Project',
'data': {
'error': str(watch_error),
'hint': '请在settings.py中设置GOOGLE_CLOUD_PROJECT'
}
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 返回一般错误
return Response({
'code': 500,
'message': f'Gmail监听设置失败: {str(watch_error)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e:
logger.error(f"刷新Gmail监听失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'刷新Gmail监听失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def check_gmail_auth(request):
"""检查Gmail认证状态"""
try:
# 获取用户的Gmail认证信息
from .models import GmailCredential
credential = GmailCredential.objects.filter(user=request.user, is_active=True).first()
if not credential:
return Response({
'code': 404,
'message': '未找到Gmail认证信息',
'data': {
'authenticated': False,
'needs_setup': True
}
})
# 获取可选的代理设置
use_proxy = request.query_params.get('use_proxy', 'true').lower() == 'true'
proxy_url = request.query_params.get('proxy_url', 'http://127.0.0.1:7890')
# 创建Gmail集成实例
gmail_integration = GmailIntegration(
user=request.user,
use_proxy=use_proxy,
proxy_url=proxy_url
)
# 测试认证
auth_valid = gmail_integration.authenticate()
# 检查监听是否过期
watch_expired = True
if credential.watch_expiration:
watch_expired = credential.watch_expiration < timezone.now()
return Response({
'code': 200,
'message': '认证信息获取成功',
'data': {
'authenticated': auth_valid,
'needs_setup': not auth_valid,
'watch_expired': watch_expired,
'last_history_id': credential.last_history_id,
'watch_expiration': credential.watch_expiration.strftime('%Y-%m-%d %H:%M:%S') if credential.watch_expiration else None
}
})
except Exception as e:
logger.error(f"检查Gmail认证状态失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'检查Gmail认证状态失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAdminUser]) # 只允许管理员访问
def refresh_all_gmail_watches(request):
"""刷新所有即将过期的Gmail监听"""
# 查找2天内将过期的凭证(提前预防)
expire_time = timezone.now() + timedelta(days=2)
credentials = GmailCredential.objects.filter(
is_active=True,
watch_expiration__lt=expire_time
).select_related('user')
results = []
for credential in credentials:
try:
user = credential.user
gmail_integration = GmailIntegration(user=user)
if gmail_integration.authenticate():
watch_result = gmail_integration.setup_watch()
results.append({
'user': user.email,
'success': True,
'expiration': watch_result['expiration']
})
else:
results.append({
'user': user.email,
'success': False,
'error': '认证失败'
})
except Exception as e:
results.append({
'user': credential.user.email if credential.user else 'Unknown',
'success': False,
'error': str(e)
})
return Response({
'code': 200,
'message': f'已刷新 {len(results)} 个Gmail监听',
'data': results
})
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def import_gmail_from_sender(request):
"""手动导入特定发件人的Gmail邮件"""
try:
# 获取请求参数
sender_email = request.data.get('sender_email')
max_results = int(request.data.get('max_results', 10))
if not sender_email:
return Response({
'code': 400,
'message': '缺少必填参数: sender_email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取可选的代理设置
use_proxy = request.data.get('use_proxy', True)
proxy_url = request.data.get('proxy_url', 'http://127.0.0.1:7890')
# 创建Gmail集成实例
gmail_integration = GmailIntegration(
user=request.user,
email=sender_email, # 设置达人邮箱
use_proxy=use_proxy,
proxy_url=proxy_url
)
# 认证Gmail
if not gmail_integration.authenticate():
return Response({
'code': 400,
'message': 'Gmail认证失败',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 构建查询参数
query = f"from:{sender_email}"
# 查询邮件
response = gmail_integration.gmail_service.users().messages().list(
userId='me',
q=query,
maxResults=max_results
).execute()
if 'messages' not in response:
return Response({
'code': 404,
'message': f'未找到来自 {sender_email} 的邮件',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 处理邮件
message_ids = [message['id'] for message in response['messages']]
logger.info(f"找到 {len(message_ids)} 封来自 {sender_email} 的邮件")
# 获取邮件详情
conversations = []
for message_id in message_ids:
message = gmail_integration.gmail_service.users().messages().get(
userId='me', id=message_id
).execute()
# 提取邮件内容
email_data = gmail_integration._extract_email_content(message)
if email_data:
conversations.append(email_data)
if not conversations:
return Response({
'code': 404,
'message': f'无法提取来自 {sender_email} 的邮件内容',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 创建知识库
knowledge_base, created = gmail_integration.create_talent_knowledge_base(sender_email)
kb_action = "创建" if created else "获取"
logger.info(f"知识库{kb_action}成功: {knowledge_base.id}")
# 保存对话到知识库
result = gmail_integration.save_conversations_to_knowledge_base(conversations, knowledge_base)
# 返回结果
return Response({
'code': 200,
'message': f'已成功导入 {len(conversations)} 封来自 {sender_email} 的邮件',
'data': {
'conversation_id': result.get('conversation_id'),
'knowledge_base_id': str(knowledge_base.id),
'message_count': len(conversations)
}
})
except Exception as e:
logger.error(f"导入Gmail邮件失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'导入Gmail邮件失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def sync_talent_emails(request):
"""手动同步特定达人的Gmail邮件"""
try:
# 获取请求参数
talent_email = request.data.get('talent_email')
max_results = int(request.data.get('max_results', 10))
if not talent_email:
return Response({
'code': 400,
'message': '缺少必填参数: talent_email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取代理设置
use_proxy = request.data.get('use_proxy', True)
proxy_url = request.data.get('proxy_url', 'http://127.0.0.1:7890')
# 初始化Gmail集成
gmail_integration = GmailIntegration(
user=request.user,
email=talent_email,
use_proxy=use_proxy,
proxy_url=proxy_url
)
# 认证
if not gmail_integration.authenticate():
return Response({
'code': 400,
'message': 'Gmail认证失败',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取最近邮件
recent_emails = gmail_integration.get_recent_emails(
from_email=talent_email,
max_results=max_results
)
if not recent_emails:
return Response({
'code': 404,
'message': f'未找到来自 {talent_email} 的最近邮件',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 创建或获取知识库
knowledge_base, created = gmail_integration.create_talent_knowledge_base(talent_email)
kb_action = "创建" if created else "获取"
logger.info(f"知识库{kb_action}成功: {knowledge_base.name}")
# 保存到知识库
result = gmail_integration.save_conversations_to_knowledge_base(recent_emails, knowledge_base)
# 响应
return Response({
'code': 200,
'message': f'已成功同步 {len(recent_emails)} 封来自 {talent_email} 的邮件',
'data': {
'conversation_id': result.get('conversation_id'),
'knowledge_base_id': str(knowledge_base.id),
'message_count': len(recent_emails)
}
})
except Exception as e:
logger.error(f"同步达人邮件失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'同步达人邮件失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def manage_user_goal(request):
"""
管理用户总目标的API视图
POST请求参数:
- content: 用户总目标内容(可选,不提供则返回当前总目标)
返回:
- 包含操作结果和总目标信息的JSON响应格式为 {data, code, message}
"""
user = request.user
goal_content = request.data.get('content')
try:
from .gmail_integration import GmailIntegration
# 创建Gmail集成实例
gmail_integration = GmailIntegration(user)
result = gmail_integration.manage_user_goal(goal_content)
# 转换为新的响应格式
response_data = {
'code': 200,
'message': 'success',
'data': result
}
return Response(response_data, status=status.HTTP_200_OK)
except Exception as e:
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def generate_conversation_summary(request):
"""
生成与达人的对话总结的API视图
POST请求参数:
- talent_email: 达人的邮箱地址
返回:
- 包含操作结果和总结信息的JSON响应格式为 {data, code, message}
"""
user = request.user
talent_email = request.data.get('talent_email')
if not talent_email:
return Response({
'code': 400,
'message': '缺少必要参数: talent_email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
try:
from .gmail_integration import GmailIntegration
# 创建Gmail集成实例
gmail_integration = GmailIntegration(user)
result = gmail_integration.generate_conversation_summary(talent_email)
# 转换为新的响应格式
response_data = {
'code': 200,
'message': 'success',
'data': result
}
return Response(response_data, status=status.HTTP_200_OK)
except Exception as e:
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def get_recommended_reply(request):
"""
获取针对达人最后一条消息的推荐回复的API视图
POST请求参数:
- conversation_id: 对话ID
- talent_email: 达人的邮箱地址
返回:
- 包含推荐回复的JSON响应格式为 {data, code, message}
"""
user = request.user
conversation_id = request.data.get('conversation_id')
talent_email = request.data.get('talent_email')
if not conversation_id or not talent_email:
return Response({
'code': 400,
'message': '缺少必要参数: conversation_id 或 talent_email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
try:
from .gmail_integration import GmailIntegration
from .models import ChatHistory
# 获取对话历史
chat_history = ChatHistory.objects.filter(
conversation_id=conversation_id,
user=user,
is_deleted=False
).order_by('created_at')
if not chat_history:
return Response({
'code': 404,
'message': '找不到对话历史',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 转换为DeepSeek API所需的格式
conversation_history = []
for message in chat_history:
role = 'user' if message.role == 'user' else 'assistant'
message_data = {
'role': role,
'content': message.content,
'metadata': {
'conversation_id': conversation_id
}
}
# 确定消息是否来自达人
if role == 'user':
message_data['metadata']['from_email'] = talent_email
conversation_history.append(message_data)
# 创建Gmail集成实例并获取推荐回复
gmail_integration = GmailIntegration(user)
reply = gmail_integration._get_recommended_reply_from_deepseek(conversation_history)
if not reply:
return Response({
'code': 500,
'message': '生成推荐回复失败',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response({
'code': 200,
'message': 'success',
'data': {
'reply': reply
}
}, status=status.HTTP_200_OK)
except Exception as e:
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@csrf_exempt
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def feishu_sync_api(request):
"""飞书数据同步API
支持自定义凭据和批量处理
"""
# 检查权限
if not (request.user.has_perm('feishu.can_sync_feishu') or request.user.role in ['admin', 'leader']):
return Response({
'code': 403,
'message': '您没有同步飞书数据的权限',
'data': None
}, status=403)
# 获取参数
app_token = request.data.get('app_token')
table_id = request.data.get('table_id')
user_access_token = request.data.get('user_access_token')
sync_all = request.data.get('sync_all', False)
sync_to_kb = request.data.get('sync_to_kb', False)
creator_ids = request.data.get('creator_ids', [])
# 验证参数
if not sync_all and not creator_ids:
return Response({
'code': 400,
'message': '请指定sync_all=true或提供creator_ids列表',
'data': None
}, status=400)
try:
from feishu.feishu import sync_from_feishu, sync_to_knowledge_base
from user_management.models import FeishuCreator
# 根据参数同步数据
if sync_all:
# 同步所有数据
result = sync_from_feishu(
app_token=app_token,
table_id=table_id,
user_access_token=user_access_token
)
if 'error_message' in result:
return Response({
'code': 500,
'message': f'同步失败: {result["error_message"]}',
'data': result
}, status=500)
# 处理同步到知识库
if sync_to_kb and result.get('created_creators'):
kb_results = []
for creator in result.get('created_creators', []):
if creator.email: # 只处理有邮箱的达人
kb, created = sync_to_knowledge_base(creator_id=creator.id)
if kb:
kb_results.append({
'creator_id': creator.id,
'handle': creator.handle,
'kb_id': kb.id,
'kb_name': kb.name,
'created': created
})
result['kb_sync'] = {
'total': len(kb_results),
'results': kb_results
}
return Response({
'code': 200,
'message': '同步成功',
'data': result
})
else:
# 处理指定的creator_ids
results = []
for creator_id in creator_ids:
try:
creator = FeishuCreator.objects.get(id=creator_id)
# 这里可以添加特定的处理逻辑
# 如果需要同步到知识库
if sync_to_kb:
kb, created = sync_to_knowledge_base(creator_id=creator_id)
results.append({
'creator_id': creator_id,
'handle': creator.handle if creator else None,
'success': True,
'kb_sync': {
'success': kb is not None,
'kb_id': kb.id if kb else None,
'created': created
}
})
else:
results.append({
'creator_id': creator_id,
'handle': creator.handle,
'success': True
})
except FeishuCreator.DoesNotExist:
results.append({
'creator_id': creator_id,
'success': False,
'message': '达人不存在'
})
except Exception as e:
results.append({
'creator_id': creator_id,
'success': False,
'message': str(e)
})
return Response({
'code': 200,
'message': '处理完成',
'data': {
'total': len(results),
'results': results
}
})
except Exception as e:
import traceback
logger.error(f"飞书同步API错误: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'系统错误: {str(e)}',
'data': None
}, status=500)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def feishu_to_kb_api(request):
"""将飞书数据同步到知识库API
支持批量处理
"""
if not (request.user.has_perm('feishu.can_sync_feishu') or request.user.role in ['admin', 'leader']):
return Response({
'code': 403,
'message': '您没有同步飞书数据的权限',
'data': None
}, status=403)
# 获取参数
creator_id = request.data.get('creator_id')
handle = request.data.get('handle')
email = request.data.get('email')
batch_mode = request.data.get('batch_mode', False)
has_email = request.data.get('has_email', False)
no_kb = request.data.get('no_kb', False)
try:
from feishu.feishu import sync_to_knowledge_base
from user_management.models import FeishuCreator
from user_management.models import KnowledgeBase
# 批量模式
if batch_mode:
query = FeishuCreator.objects.all()
# 筛选条件: 只处理有邮箱的达人
if has_email:
query = query.exclude(email__isnull=True).exclude(email='')
# 筛选条件: 只处理没有知识库的达人
if no_kb:
# 这里需要根据实际情况实现筛选逻辑
# 可能需要一个辅助函数来检查哪些达人没有知识库
# 这里是一个简化示例
creators_with_kb = []
from user_management.models import GmailTalentMapping
kb_mappings = GmailTalentMapping.objects.filter(is_active=True)
for mapping in kb_mappings:
email = mapping.talent_email
if email:
creators_with_kb.append(email)
if creators_with_kb:
query = query.exclude(email__in=creators_with_kb)
creators = query.all()
# 处理结果
results = []
success_count = 0
error_count = 0
for creator in creators:
try:
kb, created = sync_to_knowledge_base(creator_id=creator.id)
if kb:
results.append({
'creator_id': creator.id,
'handle': creator.handle,
'email': creator.email,
'success': True,
'kb_id': kb.id,
'kb_name': kb.name,
'created': created
})
success_count += 1
else:
results.append({
'creator_id': creator.id,
'handle': creator.handle,
'email': creator.email,
'success': False,
'message': '知识库创建失败'
})
error_count += 1
except Exception as e:
results.append({
'creator_id': creator.id,
'handle': creator.handle if hasattr(creator, 'handle') else None,
'email': creator.email if hasattr(creator, 'email') else None,
'success': False,
'message': str(e)
})
error_count += 1
return Response({
'code': 200,
'message': '批量处理完成',
'data': {
'total': len(results),
'success': success_count,
'error': error_count,
'results': results
}
})
# 单个处理模式
elif creator_id or handle or email:
kb, created = sync_to_knowledge_base(
creator_id=creator_id,
handle=handle,
email=email
)
if kb:
return Response({
'code': 200,
'message': '同步成功',
'data': {
'kb_id': kb.id,
'kb_name': kb.name,
'created': created
}
})
else:
return Response({
'code': 400,
'message': '同步失败,未能创建或找到知识库',
'data': None
}, status=400)
else:
return Response({
'code': 400,
'message': '请提供creator_id、handle或email参数',
'data': None
}, status=400)
except Exception as e:
import traceback
logger.error(f"飞书数据同步到知识库API错误: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'系统错误: {str(e)}',
'data': None
}, status=500)
@api_view(['GET', 'POST'])
@permission_classes([IsAuthenticated])
def check_creator_kb_api(request):
"""检查达人是否有知识库API
GET方法: 检查单个达人
POST方法: 批量检查多个达人
"""
if not request.user.has_perm('feishu.can_view_feishu'):
return Response({
'code': 403,
'message': '您没有查看飞书数据的权限',
'data': None
}, status=403)
try:
from user_management.models import FeishuCreator
from user_management.models import GmailTalentMapping
from user_management.models import KnowledgeBase
# GET方法: 检查单个达人
if request.method == 'GET':
creator_id = request.query_params.get('creator_id')
handle = request.query_params.get('handle')
email = request.query_params.get('email')
if not any([creator_id, handle, email]):
return Response({
'code': 400,
'message': '请提供creator_id、handle或email参数',
'data': None
}, status=400)
# 查找达人
query = FeishuCreator.objects.all()
if creator_id:
query = query.filter(id=creator_id)
elif handle:
query = query.filter(handle=handle)
elif email:
query = query.filter(email=email)
creator = query.first()
if not creator:
return Response({
'code': 404,
'message': '未找到达人',
'data': None
}, status=404)
# 检查知识库
kb_info = None
# 通过Email检查Gmail映射
if creator.email:
gmail_mapping = GmailTalentMapping.objects.filter(
talent_email=creator.email,
is_active=True
).first()
if gmail_mapping and gmail_mapping.knowledge_base:
kb = gmail_mapping.knowledge_base
kb_info = {
'kb_id': kb.id,
'kb_name': kb.name,
'mapping_type': 'gmail',
'mapping_id': gmail_mapping.id
}
# 也可以检查其他映射方式,如直接通过名称匹配
if not kb_info and creator.handle:
kb_name = f"达人-{creator.handle}"
kb = KnowledgeBase.objects.filter(name=kb_name, is_active=True).first()
if kb:
kb_info = {
'kb_id': kb.id,
'kb_name': kb.name,
'mapping_type': 'direct',
'mapping_id': None
}
return Response({
'code': 200,
'message': '查询成功',
'data': {
'creator_id': creator.id,
'handle': creator.handle,
'email': creator.email,
'has_kb': kb_info is not None,
'kb_info': kb_info
}
})
# POST方法: 批量检查
else:
creator_ids = request.data.get('creator_ids', [])
if not creator_ids:
return Response({
'code': 400,
'message': '请提供creator_ids参数',
'data': None
}, status=400)
results = []
# 预先加载所有Gmail映射
gmail_mappings = {}
for mapping in GmailTalentMapping.objects.filter(is_active=True):
if mapping.talent_email:
gmail_mappings[mapping.talent_email] = mapping
# 处理每个达人
for creator_id in creator_ids:
try:
creator = FeishuCreator.objects.get(id=creator_id)
# 检查知识库
kb_info = None
# 通过Email检查Gmail映射
if creator.email and creator.email in gmail_mappings:
mapping = gmail_mappings[creator.email]
if mapping.knowledge_base:
kb = mapping.knowledge_base
kb_info = {
'kb_id': kb.id,
'kb_name': kb.name,
'mapping_type': 'gmail',
'mapping_id': mapping.id
}
# 也可以检查其他映射方式
if not kb_info and creator.handle:
kb_name = f"达人-{creator.handle}"
kb = KnowledgeBase.objects.filter(name=kb_name, is_active=True).first()
if kb:
kb_info = {
'kb_id': kb.id,
'kb_name': kb.name,
'mapping_type': 'direct',
'mapping_id': None
}
results.append({
'creator_id': creator.id,
'handle': creator.handle,
'email': creator.email,
'has_kb': kb_info is not None,
'kb_info': kb_info
})
except FeishuCreator.DoesNotExist:
results.append({
'creator_id': creator_id,
'success': False,
'message': '达人不存在'
})
except Exception as e:
results.append({
'creator_id': creator_id,
'success': False,
'message': str(e)
})
return Response({
'code': 200,
'message': '批量查询成功',
'data': {
'total': len(results),
'results': results
}
})
except Exception as e:
import traceback
logger.error(f"检查达人知识库API错误: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'系统错误: {str(e)}',
'data': None
}, status=500)
from feishu.feishu_ai_chat import (
fetch_table_records, find_duplicate_email_creators,
process_duplicate_emails, auto_chat_session,
handle_set_goal, handle_check_goal
)
# 添加飞书多维表格AI对话接口
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def process_feishu_table(request):
"""
从飞书多维表格读取数据,处理重复邮箱
请求参数:
table_id: 表格ID
view_id: 视图ID
app_token: 飞书应用TOKEN (可选)
access_token: 用户访问令牌 (可选)
goal_template: 目标内容模板 (可选)
auto_chat: 是否自动执行AI对话 (可选)
turns: 自动对话轮次 (可选)
"""
try:
# 检查用户权限 - 只允许组长使用
if request.user.role != 'leader':
return Response({
'code': 403,
'message': '只有组长角色的用户可以使用此功能',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 获取参数
table_id = request.data.get("table_id")
view_id = request.data.get("view_id")
app_token = request.data.get("app_token", "XYE6bMQUOaZ5y5svj4vcWohGnmg")
access_token = request.data.get("access_token", "u-ecM5BmzKx4uHz3sG0FouQSk1l9kxgl_3Xa00l5Ma24Jy")
goal_template = request.data.get(
"goal_template",
"与达人{handle}(邮箱:{email})建立联系并了解其账号情况,评估合作潜力,处理合作需求,最终目标是达成合作并签约。"
)
auto_chat = request.data.get("auto_chat", False)
turns = request.data.get("turns", 5)
# 验证必要参数
if not table_id:
return Response({
'code': 400,
'message': '缺少参数table_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if not view_id:
return Response({
'code': 400,
'message': '缺少参数view_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 从飞书表格获取记录
records = fetch_table_records(
app_token,
table_id,
view_id,
access_token
)
if not records:
return Response({
'code': 404,
'message': '未获取到任何记录',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 查找重复邮箱的创作者
duplicate_emails = find_duplicate_email_creators(records)
if not duplicate_emails:
return Response({
'code': 200,
'message': '未发现重复邮箱',
'data': None
}, status=status.HTTP_200_OK)
# 处理重复邮箱记录
results = process_duplicate_emails(duplicate_emails, goal_template)
# 如果需要自动对话
chat_results = []
if auto_chat and results['success'] > 0:
# 为每个成功创建的记录执行自动对话
for detail in results['details']:
if detail['status'] == 'success':
email = detail['email']
chat_result = auto_chat_session(request.user, email, max_turns=turns)
chat_results.append({
'email': email,
'result': chat_result
})
# 返回处理结果
return Response({
'code': 200,
'message': 'success',
'data': {
'records_count': len(records),
'duplicate_emails_count': len(duplicate_emails),
'processing_results': results,
'chat_results': chat_results
}
}, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"处理飞书表格时出错: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
@permission_classes([IsAuthenticated])
def run_auto_chat(request):
"""
为指定邮箱执行自动对话
请求参数:
email: 达人邮箱
turns: 对话轮次 (可选)
"""
try:
# 检查用户权限 - 只允许组长使用
if request.user.role != 'leader':
return Response({
'code': 403,
'message': '只有组长角色的用户可以使用此功能',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 获取参数
email = request.data.get("email")
turns = request.data.get("turns", 5)
# 验证必要参数
if not email:
return Response({
'code': 400,
'message': '缺少参数email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 执行自动对话
result = auto_chat_session(request.user, email, max_turns=turns)
# 返回结果
return Response({
'code': 200,
'message': 'success',
'data': result
}, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"执行自动对话时出错: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET', 'POST'])
@permission_classes([IsAuthenticated])
def feishu_user_goal(request):
"""
设置或获取用户总目标
GET 请求:
获取当前用户总目标
POST 请求参数:
email: 达人邮箱
goal: 目标内容
"""
try:
# 检查用户权限 - 只允许组长使用
if request.user.role != 'leader':
return Response({
'code': 403,
'message': '只有组长角色的用户可以使用此功能',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
if request.method == 'GET':
# 创建Gmail集成实例
gmail_integration = GmailIntegration(request.user)
# 获取总目标
result = gmail_integration.manage_user_goal()
return Response({
'code': 200,
'message': 'success',
'data': result
}, status=status.HTTP_200_OK)
elif request.method == 'POST':
# 获取参数
email = request.data.get("email")
goal = request.data.get("goal")
# 验证必要参数
if not email:
return Response({
'code': 400,
'message': '缺少参数email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if not goal:
return Response({
'code': 400,
'message': '缺少参数goal',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 设置用户总目标
gmail_integration = GmailIntegration(request.user)
result = gmail_integration.manage_user_goal(goal)
return Response({
'code': 200,
'message': 'success',
'data': result
}, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"管理用户总目标时出错: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def check_goal_status(request):
"""
检查目标完成状态
请求参数:
email: 达人邮箱
"""
try:
# 检查用户权限 - 只允许组长使用
if request.user.role != 'leader':
return Response({
'code': 403,
'message': '只有组长角色的用户可以使用此功能',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
# 获取参数
email = request.query_params.get("email")
# 验证必要参数
if not email:
return Response({
'code': 400,
'message': '缺少参数email',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 查找Gmail映射关系
mapping = GmailTalentMapping.objects.filter(
user=request.user,
talent_email=email,
is_active=True
).first()
if not mapping:
return Response({
'code': 404,
'message': f'找不到与邮箱 {email} 的映射关系',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 获取对话历史中最后的AI回复
last_ai_message = ChatHistory.objects.filter(
user=request.user,
knowledge_base=mapping.knowledge_base,
conversation_id=mapping.conversation_id,
role='assistant',
is_deleted=False
).order_by('-created_at').first()
if not last_ai_message:
return Response({
'code': 404,
'message': f'找不到与邮箱 {email} 的对话历史',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 导入检查函数
from feishu.feishu_ai_chat import check_goal_achieved
# 检查目标是否已达成
goal_achieved = check_goal_achieved(last_ai_message.content)
# 获取对话总结
summary = ConversationSummary.objects.filter(
user=request.user,
talent_email=email,
is_active=True
).order_by('-updated_at').first()
result = {
'status': 'success',
'email': email,
'goal_achieved': goal_achieved,
'last_message_time': last_ai_message.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'last_message': last_ai_message.content,
'summary': summary.summary if summary else None
}
return Response({
'code': 200,
'message': 'success',
'data': result
}, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"检查目标状态时出错: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': str(e),
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)