5330 lines
221 KiB
Python
5330 lines
221 KiB
Python
from rest_framework import viewsets, status
|
||
from rest_framework.decorators import action, api_view, permission_classes
|
||
from rest_framework.permissions import IsAuthenticated, AllowAny, IsAdminUser
|
||
from rest_framework.response import Response
|
||
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError, NotFound
|
||
from rest_framework.authentication import TokenAuthentication
|
||
from django.utils import timezone
|
||
from django.db import connection
|
||
from django.db.models import Q, Max, Count, F, Sum
|
||
from datetime import timedelta, datetime
|
||
import mysql.connector
|
||
from django.contrib.auth import get_user_model, authenticate, login, logout
|
||
from channels.layers import get_channel_layer
|
||
from asgiref.sync import async_to_sync
|
||
from rest_framework.authtoken.models import Token
|
||
import requests
|
||
import json
|
||
from django.db import transaction
|
||
from django.core.exceptions import ObjectDoesNotExist
|
||
import sys
|
||
import random
|
||
import string
|
||
import time
|
||
import logging
|
||
import os
|
||
from rest_framework.test import APIRequestFactory
|
||
from django.contrib.contenttypes.models import ContentType
|
||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||
from django.http import Http404, HttpResponse, StreamingHttpResponse
|
||
from django.db import IntegrityError
|
||
from channels.exceptions import ChannelFull
|
||
from django.conf import settings
|
||
from django.shortcuts import get_object_or_404
|
||
from django.db import models
|
||
from rest_framework.views import APIView
|
||
from django.core.validators import validate_email
|
||
# from django.core.exceptions import ValidationError
|
||
from django.views.decorators.csrf import csrf_exempt
|
||
from django.utils.decorators import method_decorator
|
||
import uuid
|
||
from rest_framework import serializers
|
||
import traceback
|
||
import requests
|
||
import json
|
||
import threading
|
||
|
||
|
||
|
||
# 添加模型导入
|
||
from .models import (
|
||
User,
|
||
Data, # 替换原来的 AdminData, LeaderData, MemberData
|
||
Permission, # 替换原来的 DataPermission, TablePermission
|
||
ChatHistory,
|
||
KnowledgeBase,
|
||
Notification,
|
||
KnowledgeBasePermission as KBPermissionModel,
|
||
KnowledgeBaseDocument
|
||
)
|
||
from .serializers import (
|
||
UserSerializer,
|
||
DataSerializer, # 需要更新
|
||
PermissionSerializer, # 需要更新
|
||
ChatHistorySerializer,
|
||
KnowledgeBaseSerializer,
|
||
KnowledgePermissionSerializer, # 添加这个导入
|
||
NotificationSerializer
|
||
)
|
||
# 导入自定义权限类
|
||
from .permissions import ResourceCRUDPermission, PermissionRequestPermission, DataPermission, KnowledgeBasePermission as KBPermissionClass
|
||
from .exceptions import ExternalAPIError
|
||
|
||
# 获取正确的用户模型
|
||
User = get_user_model()
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler() # 输出到控制台
|
||
]
|
||
)
|
||
|
||
|
||
class KnowledgeBasePermissionMixin:
|
||
"""知识库权限管理混入类"""
|
||
|
||
def _can_read(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None):
|
||
"""检查读取权限"""
|
||
try:
|
||
# 1. 检查显式权限表
|
||
if knowledge_base_id:
|
||
permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=knowledge_base_id,
|
||
user=user,
|
||
can_read=True,
|
||
status='active'
|
||
).first()
|
||
if permission:
|
||
return True
|
||
|
||
# 2. 检查角色权限
|
||
# 私有知识库
|
||
if type == 'private':
|
||
return str(user.id) == str(creator_id)
|
||
|
||
# 成员级知识库
|
||
if type == 'member':
|
||
return user.department == department
|
||
|
||
# 部门级知识库
|
||
if type == 'leader':
|
||
return (user.department == department and
|
||
user.role in ['leader', 'admin'])
|
||
|
||
# 管理级知识库
|
||
if type == 'admin':
|
||
return True # 所有用户都可以读取
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查读取权限时出错: {str(e)}")
|
||
return False
|
||
|
||
def _can_edit(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None):
|
||
"""检查编辑权限"""
|
||
try:
|
||
# 1. 检查显式权限表
|
||
if knowledge_base_id:
|
||
permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=knowledge_base_id,
|
||
user=user,
|
||
can_edit=True,
|
||
status='active'
|
||
).first()
|
||
if permission:
|
||
return True
|
||
|
||
# 2. 检查角色权限
|
||
# 私有知识库
|
||
if type == 'private':
|
||
return str(user.id) == str(creator_id)
|
||
|
||
# 成员级知识库
|
||
if type == 'member':
|
||
return (user.department == department and
|
||
user.role in ['leader', 'admin'])
|
||
|
||
# 部门级知识库
|
||
if type == 'leader':
|
||
return (user.department == department and
|
||
user.role in ['leader', 'admin'])
|
||
|
||
# 管理级知识库
|
||
if type == 'admin':
|
||
return True # 所有用户都可以编辑
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查编辑权限时出错: {str(e)}")
|
||
return False
|
||
|
||
def _can_delete(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None):
|
||
"""检查删除权限"""
|
||
try:
|
||
# 1. 检查显式权限表
|
||
if knowledge_base_id:
|
||
permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=knowledge_base_id,
|
||
user=user,
|
||
can_delete=True,
|
||
status='active'
|
||
).first()
|
||
if permission:
|
||
return True
|
||
|
||
# 2. 检查角色权限
|
||
# 私有知识库
|
||
if type == 'private':
|
||
return str(user.id) == str(creator_id)
|
||
|
||
# 成员级知识库
|
||
if type == 'member':
|
||
return (user.department == department and
|
||
user.role == 'admin')
|
||
|
||
# 部门级知识库
|
||
if type == 'leader':
|
||
return (user.department == department and
|
||
user.role == 'admin')
|
||
|
||
# 管理级知识库
|
||
if type == 'admin':
|
||
return True # 所有用户都可以删除
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查删除权限时出错: {str(e)}")
|
||
return False
|
||
|
||
def check_knowledge_base_permission(self, knowledge_base, user, required_permission='read'):
|
||
"""统一的知识库权限检查方法"""
|
||
if not knowledge_base:
|
||
return False
|
||
|
||
# 1. 首先检查显式权限表
|
||
try:
|
||
# 检查是否存在显式权限记录
|
||
permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=knowledge_base.id,
|
||
user=user,
|
||
status='active'
|
||
).first()
|
||
|
||
if permission:
|
||
# 根据请求的权限类型返回对应的权限值
|
||
if required_permission == 'read':
|
||
return permission.can_read
|
||
elif required_permission == 'edit':
|
||
return permission.can_edit
|
||
elif required_permission == 'delete':
|
||
return permission.can_delete
|
||
except Exception as e:
|
||
logger.error(f"检查显式权限时出错: {str(e)}")
|
||
|
||
# 2. 如果没有显式权限记录或出错,回退到隐式权限逻辑
|
||
permission_method = {
|
||
'read': self._can_read,
|
||
'edit': self._can_edit,
|
||
'delete': self._can_delete
|
||
}.get(required_permission)
|
||
|
||
if not permission_method:
|
||
return False
|
||
|
||
return permission_method(
|
||
type=knowledge_base.type,
|
||
user=user,
|
||
department=knowledge_base.department,
|
||
group=knowledge_base.group,
|
||
creator_id=knowledge_base.user_id,
|
||
knowledge_base_id=knowledge_base.id
|
||
)
|
||
|
||
|
||
|
||
class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
|
||
permission_classes = [IsAuthenticated]
|
||
queryset = ChatHistory.objects.all()
|
||
|
||
def get_queryset(self):
|
||
"""确保用户只能看到自己的未删除的聊天记录以及有权限的知识库关联的聊天记录"""
|
||
user = self.request.user
|
||
|
||
# 当前用户的聊天记录
|
||
user_records = ChatHistory.objects.filter(
|
||
user=user,
|
||
is_deleted=False
|
||
)
|
||
|
||
# 获取用户有权限的知识库ID列表
|
||
accessible_kb_ids = []
|
||
for kb in KnowledgeBase.objects.all():
|
||
if self.check_knowledge_base_permission(kb, user, 'read'):
|
||
accessible_kb_ids.append(kb.id)
|
||
|
||
# 其他用户创建的、但当前用户有权限访问的知识库的聊天记录
|
||
others_records = ChatHistory.objects.filter(
|
||
knowledge_base_id__in=accessible_kb_ids,
|
||
is_deleted=False
|
||
).exclude(user=user) # 排除用户自己的记录,避免重复
|
||
|
||
# 合并两个查询集
|
||
combined_queryset = user_records | others_records
|
||
|
||
return combined_queryset
|
||
|
||
def list(self, request):
|
||
"""获取对话列表"""
|
||
try:
|
||
# 获取用户所有的对话
|
||
unique_conversations = self.get_queryset().values('conversation_id').annotate(
|
||
last_message=Max('created_at'),
|
||
message_count=Count('id')
|
||
).order_by('-last_message')
|
||
|
||
# 构建结果列表
|
||
conversation_list = []
|
||
for conv in unique_conversations:
|
||
# 获取对话中的第一条消息,用于显示标题和知识库信息
|
||
first_message = self.get_queryset().filter(
|
||
conversation_id=conv['conversation_id']
|
||
).order_by('created_at').first()
|
||
|
||
if not first_message:
|
||
continue
|
||
|
||
# 获取知识库信息
|
||
dataset_info = []
|
||
if first_message.metadata and 'dataset_id_list' in first_message.metadata:
|
||
# 获取用户有权限访问的知识库
|
||
valid_kb_ids = []
|
||
for kb_id in first_message.metadata['dataset_id_list']:
|
||
try:
|
||
kb = KnowledgeBase.objects.get(id=kb_id)
|
||
if self.check_knowledge_base_permission(kb, request.user, 'read'):
|
||
valid_kb_ids.append(kb_id)
|
||
dataset_info.append({
|
||
'id': str(kb.id),
|
||
'name': kb.name,
|
||
'type': kb.type
|
||
})
|
||
except KnowledgeBase.DoesNotExist:
|
||
continue
|
||
|
||
# 获取最近的消息用于预览
|
||
last_user_message = self.get_queryset().filter(
|
||
conversation_id=conv['conversation_id'],
|
||
role='user'
|
||
).order_by('-created_at').first()
|
||
|
||
# 处理对话标题 - 优先使用已有标题,否则尝试生成新标题
|
||
title = first_message.title
|
||
|
||
# 如果标题为空或为默认值'New chat',尝试生成新标题
|
||
if not title or title == 'New chat':
|
||
# 找到对话中的第一对问答
|
||
messages = list(self.get_queryset().filter(
|
||
conversation_id=conv['conversation_id']
|
||
).order_by('created_at'))
|
||
|
||
user_message = None
|
||
assistant_message = None
|
||
|
||
for i in range(len(messages)-1):
|
||
if messages[i].role == 'user' and messages[i+1].role == 'assistant' and messages[i+1].parent_id == str(messages[i].id):
|
||
user_message = messages[i]
|
||
assistant_message = messages[i+1]
|
||
break
|
||
|
||
if user_message and assistant_message:
|
||
# 调用DeepSeek API生成标题
|
||
generated_title = self._generate_conversation_title_from_deepseek(
|
||
user_message.content,
|
||
assistant_message.content
|
||
)
|
||
|
||
if generated_title:
|
||
# 更新所有相关记录的标题
|
||
title = generated_title
|
||
ChatHistory.objects.filter(
|
||
conversation_id=conv['conversation_id']
|
||
).update(title=generated_title)
|
||
|
||
# 如果生成失败,使用对话ID的一部分作为临时标题
|
||
if not title:
|
||
title = f"对话 {conv['conversation_id'][:8]}"
|
||
|
||
# 构建返回结果
|
||
conversation_data = {
|
||
'conversation_id': conv['conversation_id'],
|
||
'last_message': conv['last_message'].strftime('%Y-%m-%d %H:%M:%S'),
|
||
'message_count': conv['message_count'],
|
||
'title': title,
|
||
'preview': last_user_message.content[:100] if last_user_message else "",
|
||
'datasets': dataset_info
|
||
}
|
||
conversation_list.append(conversation_data)
|
||
|
||
# 返回结果
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取成功',
|
||
'data': conversation_list
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取对话列表失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f"获取对话列表失败: {str(e)}",
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def conversation_detail(self, request):
|
||
"""获取特定对话的详细信息"""
|
||
try:
|
||
conversation_id = request.query_params.get('conversation_id')
|
||
if not conversation_id:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少conversation_id参数',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 获取对话历史,确保按时间顺序排序
|
||
messages = self.get_queryset().filter(
|
||
conversation_id=conversation_id
|
||
).order_by('created_at')
|
||
|
||
if not messages.exists():
|
||
return Response({
|
||
'code': 404,
|
||
'message': '对话不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 获取知识库信息
|
||
first_message = messages.first()
|
||
dataset_info = []
|
||
if first_message and first_message.metadata:
|
||
if 'dataset_id_list' in first_message.metadata:
|
||
datasets = KnowledgeBase.objects.filter(
|
||
id__in=first_message.metadata['dataset_id_list']
|
||
)
|
||
# 过滤出用户有权限访问的知识库
|
||
accessible_datasets = [
|
||
ds for ds in datasets
|
||
if self.check_knowledge_base_permission(ds, request.user, 'read')
|
||
]
|
||
dataset_info = [{
|
||
'id': str(ds.id),
|
||
'name': ds.name,
|
||
'type': ds.type
|
||
} for ds in accessible_datasets]
|
||
|
||
# 处理对话标题 - 优先使用已有标题,否则尝试生成新标题
|
||
title = first_message.title
|
||
|
||
# 如果标题为空或为默认值'New chat',尝试生成新标题
|
||
if not title or title == 'New chat':
|
||
# 尝试找到一对完整的问答
|
||
user_message = None
|
||
assistant_message = None
|
||
|
||
for i in range(len(messages)-1):
|
||
if messages[i].role == 'user' and messages[i+1].role == 'assistant' and messages[i+1].parent_id == str(messages[i].id):
|
||
user_message = messages[i]
|
||
assistant_message = messages[i+1]
|
||
break
|
||
|
||
if user_message and assistant_message:
|
||
# 调用DeepSeek API生成标题
|
||
generated_title = self._generate_conversation_title_from_deepseek(
|
||
user_message.content,
|
||
assistant_message.content
|
||
)
|
||
|
||
if generated_title:
|
||
# 更新所有相关记录的标题
|
||
title = generated_title
|
||
ChatHistory.objects.filter(
|
||
conversation_id=conversation_id
|
||
).update(title=generated_title)
|
||
|
||
# 如果生成失败,使用对话ID的一部分作为临时标题
|
||
if not title:
|
||
title = f"对话 {conversation_id[:8]}"
|
||
|
||
# 构建消息列表,包含parent_id信息
|
||
message_list = []
|
||
for msg in messages:
|
||
message_data = {
|
||
'id': str(msg.id),
|
||
'parent_id': msg.parent_id, # 添加parent_id
|
||
'role': msg.role,
|
||
'content': msg.content,
|
||
'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'metadata': msg.metadata # 添加metadata
|
||
}
|
||
message_list.append(message_data)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取成功',
|
||
'data': {
|
||
'conversation_id': conversation_id,
|
||
'title': title, # 返回标题
|
||
'datasets': dataset_info,
|
||
'messages': message_list
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取对话详情失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取对话详情失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def available_datasets(self, request):
|
||
"""获取用户可访问的知识库列表"""
|
||
try:
|
||
user = request.user
|
||
all_datasets = KnowledgeBase.objects.all()
|
||
|
||
# 使用统一的权限检查方法
|
||
accessible_datasets = [
|
||
dataset for dataset in all_datasets
|
||
if self.check_knowledge_base_permission(dataset, user, 'read')
|
||
]
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取成功',
|
||
'data': [{
|
||
'id': str(ds.id),
|
||
'name': ds.name,
|
||
'type': ds.type,
|
||
'department': ds.department,
|
||
'description': ds.desc
|
||
} for ds in accessible_datasets]
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取可用知识库列表失败: {str(e)}")
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取可用知识库列表失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['post'])
|
||
def create_conversation(self, request):
|
||
"""创建会话 - 先选择知识库创建会话ID,不发送问题"""
|
||
try:
|
||
data = request.data
|
||
|
||
# 检查知识库ID:支持dataset_id或dataset_id_list格式
|
||
dataset_ids = []
|
||
if 'dataset_id' in data:
|
||
dataset_id = data['dataset_id']
|
||
# 直接使用标准UUID格式
|
||
dataset_ids.append(str(dataset_id))
|
||
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
|
||
# 处理可能的字符串格式
|
||
if isinstance(data['dataset_id_list'], str):
|
||
try:
|
||
# 尝试解析JSON字符串
|
||
dataset_list = json.loads(data['dataset_id_list'])
|
||
if isinstance(dataset_list, list):
|
||
dataset_ids = [str(id) for id in dataset_list]
|
||
except json.JSONDecodeError:
|
||
# 如果解析失败,可能是单个ID
|
||
dataset_ids = [str(data['dataset_id_list'])]
|
||
else:
|
||
# 如果已经是列表,直接使用标准UUID格式
|
||
dataset_ids = [str(id) for id in data['dataset_id_list']]
|
||
else:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必填字段: dataset_id 或 dataset_id_list',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
if not dataset_ids:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '至少需要提供一个知识库ID',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证所有知识库
|
||
user = request.user
|
||
knowledge_bases = [] # 存储所有知识库对象
|
||
|
||
for kb_id in dataset_ids:
|
||
try:
|
||
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
|
||
if not knowledge_base:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'知识库不存在: {kb_id}',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
knowledge_bases.append(knowledge_base)
|
||
|
||
# 使用统一的权限检查方法
|
||
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
|
||
return Response({
|
||
'code': 403,
|
||
'message': f'无权访问知识库: {knowledge_base.name}',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
except Exception as e:
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'处理知识库ID出错: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 创建一个新的会话ID
|
||
conversation_id = str(uuid.uuid4())
|
||
logger.info(f"创建新的会话ID: {conversation_id}")
|
||
|
||
# 获取自定义标题(如果有)
|
||
title = data.get('title', 'New chat')
|
||
|
||
# 准备metadata (仍然保存知识库名称用于内部处理)
|
||
metadata = {
|
||
'dataset_id_list': [str(id) for id in dataset_ids],
|
||
'dataset_names': [kb.name for kb in knowledge_bases]
|
||
}
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '会话创建成功',
|
||
'data': {
|
||
'conversation_id': conversation_id,
|
||
'title': title, # 添加标题字段
|
||
'dataset_id_list': metadata['dataset_id_list']
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建会话失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'创建会话失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def create(self, request):
|
||
"""创建聊天记录"""
|
||
try:
|
||
data = request.data
|
||
|
||
# 检查必填字段
|
||
if 'question' not in data:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必填字段: question',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
if 'conversation_id' not in data:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必填字段: conversation_id',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
conversation_id = data['conversation_id']
|
||
|
||
# 查找该会话ID下的历史记录,获取知识库信息
|
||
existing_records = ChatHistory.objects.filter(
|
||
conversation_id=conversation_id
|
||
).order_by('created_at')
|
||
|
||
# 如果有历史记录,使用第一条记录的metadata
|
||
if existing_records.exists():
|
||
first_record = existing_records.first()
|
||
metadata = first_record.metadata or {}
|
||
|
||
# 获取知识库信息
|
||
dataset_ids = metadata.get('dataset_id_list', [])
|
||
external_id_list = metadata.get('dataset_external_id_list', [])
|
||
|
||
# 验证知识库是否存在且用户有权限
|
||
knowledge_bases = []
|
||
if not dataset_ids:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '找不到会话关联的知识库信息',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
for kb_id in dataset_ids:
|
||
try:
|
||
kb = KnowledgeBase.objects.get(id=kb_id)
|
||
if not self.check_knowledge_base_permission(kb, request.user, 'read'):
|
||
return Response({
|
||
'code': 403,
|
||
'message': f'无权访问知识库: {kb.name}',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
knowledge_bases.append(kb)
|
||
except KnowledgeBase.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'知识库不存在: {kb_id}',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
if not external_id_list or not knowledge_bases:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '会话关联的知识库信息不完整',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
else:
|
||
# 如果是新会话的第一条记录,需要提供知识库ID
|
||
dataset_ids = []
|
||
if 'dataset_id' in data:
|
||
dataset_ids.append(str(data['dataset_id']))
|
||
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
|
||
if isinstance(data['dataset_id_list'], str):
|
||
try:
|
||
dataset_list = json.loads(data['dataset_id_list'])
|
||
if isinstance(dataset_list, list):
|
||
dataset_ids = [str(id) for id in dataset_list]
|
||
else:
|
||
dataset_ids = [str(data['dataset_id_list'])]
|
||
except json.JSONDecodeError:
|
||
dataset_ids = [str(data['dataset_id_list'])]
|
||
else:
|
||
dataset_ids = [str(id) for id in data['dataset_id_list']]
|
||
|
||
if not dataset_ids:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '新会话需要提供知识库ID',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证所有知识库并收集external_ids
|
||
external_id_list = []
|
||
knowledge_bases = []
|
||
|
||
for kb_id in dataset_ids:
|
||
try:
|
||
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
|
||
if not knowledge_base:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'知识库不存在: {kb_id}',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
knowledge_bases.append(knowledge_base)
|
||
|
||
# 使用统一的权限检查方法
|
||
if not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'):
|
||
return Response({
|
||
'code': 403,
|
||
'message': f'无权访问知识库: {knowledge_base.name}',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 添加知识库的external_id到列表
|
||
if knowledge_base.external_id:
|
||
external_id_list.append(str(knowledge_base.external_id))
|
||
else:
|
||
logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id")
|
||
|
||
except Exception as e:
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'处理知识库ID出错: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
if not external_id_list:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '没有有效的知识库external_id',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 创建metadata
|
||
metadata = {
|
||
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
|
||
'dataset_id_list': [str(id) for id in dataset_ids],
|
||
'dataset_external_id_list': [str(id) for id in external_id_list],
|
||
'dataset_names': [kb.name for kb in knowledge_bases]
|
||
}
|
||
|
||
# 检查是否有自定义标题
|
||
title = data.get('title', 'New chat')
|
||
|
||
# 创建用户问题记录
|
||
question_record = ChatHistory.objects.create(
|
||
user=request.user,
|
||
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
|
||
conversation_id=str(conversation_id),
|
||
title=title, # 设置标题
|
||
role='user',
|
||
content=data['question'],
|
||
metadata=metadata
|
||
)
|
||
|
||
# 检查是否需要流式输出
|
||
use_stream = data.get('stream', True)
|
||
|
||
if use_stream:
|
||
# 创建流式响应
|
||
response = StreamingHttpResponse(
|
||
self._stream_answer_from_external_api(
|
||
conversation_id=str(conversation_id),
|
||
question_record=question_record,
|
||
dataset_external_id_list=external_id_list,
|
||
knowledge_bases=knowledge_bases,
|
||
question=data['question'],
|
||
metadata=metadata
|
||
),
|
||
content_type='text/event-stream',
|
||
status=status.HTTP_201_CREATED # 修改状态码为201
|
||
)
|
||
|
||
# 添加禁用缓存的头部
|
||
response['Cache-Control'] = 'no-cache, no-store'
|
||
response['Connection'] = 'keep-alive'
|
||
|
||
return response
|
||
else:
|
||
# 使用非流式输出
|
||
logger.info("使用非流式输出模式")
|
||
# 调用同步 API 获取回答
|
||
answer = self._get_answer_from_external_api(external_id_list, data['question'])
|
||
|
||
if answer is None:
|
||
return Response({
|
||
'code': 500,
|
||
'message': '获取回答失败',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 创建 AI 回答记录
|
||
answer_record = ChatHistory.objects.create(
|
||
user=request.user,
|
||
knowledge_base=knowledge_bases[0],
|
||
conversation_id=str(conversation_id),
|
||
title=title, # 设置标题
|
||
parent_id=str(question_record.id),
|
||
role='assistant',
|
||
content=answer,
|
||
metadata=metadata
|
||
)
|
||
|
||
# 如果是新会话的第一条消息,并且没有自定义标题,则自动生成标题
|
||
should_generate_title = not existing_records.exists() and (not title or title == 'New chat')
|
||
if should_generate_title:
|
||
try:
|
||
generated_title = self._generate_conversation_title_from_deepseek(
|
||
data['question'],
|
||
answer
|
||
)
|
||
if generated_title:
|
||
# 更新所有相关记录的标题
|
||
ChatHistory.objects.filter(
|
||
conversation_id=str(conversation_id)
|
||
).update(title=generated_title)
|
||
title = generated_title
|
||
except Exception as e:
|
||
logger.error(f"自动生成标题失败: {str(e)}")
|
||
# 继续执行,不影响主流程
|
||
|
||
return Response({
|
||
'code': 200, # 修改状态码为201
|
||
'message': '成功',
|
||
'data': {
|
||
'id': str(answer_record.id),
|
||
'conversation_id': str(conversation_id),
|
||
'title': title, # 添加标题字段
|
||
'dataset_id_list': metadata.get('dataset_id_list', []),
|
||
'dataset_names': metadata.get('dataset_names', []),
|
||
'role': 'assistant',
|
||
'content': answer,
|
||
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
}, status=status.HTTP_200_CREATED) # 修改状态码为201
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建聊天记录失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'创建聊天记录失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def _stream_answer_from_external_api(self, conversation_id, question_record, dataset_external_id_list, knowledge_bases, question, metadata):
|
||
"""流式获取AI回答并实时返回 - 优化版本"""
|
||
try:
|
||
# 确保所有ID都是字符串
|
||
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
|
||
|
||
# 获取标题
|
||
title = question_record.title or 'New chat'
|
||
|
||
# 创建AI回答记录对象,稍后更新内容
|
||
answer_record = ChatHistory.objects.create(
|
||
user=question_record.user,
|
||
knowledge_base=knowledge_bases[0],
|
||
conversation_id=str(conversation_id),
|
||
title=title, # 设置标题
|
||
parent_id=str(question_record.id),
|
||
role='assistant',
|
||
content="", # 初始内容为空
|
||
metadata=metadata
|
||
)
|
||
|
||
# 发送初始响应告知客户端开始流式传输
|
||
yield f"data: {json.dumps({'code': 200, 'message': '开始流式传输', 'data': {'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': '', 'is_end': False}})}\n\n"
|
||
|
||
# 异步收集完整内容,用于最后保存
|
||
full_content = ""
|
||
|
||
# 打开与外部API的连接
|
||
logger.info(f"开始调用外部API,知识库ID列表: {dataset_external_ids}")
|
||
|
||
try:
|
||
# 第一步: 创建聊天会话
|
||
chat_response = requests.post(
|
||
url=f"{settings.API_BASE_URL}/api/application/chat/open",
|
||
json={
|
||
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
|
||
"model_id": "7a214d0e-e65e-11ef-9f4a-0242ac120006",
|
||
"dataset_id_list": dataset_external_ids,
|
||
"multiple_rounds_dialogue": False,
|
||
"dataset_setting": {
|
||
"top_n": 10, "similarity": "0.3",
|
||
"max_paragraph_char_number": 10000,
|
||
"search_mode": "blend",
|
||
"no_references_setting": {
|
||
"value": "{question}",
|
||
"status": "ai_questioning"
|
||
}
|
||
},
|
||
"model_setting": {
|
||
"prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}"
|
||
},
|
||
"problem_optimization": False
|
||
},
|
||
headers={"Content-Type": "application/json"},
|
||
|
||
)
|
||
|
||
if chat_response.status_code != 200:
|
||
error_msg = f"外部API调用失败: {chat_response.text}"
|
||
logger.error(error_msg)
|
||
yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n"
|
||
return
|
||
|
||
chat_data = chat_response.json()
|
||
if chat_data.get('code') != 200 or not chat_data.get('data'):
|
||
error_msg = f"外部API返回错误: {chat_data}"
|
||
logger.error(error_msg)
|
||
yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n"
|
||
return
|
||
|
||
chat_id = chat_data['data']
|
||
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
|
||
|
||
# 第二步: 建立流式连接
|
||
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
|
||
logger.info(f"开始流式请求: {message_url}")
|
||
|
||
# 创建流式请求
|
||
message_request = requests.post(
|
||
url=message_url,
|
||
json={"message": question, "re_chat": False, "stream": True},
|
||
headers={"Content-Type": "application/json"},
|
||
stream=True, # 启用流式传输
|
||
|
||
)
|
||
|
||
if message_request.status_code != 200:
|
||
error_msg = f"外部API聊天消息调用失败: {message_request.status_code}, {message_request.text}"
|
||
logger.error(error_msg)
|
||
yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n"
|
||
return
|
||
|
||
# 创建一个缓冲区以处理分段的数据
|
||
buffer = ""
|
||
|
||
# 读取并处理每个响应块
|
||
logger.info("开始处理流式响应")
|
||
for chunk in message_request.iter_content(chunk_size=1):
|
||
if not chunk:
|
||
continue
|
||
|
||
# 解码字节为字符串
|
||
chunk_str = chunk.decode('utf-8')
|
||
buffer += chunk_str
|
||
|
||
# 检查是否有完整的数据行
|
||
if '\n\n' in buffer:
|
||
lines = buffer.split('\n\n')
|
||
# 除了最后一行,其他都是完整的
|
||
for line in lines[:-1]:
|
||
# 处理完整的数据行
|
||
if line.startswith('data: '):
|
||
try:
|
||
# 提取JSON数据
|
||
json_str = line[6:] # 去掉 "data: " 前缀
|
||
data = json.loads(json_str)
|
||
|
||
# 记录并处理部分响应
|
||
if 'content' in data:
|
||
content_part = data['content']
|
||
full_content += content_part
|
||
|
||
# 构建响应数据
|
||
response_data = {
|
||
'code': 200, # 修改状态码为201
|
||
'message': 'partial',
|
||
'data': {
|
||
'id': str(answer_record.id),
|
||
'conversation_id': str(conversation_id),
|
||
'title': title, # 添加标题字段
|
||
'content': content_part,
|
||
'is_end': data.get('is_end', False)
|
||
}
|
||
}
|
||
|
||
# 立即发送每个部分到客户端
|
||
yield f"data: {json.dumps(response_data)}\n\n"
|
||
|
||
# 处理结束标记
|
||
if data.get('is_end', False):
|
||
logger.info("收到流式响应结束标记")
|
||
# 异步保存完整内容
|
||
answer_record.content = full_content.strip()
|
||
answer_record.save()
|
||
|
||
# 先检查当前conversation_id是否已有有效标题
|
||
current_title = ChatHistory.objects.filter(
|
||
conversation_id=str(conversation_id)
|
||
).exclude(
|
||
title__in=["New chat", "新对话", ""]
|
||
).values_list('title', flat=True).first()
|
||
|
||
# 如果已有有效标题,则复用
|
||
if current_title:
|
||
title = current_title
|
||
logger.info(f"复用已有标题: {title}")
|
||
else:
|
||
# 没有有效标题时,直接基于当前问题和回答生成标题
|
||
try:
|
||
# 直接使用当前的问题和完整的AI回答来生成标题
|
||
generated_title = self._generate_conversation_title_from_deepseek(
|
||
question, full_content.strip()
|
||
)
|
||
if generated_title:
|
||
# 更新所有相关记录的标题
|
||
ChatHistory.objects.filter(
|
||
conversation_id=str(conversation_id)
|
||
).update(title=generated_title)
|
||
title = generated_title
|
||
logger.info(f"成功生成标题: {title}")
|
||
else:
|
||
title = "新对话" # 如果生成失败,使用默认标题
|
||
logger.warning("生成标题失败,使用默认标题")
|
||
except Exception as e:
|
||
logger.error(f"自动生成标题失败: {str(e)}")
|
||
title = "新对话" # 如果出错,使用默认标题
|
||
|
||
# 发送完整内容的最终响应
|
||
final_response = {
|
||
'code': 200, # 修改状态码为201
|
||
'message': '完成',
|
||
'data': {
|
||
'id': str(answer_record.id),
|
||
'conversation_id': str(conversation_id),
|
||
'title': title, # 添加生成的标题
|
||
'dataset_id_list': metadata.get('dataset_id_list', []),
|
||
'dataset_names': metadata.get('dataset_names', []),
|
||
'role': 'assistant',
|
||
'content': full_content.strip(),
|
||
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'is_end': True
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(final_response)}\n\n"
|
||
return # 结束生成器
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON解析错误: {e}, 数据: {line}")
|
||
# 继续处理,跳过此行
|
||
|
||
# 保留最后一个可能不完整的行
|
||
buffer = lines[-1]
|
||
|
||
# 处理最后可能剩余的缓冲数据
|
||
if buffer:
|
||
logger.info(f"处理剩余缓冲数据: {buffer}")
|
||
if buffer.startswith('data: '):
|
||
try:
|
||
json_str = buffer[6:] # 去掉 "data: " 前缀
|
||
data = json.loads(json_str)
|
||
|
||
if 'content' in data:
|
||
content_part = data['content']
|
||
full_content += content_part
|
||
|
||
response_data = {
|
||
'code': 200, # 修改状态码为201
|
||
'message': 'partial',
|
||
'data': {
|
||
'id': str(answer_record.id),
|
||
'conversation_id': str(conversation_id), # 添加标题字段
|
||
'content': content_part,
|
||
'is_end': data.get('is_end', False)
|
||
}
|
||
}
|
||
yield f"data: {json.dumps(response_data)}\n\n"
|
||
except json.JSONDecodeError:
|
||
logger.error(f"处理剩余数据时JSON解析错误: {buffer}")
|
||
|
||
# 确保在流结束时保存内容到数据库
|
||
if full_content:
|
||
answer_record.content = full_content.strip()
|
||
answer_record.save()
|
||
logger.info(f"流结束,保存完整内容到数据库: {len(full_content)} 字符")
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"请求外部API时发生错误: {str(e)}")
|
||
yield f"data: {json.dumps({'code': 500, 'message': f'请求外部API时发生错误: {str(e)}', 'data': {'is_end': True}})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"流式处理出错: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
yield f"data: {json.dumps({'code': 500, 'message': f'流式处理出错: {str(e)}', 'data': {'is_end': True}})}\n\n"
|
||
|
||
# 尝试保存已收集的内容
|
||
if 'full_content' in locals() and full_content:
|
||
try:
|
||
answer_record.content = full_content.strip()
|
||
answer_record.save()
|
||
except Exception as save_error:
|
||
logger.error(f"保存部分内容失败: {str(save_error)}")
|
||
|
||
def _get_answer_from_external_api(self, dataset_external_id_list, question):
|
||
"""调用外部API获取AI回答(非流式版本)"""
|
||
try:
|
||
# 确保所有ID都是字符串
|
||
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
|
||
|
||
logger.info(f"准备调用外部API(非流式模式),知识库ID列表: {dataset_external_ids}")
|
||
|
||
# 第一个API调用创建聊天
|
||
chat_request_data = {
|
||
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
|
||
"model_id": "7a214d0e-e65e-11ef-9f4a-0242ac120006",
|
||
"dataset_id_list": dataset_external_ids,
|
||
"multiple_rounds_dialogue": False,
|
||
"dataset_setting": {
|
||
"top_n": 10,
|
||
"similarity": "0.3",
|
||
"max_paragraph_char_number": 10000,
|
||
"search_mode": "blend",
|
||
"no_references_setting": {
|
||
"value": "{question}",
|
||
"status": "ai_questioning"
|
||
}
|
||
},
|
||
"model_setting": {
|
||
"prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}"
|
||
},
|
||
"problem_optimization": False
|
||
}
|
||
|
||
logger.info(f"发送创建聊天请求:{settings.API_BASE_URL}/api/application/chat/open")
|
||
|
||
try:
|
||
# 测试JSON序列化,提前捕获可能的错误
|
||
json_data = json.dumps(chat_request_data)
|
||
logger.debug(f"请求数据序列化成功,长度: {len(json_data)}")
|
||
except TypeError as e:
|
||
logger.error(f"JSON序列化失败: {str(e)}")
|
||
return None
|
||
|
||
chat_response = requests.post(
|
||
url=f"{settings.API_BASE_URL}/api/application/chat/open",
|
||
json=chat_request_data,
|
||
headers={"Content-Type": "application/json"},
|
||
|
||
)
|
||
|
||
logger.info(f"API响应状态码: {chat_response.status_code}")
|
||
|
||
if chat_response.status_code != 200:
|
||
logger.error(f"外部API调用失败: {chat_response.text}")
|
||
return None
|
||
|
||
chat_data = chat_response.json()
|
||
logger.debug(f"API响应数据: {chat_data}")
|
||
|
||
if chat_data.get('code') != 200 or not chat_data.get('data'):
|
||
logger.error(f"外部API返回错误: {chat_data}")
|
||
return None
|
||
|
||
chat_id = chat_data['data']
|
||
logger.info(f"聊天创建成功,chat_id: {chat_id}")
|
||
|
||
# 第二个API调用发送消息
|
||
message_request_data = {
|
||
"message": question,
|
||
"re_chat": False,
|
||
"stream": False # 设置为非流式
|
||
}
|
||
|
||
logger.info(f"发送聊天消息请求(非流式): {settings.API_BASE_URL}/api/application/chat_message/{chat_id}")
|
||
message_response = requests.post(
|
||
url=f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}",
|
||
json=message_request_data,
|
||
headers={"Content-Type": "application/json"},
|
||
|
||
)
|
||
|
||
if message_response.status_code != 200:
|
||
logger.error(f"外部API聊天消息调用失败: {message_response.status_code}, {message_response.text}")
|
||
return None
|
||
|
||
# 处理非流式响应
|
||
try:
|
||
response_data = message_response.json()
|
||
logger.debug(f"非流式响应数据: {response_data}")
|
||
|
||
if response_data.get('code') != 200 or 'data' not in response_data:
|
||
logger.error(f"外部API返回错误: {response_data}")
|
||
return None
|
||
|
||
# 提取回答内容
|
||
answer_content = response_data.get('data', {}).get('content', '')
|
||
if not answer_content:
|
||
logger.warning("API返回的回答内容为空")
|
||
return "无法获取回答内容"
|
||
|
||
return answer_content
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析API响应JSON失败: {str(e)}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"处理API响应失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"调用外部API获取回答失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return None
|
||
|
||
def update(self, request, pk=None):
|
||
"""更新聊天记录"""
|
||
try:
|
||
record = self.get_queryset().filter(id=pk).first()
|
||
|
||
if not record:
|
||
return Response({
|
||
'code': 404,
|
||
'message': '记录不存在或无权限',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
data = request.data
|
||
updateable_fields = ['content', 'metadata']
|
||
|
||
if 'content' in data:
|
||
record.content = data['content']
|
||
|
||
if 'metadata' in data:
|
||
current_metadata = record.metadata or {}
|
||
current_metadata.update(data['metadata'])
|
||
record.metadata = current_metadata
|
||
|
||
record.save()
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '更新成功',
|
||
'data': {
|
||
'id': record.id,
|
||
'conversation_id': record.conversation_id,
|
||
'role': record.role,
|
||
'content': record.content,
|
||
'metadata': record.metadata,
|
||
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新聊天记录失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'更新聊天记录失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def destroy(self, request, pk=None):
|
||
"""删除聊天记录(软删除)"""
|
||
try:
|
||
record = self.get_queryset().filter(id=pk).first()
|
||
|
||
if not record:
|
||
return Response({
|
||
'code': 404,
|
||
'message': '记录不存在或无权限',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
record.soft_delete()
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '删除成功',
|
||
'data': None
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除聊天记录失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'删除聊天记录失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def search(self, request):
|
||
"""搜索聊天记录"""
|
||
try:
|
||
# 获取查询参数
|
||
keyword = request.query_params.get('keyword', '').strip()
|
||
dataset_id = request.query_params.get('dataset_id')
|
||
start_date = request.query_params.get('start_date')
|
||
end_date = request.query_params.get('end_date')
|
||
page = int(request.query_params.get('page', 1))
|
||
page_size = int(request.query_params.get('page_size', 10))
|
||
|
||
# 基础查询
|
||
query = self.get_queryset()
|
||
|
||
# 添加过滤条件
|
||
if keyword:
|
||
query = query.filter(
|
||
Q(content__icontains=keyword) |
|
||
Q(knowledge_base__name__icontains=keyword)
|
||
)
|
||
|
||
if dataset_id:
|
||
# 检查知识库权限
|
||
knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first()
|
||
if knowledge_base and not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'):
|
||
return Response({
|
||
'code': 403,
|
||
'message': '无权访问该知识库',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
query = query.filter(knowledge_base__id=dataset_id)
|
||
if start_date:
|
||
query = query.filter(created_at__gte=start_date)
|
||
if end_date:
|
||
query = query.filter(created_at__lte=end_date)
|
||
|
||
# 计算分页
|
||
total = query.count()
|
||
start = (page - 1) * page_size
|
||
end = start + page_size
|
||
|
||
# 获取分页数据
|
||
records = query.order_by('-created_at')[start:end]
|
||
|
||
# 序列化数据
|
||
results = []
|
||
for record in records:
|
||
result = {
|
||
'id': record.id,
|
||
'conversation_id': record.conversation_id,
|
||
'dataset_id': str(record.knowledge_base.id),
|
||
'dataset_name': record.knowledge_base.name,
|
||
'role': record.role,
|
||
'content': record.content,
|
||
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'metadata': record.metadata
|
||
}
|
||
|
||
if keyword:
|
||
result['highlights'] = {
|
||
'content': self._highlight_keyword(record.content, keyword)
|
||
}
|
||
|
||
results.append(result)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '搜索成功',
|
||
'data': {
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'results': results
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索聊天记录失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'搜索失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def export(self, request):
|
||
"""导出聊天记录为Excel文件"""
|
||
try:
|
||
# 获取查询参数
|
||
conversation_id = request.query_params.get('conversation_id')
|
||
dataset_id = request.query_params.get('dataset_id')
|
||
history_days = request.query_params.get('history_days', '7') # 默认导出最近7天
|
||
|
||
# 至少需要一个筛选条件
|
||
if not conversation_id and not dataset_id:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '需要提供conversation_id或dataset_id参数',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证权限
|
||
user = request.user
|
||
if dataset_id:
|
||
knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first()
|
||
if not knowledge_base:
|
||
return Response({
|
||
'code': 404,
|
||
'message': '知识库不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 使用统一的权限检查方法
|
||
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
|
||
return Response({
|
||
'code': 403,
|
||
'message': '无权访问该知识库',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 查询确认有聊天记录存在
|
||
query = self.get_queryset()
|
||
if conversation_id:
|
||
records = query.filter(conversation_id=conversation_id)
|
||
elif dataset_id:
|
||
records = query.filter(knowledge_base__id=dataset_id)
|
||
|
||
if not records.exists():
|
||
return Response({
|
||
'code': 404,
|
||
'message': '未找到相关对话记录',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 调用外部API导出Excel文件 - 使用GET请求
|
||
application_id = "d5d11efa-ea9a-11ef-9933-0242ac120006" # 固定值
|
||
export_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat/export?history_day={history_days}"
|
||
|
||
logger.info(f"发送导出请求:{export_url}")
|
||
|
||
export_response = requests.get(
|
||
url=export_url,
|
||
|
||
stream=True # 使用流式传输处理大文件
|
||
)
|
||
|
||
# 检查响应状态
|
||
if export_response.status_code != 200:
|
||
logger.error(f"导出API调用失败: {export_response.status_code}, {export_response.text}")
|
||
return Response({
|
||
'code': 500,
|
||
'message': '导出失败,外部服务返回错误',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 创建响应对象并设置文件下载头
|
||
response = HttpResponse(
|
||
content_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
||
)
|
||
response['Content-Disposition'] = 'attachment; filename="data.xlsx"'
|
||
|
||
# 将API响应内容写入响应对象
|
||
for chunk in export_response.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
response.write(chunk)
|
||
|
||
logger.info("导出成功完成")
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"导出聊天记录失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'导出聊天记录失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def chat_list(self, request):
|
||
"""获取对话列表"""
|
||
try:
|
||
# 获取查询参数
|
||
history_days = request.query_params.get('history_days', '7') # 默认7天
|
||
|
||
# 构建API请求
|
||
application_id = "d5d11efa-ea9a-11ef-9933-0242ac120006"
|
||
api_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat"
|
||
|
||
# 添加查询参数
|
||
params = {
|
||
'history_day': history_days
|
||
}
|
||
|
||
logger.info(f"发送获取对话列表请求:{api_url}")
|
||
|
||
# 调用外部API
|
||
response = requests.get(
|
||
url=api_url,
|
||
params=params,
|
||
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"获取对话列表失败: {response.status_code}, {response.text}")
|
||
return Response({
|
||
'code': 500,
|
||
'message': '获取对话列表失败,外部服务返回错误',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 解析响应数据
|
||
try:
|
||
result = response.json()
|
||
|
||
if result.get('code') != 200:
|
||
logger.error(f"外部API返回错误: {result}")
|
||
return Response({
|
||
'code': result.get('code', 500),
|
||
'message': result.get('message', '获取对话列表失败'),
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 处理返回的数据
|
||
chat_list = result.get('data', [])
|
||
|
||
# 格式化返回数据
|
||
formatted_chats = []
|
||
for chat in chat_list:
|
||
formatted_chat = {
|
||
'id': chat['id'],
|
||
'chat_id': chat['chat_id'],
|
||
'abstract': chat['abstract'],
|
||
'message_count': chat['chat_record_count'],
|
||
'created_at': datetime.fromisoformat(chat['create_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'),
|
||
'updated_at': datetime.fromisoformat(chat['update_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'),
|
||
'star_count': chat['star_num'],
|
||
'trample_count': chat['trample_num'],
|
||
'mark_sum': chat['mark_sum'],
|
||
'is_deleted': chat['is_deleted']
|
||
}
|
||
formatted_chats.append(formatted_chat)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取成功',
|
||
'data': {
|
||
'total': len(formatted_chats),
|
||
'results': formatted_chats
|
||
}
|
||
})
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析响应数据失败: {str(e)}")
|
||
return Response({
|
||
'code': 500,
|
||
'message': '解析响应数据失败',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取对话列表失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取对话列表失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['post'])
|
||
def hit_test(self, request):
|
||
"""获取问题与知识库文档的匹配度"""
|
||
try:
|
||
data = request.data
|
||
|
||
# 检查必填字段
|
||
if 'question' not in data:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必填字段: question',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
if 'dataset_id_list' not in data or not data['dataset_id_list']:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必填字段: dataset_id_list',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
question = data['question']
|
||
dataset_ids = data['dataset_id_list']
|
||
|
||
# 如果不是列表,转换为列表
|
||
if not isinstance(dataset_ids, list):
|
||
try:
|
||
dataset_ids = json.loads(dataset_ids)
|
||
if not isinstance(dataset_ids, list):
|
||
dataset_ids = [dataset_ids]
|
||
except (json.JSONDecodeError, TypeError):
|
||
dataset_ids = [dataset_ids]
|
||
|
||
# 检查用户是否有权限访问这些知识库
|
||
external_id_list = []
|
||
for kb_id in dataset_ids:
|
||
try:
|
||
kb = KnowledgeBase.objects.get(id=kb_id)
|
||
if not self.check_knowledge_base_permission(kb, request.user, 'read'):
|
||
return Response({
|
||
'code': 403,
|
||
'message': f'无权访问知识库: {kb.name}',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
if kb.external_id:
|
||
external_id_list.append(str(kb.external_id))
|
||
else:
|
||
logger.warning(f"知识库 {kb.id} ({kb.name}) 没有external_id")
|
||
except KnowledgeBase.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'知识库不存在: {kb_id}',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
if not external_id_list:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '没有有效的知识库external_id',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 获取所有知识库的匹配文档
|
||
all_documents = []
|
||
for dataset_id in external_id_list:
|
||
doc_info = self._call_hit_test_api(dataset_id, question)
|
||
if doc_info:
|
||
all_documents.extend(doc_info)
|
||
|
||
# 按相似度排序
|
||
all_documents = sorted(all_documents, key=lambda x: x.get('similarity', 0), reverse=True)
|
||
|
||
# 返回结果
|
||
return Response({
|
||
'code': 200,
|
||
'message': '成功',
|
||
'data': {
|
||
'question': question,
|
||
'matched_documents': all_documents,
|
||
'total_count': len(all_documents)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"hit_test接口调用失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'hit_test接口调用失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def _highlight_keyword(self, text, keyword):
|
||
"""高亮关键词"""
|
||
if not keyword or not text:
|
||
return text
|
||
return text.replace(
|
||
keyword,
|
||
f'<em class="highlight">{keyword}</em>'
|
||
)
|
||
|
||
def _call_hit_test_api(self, dataset_id, query_text):
|
||
"""调用知识库hit_test接口获取相关文档信息"""
|
||
try:
|
||
url = f"{settings.API_BASE_URL}/api/dataset/{dataset_id}/hit_test"
|
||
params = {
|
||
"query_text": query_text,
|
||
"top_number": 10,
|
||
"similarity": 0.3,
|
||
"search_mode": "blend"
|
||
}
|
||
|
||
logger.info(f"调用hit_test接口: {url}, 参数: {params}")
|
||
|
||
response = requests.get(
|
||
url=url,
|
||
params=params,
|
||
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"hit_test接口调用失败: {response.status_code}, {response.text}")
|
||
return None
|
||
|
||
result = response.json()
|
||
if result.get('code') != 200:
|
||
logger.error(f"hit_test接口业务错误: {result}")
|
||
return None
|
||
|
||
# 提取文档信息
|
||
documents = result.get('data', [])
|
||
logger.info(f"hit_test接口返回 {len(documents)} 个相关文档")
|
||
|
||
# 提取文档名称和相似度等信息
|
||
doc_info = []
|
||
for doc in documents:
|
||
doc_info.append({
|
||
"document_name": doc.get("document_name", ""),
|
||
"dataset_name": doc.get("dataset_name", ""),
|
||
"similarity": doc.get("similarity", 0),
|
||
"comprehensive_score": doc.get("comprehensive_score", 0)
|
||
})
|
||
|
||
return doc_info
|
||
except Exception as e:
|
||
logger.error(f"调用hit_test接口失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return None
|
||
|
||
@action(detail=False, methods=['delete'])
|
||
def delete_conversation(self, request):
|
||
"""通过conversation_id删除一组会话"""
|
||
try:
|
||
# 获取conversation_id
|
||
conversation_id = request.query_params.get('conversation_id')
|
||
if not conversation_id:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必要参数: conversation_id',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 查找该会话下的所有记录
|
||
records = self.get_queryset().filter(conversation_id=conversation_id)
|
||
|
||
if not records.exists():
|
||
return Response({
|
||
'code': 404,
|
||
'message': '未找到该会话或无权限访问',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 获取记录数量
|
||
records_count = records.count()
|
||
|
||
# 批量软删除
|
||
for record in records:
|
||
record.soft_delete()
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '删除成功',
|
||
'data': {
|
||
'conversation_id': conversation_id,
|
||
'deleted_count': records_count
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除会话失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'删除会话失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'], url_path='generate-conversation-title')
|
||
def generate_conversation_title(self, request):
|
||
"""更新会话标题"""
|
||
try:
|
||
conversation_id = request.query_params.get('conversation_id')
|
||
if not conversation_id:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少conversation_id参数',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 检查对话是否存在
|
||
messages = self.get_queryset().filter(
|
||
conversation_id=conversation_id,
|
||
is_deleted=False,
|
||
user=request.user
|
||
).order_by('created_at')
|
||
|
||
if not messages.exists():
|
||
return Response({
|
||
'code': 404,
|
||
'message': '对话不存在或无权访问',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 检查是否有自定义标题参数
|
||
custom_title = request.query_params.get('title')
|
||
if not custom_title:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少title参数',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 更新所有相关记录的标题
|
||
ChatHistory.objects.filter(
|
||
conversation_id=conversation_id,
|
||
user=request.user
|
||
).update(title=custom_title)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '更新会话标题成功',
|
||
'data': {
|
||
'conversation_id': conversation_id,
|
||
'title': custom_title
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新会话标题失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f"更新会话标题失败: {str(e)}",
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def _generate_conversation_title_from_deepseek(self, user_question, assistant_answer):
|
||
"""调用SiliconCloud API生成会话标题,直接基于当前问题和回答内容"""
|
||
try:
|
||
# 从Django设置中获取API密钥
|
||
api_key = settings.SILICON_CLOUD_API_KEY
|
||
if not api_key:
|
||
return "新对话"
|
||
|
||
# 构建提示信息
|
||
prompt = f"请根据用户的问题和助手的回答,生成一个简短的对话标题(不超过20个字)。\n\n用户问题: {user_question}\n\n助手回答: {assistant_answer}"
|
||
|
||
import requests
|
||
|
||
url = "https://api.siliconflow.cn/v1/chat/completions"
|
||
|
||
payload = {
|
||
"model": "deepseek-ai/DeepSeek-V3",
|
||
"stream": False,
|
||
"max_tokens": 512,
|
||
"temperature": 0.7,
|
||
"top_p": 0.7,
|
||
"top_k": 50,
|
||
"frequency_penalty": 0.5,
|
||
"n": 1,
|
||
"stop": [],
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
]
|
||
}
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(url, json=payload, headers=headers)
|
||
response_data = response.json()
|
||
|
||
if response.status_code == 200 and 'choices' in response_data and response_data['choices']:
|
||
title = response_data['choices'][0]['message']['content'].strip()
|
||
return title[:50] # 截断过长的标题
|
||
else:
|
||
logger.error(f"生成标题时出错: {response.text}")
|
||
return "新对话"
|
||
|
||
except Exception as e:
|
||
logger.exception(f"生成对话标题时发生错误: {str(e)}")
|
||
return "新对话"
|
||
|
||
|
||
class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
|
||
serializer_class = KnowledgeBaseSerializer
|
||
permission_classes = [IsAuthenticated]
|
||
|
||
def list(self, request, *args, **kwargs):
|
||
try:
|
||
queryset = self.get_queryset()
|
||
|
||
# 获取搜索关键字
|
||
keyword = request.query_params.get('keyword', '')
|
||
|
||
# 如果有关键字,构建搜索条件
|
||
if keyword:
|
||
query = Q(name__icontains=keyword) | \
|
||
Q(desc__icontains=keyword) | \
|
||
Q(department__icontains=keyword) | \
|
||
Q(group__icontains=keyword)
|
||
queryset = queryset.filter(query)
|
||
|
||
# 获取分页参数
|
||
try:
|
||
page = int(request.query_params.get('page', 1))
|
||
page_size = int(request.query_params.get('page_size', 10))
|
||
except ValueError:
|
||
page = 1
|
||
page_size = 10
|
||
|
||
# 计算总数量
|
||
total = queryset.count()
|
||
|
||
# 分页处理
|
||
start = (page - 1) * page_size
|
||
end = start + page_size
|
||
paginated_queryset = queryset[start:end]
|
||
|
||
# 序列化知识库数据
|
||
serializer = self.get_serializer(paginated_queryset, many=True)
|
||
data = serializer.data
|
||
|
||
# 为每个知识库添加权限信息
|
||
user = request.user
|
||
for item in data:
|
||
# 获取必要的知识库属性
|
||
kb_type = item['type']
|
||
department = item.get('department')
|
||
group = item.get('group')
|
||
creator_id = item.get('user_id')
|
||
kb_id = item['id']
|
||
|
||
# 首先检查权限表中的显式权限
|
||
explicit_permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=kb_id,
|
||
user=user,
|
||
status='active'
|
||
).first()
|
||
|
||
if explicit_permission:
|
||
item['permissions'] = {
|
||
'can_read': explicit_permission.can_read,
|
||
'can_edit': explicit_permission.can_edit,
|
||
'can_delete': explicit_permission.can_delete
|
||
}
|
||
# 添加知识库的到期时间
|
||
item['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
|
||
else:
|
||
# 没有显式权限时使用统一的权限判断方法
|
||
item['permissions'] = {
|
||
'can_read': self._can_read(kb_type, user, department, group, creator_id, kb_id),
|
||
'can_edit': self._can_edit(kb_type, user, department, group, creator_id, kb_id),
|
||
'can_delete': self._can_delete(kb_type, user, department, group, creator_id, kb_id)
|
||
}
|
||
# 对于admin类型的知识库,设置expires_at为None
|
||
if kb_type == 'admin':
|
||
item['expires_at'] = None
|
||
else:
|
||
# 对于其他类型,如果没有显式权限记录,则表示没有到期时间
|
||
item['expires_at'] = None
|
||
|
||
# 处理高亮
|
||
if keyword:
|
||
if 'name' in item and keyword.lower() in item['name'].lower():
|
||
item['highlighted_name'] = item['name'].replace(
|
||
keyword, f'<em class="highlight">{keyword}</em>'
|
||
)
|
||
|
||
if 'desc' in item and item.get('desc') is not None:
|
||
desc_text = str(item['desc'])
|
||
if keyword.lower() in desc_text.lower():
|
||
item['highlighted_desc'] = desc_text.replace(
|
||
keyword, f'<em class="highlight">{keyword}</em>'
|
||
)
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "获取知识库列表成功",
|
||
"data": {
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"keyword": keyword if keyword else None,
|
||
"items": data
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取知识库列表失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取知识库列表失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def get_queryset(self):
|
||
"""获取用户有权限查看的知识库列表"""
|
||
user = self.request.user
|
||
queryset = KnowledgeBase.objects.all()
|
||
|
||
# 1. 构建基础权限条件
|
||
permission_conditions = Q()
|
||
|
||
# 2. 所有用户都可以看到 admin 类型的知识库
|
||
permission_conditions |= Q(type='admin')
|
||
|
||
# 3. 用户可以看到自己创建的所有知识库
|
||
permission_conditions |= Q(user_id=user.id)
|
||
|
||
# 4. 添加显式权限条件
|
||
# 获取所有活跃的权限记录
|
||
active_permissions = KBPermissionModel.objects.filter(
|
||
user=user,
|
||
can_read=True,
|
||
status='active',
|
||
expires_at__gt=timezone.now()
|
||
).values_list('knowledge_base_id', flat=True)
|
||
|
||
if active_permissions:
|
||
permission_conditions |= Q(id__in=active_permissions)
|
||
|
||
# 5. 根据用户角色添加隐式权限
|
||
if user.role == 'admin':
|
||
# 管理员可以看到除了其他用户 private 类型外的所有知识库
|
||
permission_conditions |= ~Q(type='private') | Q(user_id=user.id)
|
||
elif user.role == 'leader':
|
||
# 组长可以查看本部门的 leader 和 member 类型知识库
|
||
permission_conditions |= Q(
|
||
type__in=['leader', 'member'],
|
||
department=user.department
|
||
)
|
||
elif user.role in ['member', 'user']:
|
||
# 成员可以查看本部门的 leader 类型知识库
|
||
permission_conditions |= Q(
|
||
type='leader',
|
||
department=user.department
|
||
)
|
||
# 成员可以查看本部门本组的 member 类型知识库
|
||
permission_conditions |= Q(
|
||
type='member',
|
||
department=user.department,
|
||
group=user.group
|
||
)
|
||
|
||
return queryset.filter(permission_conditions).distinct()
|
||
|
||
def create(self, request, *args, **kwargs):
|
||
try:
|
||
# 1. 验证知识库名称
|
||
name = request.data.get('name')
|
||
if not name:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '知识库名称不能为空',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
if KnowledgeBase.objects.filter(name=name).exists():
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'知识库名称 "{name}" 已存在',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 2. 验证用户权限和必填字段
|
||
user = request.user
|
||
type = request.data.get('type', 'private')
|
||
department = request.data.get('department')
|
||
group = request.data.get('group')
|
||
|
||
# 修改权限验证
|
||
if type == 'admin':
|
||
# 移除管理员权限检查,允许所有用户创建
|
||
department = None
|
||
group = None
|
||
|
||
elif type == 'secret':
|
||
if user.role != 'admin':
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有管理员可以创建保密级知识库',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
department = None
|
||
group = None
|
||
|
||
elif type == 'leader':
|
||
if user.role != 'admin':
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有管理员可以创建组长级知识库',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
if not department:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '创建组长级知识库时必须指定部门',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
elif type == 'member':
|
||
if user.role not in ['admin', 'leader']:
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有管理员和组长可以创建成员级知识库',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
if user.role == 'admin' and not department:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '管理员创建成员知识库时必须指定部门',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
elif user.role == 'leader':
|
||
department = user.department
|
||
|
||
if not group:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '创建成员知识库时必须指定组',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
elif type == 'private':
|
||
# 对于private类型,不保存department和group
|
||
department = None
|
||
group = None
|
||
|
||
# 3. 验证请求数据
|
||
data = request.data.copy()
|
||
data['department'] = department
|
||
data['group'] = group
|
||
|
||
# 不需要手动设置 user_id,由序列化器自动处理
|
||
serializer = self.get_serializer(data=data)
|
||
if not serializer.is_valid():
|
||
logger.error(f"数据验证失败: {serializer.errors}")
|
||
return Response({
|
||
'code': 400,
|
||
'message': '数据验证失败',
|
||
'data': serializer.errors
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
with transaction.atomic():
|
||
# 4. 创建知识库
|
||
try:
|
||
knowledge_base = serializer.save()
|
||
logger.info(f"知识库创建成功: id={knowledge_base.id}, name={knowledge_base.name}, user_id={knowledge_base.user_id}")
|
||
except Exception as e:
|
||
logger.error(f"知识库创建失败: {str(e)}")
|
||
raise
|
||
|
||
# 5. 调用外部API创建知识库
|
||
try:
|
||
external_id = self._create_external_dataset(knowledge_base)
|
||
logger.info(f"外部知识库创建成功,获取ID: {external_id}")
|
||
|
||
# 保存外部知识库ID
|
||
knowledge_base.external_id = external_id
|
||
knowledge_base.save()
|
||
logger.info(f"更新knowledge_base的external_id为: {external_id}")
|
||
|
||
except ExternalAPIError as e:
|
||
logger.error(f"外部知识库创建失败: {str(e)}")
|
||
raise
|
||
|
||
# 6. 创建权限记录
|
||
try:
|
||
# 创建者权限
|
||
KBPermissionModel.objects.create(
|
||
knowledge_base=knowledge_base,
|
||
user=request.user,
|
||
can_read=True,
|
||
can_edit=True,
|
||
can_delete=True,
|
||
granted_by=request.user,
|
||
status='active'
|
||
)
|
||
logger.info(f"创建者权限创建成功")
|
||
|
||
# 根据类型批量创建其他用户权限
|
||
permissions = []
|
||
if type == 'admin':
|
||
users_query = User.objects.exclude(id=request.user.id)
|
||
# 为所有用户赋予完全权限(读、写、删)
|
||
permissions = [
|
||
KBPermissionModel(
|
||
knowledge_base=knowledge_base,
|
||
user=user,
|
||
can_read=True,
|
||
can_edit=True,
|
||
can_delete=True,
|
||
granted_by=request.user,
|
||
status='active'
|
||
) for user in users_query
|
||
]
|
||
elif type == 'secret':
|
||
users_query = User.objects.filter(role='admin').exclude(id=request.user.id)
|
||
permissions = [
|
||
KBPermissionModel(
|
||
knowledge_base=knowledge_base,
|
||
user=user,
|
||
can_read=True,
|
||
can_edit=self._can_edit(type, user),
|
||
can_delete=self._can_delete(type, user),
|
||
granted_by=request.user,
|
||
status='active'
|
||
) for user in users_query
|
||
]
|
||
elif type == 'leader':
|
||
users_query = User.objects.filter(
|
||
Q(role='admin') |
|
||
Q(role='leader', department=department)
|
||
).exclude(id=request.user.id)
|
||
permissions = [
|
||
KBPermissionModel(
|
||
knowledge_base=knowledge_base,
|
||
user=user,
|
||
can_read=True,
|
||
can_edit=self._can_edit(type, user),
|
||
can_delete=self._can_delete(type, user),
|
||
granted_by=request.user,
|
||
status='active'
|
||
) for user in users_query
|
||
]
|
||
elif type == 'member':
|
||
users_query = User.objects.filter(
|
||
Q(role='admin') |
|
||
Q(department=department, role='leader') |
|
||
Q(department=department, group=group, role='member')
|
||
).exclude(id=request.user.id)
|
||
permissions = [
|
||
KBPermissionModel(
|
||
knowledge_base=knowledge_base,
|
||
user=user,
|
||
can_read=True,
|
||
can_edit=self._can_edit(type, user),
|
||
can_delete=self._can_delete(type, user),
|
||
granted_by=request.user,
|
||
status='active'
|
||
) for user in users_query
|
||
]
|
||
else: # private
|
||
users_query = User.objects.none()
|
||
|
||
if permissions:
|
||
KBPermissionModel.objects.bulk_create(permissions)
|
||
logger.info(f"{type}类型权限创建完成: {len(permissions)}条记录")
|
||
|
||
except Exception as e:
|
||
logger.error(f"权限创建失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
raise
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '知识库创建成功',
|
||
'data': {
|
||
'knowledge_base': serializer.data,
|
||
'external_id': knowledge_base.external_id
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建知识库失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'创建知识库失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def update(self, request, *args, **kwargs):
|
||
"""更新知识库"""
|
||
try:
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 使用统一的权限检查方法
|
||
if not self.check_knowledge_base_permission(instance, user, 'edit'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有编辑权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
with transaction.atomic():
|
||
# 执行本地更新
|
||
serializer = self.get_serializer(instance, data=request.data, partial=True)
|
||
serializer.is_valid(raise_exception=True)
|
||
self.perform_update(serializer)
|
||
|
||
# 更新外部知识库
|
||
if instance.external_id:
|
||
try:
|
||
api_data = {
|
||
"name": serializer.validated_data.get('name', instance.name),
|
||
"desc": serializer.validated_data.get('desc', instance.desc),
|
||
"type": "0", # 保持与创建时一致
|
||
"meta": {}, # 保持与创建时一致
|
||
"documents": [] # 保持与创建时一致
|
||
}
|
||
|
||
response = requests.put(
|
||
f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}',
|
||
json=api_data,
|
||
headers={'Content-Type': 'application/json'},
|
||
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}")
|
||
|
||
api_response = response.json()
|
||
if not api_response.get('code') == 200:
|
||
raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}")
|
||
|
||
logger.info(f"外部知识库更新成功: {instance.external_id}")
|
||
|
||
except requests.exceptions.Timeout:
|
||
raise ExternalAPIError("请求超时,请稍后重试")
|
||
except requests.exceptions.RequestException as e:
|
||
raise ExternalAPIError(f"API请求失败: {str(e)}")
|
||
except Exception as e:
|
||
raise ExternalAPIError(f"更新外部知识库失败: {str(e)}")
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "知识库更新成功",
|
||
"data": serializer.data
|
||
})
|
||
|
||
except Http404:
|
||
return Response({
|
||
"code": 404,
|
||
"message": "知识库不存在",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
except ExternalAPIError as e:
|
||
logger.error(f"更新外部知识库失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": str(e),
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
except Exception as e:
|
||
logger.error(f"更新知识库失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"更新知识库失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def destroy(self, request, *args, **kwargs):
|
||
"""删除知识库"""
|
||
try:
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 使用统一的权限检查方法
|
||
if not self.check_knowledge_base_permission(instance, user, 'delete'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有删除权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 删除外部知识库(如果存在)
|
||
external_delete_success = True
|
||
external_error_message = None
|
||
if instance.external_id:
|
||
try:
|
||
self._delete_external_dataset(instance.external_id)
|
||
logger.info(f"外部知识库删除成功: {instance.external_id}")
|
||
except ExternalAPIError as e:
|
||
# 记录错误但继续执行本地删除
|
||
external_delete_success = False
|
||
external_error_message = str(e)
|
||
logger.warning(f"外部知识库删除失败,将继续删除本地知识库: {str(e)}")
|
||
|
||
# 删除本地知识库
|
||
self.perform_destroy(instance)
|
||
logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}")
|
||
|
||
# 如果外部知识库删除失败,返回警告消息
|
||
if not external_delete_success:
|
||
return Response({
|
||
"code": 200,
|
||
"message": f"知识库已删除,但外部知识库删除失败: {external_error_message}",
|
||
"data": None
|
||
})
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "知识库删除成功",
|
||
"data": None
|
||
})
|
||
|
||
except Http404:
|
||
return Response({
|
||
"code": 404,
|
||
"message": "知识库不存在",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
except Exception as e:
|
||
logger.error(f"删除知识库失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"删除知识库失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def _delete_external_dataset(self, external_id):
|
||
"""删除外部知识库"""
|
||
try:
|
||
if not external_id:
|
||
logger.warning("外部知识库ID为空,跳过删除")
|
||
return True
|
||
|
||
response = requests.delete(
|
||
f'{settings.API_BASE_URL}/api/dataset/{external_id}',
|
||
headers={'Content-Type': 'application/json'},
|
||
|
||
)
|
||
|
||
logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}")
|
||
|
||
# 检查响应状态码
|
||
if response.status_code == 404:
|
||
logger.warning(f"外部知识库不存在: {external_id}")
|
||
return True # 如果知识库不存在,也视为删除成功
|
||
elif response.status_code not in [200, 204]:
|
||
logger.warning(f"删除外部知识库状态码异常: {response.status_code}, {response.text}")
|
||
return True # 即使状态码异常,也允许继续删除本地知识库
|
||
|
||
# 检查业务状态码
|
||
try:
|
||
api_response = response.json()
|
||
if api_response.get('code') != 200:
|
||
# 如果是因为ID不存在,也视为成功
|
||
if "不存在" in api_response.get('message', ''):
|
||
logger.warning(f"外部知识库ID不存在,视为删除成功: {external_id}")
|
||
return True
|
||
logger.warning(f"业务处理返回非200状态码: {api_response.get('code')}, {api_response.get('message')}")
|
||
return True # 不再抛出异常,允许本地删除继续
|
||
logger.info(f"外部知识库删除成功: {external_id}")
|
||
return True
|
||
except ValueError:
|
||
# 如果无法解析 JSON,但状态码是 200,也认为成功
|
||
logger.warning(f"外部知识库删除响应无法解析JSON,但状态码为200,视为成功: {external_id}")
|
||
return True
|
||
|
||
except requests.exceptions.Timeout:
|
||
logger.error(f"删除外部知识库超时: {external_id}")
|
||
# 不再抛出异常,允许本地删除继续
|
||
return False
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}")
|
||
# 不再抛出异常,允许本地删除继续
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}")
|
||
# 不再抛出异常,允许本地删除继续
|
||
return False
|
||
|
||
@action(detail=True, methods=['get'])
|
||
def permissions(self, request, pk=None):
|
||
"""获取用户对特定知识库的权限"""
|
||
try:
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 使用统一的权限检查方法
|
||
permissions_data = {
|
||
"can_read": self.check_knowledge_base_permission(instance, user, 'read'),
|
||
"can_edit": self.check_knowledge_base_permission(instance, user, 'edit'),
|
||
"can_delete": self.check_knowledge_base_permission(instance, user, 'delete')
|
||
}
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "获取权限信息成功",
|
||
"data": {
|
||
"knowledge_base_id": instance.id,
|
||
"knowledge_base_name": instance.name,
|
||
"permissions": permissions_data
|
||
}
|
||
})
|
||
|
||
except Http404:
|
||
return Response({
|
||
"code": 404,
|
||
"message": "知识库不存在",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
except Exception as e:
|
||
logger.error(f"获取权限信息失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取权限信息失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
@action(detail=False, methods=['get'])
|
||
def summary(self, request):
|
||
"""获取所有可见知识库的概要信息(除了secret类型)"""
|
||
try:
|
||
user = request.user
|
||
|
||
# 基础查询:排除secret类型的知识库
|
||
queryset = KnowledgeBase.objects.exclude(type='secret')
|
||
|
||
summaries = []
|
||
for kb in queryset:
|
||
# 使用统一的权限判断方法
|
||
permissions = {
|
||
'can_read': self.check_knowledge_base_permission(kb, user, 'read'),
|
||
'can_edit': self.check_knowledge_base_permission(kb, user, 'edit'),
|
||
'can_delete': self.check_knowledge_base_permission(kb, user, 'delete')
|
||
}
|
||
|
||
# 获取知识库到期时间
|
||
explicit_permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=kb.id,
|
||
user=user,
|
||
status='active'
|
||
).first()
|
||
|
||
expires_at = None
|
||
if explicit_permission:
|
||
expires_at = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
|
||
elif kb.type == 'admin':
|
||
expires_at = None
|
||
|
||
# 只返回概要信息
|
||
summary = {
|
||
'id': str(kb.id),
|
||
'name': kb.name,
|
||
'desc': kb.desc,
|
||
'type': kb.type,
|
||
'department': kb.department,
|
||
'permissions': permissions,
|
||
'expires_at': expires_at
|
||
}
|
||
summaries.append(summary)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取知识库概要信息成功',
|
||
'data': summaries
|
||
})
|
||
|
||
except Exception as e:
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取知识库概要信息失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def retrieve(self, request, *args, **kwargs):
|
||
try:
|
||
# 获取知识库对象
|
||
instance = self.get_object()
|
||
serializer = self.get_serializer(instance)
|
||
data = serializer.data
|
||
|
||
# 获取用户
|
||
user = request.user
|
||
|
||
# 使用统一的权限判断方法
|
||
data['permissions'] = {
|
||
'can_read': self.check_knowledge_base_permission(instance, user, 'read'),
|
||
'can_edit': self.check_knowledge_base_permission(instance, user, 'edit'),
|
||
'can_delete': self.check_knowledge_base_permission(instance, user, 'delete')
|
||
}
|
||
|
||
# 添加知识库到期时间
|
||
explicit_permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=instance.id,
|
||
user=user,
|
||
status='active'
|
||
).first()
|
||
|
||
if explicit_permission:
|
||
data['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
|
||
else:
|
||
# 对于admin类型的知识库,设置expires_at为None
|
||
if instance.type == 'admin':
|
||
data['expires_at'] = None
|
||
else:
|
||
# 对于其他类型,如果没有显式权限记录,则表示没有到期时间
|
||
data['expires_at'] = None
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取知识库详情成功',
|
||
'data': data
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取知识库详情失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取知识库详情失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def search(self, request):
|
||
"""搜索知识库功能"""
|
||
try:
|
||
# 获取搜索关键字
|
||
keyword = request.query_params.get('keyword', '')
|
||
if not keyword:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "搜索关键字不能为空",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 获取分页参数
|
||
try:
|
||
page = int(request.query_params.get('page', 1))
|
||
page_size = int(request.query_params.get('page_size', 10))
|
||
except ValueError:
|
||
page = 1
|
||
page_size = 10
|
||
|
||
# 构建搜索条件
|
||
query = Q(name__icontains=keyword) | \
|
||
Q(desc__icontains=keyword) | \
|
||
Q(department__icontains=keyword) | \
|
||
Q(group__icontains=keyword)
|
||
|
||
# 排除 secret 类型的知识库
|
||
queryset = KnowledgeBase.objects.filter(query).exclude(type='secret')
|
||
|
||
# 获取用户
|
||
user = request.user
|
||
|
||
# 获取用户所有有效的知识库权限
|
||
active_permissions = KBPermissionModel.objects.filter(
|
||
user=user,
|
||
status='active',
|
||
expires_at__gt=timezone.now()
|
||
).select_related('knowledge_base')
|
||
|
||
# 创建权限映射字典
|
||
permission_map = {
|
||
str(perm.knowledge_base.id): {
|
||
'can_read': perm.can_read,
|
||
'can_edit': perm.can_edit,
|
||
'can_delete': perm.can_delete
|
||
}
|
||
for perm in active_permissions
|
||
}
|
||
|
||
# 计算总数量
|
||
total = queryset.count()
|
||
|
||
# 分页处理
|
||
start = (page - 1) * page_size
|
||
end = start + page_size
|
||
paginated_queryset = queryset[start:end]
|
||
|
||
# 序列化知识库数据
|
||
serializer = self.get_serializer(paginated_queryset, many=True)
|
||
data = serializer.data
|
||
|
||
# 处理每个知识库项的权限和返回内容
|
||
result_items = []
|
||
for item in data:
|
||
# 创建一个临时的知识库对象用于权限检查
|
||
temp_kb = KnowledgeBase(
|
||
id=item['id'],
|
||
type=item['type'],
|
||
department=item.get('department'),
|
||
group=item.get('group'),
|
||
user_id=item.get('user_id')
|
||
)
|
||
|
||
# 使用统一的权限判断方法
|
||
explicit_permission = KBPermissionModel.objects.filter(
|
||
knowledge_base_id=item['id'],
|
||
user=user,
|
||
status='active'
|
||
).first()
|
||
|
||
if explicit_permission:
|
||
kb_permissions = {
|
||
'can_read': explicit_permission.can_read,
|
||
'can_edit': explicit_permission.can_edit,
|
||
'can_delete': explicit_permission.can_delete
|
||
}
|
||
# 添加知识库的到期时间
|
||
item['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None
|
||
else:
|
||
# 使用统一的权限判断方法
|
||
kb_permissions = {
|
||
'can_read': self.check_knowledge_base_permission(temp_kb, user, 'read'),
|
||
'can_edit': self.check_knowledge_base_permission(temp_kb, user, 'edit'),
|
||
'can_delete': self.check_knowledge_base_permission(temp_kb, user, 'delete')
|
||
}
|
||
# 对于admin类型的知识库,设置expires_at为None
|
||
if item['type'] == 'admin':
|
||
item['expires_at'] = None
|
||
else:
|
||
# 对于其他类型,如果没有显式权限记录,则表示没有到期时间
|
||
item['expires_at'] = None
|
||
|
||
# 添加权限信息
|
||
item['permissions'] = kb_permissions
|
||
|
||
# 根据权限返回不同级别的信息
|
||
if kb_permissions['can_read']:
|
||
result_items.append(item)
|
||
else:
|
||
# 无读取权限,只返回概要信息
|
||
summary_info = {
|
||
'id': item['id'],
|
||
'name': item['name'],
|
||
'type': item['type'],
|
||
'department': item.get('department'),
|
||
'permissions': kb_permissions
|
||
}
|
||
result_items.append(summary_info)
|
||
|
||
# 高亮搜索关键字
|
||
for item in result_items:
|
||
if 'name' in item and keyword.lower() in item['name'].lower():
|
||
highlighted = item['name'].replace(
|
||
keyword, f'<em class="highlight">{keyword}</em>'
|
||
)
|
||
item['highlighted_name'] = highlighted
|
||
|
||
# 确保desc不为None并且是字符串
|
||
if 'desc' in item and item.get('desc') is not None:
|
||
desc_text = str(item['desc']) # 转换为字符串以确保安全
|
||
if keyword.lower() in desc_text.lower():
|
||
highlighted = desc_text.replace(
|
||
keyword, f'<em class="highlight">{keyword}</em>'
|
||
)
|
||
item['highlighted_desc'] = highlighted
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "搜索知识库成功",
|
||
"data": {
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"keyword": keyword,
|
||
"items": result_items
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"搜索知识库失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"搜索知识库失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=True, methods=['post'])
|
||
def change_type(self, request, pk=None):
|
||
"""修改知识库类型"""
|
||
try:
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 使用统一的权限检查方法检查编辑权限
|
||
if not self.check_knowledge_base_permission(instance, user, 'edit'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有修改权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 其余代码保持不变...
|
||
|
||
# 获取新类型
|
||
new_type = request.data.get('type')
|
||
if not new_type:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "新类型不能为空",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证类型是否有效
|
||
valid_types = ['private', 'admin', 'secret', 'leader', 'member']
|
||
if new_type not in valid_types:
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"无效的知识库类型,可选值: {', '.join(valid_types)}",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 角色特定的类型限制
|
||
if new_type == 'leader' and not user.role == 'admin': # 组长且不是管理员
|
||
# 组长只能在private和member类型之间切换
|
||
if new_type not in ['private', 'member']:
|
||
return Response({
|
||
"code": 403,
|
||
"message": "组长只能将知识库设置为private或member类型",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 处理department和group字段
|
||
department = request.data.get('department')
|
||
group = request.data.get('group')
|
||
|
||
# 组长只能设置自己部门
|
||
if new_type == 'leader' and not user.role == 'admin':
|
||
if department and department != user.department:
|
||
return Response({
|
||
"code": 403,
|
||
"message": "组长只能为本部门设置知识库",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
# 如果未指定部门,强制设置为组长的部门
|
||
department = user.department
|
||
|
||
# 根据类型验证必填字段
|
||
if new_type == 'leader':
|
||
if not department:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "组长级知识库必须指定部门",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
if new_type == 'member':
|
||
if not department:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "成员级知识库必须指定部门",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
if not group:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "成员级知识库必须指定组",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 如果是admin或secret类型,清除department和group
|
||
if new_type in ['admin', 'secret']:
|
||
department = None
|
||
group = None
|
||
|
||
# 如果是private类型但未指定department和group,使用原值
|
||
if new_type == 'private':
|
||
if department is None:
|
||
department = instance.department
|
||
if group is None:
|
||
group = instance.group
|
||
|
||
# 更新知识库类型和相关字段
|
||
instance.type = new_type
|
||
instance.department = department
|
||
instance.group = group
|
||
instance.save()
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": f"知识库类型已更新为{new_type}",
|
||
"data": {
|
||
"id": instance.id,
|
||
"name": instance.name,
|
||
"type": instance.type,
|
||
"department": instance.department,
|
||
"group": instance.group
|
||
}
|
||
})
|
||
|
||
except Http404:
|
||
return Response({
|
||
"code": 404,
|
||
"message": "知识库不存在",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
except Exception as e:
|
||
logger.error(f"修改知识库类型失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"修改知识库类型失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def _call_split_api_multiple(self, files):
|
||
"""调用文档分割API - 直接传递多个文件"""
|
||
try:
|
||
url = f'{settings.API_BASE_URL}/api/dataset/document/split'
|
||
|
||
# 准备请求数据 - 使用单个"file"字段
|
||
file_obj = files[0] # 先只处理第一个文件,便于排查问题
|
||
|
||
# 重置文件指针位置
|
||
if hasattr(file_obj, 'seek'):
|
||
file_obj.seek(0)
|
||
|
||
logger.info(f"准备上传文件: {file_obj.name}, 大小: {file_obj.size}字节, 类型: {file_obj.content_type}")
|
||
|
||
# 读取文件内容前100个字符进行记录
|
||
if hasattr(file_obj, 'read') and hasattr(file_obj, 'seek'):
|
||
content_preview = file_obj.read(100).decode('utf-8', errors='ignore')
|
||
logger.info(f"文件内容预览: {content_preview}")
|
||
file_obj.seek(0) # 重置文件指针
|
||
|
||
# 使用正确的字段名称发送请求
|
||
files_data = {'file': file_obj}
|
||
|
||
logger.info(f"调用分割API URL: {url}")
|
||
logger.info(f"请求字段: {list(files_data.keys())}")
|
||
|
||
# 发送请求
|
||
response = requests.post(
|
||
url,
|
||
files=files_data,
|
||
|
||
)
|
||
|
||
# 记录请求头和响应信息,方便排查问题
|
||
logger.info(f"请求头: {response.request.headers}")
|
||
logger.info(f"响应状态码: {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"分割API返回错误状态码: {response.status_code}, 响应: {response.text}")
|
||
return None
|
||
|
||
# 解析响应
|
||
result = response.json()
|
||
logger.info(f"分割API响应详情: {result}")
|
||
|
||
# 如果数据为空,可能是API期望的请求格式不对,尝试使用不同的字段名
|
||
if len(result.get('data', [])) == 0:
|
||
logger.warning("分割API返回的数据为空,尝试使用后备方案")
|
||
|
||
# 创建一个手动构建的文档结构
|
||
fallback_data = {
|
||
'code': 200,
|
||
'message': '成功',
|
||
'data': [
|
||
{
|
||
'name': file_obj.name,
|
||
'content': [
|
||
{
|
||
'title': '文档内容',
|
||
'content': '文件内容无法自动分割,请检查外部API。这是一个后备内容。'
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
logger.info("使用后备数据结构")
|
||
return fallback_data
|
||
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"调用分割API失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
|
||
# 创建一个后备响应
|
||
fallback_response = {
|
||
'code': 200,
|
||
'message': '成功',
|
||
'data': []
|
||
}
|
||
|
||
# 如果有文件,为每个文件创建一个基本文档结构
|
||
if files:
|
||
fallback_response['data'] = [
|
||
{
|
||
'name': file.name,
|
||
'content': [
|
||
{
|
||
'title': '文档内容',
|
||
'content': '文件内容无法自动分割,请检查API连接。'
|
||
}
|
||
]
|
||
} for file in files
|
||
]
|
||
|
||
logger.info("由于异常,返回后备响应")
|
||
return fallback_response
|
||
|
||
@action(detail=True, methods=['post'])
|
||
def upload_document(self, request, pk=None):
|
||
"""上传文档到知识库 - 支持多文件上传"""
|
||
try:
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 使用统一的权限检查方法
|
||
if not self.check_knowledge_base_permission(instance, user, 'edit'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有编辑权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 记录请求内容,方便调试
|
||
logger.info(f"请求内容: {request.data}")
|
||
logger.info(f"请求FILES: {request.FILES}")
|
||
|
||
# 获取上传的文件,尝试多种可能的字段名
|
||
files = []
|
||
# 尝试'files'字段(多文件)
|
||
if 'files' in request.FILES:
|
||
files = request.FILES.getlist('files')
|
||
# 尝试'file'字段(多文件)
|
||
elif 'file' in request.FILES:
|
||
files = request.FILES.getlist('file')
|
||
# 尝试files[]格式(常见于前端FormData)
|
||
elif any(key.startswith('files[') for key in request.FILES):
|
||
files = [file for key, file in request.FILES.items() if key.startswith('files[')]
|
||
# 尝试file[]格式
|
||
elif any(key.startswith('file[') for key in request.FILES):
|
||
files = [file for key, file in request.FILES.items() if key.startswith('file[')]
|
||
# 单个文件上传的情况
|
||
elif len(request.FILES) > 0:
|
||
# 如果有任何文件,就全部使用
|
||
files = list(request.FILES.values())
|
||
|
||
if not files:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "未找到上传文件,请确保表单字段名为'files'或'file'",
|
||
"data": {
|
||
"available_fields": list(request.FILES.keys())
|
||
}
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
logger.info(f"接收到 {len(files)} 个文件上传请求")
|
||
|
||
# 保存所有处理后的文档
|
||
saved_documents = []
|
||
failed_documents = []
|
||
|
||
# 验证knowledge_base的external_id是否有效
|
||
if not instance.external_id:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "知识库没有有效的external_id,请先创建知识库",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 先验证外部知识库是否存在
|
||
try:
|
||
# 简单的验证请求
|
||
verify_url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}'
|
||
verify_response = requests.get(verify_url)
|
||
if verify_response.status_code != 200:
|
||
logger.error(f"外部知识库不存在或无法访问: {instance.external_id}, 状态码: {verify_response.status_code}")
|
||
return Response({
|
||
"code": 404,
|
||
"message": f"外部知识库不存在或无法访问: {instance.external_id}",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
verify_data = verify_response.json()
|
||
if verify_data.get('code') != 200:
|
||
logger.error(f"验证外部知识库失败: {verify_data.get('message')}")
|
||
return Response({
|
||
"code": verify_data.get('code', 500),
|
||
"message": f"验证外部知识库失败: {verify_data.get('message', '未知错误')}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
logger.info(f"外部知识库验证成功: {instance.external_id}")
|
||
except Exception as e:
|
||
logger.error(f"验证外部知识库时出错: {str(e)}")
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"验证外部知识库时出错: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 逐个处理每个文件 - 避免一次性传多个文件导致外部API处理失败
|
||
for i, file in enumerate(files):
|
||
logger.info(f"处理第 {i+1} 个文件: {file.name}")
|
||
|
||
# 创建只包含当前文件的列表,传递给分割API
|
||
current_file = [file]
|
||
|
||
# 调用文档分割API
|
||
split_response = self._call_split_api_multiple(current_file)
|
||
if not split_response or split_response.get('code') != 200:
|
||
error_msg = f"文件 {file.name} 分割失败: {split_response.get('message', '未知错误') if split_response else '请求失败'}"
|
||
logger.error(error_msg)
|
||
failed_documents.append({
|
||
"name": file.name,
|
||
"error": error_msg
|
||
})
|
||
continue
|
||
|
||
# 处理分割后的文档
|
||
documents_data = split_response.get('data', [])
|
||
|
||
# 如果没有文档数据,使用一个基本结构
|
||
if not documents_data:
|
||
logger.warning(f"文件 {file.name} 未返回文档数据,创建基本文档结构")
|
||
documents_data = [{
|
||
'name': file.name,
|
||
'content': [{
|
||
'title': '文档内容',
|
||
'content': '文件内容无法自动分割,请检查文件格式。'
|
||
}]
|
||
}]
|
||
|
||
# 遍历所有分割后的文档
|
||
for doc in documents_data:
|
||
doc_name = doc.get('name', file.name)
|
||
doc_content = doc.get('content', [])
|
||
|
||
logger.info(f"处理文档: {doc_name}, 包含 {len(doc_content)} 个段落")
|
||
|
||
# 如果没有内容,添加一个默认段落
|
||
if not doc_content:
|
||
doc_content = [{
|
||
'title': '文档内容',
|
||
'content': '文件内容无法自动分割,请检查文件格式。'
|
||
}]
|
||
|
||
# 准备文档数据结构
|
||
doc_data = {
|
||
"name": doc_name,
|
||
"paragraphs": []
|
||
}
|
||
|
||
# 将所有段落添加到文档中
|
||
for paragraph in doc_content:
|
||
doc_data["paragraphs"].append({
|
||
"content": paragraph.get('content', ''),
|
||
"title": paragraph.get('title', ''),
|
||
"is_active": True,
|
||
"problem_list": []
|
||
})
|
||
|
||
# 调用文档上传API
|
||
upload_response = self._call_upload_api(instance.external_id, doc_data)
|
||
|
||
if upload_response and upload_response.get('code') == 200 and upload_response.get('data'):
|
||
# 上传成功,保存记录到数据库
|
||
document_id = upload_response['data']['id']
|
||
doc_record = KnowledgeBaseDocument.objects.create(
|
||
knowledge_base=instance,
|
||
document_id=document_id,
|
||
document_name=doc_name,
|
||
external_id=document_id,
|
||
uploader_name=user.name
|
||
)
|
||
|
||
saved_documents.append({
|
||
"id": str(doc_record.id),
|
||
"name": doc_record.document_name,
|
||
"external_id": doc_record.external_id
|
||
})
|
||
|
||
logger.info(f"文档 '{doc_name}' 上传成功,ID: {document_id}")
|
||
else:
|
||
# 上传失败,记录错误信息
|
||
error_msg = upload_response.get('message', '未知错误') if upload_response else '上传API调用失败'
|
||
logger.error(f"文档 '{doc_name}' 上传失败: {error_msg}")
|
||
failed_documents.append({
|
||
"name": doc_name,
|
||
"error": error_msg
|
||
})
|
||
|
||
# 返回结果
|
||
if saved_documents:
|
||
return Response({
|
||
"code": 200,
|
||
"message": f"文档上传完成,成功: {len(saved_documents)},失败: {len(failed_documents)}",
|
||
"data": {
|
||
"uploaded_count": len(saved_documents),
|
||
"failed_count": len(failed_documents),
|
||
"total_files": len(files),
|
||
"documents": saved_documents,
|
||
"failed_documents": failed_documents
|
||
}
|
||
})
|
||
else:
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"所有文档上传失败",
|
||
"data": {
|
||
"uploaded_count": 0,
|
||
"failed_count": len(failed_documents),
|
||
"total_files": len(files),
|
||
"documents": [],
|
||
"failed_documents": failed_documents
|
||
}
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档上传失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"文档上传失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def _call_upload_api(self, external_id, doc_data):
|
||
"""调用文档上传API"""
|
||
try:
|
||
url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document'
|
||
logger.info(f"调用文档上传API: {url}")
|
||
|
||
# 记录请求数据,方便调试
|
||
logger.info(f"上传文档数据: 文档名={doc_data.get('name')}, 段落数={len(doc_data.get('paragraphs', []))}")
|
||
|
||
# 发送请求
|
||
response = requests.post(url, json=doc_data)
|
||
|
||
# 记录响应结果
|
||
logger.info(f"上传API响应状态码: {response.status_code}")
|
||
|
||
# 检查响应状态码
|
||
if response.status_code != 200:
|
||
logger.error(f"上传API HTTP错误: {response.status_code}, 响应: {response.text}")
|
||
return {
|
||
'code': response.status_code,
|
||
'message': f"上传失败,HTTP状态码: {response.status_code}",
|
||
'data': None
|
||
}
|
||
|
||
# 解析响应JSON
|
||
result = response.json()
|
||
logger.info(f"上传API响应内容: {result}")
|
||
|
||
# 检查业务状态码
|
||
if result.get('code') != 200:
|
||
error_msg = result.get('message', '未知错误')
|
||
logger.error(f"上传API业务错误: {error_msg}")
|
||
return {
|
||
'code': result.get('code', 500),
|
||
'message': error_msg,
|
||
'data': None
|
||
}
|
||
|
||
return result
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"调用上传API网络错误: {str(e)}")
|
||
return {
|
||
'code': 500,
|
||
'message': f"网络请求错误: {str(e)}",
|
||
'data': None
|
||
}
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析API响应JSON失败: {str(e)}")
|
||
return {
|
||
'code': 500,
|
||
'message': f"解析响应数据失败: {str(e)}",
|
||
'data': None
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"调用上传API其他错误: {str(e)}")
|
||
return {
|
||
'code': 500,
|
||
'message': f"上传API调用失败: {str(e)}",
|
||
'data': None
|
||
}
|
||
|
||
def _call_delete_document_api(self, external_id, document_id):
|
||
"""调用文档删除API"""
|
||
try:
|
||
url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document/{document_id}'
|
||
response = requests.delete(url)
|
||
return response.json()
|
||
except Exception as e:
|
||
logger.error(f"调用删除API失败: {str(e)}")
|
||
return None
|
||
|
||
def _create_external_dataset(self, instance):
|
||
"""创建外部知识库"""
|
||
try:
|
||
api_data = {
|
||
"name": instance.name,
|
||
"desc": instance.desc,
|
||
"type": "0", # 添加必要的type字段
|
||
"meta": {}, # 添加必要的meta字段
|
||
"documents": [] # 初始化为空列表
|
||
}
|
||
|
||
response = requests.post(
|
||
f'{settings.API_BASE_URL}/api/dataset',
|
||
json=api_data,
|
||
headers={'Content-Type': 'application/json'},
|
||
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise ExternalAPIError(f"创建失败,状态码: {response.status_code}, 响应: {response.text}")
|
||
|
||
api_response = response.json()
|
||
if not api_response.get('code') == 200:
|
||
raise ExternalAPIError(f"业务处理失败: {api_response.get('message', '未知错误')}")
|
||
|
||
dataset_id = api_response.get('data', {}).get('id')
|
||
if not dataset_id:
|
||
raise ExternalAPIError("响应数据中缺少dataset id")
|
||
|
||
return dataset_id
|
||
|
||
except requests.exceptions.Timeout:
|
||
raise ExternalAPIError("请求超时,请稍后重试")
|
||
except requests.exceptions.RequestException as e:
|
||
raise ExternalAPIError(f"API请求失败: {str(e)}")
|
||
except Exception as e:
|
||
raise ExternalAPIError(f"创建外部知识库失败: {str(e)}")
|
||
|
||
def _delete_external_dataset(self, external_id):
|
||
"""删除外部知识库"""
|
||
try:
|
||
if not external_id:
|
||
raise ExternalAPIError("外部知识库ID不能为空")
|
||
|
||
response = requests.delete(
|
||
f'{settings.API_BASE_URL}/api/dataset/{external_id}',
|
||
headers={'Content-Type': 'application/json'},
|
||
|
||
)
|
||
|
||
logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}")
|
||
|
||
# 检查响应状态码
|
||
if response.status_code == 404:
|
||
logger.warning(f"外部知识库不存在: {external_id}")
|
||
return True # 如果知识库不存在,也视为删除成功
|
||
elif response.status_code not in [200, 204]:
|
||
raise ExternalAPIError(f"删除失败,状态码: {response.status_code}, 响应: {response.text}")
|
||
|
||
# 如果是 204 状态码,说明删除成功但无返回内容
|
||
if response.status_code == 204:
|
||
logger.info(f"外部知识库删除成功: {external_id}")
|
||
return True
|
||
|
||
# 如果是 200 状态码,检查响应内容
|
||
try:
|
||
api_response = response.json()
|
||
if api_response.get('code') != 200:
|
||
raise ExternalAPIError(f"业务处理失败: {api_response.get('message', '未知错误')}")
|
||
logger.info(f"外部知识库删除成功: {external_id}")
|
||
return True
|
||
except ValueError:
|
||
# 如果无法解析 JSON,但状态码是 200,也认为成功
|
||
logger.warning(f"外部知识库删除响应无法解析JSON,但状态码为200,视为成功: {external_id}")
|
||
return True
|
||
|
||
except requests.exceptions.Timeout:
|
||
logger.error(f"删除外部知识库超时: {external_id}")
|
||
raise ExternalAPIError("请求超时,请稍后重试")
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}")
|
||
raise ExternalAPIError(f"API请求失败: {str(e)}")
|
||
except Exception as e:
|
||
logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}")
|
||
raise ExternalAPIError(f"删除外部知识库失败: {str(e)}")
|
||
|
||
@action(detail=True, methods=['get'])
|
||
def documents(self, request, pk=None):
|
||
"""获取知识库的文档列表"""
|
||
try:
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 权限检查
|
||
if not self.check_knowledge_base_permission(instance, user, 'read'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有查看权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 检查external_id是否存在
|
||
if not instance.external_id:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "知识库没有有效的external_id",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 调用外部API获取文档列表
|
||
try:
|
||
url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}/document'
|
||
response = requests.get(
|
||
url,
|
||
headers={'Content-Type': 'application/json'},
|
||
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"获取文档列表API调用失败: {response.status_code}, {response.text}")
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取文档列表失败: HTTP {response.status_code}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
result = response.json()
|
||
if result.get('code') != 200:
|
||
logger.error(f"获取文档列表业务失败: {result.get('message')}")
|
||
return Response({
|
||
"code": result.get('code', 500),
|
||
"message": result.get('message', '获取文档列表失败'),
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 同步外部文档到本地数据库
|
||
external_documents = result.get('data', [])
|
||
for doc in external_documents:
|
||
# 获取外部文档ID和名称
|
||
external_id = doc.get('id')
|
||
doc_name = doc.get('name')
|
||
|
||
if external_id and doc_name:
|
||
# 检查文档是否已存在
|
||
kb_doc, created = KnowledgeBaseDocument.objects.update_or_create(
|
||
knowledge_base=instance,
|
||
external_id=external_id,
|
||
defaults={
|
||
'document_id': external_id,
|
||
'document_name': doc_name,
|
||
'status': 'active' if doc.get('is_active', True) else 'deleted'
|
||
}
|
||
)
|
||
|
||
if created:
|
||
logger.info(f"同步创建文档: {doc_name}, ID: {external_id}")
|
||
else:
|
||
logger.info(f"同步更新文档: {doc_name}, ID: {external_id}")
|
||
|
||
# 获取最新的本地文档数据
|
||
documents = KnowledgeBaseDocument.objects.filter(
|
||
knowledge_base=instance,
|
||
status='active'
|
||
).order_by('-create_time')
|
||
|
||
# 构建响应数据
|
||
documents_data = [{
|
||
"id": str(doc.id),
|
||
"document_id": doc.document_id,
|
||
"name": doc.document_name,
|
||
"external_id": doc.external_id,
|
||
"created_at": doc.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||
# 添加外部API返回的额外信息
|
||
"char_length": next((d.get('char_length', 0) for d in external_documents if d.get('id') == doc.external_id), 0),
|
||
"paragraph_count": next((d.get('paragraph_count', 0) for d in external_documents if d.get('id') == doc.external_id), 0),
|
||
"is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True),
|
||
"uploader_name": doc.uploader_name
|
||
} for doc in documents]
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "获取文档列表成功",
|
||
"data": documents_data
|
||
})
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"获取文档列表网络异常: {str(e)}")
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取文档列表失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文档列表失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取文档列表失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=True, methods=['get'])
|
||
def document_content(self, request, pk=None):
|
||
"""获取文档内容 - 段落列表"""
|
||
try:
|
||
knowledge_base = self.get_object()
|
||
user = request.user
|
||
|
||
# 权限检查
|
||
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有查看权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 获取文档ID
|
||
document_id = request.query_params.get('document_id')
|
||
if not document_id:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "缺少document_id参数",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证文档存在
|
||
document = KnowledgeBaseDocument.objects.filter(
|
||
knowledge_base=knowledge_base,
|
||
document_id=document_id,
|
||
status='active'
|
||
).first()
|
||
|
||
if not document:
|
||
return Response({
|
||
"code": 404,
|
||
"message": "文档不存在或已删除",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 调用正确的外部API获取文档段落内容
|
||
try:
|
||
url = f'{settings.API_BASE_URL}/api/dataset/{knowledge_base.external_id}/document/{document.external_id}/paragraph'
|
||
response = requests.get(url)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"获取文档段落内容失败: {response.status_code}, {response.text}")
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取文档段落内容失败,状态码: {response.status_code}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
api_response = response.json()
|
||
if api_response.get('code') != 200:
|
||
logger.error(f"获取文档段落内容业务失败: {api_response.get('message')}")
|
||
return Response({
|
||
"code": api_response.get('code', 500),
|
||
"message": api_response.get('message', '获取文档段落内容失败'),
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
paragraphs = api_response.get('data', [])
|
||
|
||
# 直接返回外部API的段落数据
|
||
return Response({
|
||
"code": 200,
|
||
"message": "获取文档内容成功",
|
||
"data": {
|
||
"document_id": document_id,
|
||
"name": document.document_name,
|
||
"paragraphs": paragraphs
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取文档段落内容API调用失败: {str(e)}")
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取文档内容失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文档内容失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取文档内容失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=True, methods=['delete'])
|
||
def delete_document(self, request, pk=None):
|
||
"""删除知识库文档"""
|
||
try:
|
||
knowledge_base = self.get_object()
|
||
user = request.user
|
||
|
||
# 权限检查
|
||
if not self.check_knowledge_base_permission(knowledge_base, user, 'edit'):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "没有编辑权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 获取文档ID
|
||
document_id = request.query_params.get('document_id')
|
||
if not document_id:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "缺少document_id参数",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证文档存在
|
||
document = KnowledgeBaseDocument.objects.filter(
|
||
knowledge_base=knowledge_base,
|
||
document_id=document_id,
|
||
status='active'
|
||
).first()
|
||
|
||
if not document:
|
||
return Response({
|
||
"code": 404,
|
||
"message": "文档不存在或已删除",
|
||
"data": None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 调用外部API删除文档
|
||
try:
|
||
external_id = document.external_id
|
||
delete_result = self._call_delete_document_api(knowledge_base.external_id, external_id)
|
||
|
||
# 无论外部API结果如何,都更新本地状态
|
||
document.status = 'deleted'
|
||
document.save()
|
||
|
||
if delete_result and delete_result.get('code') != 200:
|
||
logger.warning(f"外部API删除文档失败,但本地标记已更新: {delete_result.get('message')}")
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "文档删除成功",
|
||
"data": {
|
||
"document_id": document_id,
|
||
"name": document.document_name
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"调用删除文档API失败: {str(e)}")
|
||
# 即使外部API调用失败,也更新本地状态
|
||
document.status = 'deleted'
|
||
document.save()
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "文档在系统中已标记为删除,但外部API调用失败",
|
||
"data": {
|
||
"document_id": document_id,
|
||
"name": document.document_name,
|
||
"error": str(e)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除文档失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"删除文档失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
|
||
class PermissionViewSet(viewsets.ModelViewSet):
|
||
serializer_class = PermissionSerializer
|
||
permission_classes = [IsAuthenticated]
|
||
|
||
def can_manage_knowledge_base(self, user, knowledge_base):
|
||
"""检查用户是否是知识库的创建者"""
|
||
return str(knowledge_base.user_id) == str(user.id)
|
||
|
||
def get_queryset(self):
|
||
"""
|
||
获取权限申请列表:
|
||
1. applicant_id 是当前用户 (看到自己发起的申请)
|
||
2. approver_id 是当前用户 (看到自己需要审批的申请)
|
||
"""
|
||
user_id = str(self.request.user.id)
|
||
|
||
# 构建查询条件:申请人是自己 或 审批人是自己
|
||
query = Q(applicant_id=user_id) | Q(approver_id=user_id)
|
||
|
||
return Permission.objects.filter(query).select_related(
|
||
'knowledge_base',
|
||
'applicant',
|
||
'approver'
|
||
)
|
||
|
||
def list(self, request, *args, **kwargs):
|
||
"""获取权限申请列表,包含详细信息"""
|
||
try:
|
||
queryset = self.get_queryset()
|
||
user_id = str(request.user.id)
|
||
|
||
# 获取分页参数
|
||
page = int(request.query_params.get('page', 1))
|
||
page_size = int(request.query_params.get('page_size', 10))
|
||
|
||
# 计算总数
|
||
total = queryset.count()
|
||
|
||
# 手动分页
|
||
start = (page - 1) * page_size
|
||
end = start + page_size
|
||
permissions = queryset[start:end]
|
||
|
||
# 构建响应数据
|
||
data = []
|
||
for permission in permissions:
|
||
# 检查当前用户是否是申请人或审批人
|
||
if user_id not in [str(permission.applicant_id), str(permission.approver_id)]:
|
||
continue
|
||
|
||
# 构建响应数据
|
||
permission_data = {
|
||
'id': str(permission.id),
|
||
'knowledge_base': {
|
||
'id': str(permission.knowledge_base.id),
|
||
'name': permission.knowledge_base.name,
|
||
'type': permission.knowledge_base.type,
|
||
},
|
||
'applicant': {
|
||
'id': str(permission.applicant.id),
|
||
'username': permission.applicant.username,
|
||
'name': permission.applicant.name,
|
||
'department': permission.applicant.department,
|
||
},
|
||
'approver': {
|
||
'id': str(permission.approver.id) if permission.approver else '',
|
||
'username': permission.approver.username if permission.approver else '',
|
||
'name': permission.approver.name if permission.approver else '',
|
||
'department': permission.approver.department if permission.approver else '',
|
||
},
|
||
'permissions': permission.permissions,
|
||
'status': permission.status,
|
||
'created_at': permission.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None,
|
||
'response_message': permission.response_message or '',
|
||
# 添加角色标识,用于前端展示
|
||
'role': 'applicant' if str(permission.applicant_id) == user_id else 'approver'
|
||
}
|
||
|
||
data.append(permission_data)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取权限申请列表成功',
|
||
'data': {
|
||
'total': len(data), # 使用过滤后的实际数量
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'results': data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取权限申请列表失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取权限申请列表失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def perform_create(self, serializer):
|
||
"""创建权限申请并发送通知给知识库创建者"""
|
||
# 获取知识库
|
||
# 获取知识库
|
||
knowledge_base = serializer.validated_data['knowledge_base']
|
||
|
||
# 检查是否是申请访问自己的知识库
|
||
if str(knowledge_base.user_id) == str(self.request.user.id):
|
||
raise ValidationError({
|
||
"code": 400,
|
||
"message": "您是此知识库的创建者,无需申请权限",
|
||
"data": None
|
||
})
|
||
# 获取知识库创建者作为审批者
|
||
approver = User.objects.get(id=knowledge_base.user_id)
|
||
|
||
# 验证权限请求
|
||
requested_permissions = serializer.validated_data.get('permissions', {})
|
||
expires_at = serializer.validated_data.get('expires_at')
|
||
|
||
if not any([requested_permissions.get('can_read'),
|
||
requested_permissions.get('can_edit'),
|
||
requested_permissions.get('can_delete')]):
|
||
raise ValidationError("至少需要申请一种权限(读/改/删)")
|
||
|
||
if not expires_at:
|
||
raise ValidationError("请指定权限到期时间")
|
||
|
||
# 检查是否已有未过期的权限申请
|
||
existing_request = Permission.objects.filter(
|
||
knowledge_base=knowledge_base,
|
||
applicant=self.request.user,
|
||
status='pending'
|
||
).first()
|
||
|
||
if existing_request:
|
||
raise ValidationError("您已有一个待处理的权限申请")
|
||
|
||
# 检查是否已有有效的权限
|
||
existing_permission = Permission.objects.filter(
|
||
knowledge_base=knowledge_base,
|
||
applicant=self.request.user,
|
||
status='approved',
|
||
expires_at__gt=timezone.now()
|
||
).first()
|
||
|
||
if existing_permission:
|
||
raise ValidationError("您已有此知识库的访问权限")
|
||
|
||
# 保存权限申请,设置审批者
|
||
permission = serializer.save(
|
||
applicant=self.request.user,
|
||
status='pending',
|
||
approver=approver # 创建时就设置审批者
|
||
)
|
||
|
||
# 获取权限类型字符串
|
||
permission_types = []
|
||
if requested_permissions.get('can_read'):
|
||
permission_types.append('读取')
|
||
if requested_permissions.get('can_edit'):
|
||
permission_types.append('编辑')
|
||
if requested_permissions.get('can_delete'):
|
||
permission_types.append('删除')
|
||
permission_str = '、'.join(permission_types)
|
||
|
||
# 发送通知给知识库创建者
|
||
owner = User.objects.get(id=knowledge_base.user_id)
|
||
self.send_notification(
|
||
user=owner,
|
||
title="新的权限申请",
|
||
content=f"用户 {self.request.user.name} 申请了知识库 '{knowledge_base.name}' 的{permission_str}权限",
|
||
notification_type="permission_request",
|
||
related_object_id=permission.id
|
||
)
|
||
|
||
def send_notification(self, user, title, content, notification_type, related_object_id):
|
||
"""发送通知"""
|
||
try:
|
||
notification = Notification.objects.create(
|
||
sender=self.request.user,
|
||
receiver=user,
|
||
title=title,
|
||
content=content,
|
||
type=notification_type,
|
||
related_resource=related_object_id,
|
||
)
|
||
|
||
# 通过WebSocket发送实时通知
|
||
channel_layer = get_channel_layer()
|
||
async_to_sync(channel_layer.group_send)(
|
||
f"notification_user_{user.id}",
|
||
{
|
||
"type": "notification",
|
||
"data": {
|
||
"id": str(notification.id),
|
||
"title": notification.title,
|
||
"content": notification.content,
|
||
"type": notification.type,
|
||
"created_at": notification.created_at.isoformat(),
|
||
"sender": {
|
||
"id": str(notification.sender.id),
|
||
"name": notification.sender.name
|
||
}
|
||
}
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"发送通知时发生错误: {str(e)}")
|
||
|
||
@action(detail=True, methods=['post'])
|
||
def approve(self, request, pk=None):
|
||
try:
|
||
# 获取权限申请记录
|
||
permission = self.get_object()
|
||
|
||
# 只检查是否是知识库创建者
|
||
if not self.can_manage_knowledge_base(request.user, permission.knowledge_base):
|
||
logger.warning(f"用户 {request.user.username} 尝试审批知识库 {permission.knowledge_base.name} 的权限申请,但不是创建者")
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有知识库创建者可以审批此申请',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 获取审批意见
|
||
response_message = request.data.get('response_message', '')
|
||
|
||
with transaction.atomic():
|
||
# 更新权限申请状态
|
||
permission.status = 'approved'
|
||
permission.approver = request.user
|
||
permission.response_message = response_message
|
||
permission.save()
|
||
|
||
# 检查是否已存在权限记录
|
||
kb_permission = KBPermissionModel.objects.filter(
|
||
knowledge_base=permission.knowledge_base,
|
||
user=permission.applicant
|
||
).first()
|
||
|
||
if kb_permission:
|
||
# 更新现有权限
|
||
kb_permission.can_read = permission.permissions.get('can_read', False)
|
||
kb_permission.can_edit = permission.permissions.get('can_edit', False)
|
||
kb_permission.can_delete = permission.permissions.get('can_delete', False)
|
||
kb_permission.granted_by = request.user
|
||
kb_permission.status = 'active'
|
||
kb_permission.expires_at = permission.expires_at
|
||
kb_permission.save()
|
||
logger.info(f"更新知识库权限记录: {kb_permission.id}")
|
||
else:
|
||
# 创建新的权限记录
|
||
kb_permission = KBPermissionModel.objects.create(
|
||
knowledge_base=permission.knowledge_base,
|
||
user=permission.applicant,
|
||
can_read=permission.permissions.get('can_read', False),
|
||
can_edit=permission.permissions.get('can_edit', False),
|
||
can_delete=permission.permissions.get('can_delete', False),
|
||
granted_by=request.user,
|
||
status='active',
|
||
expires_at=permission.expires_at
|
||
)
|
||
logger.info(f"创建新的知识库权限记录: {kb_permission.id}")
|
||
|
||
# 发送通知给申请人
|
||
self.send_notification(
|
||
user=permission.applicant,
|
||
title="权限申请已通过",
|
||
content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已通过",
|
||
notification_type="permission_approved",
|
||
related_object_id=permission.id
|
||
)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '权限申请已批准',
|
||
'data': None
|
||
})
|
||
|
||
except Permission.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': '权限申请不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
except Exception as e:
|
||
logger.error(f"处理权限申请失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'处理权限申请失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=True, methods=['post'])
|
||
def reject(self, request, pk=None):
|
||
"""拒绝权限申请"""
|
||
permission = self.get_object()
|
||
|
||
# 检查是否是知识库创建者
|
||
if str(permission.knowledge_base.user_id) != str(request.user.id):
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有知识库创建者可以审批此申请',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 检查申请是否已被处理
|
||
if permission.status != 'pending':
|
||
return Response({
|
||
'code': 400,
|
||
'message': '该申请已被处理',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证拒绝原因
|
||
response_message = request.data.get('response_message')
|
||
if not response_message:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '请填写拒绝原因',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 更新权限状态
|
||
permission.status = 'rejected'
|
||
permission.approver = request.user
|
||
permission.response_message = response_message
|
||
permission.save()
|
||
|
||
# 发送通知给申请人
|
||
self.send_notification(
|
||
user=permission.applicant,
|
||
title="权限申请已拒绝",
|
||
content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已被拒绝\n"
|
||
f"拒绝原因:{response_message}",
|
||
notification_type="permission_rejected",
|
||
related_object_id=permission.id
|
||
)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '权限申请已拒绝',
|
||
'data': PermissionSerializer(permission).data
|
||
})
|
||
|
||
@action(detail=True, methods=['post'])
|
||
def extend(self, request, pk=None):
|
||
"""延长权限有效期"""
|
||
instance = self.get_object()
|
||
user = request.user
|
||
|
||
# 检查是否有权限延长
|
||
if not self.check_extend_permission(instance, user):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "您没有权限延长此权限",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
new_expires_at = request.data.get('expires_at')
|
||
if not new_expires_at:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "请设置新的过期时间",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
try:
|
||
with transaction.atomic():
|
||
# 更新权限申请表的过期时间
|
||
instance.expires_at = new_expires_at
|
||
instance.save()
|
||
|
||
# 同步更新知识库权限表的过期时间
|
||
kb_permission = KBPermissionModel.objects.get(
|
||
knowledge_base=instance.knowledge_base,
|
||
user=instance.applicant
|
||
)
|
||
kb_permission.expires_at = new_expires_at
|
||
kb_permission.save()
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "权限有效期延长成功",
|
||
"data": PermissionSerializer(instance).data
|
||
})
|
||
except Exception as e:
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"延长权限有效期失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
def check_extend_permission(self, permission, user):
|
||
"""检查是否有权限延长权限有效期"""
|
||
knowledge_base = permission.knowledge_base
|
||
|
||
# 私人知识库只有拥有者能延长
|
||
if knowledge_base.type == 'private':
|
||
return knowledge_base.owner == user
|
||
|
||
# 组长知识库只有管理员能延长
|
||
if knowledge_base.type == 'leader':
|
||
return user.role == 'admin'
|
||
|
||
# 组员知识库可以由管理员或本部门组长延长
|
||
if knowledge_base.type == 'member':
|
||
return (
|
||
user.role == 'admin' or
|
||
(user.role == 'leader' and user.department == knowledge_base.department)
|
||
)
|
||
|
||
return False
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def user_permissions(self, request):
|
||
"""获取指定用户的所有知识库权限"""
|
||
try:
|
||
# 获取用户名参数
|
||
username = request.query_params.get('username')
|
||
if not username:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '请提供用户名',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 获取用户
|
||
try:
|
||
target_user = User.objects.get(username=username)
|
||
except User.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'用户 {username} 不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 获取该用户的所有权限记录
|
||
permissions = KBPermissionModel.objects.filter(
|
||
user=target_user,
|
||
status='active'
|
||
).select_related('knowledge_base', 'granted_by')
|
||
|
||
# 构建响应数据
|
||
permissions_data = []
|
||
for perm in permissions:
|
||
perm_data = {
|
||
'id': str(perm.id),
|
||
'knowledge_base': {
|
||
'id': str(perm.knowledge_base.id),
|
||
'name': perm.knowledge_base.name,
|
||
'type': perm.knowledge_base.type,
|
||
'department': perm.knowledge_base.department,
|
||
'group': perm.knowledge_base.group
|
||
},
|
||
'permissions': {
|
||
'can_read': perm.can_read,
|
||
'can_edit': perm.can_edit,
|
||
'can_delete': perm.can_delete
|
||
},
|
||
'granted_by': {
|
||
'id': str(perm.granted_by.id) if perm.granted_by else None,
|
||
'username': perm.granted_by.username if perm.granted_by else None,
|
||
'name': perm.granted_by.name if perm.granted_by else None
|
||
},
|
||
'created_at': perm.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None,
|
||
'status': perm.status
|
||
}
|
||
permissions_data.append(perm_data)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取用户权限成功',
|
||
'data': {
|
||
'user': {
|
||
'id': str(target_user.id),
|
||
'username': target_user.username,
|
||
'name': target_user.name,
|
||
'department': target_user.department,
|
||
'role': target_user.role
|
||
},
|
||
'permissions': permissions_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户权限失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取用户权限失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def all_permissions(self, request):
|
||
"""管理员获取所有用户的知识库权限(不包括私有知识库)"""
|
||
try:
|
||
# 检查是否是管理员
|
||
if request.user.role != 'admin':
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有管理员可以查看所有权限',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 获取查询参数
|
||
page = int(request.query_params.get('page', 1))
|
||
page_size = int(request.query_params.get('page_size', 10))
|
||
status_filter = request.query_params.get('status')
|
||
department = request.query_params.get('department')
|
||
kb_type = request.query_params.get('kb_type')
|
||
|
||
# 构建基础查询
|
||
queryset = KBPermissionModel.objects.filter(
|
||
~Q(knowledge_base__type='private')
|
||
).select_related(
|
||
'user',
|
||
'knowledge_base',
|
||
'granted_by'
|
||
)
|
||
|
||
# 应用过滤条件
|
||
if status_filter == 'active':
|
||
queryset = queryset.filter(
|
||
Q(expires_at__gt=timezone.now()) | Q(expires_at__isnull=True),
|
||
status='active'
|
||
)
|
||
elif status_filter == 'expired':
|
||
queryset = queryset.filter(
|
||
Q(expires_at__lte=timezone.now()) | Q(status='inactive')
|
||
)
|
||
|
||
if department:
|
||
queryset = queryset.filter(user__department=department)
|
||
|
||
if kb_type:
|
||
queryset = queryset.filter(knowledge_base__type=kb_type)
|
||
|
||
# 按用户分组处理数据
|
||
user_permissions = {}
|
||
for perm in queryset:
|
||
user_id = str(perm.user.id)
|
||
if user_id not in user_permissions:
|
||
user_permissions[user_id] = {
|
||
'user_info': {
|
||
'id': user_id,
|
||
'username': perm.user.username,
|
||
'name': getattr(perm.user, 'name', perm.user.username),
|
||
'department': getattr(perm.user, 'department', None),
|
||
'role': getattr(perm.user, 'role', None)
|
||
},
|
||
'permissions': [],
|
||
'stats': {
|
||
'total': 0,
|
||
'by_type': {
|
||
'admin': 0,
|
||
'secret': 0,
|
||
'leader': 0,
|
||
'member': 0
|
||
},
|
||
'by_permission': {
|
||
'read_only': 0,
|
||
'read_write': 0,
|
||
'full_access': 0
|
||
}
|
||
}
|
||
}
|
||
|
||
# 添加权限信息
|
||
perm_data = {
|
||
'id': str(perm.id),
|
||
'knowledge_base': {
|
||
'id': str(perm.knowledge_base.id),
|
||
'name': perm.knowledge_base.name,
|
||
'type': perm.knowledge_base.type,
|
||
'department': perm.knowledge_base.department,
|
||
'group': perm.knowledge_base.group,
|
||
'creator': {
|
||
'id': str(perm.knowledge_base.user_id),
|
||
'name': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'name', None),
|
||
'username': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'username', None)
|
||
}
|
||
},
|
||
'permissions': {
|
||
'can_read': perm.can_read,
|
||
'can_edit': perm.can_edit,
|
||
'can_delete': perm.can_delete
|
||
},
|
||
'granted_by': {
|
||
'id': str(perm.granted_by.id) if perm.granted_by else None,
|
||
'username': perm.granted_by.username if perm.granted_by else None,
|
||
'name': getattr(perm.granted_by, 'name', None) if perm.granted_by else None
|
||
},
|
||
'granted_at': perm.granted_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None,
|
||
'status': perm.status
|
||
}
|
||
|
||
user_permissions[user_id]['permissions'].append(perm_data)
|
||
|
||
# 更新统计信息
|
||
stats = user_permissions[user_id]['stats']
|
||
stats['total'] += 1
|
||
stats['by_type'][perm.knowledge_base.type] += 1
|
||
|
||
# 统计权限级别
|
||
if perm.can_delete:
|
||
stats['by_permission']['full_access'] += 1
|
||
elif perm.can_edit:
|
||
stats['by_permission']['read_write'] += 1
|
||
elif perm.can_read:
|
||
stats['by_permission']['read_only'] += 1
|
||
|
||
# 转换为列表并分页
|
||
users_list = list(user_permissions.values())
|
||
total = len(users_list)
|
||
start = (page - 1) * page_size
|
||
end = start + page_size
|
||
paginated_users = users_list[start:end]
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取权限列表成功',
|
||
'data': {
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'results': paginated_users
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取所有权限失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取所有权限失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=False, methods=['post'])
|
||
def update_permission(self, request):
|
||
"""管理员更新用户的知识库权限"""
|
||
try:
|
||
# 检查是否是管理员
|
||
if request.user.role != 'admin':
|
||
return Response({
|
||
'code': 403,
|
||
'message': '只有管理员可以直接修改权限',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 验证必要参数
|
||
user_id = request.data.get('user_id')
|
||
knowledge_base_id = request.data.get('knowledge_base_id')
|
||
permissions = request.data.get('permissions')
|
||
expires_at_str = request.data.get('expires_at')
|
||
|
||
if not all([user_id, knowledge_base_id, permissions]):
|
||
return Response({
|
||
'code': 400,
|
||
'message': '缺少必要参数',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证权限参数格式
|
||
required_permission_fields = ['can_read', 'can_edit', 'can_delete']
|
||
if not all(field in permissions for field in required_permission_fields):
|
||
return Response({
|
||
'code': 400,
|
||
'message': '权限参数格式错误,必须包含 can_read、can_edit、can_delete',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 获取用户和知识库
|
||
try:
|
||
user = User.objects.get(id=user_id)
|
||
knowledge_base = KnowledgeBase.objects.get(id=knowledge_base_id)
|
||
except User.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'用户ID {user_id} 不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
except KnowledgeBase.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': f'知识库ID {knowledge_base_id} 不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 检查知识库类型和用户角色的匹配
|
||
if knowledge_base.type == 'private' and str(knowledge_base.user_id) != str(user.id):
|
||
return Response({
|
||
'code': 403,
|
||
'message': '不能修改其他用户的私有知识库权限',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 处理过期时间
|
||
expires_at = None
|
||
if expires_at_str:
|
||
try:
|
||
# 将字符串转换为datetime对象
|
||
expires_at = timezone.datetime.strptime(
|
||
expires_at_str,
|
||
'%Y-%m-%dT%H:%M:%SZ'
|
||
)
|
||
# 确保时区感知
|
||
expires_at = timezone.make_aware(expires_at)
|
||
|
||
# 检查是否早于当前时间
|
||
if expires_at <= timezone.now():
|
||
return Response({
|
||
'code': 400,
|
||
'message': '过期时间不能早于或等于当前时间',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
except ValueError:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '过期时间格式错误,应为 ISO 格式 (YYYY-MM-DDThh:mm:ssZ)',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 根据用户角色限制权限
|
||
if user.role == 'member' and permissions.get('can_delete'):
|
||
return Response({
|
||
'code': 400,
|
||
'message': '普通成员不能获得删除权限',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 更新或创建权限记录
|
||
try:
|
||
with transaction.atomic():
|
||
permission, created = KBPermissionModel.objects.update_or_create(
|
||
user=user,
|
||
knowledge_base=knowledge_base,
|
||
defaults={
|
||
'can_read': permissions.get('can_read', False),
|
||
'can_edit': permissions.get('can_edit', False),
|
||
'can_delete': permissions.get('can_delete', False),
|
||
'granted_by': request.user,
|
||
'status': 'active',
|
||
'expires_at': expires_at
|
||
}
|
||
)
|
||
|
||
# 发送通知给用户
|
||
self.send_notification(
|
||
user=user,
|
||
title="知识库权限更新",
|
||
content=f"管理员已{created and '授予' or '更新'}您对知识库 '{knowledge_base.name}' 的权限",
|
||
notification_type="permission_updated",
|
||
related_object_id=permission.id
|
||
)
|
||
except IntegrityError as e:
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'数据库操作失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': f"{'创建' if created else '更新'}权限成功",
|
||
'data': {
|
||
'id': str(permission.id),
|
||
'user': {
|
||
'id': str(user.id),
|
||
'username': user.username,
|
||
'name': user.name,
|
||
'department': user.department,
|
||
'role': user.role
|
||
},
|
||
'knowledge_base': {
|
||
'id': str(knowledge_base.id),
|
||
'name': knowledge_base.name,
|
||
'type': knowledge_base.type,
|
||
'department': knowledge_base.department,
|
||
'group': knowledge_base.group
|
||
},
|
||
'permissions': {
|
||
'can_read': permission.can_read,
|
||
'can_edit': permission.can_edit,
|
||
'can_delete': permission.can_delete
|
||
},
|
||
'granted_by': {
|
||
'id': str(request.user.id),
|
||
'username': request.user.username,
|
||
'name': request.user.name
|
||
},
|
||
'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None,
|
||
'created': created
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新权限失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'更新权限失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
|
||
class NotificationViewSet(viewsets.ModelViewSet):
|
||
"""通知视图集"""
|
||
queryset = Notification.objects.all()
|
||
serializer_class = NotificationSerializer
|
||
permission_classes = [IsAuthenticated]
|
||
|
||
def get_queryset(self):
|
||
"""只返回用户自己的通知"""
|
||
return Notification.objects.filter(receiver=self.request.user)
|
||
|
||
@action(detail=True, methods=['post'])
|
||
def mark_as_read(self, request, pk=None):
|
||
"""标记通知为已读"""
|
||
notification = self.get_object()
|
||
notification.is_read = True
|
||
notification.save()
|
||
return Response({'status': 'marked as read'})
|
||
|
||
@action(detail=False, methods=['post'])
|
||
def mark_all_as_read(self, request):
|
||
"""标记所有通知为已读"""
|
||
self.get_queryset().update(is_read=True)
|
||
return Response({'status': 'all marked as read'})
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def unread_count(self, request):
|
||
"""获取未读通知数量"""
|
||
count = self.get_queryset().filter(is_read=False).count()
|
||
return Response({'unread_count': count})
|
||
|
||
@action(detail=False, methods=['get'])
|
||
def latest(self, request):
|
||
"""获取最新通知"""
|
||
notifications = self.get_queryset().filter(
|
||
is_read=False
|
||
).order_by('-created_at')[:5]
|
||
serializer = self.get_serializer(notifications, many=True)
|
||
return Response(serializer.data)
|
||
|
||
def perform_create(self, serializer):
|
||
"""创建通知时自动设置发送者"""
|
||
serializer.save(sender=self.request.user)
|
||
|
||
|
||
@method_decorator(csrf_exempt, name='dispatch')
|
||
class LoginView(APIView):
|
||
"""用户登录视图"""
|
||
authentication_classes = [] # 清空认证类
|
||
permission_classes = [AllowAny]
|
||
|
||
def post(self, request):
|
||
try:
|
||
username = request.data.get('username')
|
||
password = request.data.get('password')
|
||
|
||
# 参数验证
|
||
if not username or not password:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "请提供用户名和密码",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证用户
|
||
user = authenticate(request, username=username, password=password)
|
||
|
||
if user is not None:
|
||
# 获取或创建token
|
||
token, _ = Token.objects.get_or_create(user=user)
|
||
|
||
# 登录用户(可选)
|
||
login(request, user)
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "登录成功",
|
||
"data": {
|
||
"id": str(user.id),
|
||
"username": user.username,
|
||
"email": user.email,
|
||
"name": user.name,
|
||
"role": user.role,
|
||
"department": user.department,
|
||
"group": user.group,
|
||
"token": token.key
|
||
}
|
||
})
|
||
else:
|
||
return Response({
|
||
"code": 401,
|
||
"message": "用户名或密码错误",
|
||
"data": None
|
||
}, status=status.HTTP_401_UNAUTHORIZED)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"登录失败: {str(e)}")
|
||
logger.error(f"错误类型: {type(e)}")
|
||
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
||
|
||
return Response({
|
||
"code": 500,
|
||
"message": "登录失败,请稍后重试",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@method_decorator(csrf_exempt, name='dispatch')
|
||
class RegisterView(APIView):
|
||
"""用户注册视图"""
|
||
permission_classes = [AllowAny]
|
||
authentication_classes = [] # 清空认证类
|
||
|
||
def post(self, request):
|
||
try:
|
||
# 检查是否允许注册
|
||
from django.conf import settings
|
||
if not getattr(settings, 'ALLOW_REGISTRATION', True):
|
||
return Response({
|
||
"code": 403,
|
||
"message": "系统当前不允许注册新用户",
|
||
"data": None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
data = request.data
|
||
|
||
# 检查必填字段
|
||
required_fields = ['username', 'password', 'email', 'role', 'name']
|
||
for field in required_fields:
|
||
if not data.get(field):
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"缺少必填字段: {field}",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证角色
|
||
valid_roles = ['admin', 'leader', 'member']
|
||
roles_str = ', '.join(valid_roles) # 先构造角色字符串
|
||
if data['role'] not in valid_roles:
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"无效的角色,必须是: {roles_str}",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 检查用户名是否已存在
|
||
if User.objects.filter(username=data['username']).exists():
|
||
return Response({
|
||
"code": 400,
|
||
"message": "用户名已存在",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 检查邮箱是否已存在
|
||
if User.objects.filter(email=data['email']).exists():
|
||
return Response({
|
||
"code": 400,
|
||
"message": "邮箱已被注册",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证密码强度
|
||
if len(data['password']) < 8:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "密码长度必须至少为8位",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证邮箱格式
|
||
try:
|
||
validate_email(data['email'])
|
||
except ValidationError:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "邮箱格式不正确",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 创建用户
|
||
user = User.objects.create_user(
|
||
username=data['username'],
|
||
email=data['email'],
|
||
password=data['password'],
|
||
role=data['role'],
|
||
department=data.get('department'), # 不再强制要求部门
|
||
name=data['name'],
|
||
group=data.get('group'), # 不再强制要求小组
|
||
is_staff=False,
|
||
is_superuser=False
|
||
)
|
||
|
||
# 生成认证令牌
|
||
token, _ = Token.objects.get_or_create(user=user)
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "注册成功",
|
||
"data": {
|
||
"id": str(user.id),
|
||
"username": user.username,
|
||
"email": user.email,
|
||
"role": user.role,
|
||
"department": user.department,
|
||
"name": user.name,
|
||
"group": user.group,
|
||
"token": token.key,
|
||
"created_at": user.date_joined.strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
}, status=status.HTTP_201_CREATED)
|
||
|
||
except Exception as e:
|
||
print(f"注册失败: {str(e)}")
|
||
print(f"错误类型: {type(e)}")
|
||
print(f"错误堆栈: {traceback.format_exc()}")
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"注册失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@method_decorator(csrf_exempt, name='dispatch')
|
||
class LogoutView(APIView):
|
||
"""用户登出视图"""
|
||
permission_classes = [IsAuthenticated]
|
||
|
||
def post(self, request):
|
||
try:
|
||
# 删除用户的token
|
||
request.user.auth_token.delete()
|
||
# 执行django的登出
|
||
logout(request)
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "登出成功",
|
||
"data": None
|
||
})
|
||
except Exception as e:
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"登出失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@api_view(['GET', 'PUT'])
|
||
@permission_classes([IsAuthenticated])
|
||
def user_profile(request):
|
||
"""
|
||
获取或更新当前登录用户信息
|
||
|
||
此接口与user_update的区别:
|
||
1. 任何已认证用户可访问
|
||
2. 仅能更新当前登录用户自己的信息
|
||
3. 不能修改角色等重要字段
|
||
4. 不需要指定用户ID,自动使用当前用户
|
||
"""
|
||
try:
|
||
if request.method == 'GET':
|
||
# 检查用户是否已认证
|
||
user = request.user
|
||
if not user.is_authenticated:
|
||
return Response({
|
||
'code': 401,
|
||
'message': '用户未认证',
|
||
'data': None
|
||
}, status=status.HTTP_401_UNAUTHORIZED)
|
||
|
||
data = {
|
||
'id': str(user.id),
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'name': user.name,
|
||
'role': user.role,
|
||
'department': user.department,
|
||
'group': user.group,
|
||
'date_joined': user.date_joined.strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取用户信息成功',
|
||
'data': data
|
||
})
|
||
|
||
elif request.method == 'PUT':
|
||
# 检查请求数据格式
|
||
try:
|
||
if not request.data:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '请求数据为空或格式错误',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
except Exception as data_error:
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'请求数据格式错误: {str(data_error)}。请确保提交的是有效的JSON格式数据,属性名必须使用双引号。',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
user = request.user
|
||
# 只允许更新特定字段
|
||
allowed_fields = ['email', 'name', 'phone', 'department', 'group']
|
||
updated_fields = []
|
||
|
||
for field in allowed_fields:
|
||
if field in request.data:
|
||
setattr(user, field, request.data[field])
|
||
updated_fields.append(field)
|
||
|
||
if updated_fields:
|
||
try:
|
||
user.save()
|
||
return Response({
|
||
'code': 200,
|
||
'message': f'用户信息更新成功,已更新字段: {", ".join(updated_fields)}',
|
||
'data': {
|
||
'id': str(user.id),
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'name': user.name,
|
||
'role': user.role,
|
||
'department': user.department,
|
||
'group': user.group,
|
||
}
|
||
})
|
||
except Exception as save_error:
|
||
logger.error(f"保存用户数据失败: {str(save_error)}")
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'更新用户信息失败: {str(save_error)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
else:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '没有提供任何可更新的字段',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
else:
|
||
return Response({
|
||
'code': 405,
|
||
'message': f'不支持的请求方法: {request.method}',
|
||
'data': None
|
||
}, status=status.HTTP_405_METHOD_NOT_ALLOWED)
|
||
except Exception as e:
|
||
logger.error(f"处理用户信息请求失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'处理请求失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@csrf_exempt
|
||
@api_view(['POST'])
|
||
@permission_classes([IsAuthenticated])
|
||
def change_password(request):
|
||
"""修改密码"""
|
||
try:
|
||
old_password = request.data.get('old_password')
|
||
new_password = request.data.get('new_password')
|
||
|
||
# 验证参数
|
||
if not old_password or not new_password:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "请提供旧密码和新密码",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证旧密码
|
||
user = request.user
|
||
if not user.check_password(old_password):
|
||
return Response({
|
||
"code": 400,
|
||
"message": "旧密码错误",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证新密码长度
|
||
if len(new_password) < 8:
|
||
return Response({
|
||
"code": 400,
|
||
"message": "新密码长度必须至少为8位",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 修改密码
|
||
user.set_password(new_password)
|
||
user.save()
|
||
|
||
# 更新token
|
||
user.auth_token.delete()
|
||
token, _ = Token.objects.get_or_create(user=user)
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "密码修改成功",
|
||
"data": {
|
||
"token": token.key
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"密码修改失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@csrf_exempt
|
||
@api_view(['POST'])
|
||
@permission_classes([AllowAny])
|
||
def user_register(request):
|
||
"""
|
||
[已弃用] 用户注册 - 请使用 RegisterView 类代替
|
||
|
||
此函数仅保留用于兼容性目的,新代码应该使用 /api/auth/register/ 接口
|
||
"""
|
||
# 打印弃用警告
|
||
logger.warning("使用已弃用的user_register函数,请改用RegisterView类")
|
||
|
||
try:
|
||
data = request.data
|
||
|
||
# 检查必填字段
|
||
required_fields = ['username', 'password', 'email', 'role', 'name']
|
||
for field in required_fields:
|
||
if not data.get(field):
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'缺少必填字段: {field}',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证角色
|
||
valid_roles = ['admin', 'leader', 'member']
|
||
if data['role'] not in valid_roles:
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'无效的角色,必须是: {", ".join(valid_roles)}',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 检查用户名是否已存在
|
||
if User.objects.filter(username=data['username']).exists():
|
||
return Response({
|
||
'code': 400,
|
||
'message': '用户名已存在',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 检查邮箱是否已存在
|
||
if User.objects.filter(email=data['email']).exists():
|
||
return Response({
|
||
'code': 400,
|
||
'message': '邮箱已被注册',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证密码强度
|
||
if len(data['password']) < 8:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '密码长度必须至少为8位',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 验证邮箱格式
|
||
try:
|
||
validate_email(data['email'])
|
||
except ValidationError:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '邮箱格式不正确',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 创建用户
|
||
user = User.objects.create_user(
|
||
username=data['username'],
|
||
email=data['email'],
|
||
password=data['password'],
|
||
role=data['role'],
|
||
department=data.get('department'), # 不再强制要求部门
|
||
name=data['name'],
|
||
group=data.get('group'), # 不再强制要求小组
|
||
is_staff=False,
|
||
is_superuser=False
|
||
)
|
||
|
||
# 生成认证令牌
|
||
token, _ = Token.objects.get_or_create(user=user)
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '注册成功',
|
||
'data': {
|
||
'id': str(user.id),
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'role': user.role,
|
||
'department': user.department,
|
||
'name': user.name,
|
||
'group': user.group,
|
||
'token': token.key,
|
||
'created_at': user.date_joined.strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
}, status=status.HTTP_201_CREATED)
|
||
|
||
except Exception as e:
|
||
logger.error(f"注册失败: {str(e)}")
|
||
logger.error(f"错误类型: {type(e)}")
|
||
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'注册失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@api_view(['GET'])
|
||
@permission_classes([IsAuthenticated])
|
||
def user_detail(request, pk):
|
||
"""获取用户详情"""
|
||
try:
|
||
# 尝试转换为 UUID,处理多种可能的格式
|
||
try:
|
||
if not isinstance(pk, uuid.UUID):
|
||
# 移除所有空格,以防万一
|
||
pk = pk.strip()
|
||
|
||
# 处理带连字符和不带连字符的格式
|
||
if '-' not in pk and len(pk) == 32:
|
||
# 转换没有连字符的UUID格式
|
||
pk_with_hyphens = f"{pk[0:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}"
|
||
pk = uuid.UUID(pk_with_hyphens)
|
||
else:
|
||
# 尝试直接转换
|
||
pk = uuid.UUID(pk)
|
||
except ValueError:
|
||
# 提供更详细的错误信息
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"无效的用户ID格式: {pk}。用户ID应为有效的UUID格式。",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
user = get_object_or_404(User, pk=pk)
|
||
|
||
return Response({
|
||
"code": 200,
|
||
"message": "获取用户信息成功",
|
||
"data": {
|
||
"id": str(user.id),
|
||
"username": user.username,
|
||
"email": user.email,
|
||
"name": user.name,
|
||
"role": user.role,
|
||
"department": user.department,
|
||
"group": user.group
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取用户信息失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"获取用户信息失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@api_view(['PUT'])
|
||
@permission_classes([IsAdminUser])
|
||
def user_update(request, pk):
|
||
"""
|
||
管理员更新用户信息
|
||
|
||
此接口与user_profile的区别:
|
||
1. 仅管理员可访问
|
||
2. 可以更新任何用户的信息
|
||
3. 可以修改角色等重要字段
|
||
4. 需要在URL中指定用户ID
|
||
"""
|
||
try:
|
||
# 检查请求数据格式
|
||
try:
|
||
if not request.data:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '请求数据为空或格式错误',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
except Exception as data_error:
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'请求数据格式错误: {str(data_error)}。请确保提交的是有效的JSON格式数据,属性名必须使用双引号。',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
# 尝试转换为 UUID
|
||
try:
|
||
if not isinstance(pk, uuid.UUID):
|
||
pk = pk.strip()
|
||
|
||
# 处理带连字符和不带连字符的格式
|
||
if '-' not in pk and len(pk) == 32:
|
||
# 转换没有连字符的UUID格式
|
||
pk_with_hyphens = f"{pk[0:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}"
|
||
pk = uuid.UUID(pk_with_hyphens)
|
||
else:
|
||
# 尝试直接转换
|
||
pk = uuid.UUID(pk)
|
||
except ValueError:
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"无效的用户ID格式: {pk}。用户ID应为有效的UUID格式。",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
try:
|
||
user = User.objects.get(pk=pk)
|
||
except User.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': '用户不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 只允许更新特定字段
|
||
allowed_fields = ['email', 'role', 'department', 'group', 'is_active', 'phone', 'name']
|
||
updated_fields = []
|
||
|
||
for field in allowed_fields:
|
||
if field in request.data:
|
||
setattr(user, field, request.data[field])
|
||
updated_fields.append(field)
|
||
|
||
if not updated_fields:
|
||
return Response({
|
||
'code': 400,
|
||
'message': '没有提供任何可更新的字段',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
user.save()
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '用户信息更新成功',
|
||
'data': {
|
||
'id': str(user.id),
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'name': user.name,
|
||
'role': user.role,
|
||
'department': user.department,
|
||
'group': user.group,
|
||
'is_active': user.is_active
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"更新用户信息失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'更新用户信息失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@api_view(['DELETE'])
|
||
@permission_classes([IsAdminUser])
|
||
def user_delete(request, pk):
|
||
"""删除用户"""
|
||
try:
|
||
# 尝试转换为 UUID
|
||
try:
|
||
if not isinstance(pk, uuid.UUID):
|
||
pk = pk.strip()
|
||
|
||
# 处理带连字符和不带连字符的格式
|
||
if '-' not in pk and len(pk) == 32:
|
||
# 转换没有连字符的UUID格式
|
||
pk_with_hyphens = f"{pk[0:8]}-{pk[8:12]}-{pk[12:16]}-{pk[16:20]}-{pk[20:]}"
|
||
pk = uuid.UUID(pk_with_hyphens)
|
||
else:
|
||
# 尝试直接转换
|
||
pk = uuid.UUID(pk)
|
||
except ValueError:
|
||
return Response({
|
||
"code": 400,
|
||
"message": f"无效的用户ID格式: {pk}。用户ID应为有效的UUID格式。",
|
||
"data": None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
try:
|
||
user = User.objects.get(pk=pk)
|
||
except User.DoesNotExist:
|
||
return Response({
|
||
'code': 404,
|
||
'message': '用户不存在',
|
||
'data': None
|
||
}, status=status.HTTP_404_NOT_FOUND)
|
||
|
||
# 检查是否试图删除管理员账户
|
||
if user.is_superuser or user.role == 'admin':
|
||
return Response({
|
||
'code': 403,
|
||
'message': '不允许删除管理员账户',
|
||
'data': None
|
||
}, status=status.HTTP_403_FORBIDDEN)
|
||
|
||
# 删除用户
|
||
username = user.username
|
||
user.delete()
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': f'用户 {username} 删除成功',
|
||
'data': None
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"删除用户失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'删除用户失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@csrf_exempt
|
||
@api_view(['POST'])
|
||
@permission_classes([IsAuthenticated])
|
||
def verify_token(request):
|
||
"""验证令牌有效性"""
|
||
try:
|
||
return Response({
|
||
"code": 200,
|
||
"message": "令牌有效",
|
||
"data": {
|
||
"is_valid": True,
|
||
"user": {
|
||
"id": str(request.user.id),
|
||
"username": request.user.username,
|
||
"email": request.user.email,
|
||
"name": request.user.name,
|
||
"role": request.user.role,
|
||
"department": request.user.department,
|
||
"group": request.user.group
|
||
}
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"验证令牌失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
"code": 500,
|
||
"message": f"验证失败: {str(e)}",
|
||
"data": None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@api_view(['GET'])
|
||
@permission_classes([IsAuthenticated])
|
||
def user_list(request):
|
||
"""获取用户列表"""
|
||
try:
|
||
# 获取查询参数
|
||
page = int(request.query_params.get('page', 1))
|
||
page_size = int(request.query_params.get('page_size', 20))
|
||
keyword = request.query_params.get('keyword', '')
|
||
|
||
# 根据用户角色获取不同范围的用户列表
|
||
user = request.user
|
||
base_query = User.objects.all()
|
||
|
||
if user.role == 'admin':
|
||
users_query = base_query
|
||
elif user.role == 'leader':
|
||
users_query = base_query.filter(department=user.department)
|
||
else:
|
||
users_query = base_query.filter(id=user.id)
|
||
|
||
# 添加关键字搜索
|
||
if keyword:
|
||
users_query = users_query.filter(
|
||
Q(username__icontains=keyword) |
|
||
Q(email__icontains=keyword) |
|
||
Q(name__icontains=keyword) |
|
||
Q(department__icontains=keyword)
|
||
)
|
||
|
||
# 计算总数
|
||
total = users_query.count()
|
||
|
||
# 分页
|
||
start = (page - 1) * page_size
|
||
end = start + page_size
|
||
users = users_query[start:end]
|
||
|
||
# 构造数据
|
||
user_data = []
|
||
for u in users:
|
||
user_data.append({
|
||
'id': str(u.id),
|
||
'username': u.username,
|
||
'email': u.email,
|
||
'name': u.name,
|
||
'role': u.role,
|
||
'department': u.department,
|
||
'group': u.group,
|
||
'is_active': u.is_active,
|
||
'date_joined': u.date_joined.strftime('%Y-%m-%d %H:%M:%S')
|
||
})
|
||
|
||
return Response({
|
||
'code': 200,
|
||
'message': '获取用户列表成功',
|
||
'data': {
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'users': user_data
|
||
}
|
||
})
|
||
|
||
except ValueError as e:
|
||
# 处理参数转换错误
|
||
return Response({
|
||
'code': 400,
|
||
'message': f'参数错误: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户列表失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return Response({
|
||
'code': 500,
|
||
'message': f'获取用户列表失败: {str(e)}',
|
||
'data': None
|
||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|