From b64af631e3b14369bb13113885aa3463e2306a94 Mon Sep 17 00:00:00 2001 From: wanjia Date: Mon, 7 Apr 2025 12:41:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E6=A1=A3=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E5=8A=9F=E8=83=BD=E5=AE=8C=E5=96=84,=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E4=BC=9A=E8=AF=9D=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- role_based_system/settings.py | 592 +-- user_management/urls.py | 91 +- user_management/views.py | 8332 +++++++++++++++++++-------------- 3 files changed, 5046 insertions(+), 3969 deletions(-) diff --git a/role_based_system/settings.py b/role_based_system/settings.py index 9f61610c..ea6ebb79 100644 --- a/role_based_system/settings.py +++ b/role_based_system/settings.py @@ -1,296 +1,296 @@ -""" -Django settings for role_based_system project. - -Generated by 'django-admin startproject' using Django 5.1.5. - -For more information on this file, see -https://docs.djangoproject.com/en/5.1/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/5.1/ref/settings/ -""" - -import os -from pathlib import Path - -# API 配置 -API_BASE_URL = 'http://180.163.88.62:30331' - -DEPARTMENT_GROUPS = { - "技术部": ["开发组", "测试组", "运维组"], - "产品部": ["产品组", "设计组"], - "市场部": ["销售组", "推广组"], - "行政部": ["人事组", "财务组"] -} - -# Build paths inside the project like this: BASE_DIR / 'subdir'. -BASE_DIR = Path(__file__).resolve().parent.parent - - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/5.1/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'django-insecure-5f*=0_2did)e(()n58=e#vd5gaf&y$thgt(h6&=p+wm1*r6mb=' - -# SECURITY WARNING: don't run with debug turned on in production! -# 开发配置 -# DEBUG = True - -ALLOWED_HOSTS = ['*'] # 仅在开发环境使用 -# 服务器配置 -DEBUG = False - -# ALLOWED_HOSTS = ['frptx.chiyong.fun', 'localhost', '127.0.0.1'] - -# Application definition - -INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'rest_framework', - 'rest_framework.authtoken', - 'channels', - 'user_management', - 'channels_redis', - 'corsheaders', -] - -MIDDLEWARE = [ - 'corsheaders.middleware.CorsMiddleware', # CORS中间件要放在CommonMiddleware前面 - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', # 确保这行没有被注释 - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'user_management.middleware.UserActivityMiddleware', -] - -ROOT_URLCONF = 'role_based_system.urls' - -TEMPLATES = [ - { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [BASE_DIR / 'templates'] - , - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], - }, - }, -] - -WSGI_APPLICATION = 'role_based_system.wsgi.application' - - -# Database -# https://docs.djangoproject.com/en/5.1/ref/settings/#databases - - -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.mysql', - 'NAME': 'rolebasedfilemanagement', - 'USER': 'root', - 'PASSWORD': '123456', - 'HOST': 'localhost', - 'PORT': '3306', - 'OPTIONS': { - 'charset': 'utf8mb4', # 使用 utf8mb4 字符集 - 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'; SET innodb_strict_mode=1; SET NAMES utf8mb4;", - }, - 'TEST': { - 'CHARSET': 'utf8mb4', - 'COLLATION': 'utf8mb4_unicode_ci', - }, - } -} -# Password validation -# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', - }, -] - - - -LANGUAGE_CODE = 'en-us' - -TIME_ZONE = 'UTC' - -USE_I18N = True - -USE_TZ = True - - - -STATIC_URL = 'static/' - -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' - -AUTH_USER_MODEL = 'user_management.User' - -REST_FRAMEWORK = { - 'DEFAULT_AUTHENTICATION_CLASSES': [ - 'rest_framework.authentication.TokenAuthentication', - 'rest_framework.authentication.SessionAuthentication', - ], - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.IsAuthenticated', - ] -} - -# Channels 配置 -ASGI_APPLICATION = "role_based_system.asgi.application" - -# Channel Layers 配置 -CHANNEL_LAYERS = { - "default": { - "BACKEND": "channels_redis.core.RedisChannelLayer", - "CONFIG": { - "hosts": [("127.0.0.1", 6379)], - "capacity": 1500, # 消息队列容量 - "expiry": 10, # 消息过期时间(秒) - }, - }, -} - - -# CORS 配置 -CORS_ALLOW_ALL_ORIGINS = True -CORS_ALLOW_CREDENTIALS = True -CORS_ALLOWED_ORIGINS = [ - "http://localhost:8000", - "http://127.0.0.1:8000", - "http://124.222.236.141:8000", - "ws://localhost:8000", # 添加 WebSocket - "ws://127.0.0.1:8000", # 添加 WebSocket - "ws://124.222.236.141:8000", # 添加 WebSocket -] -# 允许的请求头 -CORS_ALLOWED_HEADERS = [ - 'accept', - 'accept-encoding', - 'authorization', - 'content-type', - 'dnt', - 'origin', - 'user-agent', - 'x-csrftoken', - 'x-requested-with', -] - -# 允许的请求方法 -CORS_ALLOWED_METHODS = [ - 'DELETE', - 'GET', - 'OPTIONS', - 'PATCH', - 'POST', - 'PUT', -] - - - -# WebSocket 允许的来源 -CSRF_TRUSTED_ORIGINS = [ - 'http://localhost:8000', - 'http://127.0.0.1:8000', - 'http://124.222.236.141:8000', - 'ws://localhost:8000', # 添加 WebSocket - 'ws://127.0.0.1:8000', # 添加 WebSocket - 'ws://124.222.236.141:8000', # 添加 WebSocket -] -# 服务器配置 -# 静态文件配置 -STATIC_URL = '/static/' -STATIC_ROOT = os.path.join(BASE_DIR, 'static') - -# 媒体文件配置 -MEDIA_URL = '/media/' -MEDIA_ROOT = os.path.join(BASE_DIR, 'media') - -# 文件上传配置 -FILE_UPLOAD_MAX_MEMORY_SIZE = 10 * 1024 * 1024 # 10MB -MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10MB - -# 日志配置 -LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - }, - 'file': { - 'class': 'logging.FileHandler', - 'filename': 'debug.log', - }, - }, - 'root': { - 'handlers': ['console', 'file'], - 'level': 'DEBUG', - }, - 'loggers': { - 'django.security.csrf': { - 'handlers': ['file'], - 'level': 'WARNING', - 'propagate': True, - }, - }, -} - -# CSRF 配置 -CSRF_COOKIE_SECURE = False # 开发环境设置为 False -CSRF_COOKIE_HTTPONLY = False -CSRF_USE_SESSIONS = False -CSRF_COOKIE_SAMESITE = 'Lax' -CSRF_TRUSTED_ORIGINS = [ - 'http://localhost:8000', - 'http://127.0.0.1:8000', - 'ws://localhost:8000', # 添加 WebSocket - 'ws://127.0.0.1:8000' # 添加 WebSocket -] - -# Session 配置 -SESSION_COOKIE_SECURE = False # 开发环境设置为 False -SESSION_COOKIE_HTTPONLY = True -SESSION_COOKIE_SAMESITE = 'Lax' - -# REST Framework 配置 -REST_FRAMEWORK = { - 'DEFAULT_AUTHENTICATION_CLASSES': [ - 'rest_framework.authentication.TokenAuthentication', - 'rest_framework.authentication.SessionAuthentication', # WebSocket 需要 - ], - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.IsAuthenticated', - ], - 'DEFAULT_PARSER_CLASSES': [ - 'rest_framework.parsers.JSONParser', - 'rest_framework.parsers.FormParser', - 'rest_framework.parsers.MultiPartParser' - ], -} +""" +Django settings for role_based_system project. + +Generated by 'django-admin startproject' using Django 5.1.5. + +For more information on this file, see +https://docs.djangoproject.com/en/5.1/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/5.1/ref/settings/ +""" + +import os +from pathlib import Path + +# API 配置 +API_BASE_URL = 'http://81.69.223.133:48329' + +DEPARTMENT_GROUPS = { + "技术部": ["开发组", "测试组", "运维组"], + "产品部": ["产品组", "设计组"], + "市场部": ["销售组", "推广组"], + "行政部": ["人事组", "财务组"] +} + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/5.1/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = 'django-insecure-5f*=0_2did)e(()n58=e#vd5gaf&y$thgt(h6&=p+wm1*r6mb=' + +# SECURITY WARNING: don't run with debug turned on in production! +# 开发配置 +# DEBUG = True + +ALLOWED_HOSTS = ['*'] # 仅在开发环境使用 +# 服务器配置 +DEBUG = False + +# ALLOWED_HOSTS = ['frptx.chiyong.fun', 'localhost', '127.0.0.1'] + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'rest_framework', + 'rest_framework.authtoken', + 'channels', + 'user_management', + 'channels_redis', + 'corsheaders', +] + +MIDDLEWARE = [ + 'corsheaders.middleware.CorsMiddleware', # CORS中间件要放在CommonMiddleware前面 + 'django.middleware.security.SecurityMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', # 确保这行没有被注释 + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'user_management.middleware.UserActivityMiddleware', +] + +ROOT_URLCONF = 'role_based_system.urls' + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [BASE_DIR / 'templates'] + , + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] + +WSGI_APPLICATION = 'role_based_system.wsgi.application' + + +# Database +# https://docs.djangoproject.com/en/5.1/ref/settings/#databases + + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.mysql', + 'NAME': 'rolebasedfilemanagement', + 'USER': 'root', + 'PASSWORD': '123456', + 'HOST': '127.0.0.1', + 'PORT': '3306', + 'OPTIONS': { + 'charset': 'utf8mb4', # 使用 utf8mb4 字符集 + 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'; SET innodb_strict_mode=1; SET NAMES utf8mb4;", + }, + 'TEST': { + 'CHARSET': 'utf8mb4', + 'COLLATION': 'utf8mb4_unicode_ci', + }, + } +} +# Password validation +# https://docs.djangoproject.com/en/5.1/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + }, +] + + + +LANGUAGE_CODE = 'en-us' + +TIME_ZONE = 'Asia/Shanghai' + +USE_I18N = True + +USE_TZ = False + + + +STATIC_URL = 'static/' + +DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' + +AUTH_USER_MODEL = 'user_management.User' + +REST_FRAMEWORK = { + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'rest_framework.authentication.TokenAuthentication', + 'rest_framework.authentication.SessionAuthentication', + ], + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + ] +} + +# Channels 配置 +ASGI_APPLICATION = "role_based_system.asgi.application" + +# Channel Layers 配置 +CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels_redis.core.RedisChannelLayer", + "CONFIG": { + "hosts": [("127.0.0.1", 6379)], + "capacity": 1500, # 消息队列容量 + "expiry": 10, # 消息过期时间(秒) + }, + }, +} + + +# CORS 配置 +CORS_ALLOW_ALL_ORIGINS = True +CORS_ALLOW_CREDENTIALS = True +CORS_ALLOWED_ORIGINS = [ + "http://localhost:8000", + "http://127.0.0.1:8000", + "http://124.222.236.141:8000", + "ws://localhost:8000", # 添加 WebSocket + "ws://127.0.0.1:8000", # 添加 WebSocket + "ws://124.222.236.141:8000", # 添加 WebSocket +] +# 允许的请求头 +CORS_ALLOWED_HEADERS = [ + 'accept', + 'accept-encoding', + 'authorization', + 'content-type', + 'dnt', + 'origin', + 'user-agent', + 'x-csrftoken', + 'x-requested-with', +] + +# 允许的请求方法 +CORS_ALLOWED_METHODS = [ + 'DELETE', + 'GET', + 'OPTIONS', + 'PATCH', + 'POST', + 'PUT', +] + + + +# WebSocket 允许的来源 +CSRF_TRUSTED_ORIGINS = [ + 'http://localhost:8000', + 'http://127.0.0.1:8000', + 'http://124.222.236.141:8000', + 'ws://localhost:8000', # 添加 WebSocket + 'ws://127.0.0.1:8000', # 添加 WebSocket + 'ws://124.222.236.141:8000', # 添加 WebSocket +] +# 服务器配置 +# 静态文件配置 +STATIC_URL = '/static/' +STATIC_ROOT = os.path.join(BASE_DIR, 'static') + +# 媒体文件配置 +MEDIA_URL = '/media/' +MEDIA_ROOT = os.path.join(BASE_DIR, 'media') + +# 文件上传配置 +FILE_UPLOAD_MAX_MEMORY_SIZE = 10 * 1024 * 1024 # 10MB +MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10MB + +# 日志配置 +LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + }, + 'file': { + 'class': 'logging.FileHandler', + 'filename': 'debug.log', + }, + }, + 'root': { + 'handlers': ['console', 'file'], + 'level': 'DEBUG', + }, + 'loggers': { + 'django.security.csrf': { + 'handlers': ['file'], + 'level': 'WARNING', + 'propagate': True, + }, + }, +} + +# CSRF 配置 +CSRF_COOKIE_SECURE = False # 开发环境设置为 False +CSRF_COOKIE_HTTPONLY = False +CSRF_USE_SESSIONS = False +CSRF_COOKIE_SAMESITE = 'Lax' +CSRF_TRUSTED_ORIGINS = [ + 'http://localhost:8000', + 'http://127.0.0.1:8000', + 'ws://localhost:8000', # 添加 WebSocket + 'ws://127.0.0.1:8000' # 添加 WebSocket +] + +# Session 配置 +SESSION_COOKIE_SECURE = False # 开发环境设置为 False +SESSION_COOKIE_HTTPONLY = True +SESSION_COOKIE_SAMESITE = 'Lax' + +# REST Framework 配置 +REST_FRAMEWORK = { + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'rest_framework.authentication.TokenAuthentication', + 'rest_framework.authentication.SessionAuthentication', # WebSocket 需要 + ], + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + ], + 'DEFAULT_PARSER_CLASSES': [ + 'rest_framework.parsers.JSONParser', + 'rest_framework.parsers.FormParser', + 'rest_framework.parsers.MultiPartParser' + ], +} diff --git a/user_management/urls.py b/user_management/urls.py index 124b149e..f07c09bd 100644 --- a/user_management/urls.py +++ b/user_management/urls.py @@ -1,46 +1,45 @@ -from django.urls import path, include -from rest_framework.routers import DefaultRouter -from .views import ( - KnowledgeBaseViewSet, - PermissionViewSet, - NotificationViewSet, - verify_token, - user_list, - user_detail, - user_update, - user_delete, - change_password, - RegisterView, - LoginView, - LogoutView, - ChatHistoryViewSet -) - -# 创建路由器 -router = DefaultRouter() - -# 注册视图集 -router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-base') -router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-bases') -router.register(r'permissions', PermissionViewSet, basename='permission') -router.register(r'notifications', NotificationViewSet, basename='notification') -router.register(r'chat-history', ChatHistoryViewSet, basename='chat-history') - -# URL patterns -urlpatterns = [ - # API 路由 - path('', include(router.urls)), - - # 用户认证相关 - path('auth/register/', RegisterView.as_view(), name='register'), - path('auth/login/', LoginView.as_view(), name='login'), - path('auth/logout/', LogoutView.as_view(), name='logout'), - path('auth/verify-token/', verify_token, name='verify-token'), - path('auth/change-password/', change_password, name='change-password'), - - # 用户管理相关 - path('users/', user_list, name='user-list'), - path('users//', user_detail, name='user-detail'), - path('users//update/', user_update, name='user-update'), - path('users//delete/', user_delete, name='user-delete'), -] +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from .views import ( + KnowledgeBaseViewSet, + PermissionViewSet, + NotificationViewSet, + verify_token, + user_list, + user_detail, + user_update, + user_delete, + change_password, + RegisterView, + LoginView, + LogoutView, + ChatHistoryViewSet +) + +# 创建路由器 +router = DefaultRouter() + +# 注册视图集 +router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-bases') +router.register(r'permissions', PermissionViewSet, basename='permission') +router.register(r'notifications', NotificationViewSet, basename='notification') +router.register(r'chat-history', ChatHistoryViewSet, basename='chat-history') + +# URL patterns +urlpatterns = [ + # API 路由 + path('', include(router.urls)), + + # 用户认证相关 + path('auth/register/', RegisterView.as_view(), name='register'), + path('auth/login/', LoginView.as_view(), name='login'), + path('auth/logout/', LogoutView.as_view(), name='logout'), + path('auth/verify-token/', verify_token, name='verify-token'), + path('auth/change-password/', change_password, name='change-password'), + + # 用户管理相关 + path('users/', user_list, name='user-list'), + path('users//', user_detail, name='user-detail'), + path('users//update/', user_update, name='user-update'), + path('users//delete/', user_delete, name='user-delete'), +] diff --git a/user_management/views.py b/user_management/views.py index 50661a38..f735e129 100644 --- a/user_management/views.py +++ b/user_management/views.py @@ -1,3627 +1,4705 @@ -from rest_framework import viewsets, status -from rest_framework.decorators import action, api_view, permission_classes -from rest_framework.permissions import IsAuthenticated, AllowAny, IsAdminUser -from rest_framework.response import Response -from rest_framework.exceptions import APIException, PermissionDenied, ValidationError, NotFound -from rest_framework.authentication import TokenAuthentication -from django.utils import timezone -from django.db import connection -from django.db.models import Q, Max, Count, F -from datetime import timedelta, datetime -import mysql.connector -from django.contrib.auth import get_user_model, authenticate, login, logout -from channels.layers import get_channel_layer -from asgiref.sync import async_to_sync -from rest_framework.authtoken.models import Token -import requests -import json -from django.db import transaction -from django.core.exceptions import ObjectDoesNotExist -import sys -import random -import string -import time -import logging -import os -from rest_framework.test import APIRequestFactory -from django.contrib.contenttypes.models import ContentType -from django.contrib.contenttypes.fields import GenericForeignKey -from django.http import Http404, HttpResponse -from django.db import IntegrityError -from channels.exceptions import ChannelFull -from django.conf import settings -from django.shortcuts import get_object_or_404 -from django.db import models -from rest_framework.views import APIView -from django.core.validators import validate_email -# from django.core.exceptions import ValidationError -from django.views.decorators.csrf import csrf_exempt -from django.utils.decorators import method_decorator -import uuid -from rest_framework import serializers -import traceback -import requests -import json - - - -# 添加模型导入 -from .models import ( - User, - Data, # 替换原来的 AdminData, LeaderData, MemberData - Permission, # 替换原来的 DataPermission, TablePermission - ChatHistory, - KnowledgeBase, - Notification, - KnowledgeBasePermission as KBPermissionModel, - KnowledgeBaseDocument -) -from .serializers import ( - UserSerializer, - DataSerializer, # 需要更新 - PermissionSerializer, # 需要更新 - ChatHistorySerializer, - KnowledgeBaseSerializer, - KnowledgePermissionSerializer, # 添加这个导入 - NotificationSerializer -) -# 导入自定义权限类 -from .permissions import ResourceCRUDPermission, PermissionRequestPermission, DataPermission, KnowledgeBasePermission as KBPermissionClass -from .exceptions import ExternalAPIError - -# 获取正确的用户模型 -User = get_user_model() - -logger = logging.getLogger(__name__) -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(message)s', - handlers=[ - logging.StreamHandler() # 输出到控制台 - ] -) - - -class KnowledgeBasePermissionMixin: - """知识库权限管理混入类""" - - def _can_read(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None): - """检查读取权限""" - try: - # 1. 检查显式权限表 - if knowledge_base_id: - permission = KBPermissionModel.objects.filter( - knowledge_base_id=knowledge_base_id, - user=user, - can_read=True, - status='active' - ).first() - if permission: - return True - - # 2. 检查角色权限 - # 私有知识库 - if type == 'private': - return str(user.id) == str(creator_id) - - # 成员级知识库 - if type == 'member': - return user.department == department - - # 部门级知识库 - if type == 'leader': - return (user.department == department and - user.role in ['leader', 'admin']) - - # 管理级知识库 - if type == 'admin': - return user.role == 'admin' - - return False - - except Exception as e: - logger.error(f"检查读取权限时出错: {str(e)}") - return False - - def _can_edit(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None): - """检查编辑权限""" - try: - # 1. 检查显式权限表 - if knowledge_base_id: - permission = KBPermissionModel.objects.filter( - knowledge_base_id=knowledge_base_id, - user=user, - can_edit=True, - status='active' - ).first() - if permission: - return True - - # 2. 检查角色权限 - # 私有知识库 - if type == 'private': - return str(user.id) == str(creator_id) - - # 成员级知识库 - if type == 'member': - return (user.department == department and - user.role in ['leader', 'admin']) - - # 部门级知识库 - if type == 'leader': - return (user.department == department and - user.role in ['leader', 'admin']) - - # 管理级知识库 - if type == 'admin': - return user.role == 'admin' - - return False - - except Exception as e: - logger.error(f"检查编辑权限时出错: {str(e)}") - return False - - def _can_delete(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None): - """检查删除权限""" - try: - # 1. 检查显式权限表 - if knowledge_base_id: - permission = KBPermissionModel.objects.filter( - knowledge_base_id=knowledge_base_id, - user=user, - can_delete=True, - status='active' - ).first() - if permission: - return True - - # 2. 检查角色权限 - # 私有知识库 - if type == 'private': - return str(user.id) == str(creator_id) - - # 成员级知识库 - if type == 'member': - return (user.department == department and - user.role == 'admin') - - # 部门级知识库 - if type == 'leader': - return (user.department == department and - user.role == 'admin') - - # 管理级知识库 - if type == 'admin': - return user.role == 'admin' - - return False - - except Exception as e: - logger.error(f"检查删除权限时出错: {str(e)}") - return False - - def check_knowledge_base_permission(self, knowledge_base, user, required_permission='read'): - """统一的知识库权限检查方法""" - if not knowledge_base: - return False - - permission_method = { - 'read': self._can_read, - 'edit': self._can_edit, - 'delete': self._can_delete - }.get(required_permission) - - if not permission_method: - return False - - return permission_method( - type=knowledge_base.type, - user=user, - department=knowledge_base.department, - group=knowledge_base.group, - creator_id=knowledge_base.user_id, - knowledge_base_id=knowledge_base.id - ) - - - -class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): - permission_classes = [IsAuthenticated] - queryset = ChatHistory.objects.all() - - def get_queryset(self): - """确保用户只能看到自己的未删除的聊天记录""" - return ChatHistory.objects.filter( - user=self.request.user, - is_deleted=False - ) - - def list(self, request): - """获取对话列表概览""" - try: - # 获取查询参数 - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - - # 获取所有对话的概览 - latest_chats = self.get_queryset().values( - 'conversation_id' - ).annotate( - latest_id=Max('id'), - message_count=Count('id'), - last_message=Max('created_at') - ).order_by('-last_message') - - # 计算分页 - total = latest_chats.count() - start = (page - 1) * page_size - end = start + page_size - chats = latest_chats[start:end] - - results = [] - for chat in chats: - # 获取最新消息记录 - latest_record = ChatHistory.objects.get(id=chat['latest_id']) - - # 从metadata中获取完整的知识库信息 - dataset_info = [] - if latest_record.metadata: - dataset_id_list = latest_record.metadata.get('dataset_id_list', []) - dataset_names = latest_record.metadata.get('dataset_names', []) - - # 如果有知识库ID列表 - if dataset_id_list: - # 如果同时有名称列表且长度匹配 - if dataset_names and len(dataset_names) == len(dataset_id_list): - dataset_info = [{ - 'id': str(id), - 'name': name - } for id, name in zip(dataset_id_list, dataset_names)] - else: - # 如果没有名称列表,则只返回ID - datasets = KnowledgeBase.objects.filter(id__in=dataset_id_list) - dataset_info = [{ - 'id': str(ds.id), - 'name': ds.name - } for ds in datasets] - - results.append({ - 'conversation_id': chat['conversation_id'], - 'message_count': chat['message_count'], - 'last_message': latest_record.content, - 'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S'), - 'dataset_id_list': [ds['id'] for ds in dataset_info], # 添加完整的知识库ID列表 - 'datasets': dataset_info # 包含ID和名称的完整信息 - }) - - return Response({ - 'code': 200, - 'message': '获取成功', - 'data': { - 'total': total, - 'page': page, - 'page_size': page_size, - 'results': results - } - }) - - except Exception as e: - logger.error(f"获取聊天记录失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取聊天记录失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def conversation_detail(self, request): - """获取特定对话的详细信息""" - try: - conversation_id = request.query_params.get('conversation_id') - if not conversation_id: - return Response({ - 'code': 400, - 'message': '缺少conversation_id参数', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 获取对话历史 - messages = self.get_queryset().filter( - conversation_id=conversation_id - ).order_by('created_at') - - if not messages.exists(): - return Response({ - 'code': 404, - 'message': '对话不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - # 获取知识库信息 - first_message = messages.first() - dataset_info = [] - if first_message and first_message.metadata: - if 'dataset_id_list' in first_message.metadata: - datasets = KnowledgeBase.objects.filter( - id__in=first_message.metadata['dataset_id_list'] - ) - # 过滤出用户有权限访问的知识库 - accessible_datasets = [ - ds for ds in datasets - if self.check_knowledge_base_permission(ds, request.user, 'read') - ] - dataset_info = [{ - 'id': str(ds.id), - 'name': ds.name, - 'type': ds.type - } for ds in accessible_datasets] - - return Response({ - 'code': 200, - 'message': '获取成功', - 'data': { - 'conversation_id': conversation_id, - 'datasets': dataset_info, - 'messages': [{ - 'id': str(msg.id), - 'role': msg.role, - 'content': msg.content, - 'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S') - } for msg in messages] - } - }) - - except Exception as e: - logger.error(f"获取对话详情失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取对话详情失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def available_datasets(self, request): - """获取用户可访问的知识库列表""" - try: - user = request.user - all_datasets = KnowledgeBase.objects.all() - - # 使用统一的权限检查方法 - accessible_datasets = [ - dataset for dataset in all_datasets - if self.check_knowledge_base_permission(dataset, user, 'read') - ] - - return Response({ - 'code': 200, - 'message': '获取成功', - 'data': [{ - 'id': str(ds.id), - 'name': ds.name, - 'type': ds.type, - 'department': ds.department, - 'description': ds.desc - } for ds in accessible_datasets] - }) - - except Exception as e: - logger.error(f"获取可用知识库列表失败: {str(e)}") - return Response({ - 'code': 500, - 'message': f'获取可用知识库列表失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def create(self, request): - """创建聊天记录""" - try: - data = request.data - - # 检查必填字段 - 支持单知识库或多知识库模式 - if 'question' not in data: - return Response({ - 'code': 400, - 'message': '缺少必填字段: question', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 检查知识库ID:支持dataset_id或dataset_id_list格式 - dataset_ids = [] - if 'dataset_id' in data: - dataset_id = data['dataset_id'] - # 直接使用标准UUID格式 - dataset_ids.append(str(dataset_id)) - elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)): - # 处理可能的字符串格式 - if isinstance(data['dataset_id_list'], str): - try: - # 尝试解析JSON字符串 - dataset_list = json.loads(data['dataset_id_list']) - if isinstance(dataset_list, list): - dataset_ids = [str(id) for id in dataset_list] - except json.JSONDecodeError: - # 如果解析失败,可能是单个ID - dataset_ids = [str(data['dataset_id_list'])] - else: - # 如果已经是列表,直接使用标准UUID格式 - dataset_ids = [str(id) for id in data['dataset_id_list']] - else: - return Response({ - 'code': 400, - 'message': '缺少必填字段: dataset_id 或 dataset_id_list', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - if not dataset_ids: - return Response({ - 'code': 400, - 'message': '至少需要提供一个知识库ID', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证所有知识库并收集external_ids - external_id_list = [] - user = request.user - knowledge_bases = [] # 存储所有知识库对象 - - for kb_id in dataset_ids: - try: - knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first() - if not knowledge_base: - return Response({ - 'code': 404, - 'message': f'知识库不存在: {kb_id}', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - knowledge_bases.append(knowledge_base) - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(knowledge_base, user, 'read'): - return Response({ - 'code': 403, - 'message': f'无权访问知识库: {knowledge_base.name}', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 添加知识库的external_id到列表 - if knowledge_base.external_id: - external_id_list.append(knowledge_base.external_id) - else: - logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id") - - except Exception as e: - return Response({ - 'code': 400, - 'message': f'处理知识库ID出错: {str(e)}', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - if not external_id_list: - return Response({ - 'code': 400, - 'message': '没有有效的知识库external_id', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 获取或创建对话ID - conversation_id = data.get('conversation_id') - - # 如果没有提供 conversation_id,根据知识库组合生成新的ID - if not conversation_id: - # 对知识库ID列表排序以确保相同组合生成相同的hash - sorted_kb_ids = sorted(dataset_ids) - # 使用知识库ID组合生成唯一的conversation_id - conversation_id = str(uuid.uuid5( - uuid.NAMESPACE_DNS, - '-'.join(sorted_kb_ids) - )) - logger.info(f"为知识库组合 {sorted_kb_ids} 生成新的conversation_id: {conversation_id}") - else: - logger.info(f"使用现有conversation_id: {conversation_id}") - - # 调用外部API获取答案 (传递多个knowledge base的external_id) - answer = self._get_answer_from_external_api( - dataset_external_id_list=external_id_list, - question=data['question'] - ) - - if not answer: - return Response({ - 'code': 500, - 'message': '获取AI回答失败', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # 准备完整的metadata - metadata = { - 'model_id': data.get('model_id', '58c5deb4-f2e2-11ef-9a1b-0242ac120009'), - 'dataset_id_list': [str(id) for id in dataset_ids], - 'dataset_external_id_list': [str(id) for id in external_id_list], - 'dataset_names': [kb.name for kb in knowledge_bases] # 添加知识库名称列表 - } - - # 创建用户问题记录 - question_record = ChatHistory.objects.create( - user=request.user, - knowledge_base=knowledge_bases[0], # 仍然需要一个主知识库,使用第一个 - conversation_id=str(conversation_id), - role='user', - content=data['question'], - metadata=metadata - ) - - # 创建AI回答记录 - answer_record = ChatHistory.objects.create( - user=request.user, - knowledge_base=knowledge_bases[0], # 仍然需要一个主知识库,使用第一个 - conversation_id=str(conversation_id), - parent_id=str(question_record.id), - role='assistant', - content=answer, - metadata=metadata - ) - - # 返回完整的响应 - return Response({ - 'code': 200, - 'message': '创建成功', - 'data': { - 'id': str(answer_record.id), - 'conversation_id': str(conversation_id), - 'dataset_id_list': [str(id) for id in dataset_ids], - 'dataset_names': [kb.name for kb in knowledge_bases], # 返回所有知识库名称 - 'role': 'assistant', - 'content': answer_record.content, - 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S') - } - }, status=status.HTTP_201_CREATED) - - except Exception as e: - logger.error(f"创建聊天记录失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'创建聊天记录失败: {str(e)}', - 'data': None - }, status.HTTP_500_INTERNAL_SERVER_ERROR) - - def _get_answer_from_external_api(self, dataset_external_id_list, question): - """调用外部API获取AI回答""" - try: - # 确保所有ID都是字符串 - dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list] - - logger.info(f"准备调用外部API,知识库ID列表: {dataset_external_ids}") - - # 第一个API调用创建聊天 - chat_request_data = { - "id": "65031f4d-c86d-430e-8089-d8ff2731a837", - "model_id": "58c5deb4-f2e2-11ef-9a1b-0242ac120009", - "dataset_id_list": dataset_external_ids, - "multiple_rounds_dialogue": False, - "dataset_setting": { - "top_n": 10, - "similarity": "0.3", - "max_paragraph_char_number": 10000, - "search_mode": "blend", - "no_references_setting": { - "value": "{question}", - "status": "ai_questioning" - } - }, - "model_setting": { - "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" - }, - "problem_optimization": False - } - - logger.info(f"发送创建聊天请求:{settings.API_BASE_URL}/api/application/chat/open") - - try: - # 测试JSON序列化,提前捕获可能的错误 - json_data = json.dumps(chat_request_data) - logger.debug(f"请求数据序列化成功,长度: {len(json_data)}") - except TypeError as e: - logger.error(f"JSON序列化失败: {str(e)}") - return None - - chat_response = requests.post( - url=f"{settings.API_BASE_URL}/api/application/chat/open", - json=chat_request_data, - headers={"Content-Type": "application/json"}, - timeout=30 - ) - - logger.info(f"API响应状态码: {chat_response.status_code}") - - if chat_response.status_code != 200: - logger.error(f"外部API调用失败: {chat_response.text}") - return None - - chat_data = chat_response.json() - logger.debug(f"API响应数据: {chat_data}") - - if chat_data.get('code') != 200 or not chat_data.get('data'): - logger.error(f"外部API返回错误: {chat_data}") - return None - - chat_id = chat_data['data'] - logger.info(f"聊天创建成功,chat_id: {chat_id}") - - # 第二个API调用发送消息 - message_request_data = { - "message": question, - "re_chat": False, - "stream": True - } - - logger.info(f"发送聊天消息请求: {settings.API_BASE_URL}/api/application/chat_message/{chat_id}") - message_response = requests.post( - url=f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}", - json=message_request_data, - headers={"Content-Type": "application/json"}, - stream=True, - timeout=60 - ) - - if message_response.status_code != 200: - logger.error(f"外部API聊天消息调用失败: {message_response.status_code}, {message_response.text}") - return None - - # 拼接流式响应 - 修复SSE格式解析 - full_content = "" - try: - for line in message_response.iter_lines(): - if line: - line_text = line.decode('utf-8') - # 处理SSE格式 (data: {...}) - if line_text.startswith('data: '): - json_str = line_text[6:] # 去掉 "data: " 前缀 - logger.debug(f"处理SSE数据: {json_str}") - try: - chunk = json.loads(json_str) - if 'content' in chunk: - content_part = chunk['content'] - full_content += content_part - logger.debug(f"追加内容: '{content_part}'") - if chunk.get('is_end', False): - logger.debug("收到结束标记") - except json.JSONDecodeError as e: - logger.error(f"JSON解析错误: {str(e)}, 原始数据: {json_str}") - else: - logger.debug(f"收到非SSE格式数据: {line_text}") - except Exception as e: - logger.error(f"处理流式响应出错: {str(e)}") - if full_content: - logger.info(f"已接收部分内容: {len(full_content)} 字符") - return full_content.strip() - return None - - logger.info(f"聊天回答拼接完成,总长度: {len(full_content)}") - return full_content.strip() if full_content else "未能获取到有效回答" - - except Exception as e: - logger.error(f"调用外部API获取回答失败: {str(e)}") - logger.error(traceback.format_exc()) - return None - - def update(self, request, pk=None): - """更新聊天记录""" - try: - record = self.get_queryset().filter(id=pk).first() - - if not record: - return Response({ - 'code': 404, - 'message': '记录不存在或无权限', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - data = request.data - updateable_fields = ['content', 'metadata'] - - if 'content' in data: - record.content = data['content'] - - if 'metadata' in data: - current_metadata = record.metadata or {} - current_metadata.update(data['metadata']) - record.metadata = current_metadata - - record.save() - - return Response({ - 'code': 200, - 'message': '更新成功', - 'data': { - 'id': record.id, - 'conversation_id': record.conversation_id, - 'role': record.role, - 'content': record.content, - 'metadata': record.metadata, - 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S') - } - }) - - except Exception as e: - logger.error(f"更新聊天记录失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'更新聊天记录失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def destroy(self, request, pk=None): - """删除聊天记录(软删除)""" - try: - record = self.get_queryset().filter(id=pk).first() - - if not record: - return Response({ - 'code': 404, - 'message': '记录不存在或无权限', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - record.soft_delete() - - return Response({ - 'code': 200, - 'message': '删除成功', - 'data': None - }) - - except Exception as e: - logger.error(f"删除聊天记录失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'删除聊天记录失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def search(self, request): - """搜索聊天记录""" - try: - # 获取查询参数 - keyword = request.query_params.get('keyword', '').strip() - dataset_id = request.query_params.get('dataset_id') - start_date = request.query_params.get('start_date') - end_date = request.query_params.get('end_date') - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - - # 基础查询 - query = self.get_queryset() - - # 添加过滤条件 - if keyword: - query = query.filter( - Q(content__icontains=keyword) | - Q(knowledge_base__name__icontains=keyword) - ) - - if dataset_id: - # 检查知识库权限 - knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first() - if knowledge_base and not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'): - return Response({ - 'code': 403, - 'message': '无权访问该知识库', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - query = query.filter(knowledge_base__id=dataset_id) - if start_date: - query = query.filter(created_at__gte=start_date) - if end_date: - query = query.filter(created_at__lte=end_date) - - # 计算分页 - total = query.count() - start = (page - 1) * page_size - end = start + page_size - - # 获取分页数据 - records = query.order_by('-created_at')[start:end] - - # 序列化数据 - results = [] - for record in records: - result = { - 'id': record.id, - 'conversation_id': record.conversation_id, - 'dataset_id': str(record.knowledge_base.id), - 'dataset_name': record.knowledge_base.name, - 'role': record.role, - 'content': record.content, - 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'metadata': record.metadata - } - - if keyword: - result['highlights'] = { - 'content': self._highlight_keyword(record.content, keyword) - } - - results.append(result) - - return Response({ - 'code': 200, - 'message': '搜索成功', - 'data': { - 'total': total, - 'page': page, - 'page_size': page_size, - 'results': results - } - }) - - except Exception as e: - logger.error(f"搜索聊天记录失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'搜索失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def export(self, request): - """导出聊天记录为Excel文件""" - try: - # 获取查询参数 - conversation_id = request.query_params.get('conversation_id') - dataset_id = request.query_params.get('dataset_id') - history_days = request.query_params.get('history_days', '7') # 默认导出最近7天 - - # 至少需要一个筛选条件 - if not conversation_id and not dataset_id: - return Response({ - 'code': 400, - 'message': '需要提供conversation_id或dataset_id参数', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证权限 - user = request.user - if dataset_id: - knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first() - if not knowledge_base: - return Response({ - 'code': 404, - 'message': '知识库不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(knowledge_base, user, 'read'): - return Response({ - 'code': 403, - 'message': '无权访问该知识库', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 查询确认有聊天记录存在 - query = self.get_queryset() - if conversation_id: - records = query.filter(conversation_id=conversation_id) - elif dataset_id: - records = query.filter(knowledge_base__id=dataset_id) - - if not records.exists(): - return Response({ - 'code': 404, - 'message': '未找到相关对话记录', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - # 调用外部API导出Excel文件 - 使用GET请求 - application_id = "65031f4d-c86d-430e-8089-d8ff2731a837" # 固定值 - export_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat/export?history_day={history_days}" - - logger.info(f"发送导出请求:{export_url}") - - export_response = requests.get( - url=export_url, - timeout=60, - stream=True # 使用流式传输处理大文件 - ) - - # 检查响应状态 - if export_response.status_code != 200: - logger.error(f"导出API调用失败: {export_response.status_code}, {export_response.text}") - return Response({ - 'code': 500, - 'message': '导出失败,外部服务返回错误', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # 创建响应对象并设置文件下载头 - response = HttpResponse( - content_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' - ) - response['Content-Disposition'] = 'attachment; filename="data.xlsx"' - - # 将API响应内容写入响应对象 - for chunk in export_response.iter_content(chunk_size=8192): - if chunk: - response.write(chunk) - - logger.info("导出成功完成") - return response - - except Exception as e: - logger.error(f"导出聊天记录失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'导出聊天记录失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def chat_list(self, request): - """获取对话列表""" - try: - # 获取查询参数 - history_days = request.query_params.get('history_days', '7') # 默认7天 - - # 构建API请求 - application_id = "65031f4d-c86d-430e-8089-d8ff2731a837" - api_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat" - - # 添加查询参数 - params = { - 'history_day': history_days - } - - logger.info(f"发送获取对话列表请求:{api_url}") - - # 调用外部API - response = requests.get( - url=api_url, - params=params, - timeout=30 - ) - - if response.status_code != 200: - logger.error(f"获取对话列表失败: {response.status_code}, {response.text}") - return Response({ - 'code': 500, - 'message': '获取对话列表失败,外部服务返回错误', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # 解析响应数据 - try: - result = response.json() - - if result.get('code') != 200: - logger.error(f"外部API返回错误: {result}") - return Response({ - 'code': result.get('code', 500), - 'message': result.get('message', '获取对话列表失败'), - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # 处理返回的数据 - chat_list = result.get('data', []) - - # 格式化返回数据 - formatted_chats = [] - for chat in chat_list: - formatted_chat = { - 'id': chat['id'], - 'chat_id': chat['chat_id'], - 'abstract': chat['abstract'], - 'message_count': chat['chat_record_count'], - 'created_at': datetime.fromisoformat(chat['create_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'), - 'updated_at': datetime.fromisoformat(chat['update_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'), - 'star_count': chat['star_num'], - 'trample_count': chat['trample_num'], - 'mark_sum': chat['mark_sum'], - 'is_deleted': chat['is_deleted'] - } - formatted_chats.append(formatted_chat) - - return Response({ - 'code': 200, - 'message': '获取成功', - 'data': { - 'total': len(formatted_chats), - 'results': formatted_chats - } - }) - - except json.JSONDecodeError as e: - logger.error(f"解析响应数据失败: {str(e)}") - return Response({ - 'code': 500, - 'message': '解析响应数据失败', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - except Exception as e: - logger.error(f"获取对话列表失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取对话列表失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def _highlight_keyword(self, text, keyword): - """高亮关键词""" - if not keyword or not text: - return text - return text.replace( - keyword, - f'{keyword}' - ) - - -class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): - serializer_class = KnowledgeBaseSerializer - permission_classes = [IsAuthenticated] - - def list(self, request, *args, **kwargs): - try: - queryset = self.get_queryset() - - # 获取搜索关键字 - keyword = request.query_params.get('keyword', '') - - # 如果有关键字,构建搜索条件 - if keyword: - query = Q(name__icontains=keyword) | \ - Q(desc__icontains=keyword) | \ - Q(department__icontains=keyword) | \ - Q(group__icontains=keyword) - queryset = queryset.filter(query) - - # 获取分页参数 - try: - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - except ValueError: - page = 1 - page_size = 10 - - # 计算总数量 - total = queryset.count() - - # 分页处理 - start = (page - 1) * page_size - end = start + page_size - paginated_queryset = queryset[start:end] - - # 序列化知识库数据 - serializer = self.get_serializer(paginated_queryset, many=True) - data = serializer.data - - # 为每个知识库添加权限信息 - user = request.user - for item in data: - # 获取必要的知识库属性 - kb_type = item['type'] - department = item.get('department') - group = item.get('group') - creator_id = item.get('user_id') - kb_id = item['id'] - - # 使用统一的权限判断方法 - item['permissions'] = { - 'can_read': self._can_read(kb_type, user, department, group, creator_id, kb_id), - 'can_edit': self._can_edit(kb_type, user, department, group, creator_id, kb_id), - 'can_delete': self._can_delete(kb_type, user, department, group, creator_id, kb_id) - } - - # 处理高亮 - if keyword: - if 'name' in item and keyword.lower() in item['name'].lower(): - item['highlighted_name'] = item['name'].replace( - keyword, f'{keyword}' - ) - - if 'desc' in item and item.get('desc') is not None: - desc_text = str(item['desc']) - if keyword.lower() in desc_text.lower(): - item['highlighted_desc'] = desc_text.replace( - keyword, f'{keyword}' - ) - - return Response({ - "code": 200, - "message": "获取知识库列表成功", - "data": { - "total": total, - "page": page, - "page_size": page_size, - "keyword": keyword if keyword else None, - "items": data - } - }) - except Exception as e: - logger.error(f"获取知识库列表失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"获取知识库列表失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def get_queryset(self): - """获取用户有权限查看的知识库列表""" - user = self.request.user - queryset = KnowledgeBase.objects.all() - - # 1. 构建基础权限条件 - permission_conditions = Q() - - # 2. 所有用户都可以看到 admin 类型的知识库 - permission_conditions |= Q(type='admin') - - # 3. 用户可以看到自己创建的所有知识库 - permission_conditions |= Q(user_id=user.id) - - # 4. 添加显式权限条件 - # 获取所有活跃的权限记录 - active_permissions = KBPermissionModel.objects.filter( - user=user, - can_read=True, - status='active', - expires_at__gt=timezone.now() - ).values_list('knowledge_base_id', flat=True) - - if active_permissions: - permission_conditions |= Q(id__in=active_permissions) - - # 5. 根据用户角色添加隐式权限 - if user.role == 'admin': - # 管理员可以看到除了其他用户 private 类型外的所有知识库 - permission_conditions |= ~Q(type='private') | Q(user_id=user.id) - elif user.role == 'leader': - # 组长可以查看本部门的 leader 和 member 类型知识库 - permission_conditions |= Q( - type__in=['leader', 'member'], - department=user.department - ) - elif user.role in ['member', 'user']: - # 成员可以查看本部门的 leader 类型知识库 - permission_conditions |= Q( - type='leader', - department=user.department - ) - # 成员可以查看本部门本组的 member 类型知识库 - permission_conditions |= Q( - type='member', - department=user.department, - group=user.group - ) - - return queryset.filter(permission_conditions).distinct() - - def create(self, request, *args, **kwargs): - try: - # 1. 验证知识库名称 - name = request.data.get('name') - if not name: - return Response({ - 'code': 400, - 'message': '知识库名称不能为空', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - if KnowledgeBase.objects.filter(name=name).exists(): - return Response({ - 'code': 400, - 'message': f'知识库名称 "{name}" 已存在', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 2. 验证用户权限和必填字段 - user = request.user - type = request.data.get('type', 'private') - department = request.data.get('department') - group = request.data.get('group') - - # 修改权限验证 - if type == 'admin': - # 移除管理员权限检查,允许所有用户创建 - department = None - group = None - - elif type == 'secret': - if user.role != 'admin': - return Response({ - 'code': 403, - 'message': '只有管理员可以创建保密级知识库', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - department = None - group = None - - elif type == 'leader': - if user.role != 'admin': - return Response({ - 'code': 403, - 'message': '只有管理员可以创建组长级知识库', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - if not department: - return Response({ - 'code': 400, - 'message': '创建组长级知识库时必须指定部门', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - elif type == 'member': - if user.role not in ['admin', 'leader']: - return Response({ - 'code': 403, - 'message': '只有管理员和组长可以创建成员级知识库', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - if user.role == 'admin' and not department: - return Response({ - 'code': 400, - 'message': '管理员创建成员知识库时必须指定部门', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - elif user.role == 'leader': - department = user.department - - if not group: - return Response({ - 'code': 400, - 'message': '创建成员知识库时必须指定组', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - elif type == 'private': - # 对于private类型,不保存department和group - department = None - group = None - - # 3. 验证请求数据 - data = request.data.copy() - data['department'] = department - data['group'] = group - - # 不需要手动设置 user_id,由序列化器自动处理 - serializer = self.get_serializer(data=data) - if not serializer.is_valid(): - logger.error(f"数据验证失败: {serializer.errors}") - return Response({ - 'code': 400, - 'message': '数据验证失败', - 'data': serializer.errors - }, status=status.HTTP_400_BAD_REQUEST) - - with transaction.atomic(): - # 4. 创建知识库 - try: - knowledge_base = serializer.save() - logger.info(f"知识库创建成功: id={knowledge_base.id}, name={knowledge_base.name}, user_id={knowledge_base.user_id}") - except Exception as e: - logger.error(f"知识库创建失败: {str(e)}") - raise - - # 5. 调用外部API创建知识库 - try: - external_response = self._create_external_dataset(knowledge_base) - logger.info(f"外部知识库创建响应: {external_response}") - - # 处理外部API响应 - if isinstance(external_response, str): - knowledge_base.external_id = external_response - knowledge_base.save() - logger.info(f"更新knowledge_base的external_id为: {external_response}") - else: - if external_response.get('code') == 200: - external_id = external_response.get('data', {}).get('id') - if external_id: - knowledge_base.external_id = external_id - knowledge_base.save() - logger.info(f"更新knowledge_base的external_id为: {external_id}") - else: - raise ValueError("外部API响应中未找到知识库ID") - else: - raise ValueError(f"外部API调用失败: {external_response.get('message', '未知错误')}") - - except Exception as e: - logger.error(f"外部知识库创建失败: {str(e)}") - logger.error(f"外部API响应内容: {external_response if locals().get('external_response') else 'No response'}") - raise ExternalAPIError(f"外部知识库创建失败: {str(e)}") - - # 6. 创建权限记录 - try: - # 创建者权限 - KBPermissionModel.objects.create( - knowledge_base=knowledge_base, - user=request.user, - can_read=True, - can_edit=True, - can_delete=True, - granted_by=request.user, - status='active' - ) - logger.info(f"创建者权限创建成功") - - # 根据类型批量创建其他用户权限 - if type == 'admin': - users_query = User.objects.exclude(id=request.user.id) - elif type == 'secret': - users_query = User.objects.filter(role='admin').exclude(id=request.user.id) - elif type == 'leader': - users_query = User.objects.filter( - Q(role='admin') | - Q(role='leader', department=department) - ).exclude(id=request.user.id) - elif type == 'member': - users_query = User.objects.filter( - Q(role='admin') | - Q(department=department, role='leader') | - Q(department=department, group=group, role='member') - ).exclude(id=request.user.id) - else: # private - users_query = User.objects.none() - - if users_query.exists(): - permissions = [ - KBPermissionModel( - knowledge_base=knowledge_base, - user=user, - can_read=True, - can_edit=self._can_edit(type, user), - can_delete=self._can_delete(type, user), - granted_by=request.user, - status='active' - ) for user in users_query - ] - KBPermissionModel.objects.bulk_create(permissions) - logger.info(f"{type}类型权限创建完成: {len(permissions)}条记录") - - except Exception as e: - logger.error(f"权限创建失败: {str(e)}") - logger.error(traceback.format_exc()) - raise - - return Response({ - 'code': 200, - 'message': '知识库创建成功', - 'data': { - 'knowledge_base': serializer.data, - 'external_id': knowledge_base.external_id - } - }) - - except Exception as e: - logger.error(f"创建知识库失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'创建知识库失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def update(self, request, *args, **kwargs): - """更新知识库""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(instance, user, 'edit'): - return Response({ - "code": 403, - "message": "没有编辑权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - with transaction.atomic(): - # 执行本地更新 - serializer = self.get_serializer(instance, data=request.data, partial=True) - serializer.is_valid(raise_exception=True) - self.perform_update(serializer) - - # 更新外部知识库 - if instance.external_id: - try: - api_data = { - "name": serializer.validated_data.get('name', instance.name), - "desc": serializer.validated_data.get('desc', instance.desc), - "type": "0", # 保持与创建时一致 - "meta": {}, # 保持与创建时一致 - "documents": [] # 保持与创建时一致 - } - - response = requests.put( - f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}', - json=api_data, - headers={'Content-Type': 'application/json'}, - timeout=30 - ) - - if response.status_code != 200: - raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}") - - api_response = response.json() - if not api_response.get('code') == 200: - raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}") - - logger.info(f"外部知识库更新成功: {instance.external_id}") - - except requests.exceptions.Timeout: - raise ExternalAPIError("请求超时,请稍后重试") - except requests.exceptions.RequestException as e: - raise ExternalAPIError(f"API请求失败: {str(e)}") - except Exception as e: - raise ExternalAPIError(f"更新外部知识库失败: {str(e)}") - - return Response({ - "code": 200, - "message": "知识库更新成功", - "data": serializer.data - }) - - except Http404: - return Response({ - "code": 404, - "message": "知识库不存在", - "data": None - }, status=status.HTTP_404_NOT_FOUND) - except ExternalAPIError as e: - logger.error(f"更新外部知识库失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": str(e), - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - except Exception as e: - logger.error(f"更新知识库失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"更新知识库失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def destroy(self, request, *args, **kwargs): - """删除知识库""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(instance, user, 'delete'): - return Response({ - "code": 403, - "message": "没有删除权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - with transaction.atomic(): - # 删除外部知识库 - if instance.external_id: - try: - self._delete_external_dataset(instance.external_id) - logger.info(f"外部知识库删除成功: {instance.external_id}") - except ExternalAPIError as e: - logger.error(f"删除外部知识库失败: {str(e)}") - return Response({ - "code": 500, - "message": f"删除外部知识库失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # 删除本地知识库 - self.perform_destroy(instance) - logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}") - - return Response({ - "code": 200, - "message": "知识库删除成功", - "data": None - }) - - except Http404: - return Response({ - "code": 404, - "message": "知识库不存在", - "data": None - }, status=status.HTTP_404_NOT_FOUND) - except Exception as e: - logger.error(f"删除知识库失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"删除知识库失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['get']) - def permissions(self, request, pk=None): - """获取用户对特定知识库的权限""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法 - permissions_data = { - "can_read": self.check_knowledge_base_permission(instance, user, 'read'), - "can_edit": self.check_knowledge_base_permission(instance, user, 'edit'), - "can_delete": self.check_knowledge_base_permission(instance, user, 'delete') - } - - return Response({ - "code": 200, - "message": "获取权限信息成功", - "data": { - "knowledge_base_id": instance.id, - "knowledge_base_name": instance.name, - "permissions": permissions_data - } - }) - - except Http404: - return Response({ - "code": 404, - "message": "知识库不存在", - "data": None - }, status=status.HTTP_404_NOT_FOUND) - except Exception as e: - logger.error(f"获取权限信息失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"获取权限信息失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - @action(detail=False, methods=['get']) - def summary(self, request): - """获取所有可见知识库的概要信息(除了secret类型)""" - try: - user = request.user - - # 基础查询:排除secret类型的知识库 - queryset = KnowledgeBase.objects.exclude(type='secret') - - summaries = [] - for kb in queryset: - # 使用统一的权限判断方法 - permissions = { - 'can_read': self.check_knowledge_base_permission(kb, user, 'read'), - 'can_edit': self.check_knowledge_base_permission(kb, user, 'edit'), - 'can_delete': self.check_knowledge_base_permission(kb, user, 'delete') - } - - # 只返回概要信息 - summary = { - 'id': str(kb.id), - 'name': kb.name, - 'desc': kb.desc, - 'type': kb.type, - 'department': kb.department, - 'permissions': permissions - } - summaries.append(summary) - - return Response({ - 'code': 200, - 'message': '获取知识库概要信息成功', - 'data': summaries - }) - - except Exception as e: - return Response({ - 'code': 500, - 'message': f'获取知识库概要信息失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def retrieve(self, request, *args, **kwargs): - try: - # 获取知识库对象 - instance = self.get_object() - serializer = self.get_serializer(instance) - data = serializer.data - - # 获取用户 - user = request.user - - # 使用统一的权限判断方法 - data['permissions'] = { - 'can_read': self.check_knowledge_base_permission(instance, user, 'read'), - 'can_edit': self.check_knowledge_base_permission(instance, user, 'edit'), - 'can_delete': self.check_knowledge_base_permission(instance, user, 'delete') - } - - return Response({ - 'code': 200, - 'message': '获取知识库详情成功', - 'data': data - }) - except Exception as e: - logger.error(f"获取知识库详情失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取知识库详情失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def search(self, request): - """搜索知识库功能""" - try: - # 获取搜索关键字 - keyword = request.query_params.get('keyword', '') - if not keyword: - return Response({ - "code": 400, - "message": "搜索关键字不能为空", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 获取分页参数 - try: - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - except ValueError: - page = 1 - page_size = 10 - - # 构建搜索条件 - query = Q(name__icontains=keyword) | \ - Q(desc__icontains=keyword) | \ - Q(department__icontains=keyword) | \ - Q(group__icontains=keyword) - - # 排除 secret 类型的知识库 - queryset = KnowledgeBase.objects.filter(query).exclude(type='secret') - - # 获取用户 - user = request.user - - # 获取用户所有有效的知识库权限 - active_permissions = KBPermissionModel.objects.filter( - user=user, - status='active', - expires_at__gt=timezone.now() - ).select_related('knowledge_base') - - # 创建权限映射字典 - permission_map = { - str(perm.knowledge_base.id): { - 'can_read': perm.can_read, - 'can_edit': perm.can_edit, - 'can_delete': perm.can_delete - } - for perm in active_permissions - } - - # 计算总数量 - total = queryset.count() - - # 分页处理 - start = (page - 1) * page_size - end = start + page_size - paginated_queryset = queryset[start:end] - - # 序列化知识库数据 - serializer = self.get_serializer(paginated_queryset, many=True) - data = serializer.data - - # 处理每个知识库项的权限和返回内容 - result_items = [] - for item in data: - # 使用统一的权限判断方法 - kb_permissions = { - 'can_read': self.check_knowledge_base_permission( - type=item['type'], - user=user, - department=item.get('department'), - group=item.get('group'), - creator_id=item.get('user_id'), - knowledge_base_id=item['id'] - ), - 'can_edit': self.check_knowledge_base_permission( - type=item['type'], - user=user, - department=item.get('department'), - group=item.get('group'), - creator_id=item.get('user_id'), - knowledge_base_id=item['id'] - ), - 'can_delete': self.check_knowledge_base_permission( - type=item['type'], - user=user, - department=item.get('department'), - group=item.get('group'), - creator_id=item.get('user_id'), - knowledge_base_id=item['id'] - ) - } - - # 添加权限信息 - item['permissions'] = kb_permissions - - # 根据权限返回不同级别的信息 - if kb_permissions['can_read']: - result_items.append(item) - else: - # 无读取权限,只返回概要信息 - summary_info = { - 'id': item['id'], - 'name': item['name'], - 'type': item['type'], - 'department': item.get('department'), - 'permissions': kb_permissions - } - result_items.append(summary_info) - - # 高亮搜索关键字 - for item in result_items: - if 'name' in item and keyword.lower() in item['name'].lower(): - highlighted = item['name'].replace( - keyword, f'{keyword}' - ) - item['highlighted_name'] = highlighted - - # 确保desc不为None并且是字符串 - if 'desc' in item and item.get('desc') is not None: - desc_text = str(item['desc']) # 转换为字符串以确保安全 - if keyword.lower() in desc_text.lower(): - highlighted = desc_text.replace( - keyword, f'{keyword}' - ) - item['highlighted_desc'] = highlighted - - return Response({ - "code": 200, - "message": "搜索知识库成功", - "data": { - "total": total, - "page": page, - "page_size": page_size, - "keyword": keyword, - "items": result_items - } - }) - except Exception as e: - logger.error(f"搜索知识库失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"搜索知识库失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['post']) - def change_type(self, request, pk=None): - """修改知识库类型""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法检查编辑权限 - if not self.check_knowledge_base_permission(instance, user, 'edit'): - return Response({ - "code": 403, - "message": "没有修改权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - # 其余代码保持不变... - - # 获取新类型 - new_type = request.data.get('type') - if not new_type: - return Response({ - "code": 400, - "message": "新类型不能为空", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证类型是否有效 - valid_types = ['private', 'admin', 'secret', 'leader', 'member'] - if new_type not in valid_types: - return Response({ - "code": 400, - "message": f"无效的知识库类型,可选值: {', '.join(valid_types)}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 角色特定的类型限制 - if new_type == 'leader' and not user.role == 'admin': # 组长且不是管理员 - # 组长只能在private和member类型之间切换 - if new_type not in ['private', 'member']: - return Response({ - "code": 403, - "message": "组长只能将知识库设置为private或member类型", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - # 处理department和group字段 - department = request.data.get('department') - group = request.data.get('group') - - # 组长只能设置自己部门 - if new_type == 'leader' and not user.role == 'admin': - if department and department != user.department: - return Response({ - "code": 403, - "message": "组长只能为本部门设置知识库", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - # 如果未指定部门,强制设置为组长的部门 - department = user.department - - # 根据类型验证必填字段 - if new_type == 'leader': - if not department: - return Response({ - "code": 400, - "message": "组长级知识库必须指定部门", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - if new_type == 'member': - if not department: - return Response({ - "code": 400, - "message": "成员级知识库必须指定部门", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - if not group: - return Response({ - "code": 400, - "message": "成员级知识库必须指定组", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 如果是admin或secret类型,清除department和group - if new_type in ['admin', 'secret']: - department = None - group = None - - # 如果是private类型但未指定department和group,使用原值 - if new_type == 'private': - if department is None: - department = instance.department - if group is None: - group = instance.group - - # 更新知识库类型和相关字段 - instance.type = new_type - instance.department = department - instance.group = group - instance.save() - - return Response({ - "code": 200, - "message": f"知识库类型已更新为{new_type}", - "data": { - "id": instance.id, - "name": instance.name, - "type": instance.type, - "department": instance.department, - "group": instance.group - } - }) - - except Http404: - return Response({ - "code": 404, - "message": "知识库不存在", - "data": None - }, status=status.HTTP_404_NOT_FOUND) - except Exception as e: - logger.error(f"修改知识库类型失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"修改知识库类型失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['post']) - def upload_document(self, request, pk=None): - """上传文档到知识库""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(instance, user, 'edit'): - return Response({ - "code": 403, - "message": "没有编辑权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - # 获取文件 - file = request.FILES.get('file') - if not file: - return Response({ - "code": 400, - "message": "未找到上传文件", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 调用文档分割API - split_response = self._call_split_api(file) - if not split_response or split_response.get('code') != 200: - return Response({ - "code": 500, - "message": "文档分割失败", - "data": split_response - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # 处理分割后的文档 - documents_data = split_response.get('data', []) - saved_documents = [] - - for doc in documents_data: - # 准备文档数据 - doc_data = { - "name": doc.get('name'), - "paragraphs": [{ - "content": para.get('content', ''), - "title": para.get('title', ''), - "is_active": True, - "problem_list": [] - } for para in doc.get('content', [])] - } - - # 调用文档上传API - upload_response = self._call_upload_api(instance.external_id, doc_data) - if upload_response and upload_response.get('code') == 200: - # 保存文档记录到数据库 - doc_record = KnowledgeBaseDocument.objects.create( - knowledge_base=instance, - document_id=upload_response['data']['id'], - document_name=doc.get('name'), - external_id=upload_response['data']['id'] - ) - saved_documents.append(doc_record) - - return Response({ - "code": 200, - "message": "文档上传成功", - "data": { - "uploaded_count": len(saved_documents), - "documents": [{ - "id": str(doc.id), - "name": doc.document_name, - "external_id": doc.external_id - } for doc in saved_documents] - } - }) - - except Exception as e: - logger.error(f"文档上传失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"文档上传失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['delete']) - def delete_document(self, request, pk=None): - """从知识库中删除文档""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(instance, user, 'delete'): - return Response({ - "code": 403, - "message": "没有删除权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - # 查找文档记录 - try: - doc_record = KnowledgeBaseDocument.objects.get( - knowledge_base=instance, - document_id=document_id, - status='active' - ) - except KnowledgeBaseDocument.DoesNotExist: - return Response({ - "code": 404, - "message": "文档不存在", - "data": None - }, status=status.HTTP_404_NOT_FOUND) - - # 调用外部API删除文档 - delete_response = self._call_delete_document_api( - instance.external_id, - doc_record.external_id - ) - - if delete_response and delete_response.get('code') == 200: - # 更新文档状态为已删除 - doc_record.status = 'deleted' - doc_record.save() - - return Response({ - "code": 200, - "message": "文档删除成功", - "data": { - "document_id": document_id, - "name": doc_record.document_name - } - }) - else: - raise Exception("外部API删除文档失败") - - except Exception as e: - logger.error(f"删除文档失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"删除文档失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['get']) - def documents(self, request, pk=None): - """获取知识库下的所有文档信息""" - try: - instance = self.get_object() - user = request.user - - # 使用统一的权限检查方法 - if not self.check_knowledge_base_permission(instance, user, 'read'): - return Response({ - "code": 403, - "message": "没有查看权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - # 2. 获取分页参数 - try: - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - except ValueError: - page = 1 - page_size = 10 - - # 3. 查询文档记录 - documents = KnowledgeBaseDocument.objects.filter( - knowledge_base=instance, - status='active' - ).order_by('-create_time') - - # 4. 计算总数 - total = documents.count() - - # 5. 分页 - start = (page - 1) * page_size - end = start + page_size - documents = documents[start:end] - - # 6. 构建响应数据 - documents_data = [{ - "id": str(doc.id), - "document_id": doc.document_id, - "document_name": doc.document_name, - "external_id": doc.external_id, - "create_time": doc.create_time.strftime("%Y-%m-%d %H:%M:%S"), - "update_time": doc.update_time.strftime("%Y-%m-%d %H:%M:%S") - } for doc in documents] - - return Response({ - "code": 200, - "message": "获取文档列表成功", - "data": { - "total": total, - "page": page, - "page_size": page_size, - "items": documents_data - } - }) - - except Exception as e: - logger.error(f"获取文档列表失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"获取文档列表失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['get']) - def document_content(self, request, pk=None): - """获取文档内容""" - try: - instance = self.get_object() - user = request.user - document_id = request.query_params.get('document_id') - - if not document_id: - return Response({ - "code": 400, - "message": "缺少document_id参数", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 使用统一的权限判断方法 - if not self.check_knowledge_base_permission( - type=instance.type, - user=user, - department=instance.department, - group=instance.group, - creator_id=instance.user_id, - knowledge_base_id=instance.id - ): - return Response({ - "code": 403, - "message": "没有查看权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - # 2. 查找文档记录 - try: - doc_record = KnowledgeBaseDocument.objects.get( - knowledge_base=instance, - document_id=document_id, - status='active' - ) - except KnowledgeBaseDocument.DoesNotExist: - return Response({ - "code": 404, - "message": "文档不存在", - "data": None - }, status=status.HTTP_404_NOT_FOUND) - - # 3. 调用外部API获取文档内容 - try: - response = requests.get( - f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}/document/{doc_record.external_id}/paragraph', - headers={'Content-Type': 'application/json'}, - timeout=30 - ) - - if response.status_code != 200: - raise Exception(f"获取文档内容失败,状态码: {response.status_code}") - - content_data = response.json() - - # 4. 构建响应数据 - return Response({ - "code": 200, - "message": "获取文档内容成功", - "data": { - "document_info": { - "id": str(doc_record.id), - "name": doc_record.document_name, - "create_time": doc_record.create_time.strftime("%Y-%m-%d %H:%M:%S"), - "update_time": doc_record.update_time.strftime("%Y-%m-%d %H:%M:%S") - }, - "content": content_data.get('data', []) - } - }) - - except requests.exceptions.RequestException as e: - raise Exception(f"调用外部API失败: {str(e)}") - - except Exception as e: - logger.error(f"获取文档内容失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - "code": 500, - "message": f"获取文档内容失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def _call_split_api(self, file): - """调用文档分割API""" - try: - url = f'{settings.API_BASE_URL}/api/dataset/document/split' - files = {'file': file} - response = requests.post(url, files=files) - return response.json() - except Exception as e: - logger.error(f"调用分割API失败: {str(e)}") - return None - - def _call_upload_api(self, external_id, doc_data): - """调用文档上传API""" - try: - url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document' - response = requests.post(url, json=doc_data) - return response.json() - except Exception as e: - logger.error(f"调用上传API失败: {str(e)}") - return None - - def _call_delete_document_api(self, external_id, document_id): - """调用文档删除API""" - try: - url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document/{document_id}' - response = requests.delete(url) - return response.json() - except Exception as e: - logger.error(f"调用删除API失败: {str(e)}") - return None - - -class PermissionViewSet(viewsets.ModelViewSet): - serializer_class = PermissionSerializer - permission_classes = [IsAuthenticated] - - def can_manage_knowledge_base(self, user, knowledge_base): - """检查用户是否是知识库的创建者""" - return str(knowledge_base.user_id) == str(user.id) - - def get_queryset(self): - """ - 获取权限申请列表: - 1. applicant_id 是当前用户 (看到自己发起的申请) - 2. approver_id 是当前用户 (看到自己需要审批的申请) - """ - user_id = str(self.request.user.id) - - # 构建查询条件:申请人是自己 或 审批人是自己 - query = Q(applicant_id=user_id) | Q(approver_id=user_id) - - return Permission.objects.filter(query).select_related( - 'knowledge_base', - 'applicant', - 'approver' - ) - - def list(self, request, *args, **kwargs): - """获取权限申请列表,包含详细信息""" - try: - queryset = self.get_queryset() - user_id = str(request.user.id) - - # 获取分页参数 - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - - # 计算总数 - total = queryset.count() - - # 手动分页 - start = (page - 1) * page_size - end = start + page_size - permissions = queryset[start:end] - - # 构建响应数据 - data = [] - for permission in permissions: - # 检查当前用户是否是申请人或审批人 - if user_id not in [str(permission.applicant_id), str(permission.approver_id)]: - continue - - # 构建响应数据 - permission_data = { - 'id': str(permission.id), - 'knowledge_base': { - 'id': str(permission.knowledge_base.id), - 'name': permission.knowledge_base.name, - 'type': permission.knowledge_base.type, - }, - 'applicant': { - 'id': str(permission.applicant.id), - 'username': permission.applicant.username, - 'name': permission.applicant.name, - 'department': permission.applicant.department, - }, - 'approver': { - 'id': str(permission.approver.id) if permission.approver else '', - 'username': permission.approver.username if permission.approver else '', - 'name': permission.approver.name if permission.approver else '', - 'department': permission.approver.department if permission.approver else '', - }, - 'permissions': permission.permissions, - 'status': permission.status, - 'created_at': permission.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None, - 'response_message': permission.response_message or '', - # 添加角色标识,用于前端展示 - 'role': 'applicant' if str(permission.applicant_id) == user_id else 'approver' - } - - data.append(permission_data) - - return Response({ - 'code': 200, - 'message': '获取权限申请列表成功', - 'data': { - 'total': len(data), # 使用过滤后的实际数量 - 'page': page, - 'page_size': page_size, - 'results': data - } - }) - - except Exception as e: - logger.error(f"获取权限申请列表失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取权限申请列表失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def perform_create(self, serializer): - """创建权限申请并发送通知给知识库创建者""" - # 获取知识库 - # 获取知识库 - knowledge_base = serializer.validated_data['knowledge_base'] - - # 检查是否是申请访问自己的知识库 - if str(knowledge_base.user_id) == str(self.request.user.id): - raise ValidationError({ - "code": 400, - "message": "您是此知识库的创建者,无需申请权限", - "data": None - }) - # 获取知识库创建者作为审批者 - approver = User.objects.get(id=knowledge_base.user_id) - - # 验证权限请求 - requested_permissions = serializer.validated_data.get('permissions', {}) - expires_at = serializer.validated_data.get('expires_at') - - if not any([requested_permissions.get('can_read'), - requested_permissions.get('can_edit'), - requested_permissions.get('can_delete')]): - raise ValidationError("至少需要申请一种权限(读/改/删)") - - if not expires_at: - raise ValidationError("请指定权限到期时间") - - # 检查是否已有未过期的权限申请 - existing_request = Permission.objects.filter( - knowledge_base=knowledge_base, - applicant=self.request.user, - status='pending' - ).first() - - if existing_request: - raise ValidationError("您已有一个待处理的权限申请") - - # 检查是否已有有效的权限 - existing_permission = Permission.objects.filter( - knowledge_base=knowledge_base, - applicant=self.request.user, - status='approved', - expires_at__gt=timezone.now() - ).first() - - if existing_permission: - raise ValidationError("您已有此知识库的访问权限") - - # 保存权限申请,设置审批者 - permission = serializer.save( - applicant=self.request.user, - status='pending', - approver=approver # 创建时就设置审批者 - ) - - # 获取权限类型字符串 - permission_types = [] - if requested_permissions.get('can_read'): - permission_types.append('读取') - if requested_permissions.get('can_edit'): - permission_types.append('编辑') - if requested_permissions.get('can_delete'): - permission_types.append('删除') - permission_str = '、'.join(permission_types) - - # 发送通知给知识库创建者 - owner = User.objects.get(id=knowledge_base.user_id) - self.send_notification( - user=owner, - title="新的权限申请", - content=f"用户 {self.request.user.name} 申请了知识库 '{knowledge_base.name}' 的{permission_str}权限", - notification_type="permission_request", - related_object_id=permission.id - ) - - def send_notification(self, user, title, content, notification_type, related_object_id): - """发送通知""" - try: - notification = Notification.objects.create( - sender=self.request.user, - receiver=user, - title=title, - content=content, - type=notification_type, - related_resource=related_object_id, - ) - - # 通过WebSocket发送实时通知 - channel_layer = get_channel_layer() - async_to_sync(channel_layer.group_send)( - f"notification_user_{user.id}", - { - "type": "notification", - "data": { - "id": str(notification.id), - "title": notification.title, - "content": notification.content, - "type": notification.type, - "created_at": notification.created_at.isoformat(), - "sender": { - "id": str(notification.sender.id), - "name": notification.sender.name - } - } - } - ) - except Exception as e: - logger.error(f"发送通知时发生错误: {str(e)}") - - @action(detail=True, methods=['post']) - def approve(self, request, pk=None): - try: - # 获取权限申请记录 - permission = self.get_object() - - # 只检查是否是知识库创建者 - if not self.can_manage_knowledge_base(request.user, permission.knowledge_base): - logger.warning(f"用户 {request.user.username} 尝试审批知识库 {permission.knowledge_base.name} 的权限申请,但不是创建者") - return Response({ - 'code': 403, - 'message': '只有知识库创建者可以审批此申请', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 获取审批意见 - response_message = request.data.get('response_message', '') - - with transaction.atomic(): - # 更新权限申请状态 - permission.status = 'approved' - permission.approver = request.user - permission.response_message = response_message - permission.save() - - # 检查是否已存在权限记录 - kb_permission = KBPermissionModel.objects.filter( - knowledge_base=permission.knowledge_base, - user=permission.applicant - ).first() - - if kb_permission: - # 更新现有权限 - kb_permission.can_read = permission.permissions.get('can_read', False) - kb_permission.can_edit = permission.permissions.get('can_edit', False) - kb_permission.can_delete = permission.permissions.get('can_delete', False) - kb_permission.granted_by = request.user - kb_permission.status = 'active' - kb_permission.expires_at = permission.expires_at - kb_permission.save() - logger.info(f"更新知识库权限记录: {kb_permission.id}") - else: - # 创建新的权限记录 - kb_permission = KBPermissionModel.objects.create( - knowledge_base=permission.knowledge_base, - user=permission.applicant, - can_read=permission.permissions.get('can_read', False), - can_edit=permission.permissions.get('can_edit', False), - can_delete=permission.permissions.get('can_delete', False), - granted_by=request.user, - status='active', - expires_at=permission.expires_at - ) - logger.info(f"创建新的知识库权限记录: {kb_permission.id}") - - # 发送通知给申请人 - self.send_notification( - user=permission.applicant, - title="权限申请已通过", - content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已通过", - notification_type="permission_approved", - related_object_id=permission.id - ) - - return Response({ - 'code': 200, - 'message': '权限申请已批准', - 'data': None - }) - - except Permission.DoesNotExist: - return Response({ - 'code': 404, - 'message': '权限申请不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - except Exception as e: - logger.error(f"处理权限申请失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'处理权限申请失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=True, methods=['post']) - def reject(self, request, pk=None): - """拒绝权限申请""" - permission = self.get_object() - - # 检查是否是知识库创建者 - if str(permission.knowledge_base.user_id) != str(request.user.id): - return Response({ - 'code': 403, - 'message': '只有知识库创建者可以审批此申请', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 检查申请是否已被处理 - if permission.status != 'pending': - return Response({ - 'code': 400, - 'message': '该申请已被处理', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证拒绝原因 - response_message = request.data.get('response_message') - if not response_message: - return Response({ - 'code': 400, - 'message': '请填写拒绝原因', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 更新权限状态 - permission.status = 'rejected' - permission.approver = request.user - permission.response_message = response_message - permission.save() - - # 发送通知给申请人 - self.send_notification( - user=permission.applicant, - title="权限申请已拒绝", - content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已被拒绝\n" - f"拒绝原因:{response_message}", - notification_type="permission_rejected", - related_object_id=permission.id - ) - - return Response({ - 'code': 200, - 'message': '权限申请已拒绝', - 'data': PermissionSerializer(permission).data - }) - - @action(detail=True, methods=['post']) - def extend(self, request, pk=None): - """延长权限有效期""" - instance = self.get_object() - user = request.user - - # 检查是否有权限延长 - if not self.check_extend_permission(instance, user): - return Response({ - "code": 403, - "message": "您没有权限延长此权限", - "data": None - }, status=status.HTTP_403_FORBIDDEN) - - new_expires_at = request.data.get('expires_at') - if not new_expires_at: - return Response({ - "code": 400, - "message": "请设置新的过期时间", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - try: - with transaction.atomic(): - # 更新权限申请表的过期时间 - instance.expires_at = new_expires_at - instance.save() - - # 同步更新知识库权限表的过期时间 - kb_permission = KBPermissionModel.objects.get( - knowledge_base=instance.knowledge_base, - user=instance.applicant - ) - kb_permission.expires_at = new_expires_at - kb_permission.save() - - return Response({ - "code": 200, - "message": "权限有效期延长成功", - "data": PermissionSerializer(instance).data - }) - except Exception as e: - return Response({ - "code": 500, - "message": f"延长权限有效期失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - def check_extend_permission(self, permission, user): - """检查是否有权限延长权限有效期""" - knowledge_base = permission.knowledge_base - - # 私人知识库只有拥有者能延长 - if knowledge_base.type == 'private': - return knowledge_base.owner == user - - # 组长知识库只有管理员能延长 - if knowledge_base.type == 'leader': - return user.role == 'admin' - - # 组员知识库可以由管理员或本部门组长延长 - if knowledge_base.type == 'member': - return ( - user.role == 'admin' or - (user.role == 'leader' and user.department == knowledge_base.department) - ) - - return False - - @action(detail=False, methods=['get']) - def user_permissions(self, request): - """获取指定用户的所有知识库权限""" - try: - # 获取用户名参数 - username = request.query_params.get('username') - if not username: - return Response({ - 'code': 400, - 'message': '请提供用户名', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 获取用户 - try: - target_user = User.objects.get(username=username) - except User.DoesNotExist: - return Response({ - 'code': 404, - 'message': f'用户 {username} 不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - # 获取该用户的所有权限记录 - permissions = KBPermissionModel.objects.filter( - user=target_user, - status='active' - ).select_related('knowledge_base', 'granted_by') - - # 构建响应数据 - permissions_data = [] - for perm in permissions: - perm_data = { - 'id': str(perm.id), - 'knowledge_base': { - 'id': str(perm.knowledge_base.id), - 'name': perm.knowledge_base.name, - 'type': perm.knowledge_base.type, - 'department': perm.knowledge_base.department, - 'group': perm.knowledge_base.group - }, - 'permissions': { - 'can_read': perm.can_read, - 'can_edit': perm.can_edit, - 'can_delete': perm.can_delete - }, - 'granted_by': { - 'id': str(perm.granted_by.id) if perm.granted_by else None, - 'username': perm.granted_by.username if perm.granted_by else None, - 'name': perm.granted_by.name if perm.granted_by else None - }, - 'created_at': perm.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None, - 'status': perm.status - } - permissions_data.append(perm_data) - - return Response({ - 'code': 200, - 'message': '获取用户权限成功', - 'data': { - 'user': { - 'id': str(target_user.id), - 'username': target_user.username, - 'name': target_user.name, - 'department': target_user.department, - 'role': target_user.role - }, - 'permissions': permissions_data - } - }) - - except Exception as e: - logger.error(f"获取用户权限失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取用户权限失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get']) - def all_permissions(self, request): - """管理员获取所有用户的知识库权限(不包括私有知识库)""" - try: - # 检查是否是管理员 - if request.user.role != 'admin': - return Response({ - 'code': 403, - 'message': '只有管理员可以查看所有权限', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 获取查询参数 - page = int(request.query_params.get('page', 1)) - page_size = int(request.query_params.get('page_size', 10)) - status_filter = request.query_params.get('status') - department = request.query_params.get('department') - kb_type = request.query_params.get('kb_type') - - # 构建基础查询 - queryset = KBPermissionModel.objects.filter( - ~Q(knowledge_base__type='private') - ).select_related( - 'user', - 'knowledge_base', - 'granted_by' - ) - - # 应用过滤条件 - if status_filter == 'active': - queryset = queryset.filter( - Q(expires_at__gt=timezone.now()) | Q(expires_at__isnull=True), - status='active' - ) - elif status_filter == 'expired': - queryset = queryset.filter( - Q(expires_at__lte=timezone.now()) | Q(status='inactive') - ) - - if department: - queryset = queryset.filter(user__department=department) - - if kb_type: - queryset = queryset.filter(knowledge_base__type=kb_type) - - # 按用户分组处理数据 - user_permissions = {} - for perm in queryset: - user_id = str(perm.user.id) - if user_id not in user_permissions: - user_permissions[user_id] = { - 'user_info': { - 'id': user_id, - 'username': perm.user.username, - 'name': getattr(perm.user, 'name', perm.user.username), - 'department': getattr(perm.user, 'department', None), - 'role': getattr(perm.user, 'role', None) - }, - 'permissions': [], - 'stats': { - 'total': 0, - 'by_type': { - 'admin': 0, - 'secret': 0, - 'leader': 0, - 'member': 0 - }, - 'by_permission': { - 'read_only': 0, - 'read_write': 0, - 'full_access': 0 - } - } - } - - # 添加权限信息 - perm_data = { - 'id': str(perm.id), - 'knowledge_base': { - 'id': str(perm.knowledge_base.id), - 'name': perm.knowledge_base.name, - 'type': perm.knowledge_base.type, - 'department': perm.knowledge_base.department, - 'group': perm.knowledge_base.group, - 'creator': { - 'id': str(perm.knowledge_base.user_id), - 'name': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'name', None), - 'username': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'username', None) - } - }, - 'permissions': { - 'can_read': perm.can_read, - 'can_edit': perm.can_edit, - 'can_delete': perm.can_delete - }, - 'granted_by': { - 'id': str(perm.granted_by.id) if perm.granted_by else None, - 'username': perm.granted_by.username if perm.granted_by else None, - 'name': getattr(perm.granted_by, 'name', None) if perm.granted_by else None - }, - 'granted_at': perm.granted_at.strftime('%Y-%m-%d %H:%M:%S'), - 'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None, - 'status': perm.status - } - - user_permissions[user_id]['permissions'].append(perm_data) - - # 更新统计信息 - stats = user_permissions[user_id]['stats'] - stats['total'] += 1 - stats['by_type'][perm.knowledge_base.type] += 1 - - # 统计权限级别 - if perm.can_delete: - stats['by_permission']['full_access'] += 1 - elif perm.can_edit: - stats['by_permission']['read_write'] += 1 - elif perm.can_read: - stats['by_permission']['read_only'] += 1 - - # 转换为列表并分页 - users_list = list(user_permissions.values()) - total = len(users_list) - start = (page - 1) * page_size - end = start + page_size - paginated_users = users_list[start:end] - - return Response({ - 'code': 200, - 'message': '获取权限列表成功', - 'data': { - 'total': total, - 'page': page, - 'page_size': page_size, - 'results': paginated_users - } - }) - - except Exception as e: - logger.error(f"获取所有权限失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'获取所有权限失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['post']) - def update_permission(self, request): - """管理员更新用户的知识库权限""" - try: - # 检查是否是管理员 - if request.user.role != 'admin': - return Response({ - 'code': 403, - 'message': '只有管理员可以直接修改权限', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 验证必要参数 - user_id = request.data.get('user_id') - knowledge_base_id = request.data.get('knowledge_base_id') - permissions = request.data.get('permissions') - expires_at_str = request.data.get('expires_at') - - if not all([user_id, knowledge_base_id, permissions]): - return Response({ - 'code': 400, - 'message': '缺少必要参数', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证权限参数格式 - required_permission_fields = ['can_read', 'can_edit', 'can_delete'] - if not all(field in permissions for field in required_permission_fields): - return Response({ - 'code': 400, - 'message': '权限参数格式错误,必须包含 can_read、can_edit、can_delete', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 获取用户和知识库 - try: - user = User.objects.get(id=user_id) - knowledge_base = KnowledgeBase.objects.get(id=knowledge_base_id) - except User.DoesNotExist: - return Response({ - 'code': 404, - 'message': f'用户ID {user_id} 不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - except KnowledgeBase.DoesNotExist: - return Response({ - 'code': 404, - 'message': f'知识库ID {knowledge_base_id} 不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) - - # 检查知识库类型和用户角色的匹配 - if knowledge_base.type == 'private' and str(knowledge_base.user_id) != str(user.id): - return Response({ - 'code': 403, - 'message': '不能修改其他用户的私有知识库权限', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) - - # 处理过期时间 - expires_at = None - if expires_at_str: - try: - # 将字符串转换为datetime对象 - expires_at = timezone.datetime.strptime( - expires_at_str, - '%Y-%m-%dT%H:%M:%SZ' - ) - # 确保时区感知 - expires_at = timezone.make_aware(expires_at) - - # 检查是否早于当前时间 - if expires_at <= timezone.now(): - return Response({ - 'code': 400, - 'message': '过期时间不能早于或等于当前时间', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - except ValueError: - return Response({ - 'code': 400, - 'message': '过期时间格式错误,应为 ISO 格式 (YYYY-MM-DDThh:mm:ssZ)', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 根据用户角色限制权限 - if user.role == 'member' and permissions.get('can_delete'): - return Response({ - 'code': 400, - 'message': '普通成员不能获得删除权限', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) - - # 更新或创建权限记录 - try: - with transaction.atomic(): - permission, created = KBPermissionModel.objects.update_or_create( - user=user, - knowledge_base=knowledge_base, - defaults={ - 'can_read': permissions.get('can_read', False), - 'can_edit': permissions.get('can_edit', False), - 'can_delete': permissions.get('can_delete', False), - 'granted_by': request.user, - 'status': 'active', - 'expires_at': expires_at - } - ) - - # 发送通知给用户 - self.send_notification( - user=user, - title="知识库权限更新", - content=f"管理员已{created and '授予' or '更新'}您对知识库 '{knowledge_base.name}' 的权限", - notification_type="permission_updated", - related_object_id=permission.id - ) - except IntegrityError as e: - return Response({ - 'code': 500, - 'message': f'数据库操作失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - return Response({ - 'code': 200, - 'message': f"{'创建' if created else '更新'}权限成功", - 'data': { - 'id': str(permission.id), - 'user': { - 'id': str(user.id), - 'username': user.username, - 'name': user.name, - 'department': user.department, - 'role': user.role - }, - 'knowledge_base': { - 'id': str(knowledge_base.id), - 'name': knowledge_base.name, - 'type': knowledge_base.type, - 'department': knowledge_base.department, - 'group': knowledge_base.group - }, - 'permissions': { - 'can_read': permission.can_read, - 'can_edit': permission.can_edit, - 'can_delete': permission.can_delete - }, - 'granted_by': { - 'id': str(request.user.id), - 'username': request.user.username, - 'name': request.user.name - }, - 'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None, - 'created': created - } - }) - - except Exception as e: - logger.error(f"更新权限失败: {str(e)}") - logger.error(traceback.format_exc()) - return Response({ - 'code': 500, - 'message': f'更新权限失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - -class NotificationViewSet(viewsets.ModelViewSet): - """通知视图集""" - queryset = Notification.objects.all() - serializer_class = NotificationSerializer - permission_classes = [IsAuthenticated] - - def get_queryset(self): - """只返回用户自己的通知""" - return Notification.objects.filter(receiver=self.request.user) - - @action(detail=True, methods=['post']) - def mark_as_read(self, request, pk=None): - """标记通知为已读""" - notification = self.get_object() - notification.is_read = True - notification.save() - return Response({'status': 'marked as read'}) - - @action(detail=False, methods=['post']) - def mark_all_as_read(self, request): - """标记所有通知为已读""" - self.get_queryset().update(is_read=True) - return Response({'status': 'all marked as read'}) - - @action(detail=False, methods=['get']) - def unread_count(self, request): - """获取未读通知数量""" - count = self.get_queryset().filter(is_read=False).count() - return Response({'unread_count': count}) - - @action(detail=False, methods=['get']) - def latest(self, request): - """获取最新通知""" - notifications = self.get_queryset().filter( - is_read=False - ).order_by('-created_at')[:5] - serializer = self.get_serializer(notifications, many=True) - return Response(serializer.data) - - def perform_create(self, serializer): - """创建通知时自动设置发送者""" - serializer.save(sender=self.request.user) - - -@method_decorator(csrf_exempt, name='dispatch') -class LoginView(APIView): - """用户登录视图""" - authentication_classes = [] # 清空认证类 - permission_classes = [AllowAny] - - def post(self, request): - try: - username = request.data.get('username') - password = request.data.get('password') - - # 参数验证 - if not username or not password: - return Response({ - "code": 400, - "message": "请提供用户名和密码", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证用户 - user = authenticate(request, username=username, password=password) - - if user is not None: - # 获取或创建token - token, _ = Token.objects.get_or_create(user=user) - - return Response({ - "code": 200, - "message": "登录成功", - "data": { - "id": str(user.id), - "username": user.username, - "email": user.email, - "role": user.role, - "department": user.department, - "name": user.name, - "group": user.group, - "token": token.key - } - }) - else: - return Response({ - "code": 401, - "message": "用户名或密码错误", - "data": None - }, status=status.HTTP_401_UNAUTHORIZED) - - except Exception as e: - import traceback - logger.error(f"登录失败: {str(e)}") - logger.error(f"错误类型: {type(e)}") - logger.error(f"错误堆栈: {traceback.format_exc()}") - - return Response({ - "code": 500, - "message": "登录失败,请稍后重试", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@method_decorator(csrf_exempt, name='dispatch') -class RegisterView(APIView): - """用户注册视图""" - permission_classes = [AllowAny] - - def post(self, request): - try: - data = request.data - - # 检查必填字段 - required_fields = ['username', 'password', 'email', 'role', 'department', 'name'] - for field in required_fields: - if not data.get(field): - return Response({ - "code": 400, - "message": f"缺少必填字段: {field}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证角色 - valid_roles = ['admin', 'leader', 'member'] - roles_str = ', '.join(valid_roles) # 先构造角色字符串 - if data['role'] not in valid_roles: - return Response({ - "code": 400, - "message": f"无效的角色,必须是: {roles_str}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证部门是否存在 - if data['department'] not in settings.DEPARTMENT_GROUPS: - return Response({ - "code": 400, - "message": f"无效的部门,可选部门: {', '.join(settings.DEPARTMENT_GROUPS.keys())}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 如果是组员,验证小组 - if data['role'] == 'member': - if not data.get('group'): - return Response({ - "code": 400, - "message": "组员必须指定所属小组", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证小组是否存在且属于指定部门 - valid_groups = settings.DEPARTMENT_GROUPS.get(data['department'], []) - if data['group'] not in valid_groups: - return Response({ - "code": 400, - "message": f"无效的小组,{data['department']}的可选小组: {', '.join(valid_groups)}", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 检查用户名是否已存在 - if User.objects.filter(username=data['username']).exists(): - return Response({ - "code": 400, - "message": "用户名已存在", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 检查邮箱是否已存在 - if User.objects.filter(email=data['email']).exists(): - return Response({ - "code": 400, - "message": "邮箱已被注册", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证密码强度 - if len(data['password']) < 8: - return Response({ - "code": 400, - "message": "密码长度必须至少为8位", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证邮箱格式 - try: - validate_email(data['email']) - except ValidationError: - return Response({ - "code": 400, - "message": "邮箱格式不正确", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 创建用户 - user = User.objects.create_user( - username=data['username'], - email=data['email'], - password=data['password'], - role=data['role'], - department=data['department'], - name=data['name'], - group=data.get('group') if data['role'] == 'member' else None, - is_staff=False, - is_superuser=False - ) - - # 生成认证令牌 - token, _ = Token.objects.get_or_create(user=user) - - return Response({ - "code": 200, - "message": "注册成功", - "data": { - "id": user.id, - "username": user.username, - "email": user.email, - "role": user.role, - "department": user.department, - "name": user.name, - "group": user.group, - "token": token.key, - "created_at": user.date_joined.strftime('%Y-%m-%d %H:%M:%S') - } - }, status=status.HTTP_201_CREATED) - - except Exception as e: - print(f"注册失败: {str(e)}") - print(f"错误类型: {type(e)}") - print(f"错误堆栈: {traceback.format_exc()}") - return Response({ - "code": 500, - "message": f"注册失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@method_decorator(csrf_exempt, name='dispatch') -class LogoutView(APIView): - """用户登出视图""" - permission_classes = [IsAuthenticated] - - def post(self, request): - try: - # 删除用户的token - request.user.auth_token.delete() - # 执行django的登出 - logout(request) - - return Response({ - "code": 200, - "message": "登出成功", - "data": None - }) - except Exception as e: - return Response({ - "code": 500, - "message": f"登出失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@api_view(['GET', 'PUT']) -@permission_classes([IsAuthenticated]) -def user_profile(request): - """获取或更新用户信息""" - if request.method == 'GET': - data = { - 'id': request.user.id, - 'username': request.user.username, - 'email': request.user.email, - 'role': request.user.role, - 'department': request.user.department, - 'phone': request.user.phone, - 'date_joined': request.user.date_joined - } - return Response(data) - - elif request.method == 'PUT': - user = request.user - # 只允许更新特定字段 - allowed_fields = ['email', 'phone', 'department'] - for field in allowed_fields: - if field in request.data: - setattr(user, field, request.data[field]) - user.save() - return Response({'message': '用户信息更新成功'}) - -@csrf_exempt -@api_view(['POST']) -@permission_classes([IsAuthenticated]) -def change_password(request): - """修改密码""" - try: - old_password = request.data.get('old_password') - new_password = request.data.get('new_password') - - # 验证参数 - if not old_password or not new_password: - return Response({ - "code": 400, - "message": "请提供旧密码和新密码", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证旧密码 - user = request.user - if not user.check_password(old_password): - return Response({ - "code": 400, - "message": "旧密码错误", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证新密码长度 - if len(new_password) < 8: - return Response({ - "code": 400, - "message": "新密码长度必须至少为8位", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - # 修改密码 - user.set_password(new_password) - user.save() - - # 更新token - user.auth_token.delete() - token, _ = Token.objects.get_or_create(user=user) - - return Response({ - "code": 200, - "message": "密码修改成功", - "data": { - "token": token.key - } - }) - - except Exception as e: - return Response({ - "code": 500, - "message": f"密码修改失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@api_view(['POST']) -@permission_classes([AllowAny]) -def user_register(request): - """用户注册""" - try: - data = request.data - - # 检查必填字段 - required_fields = ['username', 'password', 'email', 'role', 'department', 'name'] - for field in required_fields: - if not data.get(field): - return Response({ - 'error': f'缺少必填字段: {field}' - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证角色 - valid_roles = ['admin', 'leader', 'member'] - if data['role'] not in valid_roles: - return Response({ - 'error': f'无效的角色,必须是: {", ".join(valid_roles)}' - }, status=status.HTTP_400_BAD_REQUEST) - - # 如果是组员,必须指定小组 - if data['role'] == 'member' and not data.get('group'): - return Response({ - 'error': '组员必须指定所属小组' - }, status=status.HTTP_400_BAD_REQUEST) - - # 检查用户名是否已存在 - if User.objects.filter(username=data['username']).exists(): - return Response({ - 'error': '用户名已存在' - }, status=status.HTTP_400_BAD_REQUEST) - - # 检查邮箱是否已存在 - if User.objects.filter(email=data['email']).exists(): - return Response({ - 'error': '邮箱已被注册' - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证密码强度 - if len(data['password']) < 8: - return Response({ - 'error': '密码长度必须至少为8位' - }, status=status.HTTP_400_BAD_REQUEST) - - # 验证邮箱格式 - try: - validate_email(data['email']) - except ValidationError: - return Response({ - 'error': '邮箱格式不正确' - }, status=status.HTTP_400_BAD_REQUEST) - - # 创建用户 - user = User.objects.create_user( - username=data['username'], - email=data['email'], - password=data['password'], - role=data['role'], - department=data['department'], - name=data['name'], - group=data.get('group') if data['role'] == 'member' else None, - is_staff=False, - is_superuser=False - ) - - # 生成认证令牌 - token, _ = Token.objects.get_or_create(user=user) - - return Response({ - 'message': '注册成功', - 'data': { - 'id': user.id, - 'username': user.username, - 'email': user.email, - 'role': user.role, - 'department': user.department, - 'name': user.name, - 'group': user.group, - 'token': token.key, - 'created_at': user.date_joined.strftime('%Y-%m-%d %H:%M:%S') - } - }, status=status.HTTP_201_CREATED) - - except Exception as e: - print(f"注册失败: {str(e)}") - print(f"错误类型: {type(e)}") - print(f"错误堆栈: {traceback.format_exc()}") - return Response({ - 'error': f'注册失败: {str(e)}', - 'data': None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@csrf_exempt -@api_view(['POST']) -@permission_classes([IsAuthenticated]) -def verify_token(request): - """验证令牌有效性""" - try: - return Response({ - "code": 200, - "message": "令牌有效", - "data": { - "is_valid": True, - "user": { - "id": request.user.id, - "username": request.user.username, - "email": request.user.email, - "role": request.user.role, - "department": request.user.department, - "name": request.user.name, - "group": request.user.group - } - } - }) - except Exception as e: - return Response({ - "code": 500, - "message": f"验证失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@api_view(['GET']) -@permission_classes([IsAuthenticated]) -def user_list(request): - """获取用户列表""" - user = request.user - if user.role == 'admin': - users = User.objects.all() - elif user.role == 'leader': - users = User.objects.filter(department=user.department) - else: - users = User.objects.filter(id=user.id) - - data = [{ - 'id': u.id, - 'username': u.username, - 'email': u.email, - 'role': u.role, - 'department': u.department, - 'is_active': u.is_active, - 'date_joined': u.date_joined - } for u in users] - - return Response(data) - -@api_view(['GET']) -@permission_classes([IsAuthenticated]) -def user_detail(request, pk): - """获取用户详情""" - try: - # 尝试转换为 UUID - if not isinstance(pk, uuid.UUID): - try: - pk = uuid.UUID(pk) - except ValueError: - return Response({ - "code": 400, - "message": "无效的用户ID格式", - "data": None - }, status=status.HTTP_400_BAD_REQUEST) - - user = get_object_or_404(User, pk=pk) - - return Response({ - "code": 200, - "message": "获取用户信息成功", - "data": { - "id": str(user.id), - "username": user.username, - "email": user.email, - "name": user.name, - "role": user.role, - "department": user.department, - "group": user.group - } - }) - except Exception as e: - return Response({ - "code": 500, - "message": f"获取用户信息失败: {str(e)}", - "data": None - }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - -@api_view(['PUT']) -@permission_classes([IsAdminUser]) -def user_update(request, pk): - """更新用户信息""" - try: - user = User.objects.get(pk=pk) - # 只允许更新特定字段 - allowed_fields = ['email', 'role', 'department', 'is_active', 'phone'] - for field in allowed_fields: - if field in request.data: - setattr(user, field, request.data[field]) - user.save() - return Response({'message': '用户信息更新成功'}) - except User.DoesNotExist: - return Response({'message': '用户不存在'}, status=404) - -@api_view(['DELETE']) -@permission_classes([IsAdminUser]) -def user_delete(request, pk): - """删除用户""" - try: - user = User.objects.get(pk=pk) - user.delete() - return Response({'message': '用户删除成功'}) - except User.DoesNotExist: - return Response({'message': '用户不存在'}, status=404) - +from rest_framework import viewsets, status +from rest_framework.decorators import action, api_view, permission_classes +from rest_framework.permissions import IsAuthenticated, AllowAny, IsAdminUser +from rest_framework.response import Response +from rest_framework.exceptions import APIException, PermissionDenied, ValidationError, NotFound +from rest_framework.authentication import TokenAuthentication +from django.utils import timezone +from django.db import connection +from django.db.models import Q, Max, Count, F +from datetime import timedelta, datetime +import mysql.connector +from django.contrib.auth import get_user_model, authenticate, login, logout +from channels.layers import get_channel_layer +from asgiref.sync import async_to_sync +from rest_framework.authtoken.models import Token +import requests +import json +from django.db import transaction +from django.core.exceptions import ObjectDoesNotExist +import sys +import random +import string +import time +import logging +import os +from rest_framework.test import APIRequestFactory +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.fields import GenericForeignKey +from django.http import Http404, HttpResponse, StreamingHttpResponse +from django.db import IntegrityError +from channels.exceptions import ChannelFull +from django.conf import settings +from django.shortcuts import get_object_or_404 +from django.db import models +from rest_framework.views import APIView +from django.core.validators import validate_email +# from django.core.exceptions import ValidationError +from django.views.decorators.csrf import csrf_exempt +from django.utils.decorators import method_decorator +import uuid +from rest_framework import serializers +import traceback +import requests +import json +import threading + + + +# 添加模型导入 +from .models import ( + User, + Data, # 替换原来的 AdminData, LeaderData, MemberData + Permission, # 替换原来的 DataPermission, TablePermission + ChatHistory, + KnowledgeBase, + Notification, + KnowledgeBasePermission as KBPermissionModel, + KnowledgeBaseDocument +) +from .serializers import ( + UserSerializer, + DataSerializer, # 需要更新 + PermissionSerializer, # 需要更新 + ChatHistorySerializer, + KnowledgeBaseSerializer, + KnowledgePermissionSerializer, # 添加这个导入 + NotificationSerializer +) +# 导入自定义权限类 +from .permissions import ResourceCRUDPermission, PermissionRequestPermission, DataPermission, KnowledgeBasePermission as KBPermissionClass +from .exceptions import ExternalAPIError + +# 获取正确的用户模型 +User = get_user_model() + +logger = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[ + logging.StreamHandler() # 输出到控制台 + ] +) + + +class KnowledgeBasePermissionMixin: + """知识库权限管理混入类""" + + def _can_read(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None): + """检查读取权限""" + try: + # 1. 检查显式权限表 + if knowledge_base_id: + permission = KBPermissionModel.objects.filter( + knowledge_base_id=knowledge_base_id, + user=user, + can_read=True, + status='active' + ).first() + if permission: + return True + + # 2. 检查角色权限 + # 私有知识库 + if type == 'private': + return str(user.id) == str(creator_id) + + # 成员级知识库 + if type == 'member': + return user.department == department + + # 部门级知识库 + if type == 'leader': + return (user.department == department and + user.role in ['leader', 'admin']) + + # 管理级知识库 + if type == 'admin': + return True # 所有用户都可以读取 + + return False + + except Exception as e: + logger.error(f"检查读取权限时出错: {str(e)}") + return False + + def _can_edit(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None): + """检查编辑权限""" + try: + # 1. 检查显式权限表 + if knowledge_base_id: + permission = KBPermissionModel.objects.filter( + knowledge_base_id=knowledge_base_id, + user=user, + can_edit=True, + status='active' + ).first() + if permission: + return True + + # 2. 检查角色权限 + # 私有知识库 + if type == 'private': + return str(user.id) == str(creator_id) + + # 成员级知识库 + if type == 'member': + return (user.department == department and + user.role in ['leader', 'admin']) + + # 部门级知识库 + if type == 'leader': + return (user.department == department and + user.role in ['leader', 'admin']) + + # 管理级知识库 + if type == 'admin': + return True # 所有用户都可以编辑 + + return False + + except Exception as e: + logger.error(f"检查编辑权限时出错: {str(e)}") + return False + + def _can_delete(self, type, user, department=None, group=None, creator_id=None, knowledge_base_id=None): + """检查删除权限""" + try: + # 1. 检查显式权限表 + if knowledge_base_id: + permission = KBPermissionModel.objects.filter( + knowledge_base_id=knowledge_base_id, + user=user, + can_delete=True, + status='active' + ).first() + if permission: + return True + + # 2. 检查角色权限 + # 私有知识库 + if type == 'private': + return str(user.id) == str(creator_id) + + # 成员级知识库 + if type == 'member': + return (user.department == department and + user.role == 'admin') + + # 部门级知识库 + if type == 'leader': + return (user.department == department and + user.role == 'admin') + + # 管理级知识库 + if type == 'admin': + return True # 所有用户都可以删除 + + return False + + except Exception as e: + logger.error(f"检查删除权限时出错: {str(e)}") + return False + + def check_knowledge_base_permission(self, knowledge_base, user, required_permission='read'): + """统一的知识库权限检查方法""" + if not knowledge_base: + return False + + # 1. 首先检查显式权限表 + try: + # 检查是否存在显式权限记录 + permission = KBPermissionModel.objects.filter( + knowledge_base_id=knowledge_base.id, + user=user, + status='active' + ).first() + + if permission: + # 根据请求的权限类型返回对应的权限值 + if required_permission == 'read': + return permission.can_read + elif required_permission == 'edit': + return permission.can_edit + elif required_permission == 'delete': + return permission.can_delete + except Exception as e: + logger.error(f"检查显式权限时出错: {str(e)}") + + # 2. 如果没有显式权限记录或出错,回退到隐式权限逻辑 + permission_method = { + 'read': self._can_read, + 'edit': self._can_edit, + 'delete': self._can_delete + }.get(required_permission) + + if not permission_method: + return False + + return permission_method( + type=knowledge_base.type, + user=user, + department=knowledge_base.department, + group=knowledge_base.group, + creator_id=knowledge_base.user_id, + knowledge_base_id=knowledge_base.id + ) + + + +class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): + permission_classes = [IsAuthenticated] + queryset = ChatHistory.objects.all() + + def get_queryset(self): + """确保用户只能看到自己的未删除的聊天记录""" + return ChatHistory.objects.filter( + user=self.request.user, + is_deleted=False + ) + + def list(self, request): + """获取对话列表概览""" + try: + # 获取查询参数 + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 10)) + + # 获取所有对话的概览 + latest_chats = self.get_queryset().values( + 'conversation_id' + ).annotate( + latest_id=Max('id'), + message_count=Count('id'), + last_message=Max('created_at') + ).order_by('-last_message') + + # 计算分页 + total = latest_chats.count() + start = (page - 1) * page_size + end = start + page_size + chats = latest_chats[start:end] + + results = [] + for chat in chats: + # 获取最新消息记录 + latest_record = ChatHistory.objects.get(id=chat['latest_id']) + + # 从metadata中获取完整的知识库信息 + dataset_info = [] + if latest_record.metadata: + dataset_id_list = latest_record.metadata.get('dataset_id_list', []) + dataset_names = latest_record.metadata.get('dataset_names', []) + + # 如果有知识库ID列表 + if dataset_id_list: + # 如果同时有名称列表且长度匹配 + if dataset_names and len(dataset_names) == len(dataset_id_list): + dataset_info = [{ + 'id': str(id), + 'name': name + } for id, name in zip(dataset_id_list, dataset_names)] + else: + # 如果没有名称列表,则只返回ID + datasets = KnowledgeBase.objects.filter(id__in=dataset_id_list) + dataset_info = [{ + 'id': str(ds.id), + 'name': ds.name + } for ds in datasets] + + results.append({ + 'conversation_id': chat['conversation_id'], + 'message_count': chat['message_count'], + 'last_message': latest_record.content, + 'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S'), + 'dataset_id_list': [ds['id'] for ds in dataset_info], # 添加完整的知识库ID列表 + 'datasets': dataset_info # 包含ID和名称的完整信息 + }) + + return Response({ + 'code': 200, + 'message': '获取成功', + 'data': { + 'total': total, + 'page': page, + 'page_size': page_size, + 'results': results + } + }) + + except Exception as e: + logger.error(f"获取聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取聊天记录失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def conversation_detail(self, request): + """获取特定对话的详细信息""" + try: + conversation_id = request.query_params.get('conversation_id') + if not conversation_id: + return Response({ + 'code': 400, + 'message': '缺少conversation_id参数', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取对话历史 + messages = self.get_queryset().filter( + conversation_id=conversation_id + ).order_by('created_at') + + if not messages.exists(): + return Response({ + 'code': 404, + 'message': '对话不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 获取知识库信息 + first_message = messages.first() + dataset_info = [] + if first_message and first_message.metadata: + if 'dataset_id_list' in first_message.metadata: + datasets = KnowledgeBase.objects.filter( + id__in=first_message.metadata['dataset_id_list'] + ) + # 过滤出用户有权限访问的知识库 + accessible_datasets = [ + ds for ds in datasets + if self.check_knowledge_base_permission(ds, request.user, 'read') + ] + dataset_info = [{ + 'id': str(ds.id), + 'name': ds.name, + 'type': ds.type + } for ds in accessible_datasets] + + return Response({ + 'code': 200, + 'message': '获取成功', + 'data': { + 'conversation_id': conversation_id, + 'datasets': dataset_info, + 'messages': [{ + 'id': str(msg.id), + 'role': msg.role, + 'content': msg.content, + 'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S') + } for msg in messages] + } + }) + + except Exception as e: + logger.error(f"获取对话详情失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取对话详情失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def available_datasets(self, request): + """获取用户可访问的知识库列表""" + try: + user = request.user + all_datasets = KnowledgeBase.objects.all() + + # 使用统一的权限检查方法 + accessible_datasets = [ + dataset for dataset in all_datasets + if self.check_knowledge_base_permission(dataset, user, 'read') + ] + + return Response({ + 'code': 200, + 'message': '获取成功', + 'data': [{ + 'id': str(ds.id), + 'name': ds.name, + 'type': ds.type, + 'department': ds.department, + 'description': ds.desc + } for ds in accessible_datasets] + }) + + except Exception as e: + logger.error(f"获取可用知识库列表失败: {str(e)}") + return Response({ + 'code': 500, + 'message': f'获取可用知识库列表失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['post']) + def create_conversation(self, request): + """创建会话 - 先选择知识库创建会话ID,不发送问题""" + try: + data = request.data + + # 检查知识库ID:支持dataset_id或dataset_id_list格式 + dataset_ids = [] + if 'dataset_id' in data: + dataset_id = data['dataset_id'] + # 直接使用标准UUID格式 + dataset_ids.append(str(dataset_id)) + elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)): + # 处理可能的字符串格式 + if isinstance(data['dataset_id_list'], str): + try: + # 尝试解析JSON字符串 + dataset_list = json.loads(data['dataset_id_list']) + if isinstance(dataset_list, list): + dataset_ids = [str(id) for id in dataset_list] + except json.JSONDecodeError: + # 如果解析失败,可能是单个ID + dataset_ids = [str(data['dataset_id_list'])] + else: + # 如果已经是列表,直接使用标准UUID格式 + dataset_ids = [str(id) for id in data['dataset_id_list']] + else: + return Response({ + 'code': 400, + 'message': '缺少必填字段: dataset_id 或 dataset_id_list', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + if not dataset_ids: + return Response({ + 'code': 400, + 'message': '至少需要提供一个知识库ID', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证所有知识库 + user = request.user + knowledge_bases = [] # 存储所有知识库对象 + + for kb_id in dataset_ids: + try: + knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first() + if not knowledge_base: + return Response({ + 'code': 404, + 'message': f'知识库不存在: {kb_id}', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + knowledge_bases.append(knowledge_base) + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(knowledge_base, user, 'read'): + return Response({ + 'code': 403, + 'message': f'无权访问知识库: {knowledge_base.name}', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + except Exception as e: + return Response({ + 'code': 400, + 'message': f'处理知识库ID出错: {str(e)}', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建一个新的会话ID + conversation_id = str(uuid.uuid4()) + logger.info(f"创建新的会话ID: {conversation_id}") + + # 准备metadata (仍然保存知识库名称用于内部处理) + metadata = { + 'dataset_id_list': [str(id) for id in dataset_ids], + 'dataset_names': [kb.name for kb in knowledge_bases] + } + + return Response({ + 'code': 200, + 'message': '会话创建成功', + 'data': { + 'conversation_id': conversation_id, + 'dataset_id_list': metadata['dataset_id_list'] + } + }) + + except Exception as e: + logger.error(f"创建会话失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'创建会话失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def create(self, request): + """创建聊天记录""" + try: + data = request.data + + # 检查必填字段 + if 'question' not in data: + return Response({ + 'code': 400, + 'message': '缺少必填字段: question', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + if 'conversation_id' not in data: + return Response({ + 'code': 400, + 'message': '缺少必填字段: conversation_id', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + conversation_id = data['conversation_id'] + + # 查找该会话ID下的历史记录,获取知识库信息 + existing_records = ChatHistory.objects.filter( + conversation_id=conversation_id + ).order_by('created_at') + + # 如果有历史记录,使用第一条记录的metadata + if existing_records.exists(): + first_record = existing_records.first() + metadata = first_record.metadata or {} + + # 获取知识库信息 + dataset_ids = metadata.get('dataset_id_list', []) + external_id_list = metadata.get('dataset_external_id_list', []) + + # 验证知识库是否存在且用户有权限 + knowledge_bases = [] + if not dataset_ids: + return Response({ + 'code': 400, + 'message': '找不到会话关联的知识库信息', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + for kb_id in dataset_ids: + try: + kb = KnowledgeBase.objects.get(id=kb_id) + if not self.check_knowledge_base_permission(kb, request.user, 'read'): + return Response({ + 'code': 403, + 'message': f'无权访问知识库: {kb.name}', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + knowledge_bases.append(kb) + except KnowledgeBase.DoesNotExist: + return Response({ + 'code': 404, + 'message': f'知识库不存在: {kb_id}', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + if not external_id_list or not knowledge_bases: + return Response({ + 'code': 400, + 'message': '会话关联的知识库信息不完整', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + else: + # 如果是新会话的第一条记录,需要提供知识库ID + dataset_ids = [] + if 'dataset_id' in data: + dataset_ids.append(str(data['dataset_id'])) + elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)): + if isinstance(data['dataset_id_list'], str): + try: + dataset_list = json.loads(data['dataset_id_list']) + if isinstance(dataset_list, list): + dataset_ids = [str(id) for id in dataset_list] + else: + dataset_ids = [str(data['dataset_id_list'])] + except json.JSONDecodeError: + dataset_ids = [str(data['dataset_id_list'])] + else: + dataset_ids = [str(id) for id in data['dataset_id_list']] + + if not dataset_ids: + return Response({ + 'code': 400, + 'message': '新会话需要提供知识库ID', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证所有知识库并收集external_ids + external_id_list = [] + knowledge_bases = [] + + for kb_id in dataset_ids: + try: + knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first() + if not knowledge_base: + return Response({ + 'code': 404, + 'message': f'知识库不存在: {kb_id}', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + knowledge_bases.append(knowledge_base) + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'): + return Response({ + 'code': 403, + 'message': f'无权访问知识库: {knowledge_base.name}', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 添加知识库的external_id到列表 + if knowledge_base.external_id: + external_id_list.append(str(knowledge_base.external_id)) + else: + logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id") + + except Exception as e: + return Response({ + 'code': 400, + 'message': f'处理知识库ID出错: {str(e)}', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + if not external_id_list: + return Response({ + 'code': 400, + 'message': '没有有效的知识库external_id', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建metadata + metadata = { + 'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), + 'dataset_id_list': [str(id) for id in dataset_ids], + 'dataset_external_id_list': [str(id) for id in external_id_list], + 'dataset_names': [kb.name for kb in knowledge_bases] + } + + # 创建用户问题记录 + question_record = ChatHistory.objects.create( + user=request.user, + knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库 + conversation_id=str(conversation_id), + role='user', + content=data['question'], + metadata=metadata + ) + + # 检查是否需要流式输出 + use_stream = data.get('stream', True) + + if use_stream: + # 创建流式响应 + return StreamingHttpResponse( + self._stream_answer_from_external_api( + conversation_id=str(conversation_id), + question_record=question_record, + dataset_external_id_list=external_id_list, + knowledge_bases=knowledge_bases, + question=data['question'], + metadata=metadata + ), + content_type='text/event-stream' + ) + else: + # 使用非流式输出 + logger.info("使用非流式输出模式") + # 调用同步 API 获取回答 + answer = self._get_answer_from_external_api(external_id_list, data['question']) + + if answer is None: + return Response({ + 'code': 500, + 'message': '获取回答失败', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 创建 AI 回答记录 + answer_record = ChatHistory.objects.create( + user=request.user, + knowledge_base=knowledge_bases[0], + conversation_id=str(conversation_id), + parent_id=str(question_record.id), + role='assistant', + content=answer, + metadata=metadata + ) + + return Response({ + 'code': 200, + 'message': '成功', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'dataset_id_list': metadata.get('dataset_id_list', []), + 'dataset_names': metadata.get('dataset_names', []), + 'role': 'assistant', + 'content': answer, + 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S') + } + }) + + except Exception as e: + logger.error(f"创建聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'创建聊天记录失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def _stream_answer_from_external_api(self, conversation_id, question_record, dataset_external_id_list, knowledge_bases, question, metadata): + """流式获取AI回答并实时返回 - 优化版本""" + try: + # 确保所有ID都是字符串 + dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list] + + # 创建AI回答记录对象,稍后更新内容 + answer_record = ChatHistory.objects.create( + user=question_record.user, + knowledge_base=knowledge_bases[0], + conversation_id=str(conversation_id), + parent_id=str(question_record.id), + role='assistant', + content="", # 初始内容为空 + metadata=metadata + ) + + # 发送初始响应告知客户端开始流式传输 + yield f"data: {json.dumps({'code': 200, 'message': '开始流式传输', 'data': {'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': '', 'is_end': False}})}\n\n" + + # 异步收集完整内容,用于最后保存 + full_content = "" + + # 打开与外部API的连接 + logger.info(f"开始调用外部API,知识库ID列表: {dataset_external_ids}") + + try: + # 第一步: 创建聊天会话 + chat_response = requests.post( + url=f"{settings.API_BASE_URL}/api/application/chat/open", + json={ + "id": "d5d11efa-ea9a-11ef-9933-0242ac120006", + "model_id": "7a214d0e-e65e-11ef-9f4a-0242ac120006", + "dataset_id_list": dataset_external_ids, + "multiple_rounds_dialogue": False, + "dataset_setting": { + "top_n": 10, "similarity": "0.3", + "max_paragraph_char_number": 10000, + "search_mode": "blend", + "no_references_setting": { + "value": "{question}", + "status": "ai_questioning" + } + }, + "model_setting": { + "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" + }, + "problem_optimization": False + }, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + if chat_response.status_code != 200: + error_msg = f"外部API调用失败: {chat_response.text}" + logger.error(error_msg) + yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n" + return + + chat_data = chat_response.json() + if chat_data.get('code') != 200 or not chat_data.get('data'): + error_msg = f"外部API返回错误: {chat_data}" + logger.error(error_msg) + yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n" + return + + chat_id = chat_data['data'] + logger.info(f"成功创建聊天会话, chat_id: {chat_id}") + + # 第二步: 建立流式连接 + message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}" + logger.info(f"开始流式请求: {message_url}") + + # 创建流式请求 + message_request = requests.post( + url=message_url, + json={"message": question, "re_chat": False, "stream": True}, + headers={"Content-Type": "application/json"}, + stream=True, # 启用流式传输 + timeout=60 + ) + + if message_request.status_code != 200: + error_msg = f"外部API聊天消息调用失败: {message_request.status_code}, {message_request.text}" + logger.error(error_msg) + yield f"data: {json.dumps({'code': 500, 'message': error_msg, 'data': {'is_end': True}})}\n\n" + return + + # 创建一个缓冲区以处理分段的数据 + buffer = "" + + # 读取并处理每个响应块 + logger.info("开始处理流式响应") + for chunk in message_request.iter_content(chunk_size=1): + if not chunk: + continue + + # 解码字节为字符串 + chunk_str = chunk.decode('utf-8') + buffer += chunk_str + + # 检查是否有完整的数据行 + if '\n\n' in buffer: + lines = buffer.split('\n\n') + # 除了最后一行,其他都是完整的 + for line in lines[:-1]: + # 处理完整的数据行 + if line.startswith('data: '): + try: + # 提取JSON数据 + json_str = line[6:] # 去掉 "data: " 前缀 + data = json.loads(json_str) + + # 记录并处理部分响应 + if 'content' in data: + content_part = data['content'] + full_content += content_part + + # 构建响应数据 + response_data = { + 'code': 200, + 'message': 'partial', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'content': content_part, + 'is_end': data.get('is_end', False) + } + } + + # 立即发送每个部分到客户端 + yield f"data: {json.dumps(response_data)}\n\n" + + # 处理结束标记 + if data.get('is_end', False): + logger.info("收到流式响应结束标记") + # 异步保存完整内容 + answer_record.content = full_content.strip() + answer_record.save() + + # 发送完整内容的最终响应 + final_response = { + 'code': 200, + 'message': '完成', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'dataset_id_list': metadata.get('dataset_id_list', []), + 'dataset_names': metadata.get('dataset_names', []), + 'role': 'assistant', + 'content': full_content.strip(), + 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'is_end': True + } + } + yield f"data: {json.dumps(final_response)}\n\n" + return # 结束生成器 + except json.JSONDecodeError as e: + logger.error(f"JSON解析错误: {e}, 数据: {line}") + # 继续处理,跳过此行 + + # 保留最后一个可能不完整的行 + buffer = lines[-1] + + # 处理最后可能剩余的缓冲数据 + if buffer: + logger.info(f"处理剩余缓冲数据: {buffer}") + if buffer.startswith('data: '): + try: + json_str = buffer[6:] # 去掉 "data: " 前缀 + data = json.loads(json_str) + + if 'content' in data: + content_part = data['content'] + full_content += content_part + + response_data = { + 'code': 200, + 'message': 'partial', + 'data': { + 'id': str(answer_record.id), + 'conversation_id': str(conversation_id), + 'content': content_part, + 'is_end': data.get('is_end', False) + } + } + yield f"data: {json.dumps(response_data)}\n\n" + except json.JSONDecodeError: + logger.error(f"处理剩余数据时JSON解析错误: {buffer}") + + # 确保在流结束时保存内容到数据库 + if full_content: + answer_record.content = full_content.strip() + answer_record.save() + logger.info(f"流结束,保存完整内容到数据库: {len(full_content)} 字符") + + except requests.exceptions.RequestException as e: + logger.error(f"请求外部API时发生错误: {str(e)}") + yield f"data: {json.dumps({'code': 500, 'message': f'请求外部API时发生错误: {str(e)}', 'data': {'is_end': True}})}\n\n" + + except Exception as e: + logger.error(f"流式处理出错: {str(e)}") + logger.error(traceback.format_exc()) + yield f"data: {json.dumps({'code': 500, 'message': f'流式处理出错: {str(e)}', 'data': {'is_end': True}})}\n\n" + + # 尝试保存已收集的内容 + if 'full_content' in locals() and full_content: + try: + answer_record.content = full_content.strip() + answer_record.save() + except Exception as save_error: + logger.error(f"保存部分内容失败: {str(save_error)}") + + def _get_answer_from_external_api(self, dataset_external_id_list, question): + """调用外部API获取AI回答(非流式版本)""" + try: + # 确保所有ID都是字符串 + dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list] + + logger.info(f"准备调用外部API(非流式模式),知识库ID列表: {dataset_external_ids}") + + # 第一个API调用创建聊天 + chat_request_data = { + "id": "d5d11efa-ea9a-11ef-9933-0242ac120006", + "model_id": "7a214d0e-e65e-11ef-9f4a-0242ac120006", + "dataset_id_list": dataset_external_ids, + "multiple_rounds_dialogue": False, + "dataset_setting": { + "top_n": 10, + "similarity": "0.3", + "max_paragraph_char_number": 10000, + "search_mode": "blend", + "no_references_setting": { + "value": "{question}", + "status": "ai_questioning" + } + }, + "model_setting": { + "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" + }, + "problem_optimization": False + } + + logger.info(f"发送创建聊天请求:{settings.API_BASE_URL}/api/application/chat/open") + + try: + # 测试JSON序列化,提前捕获可能的错误 + json_data = json.dumps(chat_request_data) + logger.debug(f"请求数据序列化成功,长度: {len(json_data)}") + except TypeError as e: + logger.error(f"JSON序列化失败: {str(e)}") + return None + + chat_response = requests.post( + url=f"{settings.API_BASE_URL}/api/application/chat/open", + json=chat_request_data, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + logger.info(f"API响应状态码: {chat_response.status_code}") + + if chat_response.status_code != 200: + logger.error(f"外部API调用失败: {chat_response.text}") + return None + + chat_data = chat_response.json() + logger.debug(f"API响应数据: {chat_data}") + + if chat_data.get('code') != 200 or not chat_data.get('data'): + logger.error(f"外部API返回错误: {chat_data}") + return None + + chat_id = chat_data['data'] + logger.info(f"聊天创建成功,chat_id: {chat_id}") + + # 第二个API调用发送消息 + message_request_data = { + "message": question, + "re_chat": False, + "stream": False # 设置为非流式 + } + + logger.info(f"发送聊天消息请求(非流式): {settings.API_BASE_URL}/api/application/chat_message/{chat_id}") + message_response = requests.post( + url=f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}", + json=message_request_data, + headers={"Content-Type": "application/json"}, + timeout=60 + ) + + if message_response.status_code != 200: + logger.error(f"外部API聊天消息调用失败: {message_response.status_code}, {message_response.text}") + return None + + # 处理非流式响应 + try: + response_data = message_response.json() + logger.debug(f"非流式响应数据: {response_data}") + + if response_data.get('code') != 200 or 'data' not in response_data: + logger.error(f"外部API返回错误: {response_data}") + return None + + # 提取回答内容 + answer_content = response_data.get('data', {}).get('content', '') + if not answer_content: + logger.warning("API返回的回答内容为空") + return "无法获取回答内容" + + return answer_content + + except json.JSONDecodeError as e: + logger.error(f"解析API响应JSON失败: {str(e)}") + return None + except Exception as e: + logger.error(f"处理API响应失败: {str(e)}") + logger.error(traceback.format_exc()) + return None + + except Exception as e: + logger.error(f"调用外部API获取回答失败: {str(e)}") + logger.error(traceback.format_exc()) + return None + + def update(self, request, pk=None): + """更新聊天记录""" + try: + record = self.get_queryset().filter(id=pk).first() + + if not record: + return Response({ + 'code': 404, + 'message': '记录不存在或无权限', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + data = request.data + updateable_fields = ['content', 'metadata'] + + if 'content' in data: + record.content = data['content'] + + if 'metadata' in data: + current_metadata = record.metadata or {} + current_metadata.update(data['metadata']) + record.metadata = current_metadata + + record.save() + + return Response({ + 'code': 200, + 'message': '更新成功', + 'data': { + 'id': record.id, + 'conversation_id': record.conversation_id, + 'role': record.role, + 'content': record.content, + 'metadata': record.metadata, + 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + }) + + except Exception as e: + logger.error(f"更新聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'更新聊天记录失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def destroy(self, request, pk=None): + """删除聊天记录(软删除)""" + try: + record = self.get_queryset().filter(id=pk).first() + + if not record: + return Response({ + 'code': 404, + 'message': '记录不存在或无权限', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + record.soft_delete() + + return Response({ + 'code': 200, + 'message': '删除成功', + 'data': None + }) + + except Exception as e: + logger.error(f"删除聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'删除聊天记录失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def search(self, request): + """搜索聊天记录""" + try: + # 获取查询参数 + keyword = request.query_params.get('keyword', '').strip() + dataset_id = request.query_params.get('dataset_id') + start_date = request.query_params.get('start_date') + end_date = request.query_params.get('end_date') + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 10)) + + # 基础查询 + query = self.get_queryset() + + # 添加过滤条件 + if keyword: + query = query.filter( + Q(content__icontains=keyword) | + Q(knowledge_base__name__icontains=keyword) + ) + + if dataset_id: + # 检查知识库权限 + knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first() + if knowledge_base and not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'): + return Response({ + 'code': 403, + 'message': '无权访问该知识库', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + query = query.filter(knowledge_base__id=dataset_id) + if start_date: + query = query.filter(created_at__gte=start_date) + if end_date: + query = query.filter(created_at__lte=end_date) + + # 计算分页 + total = query.count() + start = (page - 1) * page_size + end = start + page_size + + # 获取分页数据 + records = query.order_by('-created_at')[start:end] + + # 序列化数据 + results = [] + for record in records: + result = { + 'id': record.id, + 'conversation_id': record.conversation_id, + 'dataset_id': str(record.knowledge_base.id), + 'dataset_name': record.knowledge_base.name, + 'role': record.role, + 'content': record.content, + 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'metadata': record.metadata + } + + if keyword: + result['highlights'] = { + 'content': self._highlight_keyword(record.content, keyword) + } + + results.append(result) + + return Response({ + 'code': 200, + 'message': '搜索成功', + 'data': { + 'total': total, + 'page': page, + 'page_size': page_size, + 'results': results + } + }) + + except Exception as e: + logger.error(f"搜索聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'搜索失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def export(self, request): + """导出聊天记录为Excel文件""" + try: + # 获取查询参数 + conversation_id = request.query_params.get('conversation_id') + dataset_id = request.query_params.get('dataset_id') + history_days = request.query_params.get('history_days', '7') # 默认导出最近7天 + + # 至少需要一个筛选条件 + if not conversation_id and not dataset_id: + return Response({ + 'code': 400, + 'message': '需要提供conversation_id或dataset_id参数', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证权限 + user = request.user + if dataset_id: + knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first() + if not knowledge_base: + return Response({ + 'code': 404, + 'message': '知识库不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(knowledge_base, user, 'read'): + return Response({ + 'code': 403, + 'message': '无权访问该知识库', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 查询确认有聊天记录存在 + query = self.get_queryset() + if conversation_id: + records = query.filter(conversation_id=conversation_id) + elif dataset_id: + records = query.filter(knowledge_base__id=dataset_id) + + if not records.exists(): + return Response({ + 'code': 404, + 'message': '未找到相关对话记录', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 调用外部API导出Excel文件 - 使用GET请求 + application_id = "d5d11efa-ea9a-11ef-9933-0242ac120006" # 固定值 + export_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat/export?history_day={history_days}" + + logger.info(f"发送导出请求:{export_url}") + + export_response = requests.get( + url=export_url, + timeout=60, + stream=True # 使用流式传输处理大文件 + ) + + # 检查响应状态 + if export_response.status_code != 200: + logger.error(f"导出API调用失败: {export_response.status_code}, {export_response.text}") + return Response({ + 'code': 500, + 'message': '导出失败,外部服务返回错误', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 创建响应对象并设置文件下载头 + response = HttpResponse( + content_type='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + ) + response['Content-Disposition'] = 'attachment; filename="data.xlsx"' + + # 将API响应内容写入响应对象 + for chunk in export_response.iter_content(chunk_size=8192): + if chunk: + response.write(chunk) + + logger.info("导出成功完成") + return response + + except Exception as e: + logger.error(f"导出聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'导出聊天记录失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def chat_list(self, request): + """获取对话列表""" + try: + # 获取查询参数 + history_days = request.query_params.get('history_days', '7') # 默认7天 + + # 构建API请求 + application_id = "d5d11efa-ea9a-11ef-9933-0242ac120006" + api_url = f"{settings.API_BASE_URL}/api/application/{application_id}/chat" + + # 添加查询参数 + params = { + 'history_day': history_days + } + + logger.info(f"发送获取对话列表请求:{api_url}") + + # 调用外部API + response = requests.get( + url=api_url, + params=params, + timeout=30 + ) + + if response.status_code != 200: + logger.error(f"获取对话列表失败: {response.status_code}, {response.text}") + return Response({ + 'code': 500, + 'message': '获取对话列表失败,外部服务返回错误', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 解析响应数据 + try: + result = response.json() + + if result.get('code') != 200: + logger.error(f"外部API返回错误: {result}") + return Response({ + 'code': result.get('code', 500), + 'message': result.get('message', '获取对话列表失败'), + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 处理返回的数据 + chat_list = result.get('data', []) + + # 格式化返回数据 + formatted_chats = [] + for chat in chat_list: + formatted_chat = { + 'id': chat['id'], + 'chat_id': chat['chat_id'], + 'abstract': chat['abstract'], + 'message_count': chat['chat_record_count'], + 'created_at': datetime.fromisoformat(chat['create_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'), + 'updated_at': datetime.fromisoformat(chat['update_time'].replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S'), + 'star_count': chat['star_num'], + 'trample_count': chat['trample_num'], + 'mark_sum': chat['mark_sum'], + 'is_deleted': chat['is_deleted'] + } + formatted_chats.append(formatted_chat) + + return Response({ + 'code': 200, + 'message': '获取成功', + 'data': { + 'total': len(formatted_chats), + 'results': formatted_chats + } + }) + + except json.JSONDecodeError as e: + logger.error(f"解析响应数据失败: {str(e)}") + return Response({ + 'code': 500, + 'message': '解析响应数据失败', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + except Exception as e: + logger.error(f"获取对话列表失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取对话列表失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['post']) + def hit_test(self, request): + """获取问题与知识库文档的匹配度""" + try: + data = request.data + + # 检查必填字段 + if 'question' not in data: + return Response({ + 'code': 400, + 'message': '缺少必填字段: question', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + if 'dataset_id_list' not in data or not data['dataset_id_list']: + return Response({ + 'code': 400, + 'message': '缺少必填字段: dataset_id_list', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + question = data['question'] + dataset_ids = data['dataset_id_list'] + + # 如果不是列表,转换为列表 + if not isinstance(dataset_ids, list): + try: + dataset_ids = json.loads(dataset_ids) + if not isinstance(dataset_ids, list): + dataset_ids = [dataset_ids] + except (json.JSONDecodeError, TypeError): + dataset_ids = [dataset_ids] + + # 检查用户是否有权限访问这些知识库 + external_id_list = [] + for kb_id in dataset_ids: + try: + kb = KnowledgeBase.objects.get(id=kb_id) + if not self.check_knowledge_base_permission(kb, request.user, 'read'): + return Response({ + 'code': 403, + 'message': f'无权访问知识库: {kb.name}', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + if kb.external_id: + external_id_list.append(str(kb.external_id)) + else: + logger.warning(f"知识库 {kb.id} ({kb.name}) 没有external_id") + except KnowledgeBase.DoesNotExist: + return Response({ + 'code': 404, + 'message': f'知识库不存在: {kb_id}', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + if not external_id_list: + return Response({ + 'code': 400, + 'message': '没有有效的知识库external_id', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取所有知识库的匹配文档 + all_documents = [] + for dataset_id in external_id_list: + doc_info = self._call_hit_test_api(dataset_id, question) + if doc_info: + all_documents.extend(doc_info) + + # 按相似度排序 + all_documents = sorted(all_documents, key=lambda x: x.get('similarity', 0), reverse=True) + + # 返回结果 + return Response({ + 'code': 200, + 'message': '成功', + 'data': { + 'question': question, + 'matched_documents': all_documents, + 'total_count': len(all_documents) + } + }) + + except Exception as e: + logger.error(f"hit_test接口调用失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'hit_test接口调用失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def _highlight_keyword(self, text, keyword): + """高亮关键词""" + if not keyword or not text: + return text + return text.replace( + keyword, + f'{keyword}' + ) + + def _call_hit_test_api(self, dataset_id, query_text): + """调用知识库hit_test接口获取相关文档信息""" + try: + url = f"{settings.API_BASE_URL}/api/dataset/{dataset_id}/hit_test" + params = { + "query_text": query_text, + "top_number": 10, + "similarity": 0.3, + "search_mode": "blend" + } + + logger.info(f"调用hit_test接口: {url}, 参数: {params}") + + response = requests.get( + url=url, + params=params, + timeout=30 + ) + + if response.status_code != 200: + logger.error(f"hit_test接口调用失败: {response.status_code}, {response.text}") + return None + + result = response.json() + if result.get('code') != 200: + logger.error(f"hit_test接口业务错误: {result}") + return None + + # 提取文档信息 + documents = result.get('data', []) + logger.info(f"hit_test接口返回 {len(documents)} 个相关文档") + + # 提取文档名称和相似度等信息 + doc_info = [] + for doc in documents: + doc_info.append({ + "document_name": doc.get("document_name", ""), + "dataset_name": doc.get("dataset_name", ""), + "similarity": doc.get("similarity", 0), + "comprehensive_score": doc.get("comprehensive_score", 0) + }) + + return doc_info + except Exception as e: + logger.error(f"调用hit_test接口失败: {str(e)}") + logger.error(traceback.format_exc()) + return None + + @action(detail=False, methods=['delete']) + def delete_conversation(self, request): + """通过conversation_id删除一组会话""" + try: + # 获取conversation_id + conversation_id = request.query_params.get('conversation_id') + if not conversation_id: + return Response({ + 'code': 400, + 'message': '缺少必要参数: conversation_id', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 查找该会话下的所有记录 + records = self.get_queryset().filter(conversation_id=conversation_id) + + if not records.exists(): + return Response({ + 'code': 404, + 'message': '未找到该会话或无权限访问', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 获取记录数量 + records_count = records.count() + + # 批量软删除 + for record in records: + record.soft_delete() + + return Response({ + 'code': 200, + 'message': '删除成功', + 'data': { + 'conversation_id': conversation_id, + 'deleted_count': records_count + } + }) + + except Exception as e: + logger.error(f"删除会话失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'删除会话失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet): + serializer_class = KnowledgeBaseSerializer + permission_classes = [IsAuthenticated] + + def list(self, request, *args, **kwargs): + try: + queryset = self.get_queryset() + + # 获取搜索关键字 + keyword = request.query_params.get('keyword', '') + + # 如果有关键字,构建搜索条件 + if keyword: + query = Q(name__icontains=keyword) | \ + Q(desc__icontains=keyword) | \ + Q(department__icontains=keyword) | \ + Q(group__icontains=keyword) + queryset = queryset.filter(query) + + # 获取分页参数 + try: + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 10)) + except ValueError: + page = 1 + page_size = 10 + + # 计算总数量 + total = queryset.count() + + # 分页处理 + start = (page - 1) * page_size + end = start + page_size + paginated_queryset = queryset[start:end] + + # 序列化知识库数据 + serializer = self.get_serializer(paginated_queryset, many=True) + data = serializer.data + + # 为每个知识库添加权限信息 + user = request.user + for item in data: + # 获取必要的知识库属性 + kb_type = item['type'] + department = item.get('department') + group = item.get('group') + creator_id = item.get('user_id') + kb_id = item['id'] + + # 首先检查权限表中的显式权限 + explicit_permission = KBPermissionModel.objects.filter( + knowledge_base_id=kb_id, + user=user, + status='active' + ).first() + + if explicit_permission: + item['permissions'] = { + 'can_read': explicit_permission.can_read, + 'can_edit': explicit_permission.can_edit, + 'can_delete': explicit_permission.can_delete + } + # 添加知识库的到期时间 + item['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None + else: + # 没有显式权限时使用统一的权限判断方法 + item['permissions'] = { + 'can_read': self._can_read(kb_type, user, department, group, creator_id, kb_id), + 'can_edit': self._can_edit(kb_type, user, department, group, creator_id, kb_id), + 'can_delete': self._can_delete(kb_type, user, department, group, creator_id, kb_id) + } + # 对于admin类型的知识库,设置expires_at为None + if kb_type == 'admin': + item['expires_at'] = None + else: + # 对于其他类型,如果没有显式权限记录,则表示没有到期时间 + item['expires_at'] = None + + # 处理高亮 + if keyword: + if 'name' in item and keyword.lower() in item['name'].lower(): + item['highlighted_name'] = item['name'].replace( + keyword, f'{keyword}' + ) + + if 'desc' in item and item.get('desc') is not None: + desc_text = str(item['desc']) + if keyword.lower() in desc_text.lower(): + item['highlighted_desc'] = desc_text.replace( + keyword, f'{keyword}' + ) + + return Response({ + "code": 200, + "message": "获取知识库列表成功", + "data": { + "total": total, + "page": page, + "page_size": page_size, + "keyword": keyword if keyword else None, + "items": data + } + }) + except Exception as e: + logger.error(f"获取知识库列表失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"获取知识库列表失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def get_queryset(self): + """获取用户有权限查看的知识库列表""" + user = self.request.user + queryset = KnowledgeBase.objects.all() + + # 1. 构建基础权限条件 + permission_conditions = Q() + + # 2. 所有用户都可以看到 admin 类型的知识库 + permission_conditions |= Q(type='admin') + + # 3. 用户可以看到自己创建的所有知识库 + permission_conditions |= Q(user_id=user.id) + + # 4. 添加显式权限条件 + # 获取所有活跃的权限记录 + active_permissions = KBPermissionModel.objects.filter( + user=user, + can_read=True, + status='active', + expires_at__gt=timezone.now() + ).values_list('knowledge_base_id', flat=True) + + if active_permissions: + permission_conditions |= Q(id__in=active_permissions) + + # 5. 根据用户角色添加隐式权限 + if user.role == 'admin': + # 管理员可以看到除了其他用户 private 类型外的所有知识库 + permission_conditions |= ~Q(type='private') | Q(user_id=user.id) + elif user.role == 'leader': + # 组长可以查看本部门的 leader 和 member 类型知识库 + permission_conditions |= Q( + type__in=['leader', 'member'], + department=user.department + ) + elif user.role in ['member', 'user']: + # 成员可以查看本部门的 leader 类型知识库 + permission_conditions |= Q( + type='leader', + department=user.department + ) + # 成员可以查看本部门本组的 member 类型知识库 + permission_conditions |= Q( + type='member', + department=user.department, + group=user.group + ) + + return queryset.filter(permission_conditions).distinct() + + def create(self, request, *args, **kwargs): + try: + # 1. 验证知识库名称 + name = request.data.get('name') + if not name: + return Response({ + 'code': 400, + 'message': '知识库名称不能为空', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + if KnowledgeBase.objects.filter(name=name).exists(): + return Response({ + 'code': 400, + 'message': f'知识库名称 "{name}" 已存在', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 2. 验证用户权限和必填字段 + user = request.user + type = request.data.get('type', 'private') + department = request.data.get('department') + group = request.data.get('group') + + # 修改权限验证 + if type == 'admin': + # 移除管理员权限检查,允许所有用户创建 + department = None + group = None + + elif type == 'secret': + if user.role != 'admin': + return Response({ + 'code': 403, + 'message': '只有管理员可以创建保密级知识库', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + department = None + group = None + + elif type == 'leader': + if user.role != 'admin': + return Response({ + 'code': 403, + 'message': '只有管理员可以创建组长级知识库', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + if not department: + return Response({ + 'code': 400, + 'message': '创建组长级知识库时必须指定部门', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + elif type == 'member': + if user.role not in ['admin', 'leader']: + return Response({ + 'code': 403, + 'message': '只有管理员和组长可以创建成员级知识库', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + if user.role == 'admin' and not department: + return Response({ + 'code': 400, + 'message': '管理员创建成员知识库时必须指定部门', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + elif user.role == 'leader': + department = user.department + + if not group: + return Response({ + 'code': 400, + 'message': '创建成员知识库时必须指定组', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + elif type == 'private': + # 对于private类型,不保存department和group + department = None + group = None + + # 3. 验证请求数据 + data = request.data.copy() + data['department'] = department + data['group'] = group + + # 不需要手动设置 user_id,由序列化器自动处理 + serializer = self.get_serializer(data=data) + if not serializer.is_valid(): + logger.error(f"数据验证失败: {serializer.errors}") + return Response({ + 'code': 400, + 'message': '数据验证失败', + 'data': serializer.errors + }, status=status.HTTP_400_BAD_REQUEST) + + with transaction.atomic(): + # 4. 创建知识库 + try: + knowledge_base = serializer.save() + logger.info(f"知识库创建成功: id={knowledge_base.id}, name={knowledge_base.name}, user_id={knowledge_base.user_id}") + except Exception as e: + logger.error(f"知识库创建失败: {str(e)}") + raise + + # 5. 调用外部API创建知识库 + try: + external_id = self._create_external_dataset(knowledge_base) + logger.info(f"外部知识库创建成功,获取ID: {external_id}") + + # 保存外部知识库ID + knowledge_base.external_id = external_id + knowledge_base.save() + logger.info(f"更新knowledge_base的external_id为: {external_id}") + + except ExternalAPIError as e: + logger.error(f"外部知识库创建失败: {str(e)}") + raise + + # 6. 创建权限记录 + try: + # 创建者权限 + KBPermissionModel.objects.create( + knowledge_base=knowledge_base, + user=request.user, + can_read=True, + can_edit=True, + can_delete=True, + granted_by=request.user, + status='active' + ) + logger.info(f"创建者权限创建成功") + + # 根据类型批量创建其他用户权限 + permissions = [] + if type == 'admin': + users_query = User.objects.exclude(id=request.user.id) + # 为所有用户赋予完全权限(读、写、删) + permissions = [ + KBPermissionModel( + knowledge_base=knowledge_base, + user=user, + can_read=True, + can_edit=True, + can_delete=True, + granted_by=request.user, + status='active' + ) for user in users_query + ] + elif type == 'secret': + users_query = User.objects.filter(role='admin').exclude(id=request.user.id) + permissions = [ + KBPermissionModel( + knowledge_base=knowledge_base, + user=user, + can_read=True, + can_edit=self._can_edit(type, user), + can_delete=self._can_delete(type, user), + granted_by=request.user, + status='active' + ) for user in users_query + ] + elif type == 'leader': + users_query = User.objects.filter( + Q(role='admin') | + Q(role='leader', department=department) + ).exclude(id=request.user.id) + permissions = [ + KBPermissionModel( + knowledge_base=knowledge_base, + user=user, + can_read=True, + can_edit=self._can_edit(type, user), + can_delete=self._can_delete(type, user), + granted_by=request.user, + status='active' + ) for user in users_query + ] + elif type == 'member': + users_query = User.objects.filter( + Q(role='admin') | + Q(department=department, role='leader') | + Q(department=department, group=group, role='member') + ).exclude(id=request.user.id) + permissions = [ + KBPermissionModel( + knowledge_base=knowledge_base, + user=user, + can_read=True, + can_edit=self._can_edit(type, user), + can_delete=self._can_delete(type, user), + granted_by=request.user, + status='active' + ) for user in users_query + ] + else: # private + users_query = User.objects.none() + + if permissions: + KBPermissionModel.objects.bulk_create(permissions) + logger.info(f"{type}类型权限创建完成: {len(permissions)}条记录") + + except Exception as e: + logger.error(f"权限创建失败: {str(e)}") + logger.error(traceback.format_exc()) + raise + + return Response({ + 'code': 200, + 'message': '知识库创建成功', + 'data': { + 'knowledge_base': serializer.data, + 'external_id': knowledge_base.external_id + } + }) + + except Exception as e: + logger.error(f"创建知识库失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'创建知识库失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def update(self, request, *args, **kwargs): + """更新知识库""" + try: + instance = self.get_object() + user = request.user + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(instance, user, 'edit'): + return Response({ + "code": 403, + "message": "没有编辑权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + with transaction.atomic(): + # 执行本地更新 + serializer = self.get_serializer(instance, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + # 更新外部知识库 + if instance.external_id: + try: + api_data = { + "name": serializer.validated_data.get('name', instance.name), + "desc": serializer.validated_data.get('desc', instance.desc), + "type": "0", # 保持与创建时一致 + "meta": {}, # 保持与创建时一致 + "documents": [] # 保持与创建时一致 + } + + response = requests.put( + f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}', + json=api_data, + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + if response.status_code != 200: + raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}") + + api_response = response.json() + if not api_response.get('code') == 200: + raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}") + + logger.info(f"外部知识库更新成功: {instance.external_id}") + + except requests.exceptions.Timeout: + raise ExternalAPIError("请求超时,请稍后重试") + except requests.exceptions.RequestException as e: + raise ExternalAPIError(f"API请求失败: {str(e)}") + except Exception as e: + raise ExternalAPIError(f"更新外部知识库失败: {str(e)}") + + return Response({ + "code": 200, + "message": "知识库更新成功", + "data": serializer.data + }) + + except Http404: + return Response({ + "code": 404, + "message": "知识库不存在", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + except ExternalAPIError as e: + logger.error(f"更新外部知识库失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": str(e), + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + except Exception as e: + logger.error(f"更新知识库失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"更新知识库失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def destroy(self, request, *args, **kwargs): + """删除知识库""" + try: + instance = self.get_object() + user = request.user + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(instance, user, 'delete'): + return Response({ + "code": 403, + "message": "没有删除权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + with transaction.atomic(): + # 删除外部知识库 + if instance.external_id: + try: + self._delete_external_dataset(instance.external_id) + logger.info(f"外部知识库删除成功: {instance.external_id}") + except ExternalAPIError as e: + logger.error(f"删除外部知识库失败: {str(e)}") + return Response({ + "code": 500, + "message": f"删除外部知识库失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 删除本地知识库 + self.perform_destroy(instance) + logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}") + + return Response({ + "code": 200, + "message": "知识库删除成功", + "data": None + }) + + except Http404: + return Response({ + "code": 404, + "message": "知识库不存在", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + logger.error(f"删除知识库失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"删除知识库失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=True, methods=['get']) + def permissions(self, request, pk=None): + """获取用户对特定知识库的权限""" + try: + instance = self.get_object() + user = request.user + + # 使用统一的权限检查方法 + permissions_data = { + "can_read": self.check_knowledge_base_permission(instance, user, 'read'), + "can_edit": self.check_knowledge_base_permission(instance, user, 'edit'), + "can_delete": self.check_knowledge_base_permission(instance, user, 'delete') + } + + return Response({ + "code": 200, + "message": "获取权限信息成功", + "data": { + "knowledge_base_id": instance.id, + "knowledge_base_name": instance.name, + "permissions": permissions_data + } + }) + + except Http404: + return Response({ + "code": 404, + "message": "知识库不存在", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + logger.error(f"获取权限信息失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"获取权限信息失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + @action(detail=False, methods=['get']) + def summary(self, request): + """获取所有可见知识库的概要信息(除了secret类型)""" + try: + user = request.user + + # 基础查询:排除secret类型的知识库 + queryset = KnowledgeBase.objects.exclude(type='secret') + + summaries = [] + for kb in queryset: + # 使用统一的权限判断方法 + permissions = { + 'can_read': self.check_knowledge_base_permission(kb, user, 'read'), + 'can_edit': self.check_knowledge_base_permission(kb, user, 'edit'), + 'can_delete': self.check_knowledge_base_permission(kb, user, 'delete') + } + + # 获取知识库到期时间 + explicit_permission = KBPermissionModel.objects.filter( + knowledge_base_id=kb.id, + user=user, + status='active' + ).first() + + expires_at = None + if explicit_permission: + expires_at = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None + elif kb.type == 'admin': + expires_at = None + + # 只返回概要信息 + summary = { + 'id': str(kb.id), + 'name': kb.name, + 'desc': kb.desc, + 'type': kb.type, + 'department': kb.department, + 'permissions': permissions, + 'expires_at': expires_at + } + summaries.append(summary) + + return Response({ + 'code': 200, + 'message': '获取知识库概要信息成功', + 'data': summaries + }) + + except Exception as e: + return Response({ + 'code': 500, + 'message': f'获取知识库概要信息失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def retrieve(self, request, *args, **kwargs): + try: + # 获取知识库对象 + instance = self.get_object() + serializer = self.get_serializer(instance) + data = serializer.data + + # 获取用户 + user = request.user + + # 使用统一的权限判断方法 + data['permissions'] = { + 'can_read': self.check_knowledge_base_permission(instance, user, 'read'), + 'can_edit': self.check_knowledge_base_permission(instance, user, 'edit'), + 'can_delete': self.check_knowledge_base_permission(instance, user, 'delete') + } + + # 添加知识库到期时间 + explicit_permission = KBPermissionModel.objects.filter( + knowledge_base_id=instance.id, + user=user, + status='active' + ).first() + + if explicit_permission: + data['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None + else: + # 对于admin类型的知识库,设置expires_at为None + if instance.type == 'admin': + data['expires_at'] = None + else: + # 对于其他类型,如果没有显式权限记录,则表示没有到期时间 + data['expires_at'] = None + + return Response({ + 'code': 200, + 'message': '获取知识库详情成功', + 'data': data + }) + except Exception as e: + logger.error(f"获取知识库详情失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取知识库详情失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def search(self, request): + """搜索知识库功能""" + try: + # 获取搜索关键字 + keyword = request.query_params.get('keyword', '') + if not keyword: + return Response({ + "code": 400, + "message": "搜索关键字不能为空", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取分页参数 + try: + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 10)) + except ValueError: + page = 1 + page_size = 10 + + # 构建搜索条件 + query = Q(name__icontains=keyword) | \ + Q(desc__icontains=keyword) | \ + Q(department__icontains=keyword) | \ + Q(group__icontains=keyword) + + # 排除 secret 类型的知识库 + queryset = KnowledgeBase.objects.filter(query).exclude(type='secret') + + # 获取用户 + user = request.user + + # 获取用户所有有效的知识库权限 + active_permissions = KBPermissionModel.objects.filter( + user=user, + status='active', + expires_at__gt=timezone.now() + ).select_related('knowledge_base') + + # 创建权限映射字典 + permission_map = { + str(perm.knowledge_base.id): { + 'can_read': perm.can_read, + 'can_edit': perm.can_edit, + 'can_delete': perm.can_delete + } + for perm in active_permissions + } + + # 计算总数量 + total = queryset.count() + + # 分页处理 + start = (page - 1) * page_size + end = start + page_size + paginated_queryset = queryset[start:end] + + # 序列化知识库数据 + serializer = self.get_serializer(paginated_queryset, many=True) + data = serializer.data + + # 处理每个知识库项的权限和返回内容 + result_items = [] + for item in data: + # 创建一个临时的知识库对象用于权限检查 + temp_kb = KnowledgeBase( + id=item['id'], + type=item['type'], + department=item.get('department'), + group=item.get('group'), + user_id=item.get('user_id') + ) + + # 使用统一的权限判断方法 + explicit_permission = KBPermissionModel.objects.filter( + knowledge_base_id=item['id'], + user=user, + status='active' + ).first() + + if explicit_permission: + kb_permissions = { + 'can_read': explicit_permission.can_read, + 'can_edit': explicit_permission.can_edit, + 'can_delete': explicit_permission.can_delete + } + # 添加知识库的到期时间 + item['expires_at'] = explicit_permission.expires_at.strftime("%Y-%m-%d %H:%M:%S") if explicit_permission.expires_at else None + else: + # 使用统一的权限判断方法 + kb_permissions = { + 'can_read': self.check_knowledge_base_permission(temp_kb, user, 'read'), + 'can_edit': self.check_knowledge_base_permission(temp_kb, user, 'edit'), + 'can_delete': self.check_knowledge_base_permission(temp_kb, user, 'delete') + } + # 对于admin类型的知识库,设置expires_at为None + if item['type'] == 'admin': + item['expires_at'] = None + else: + # 对于其他类型,如果没有显式权限记录,则表示没有到期时间 + item['expires_at'] = None + + # 添加权限信息 + item['permissions'] = kb_permissions + + # 根据权限返回不同级别的信息 + if kb_permissions['can_read']: + result_items.append(item) + else: + # 无读取权限,只返回概要信息 + summary_info = { + 'id': item['id'], + 'name': item['name'], + 'type': item['type'], + 'department': item.get('department'), + 'permissions': kb_permissions + } + result_items.append(summary_info) + + # 高亮搜索关键字 + for item in result_items: + if 'name' in item and keyword.lower() in item['name'].lower(): + highlighted = item['name'].replace( + keyword, f'{keyword}' + ) + item['highlighted_name'] = highlighted + + # 确保desc不为None并且是字符串 + if 'desc' in item and item.get('desc') is not None: + desc_text = str(item['desc']) # 转换为字符串以确保安全 + if keyword.lower() in desc_text.lower(): + highlighted = desc_text.replace( + keyword, f'{keyword}' + ) + item['highlighted_desc'] = highlighted + + return Response({ + "code": 200, + "message": "搜索知识库成功", + "data": { + "total": total, + "page": page, + "page_size": page_size, + "keyword": keyword, + "items": result_items + } + }) + except Exception as e: + logger.error(f"搜索知识库失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"搜索知识库失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=True, methods=['post']) + def change_type(self, request, pk=None): + """修改知识库类型""" + try: + instance = self.get_object() + user = request.user + + # 使用统一的权限检查方法检查编辑权限 + if not self.check_knowledge_base_permission(instance, user, 'edit'): + return Response({ + "code": 403, + "message": "没有修改权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + # 其余代码保持不变... + + # 获取新类型 + new_type = request.data.get('type') + if not new_type: + return Response({ + "code": 400, + "message": "新类型不能为空", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证类型是否有效 + valid_types = ['private', 'admin', 'secret', 'leader', 'member'] + if new_type not in valid_types: + return Response({ + "code": 400, + "message": f"无效的知识库类型,可选值: {', '.join(valid_types)}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 角色特定的类型限制 + if new_type == 'leader' and not user.role == 'admin': # 组长且不是管理员 + # 组长只能在private和member类型之间切换 + if new_type not in ['private', 'member']: + return Response({ + "code": 403, + "message": "组长只能将知识库设置为private或member类型", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + # 处理department和group字段 + department = request.data.get('department') + group = request.data.get('group') + + # 组长只能设置自己部门 + if new_type == 'leader' and not user.role == 'admin': + if department and department != user.department: + return Response({ + "code": 403, + "message": "组长只能为本部门设置知识库", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + # 如果未指定部门,强制设置为组长的部门 + department = user.department + + # 根据类型验证必填字段 + if new_type == 'leader': + if not department: + return Response({ + "code": 400, + "message": "组长级知识库必须指定部门", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + if new_type == 'member': + if not department: + return Response({ + "code": 400, + "message": "成员级知识库必须指定部门", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + if not group: + return Response({ + "code": 400, + "message": "成员级知识库必须指定组", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 如果是admin或secret类型,清除department和group + if new_type in ['admin', 'secret']: + department = None + group = None + + # 如果是private类型但未指定department和group,使用原值 + if new_type == 'private': + if department is None: + department = instance.department + if group is None: + group = instance.group + + # 更新知识库类型和相关字段 + instance.type = new_type + instance.department = department + instance.group = group + instance.save() + + return Response({ + "code": 200, + "message": f"知识库类型已更新为{new_type}", + "data": { + "id": instance.id, + "name": instance.name, + "type": instance.type, + "department": instance.department, + "group": instance.group + } + }) + + except Http404: + return Response({ + "code": 404, + "message": "知识库不存在", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + logger.error(f"修改知识库类型失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"修改知识库类型失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def _call_split_api_multiple(self, files): + """调用文档分割API - 直接传递多个文件""" + try: + url = f'{settings.API_BASE_URL}/api/dataset/document/split' + + # 准备请求数据 - 使用单个"file"字段 + file_obj = files[0] # 先只处理第一个文件,便于排查问题 + + # 重置文件指针位置 + if hasattr(file_obj, 'seek'): + file_obj.seek(0) + + logger.info(f"准备上传文件: {file_obj.name}, 大小: {file_obj.size}字节, 类型: {file_obj.content_type}") + + # 读取文件内容前100个字符进行记录 + if hasattr(file_obj, 'read') and hasattr(file_obj, 'seek'): + content_preview = file_obj.read(100).decode('utf-8', errors='ignore') + logger.info(f"文件内容预览: {content_preview}") + file_obj.seek(0) # 重置文件指针 + + # 使用正确的字段名称发送请求 + files_data = {'file': file_obj} + + logger.info(f"调用分割API URL: {url}") + logger.info(f"请求字段: {list(files_data.keys())}") + + # 发送请求 + response = requests.post( + url, + files=files_data, + timeout=60 + ) + + # 记录请求头和响应信息,方便排查问题 + logger.info(f"请求头: {response.request.headers}") + logger.info(f"响应状态码: {response.status_code}") + + if response.status_code != 200: + logger.error(f"分割API返回错误状态码: {response.status_code}, 响应: {response.text}") + return None + + # 解析响应 + result = response.json() + logger.info(f"分割API响应详情: {result}") + + # 如果数据为空,可能是API期望的请求格式不对,尝试使用不同的字段名 + if len(result.get('data', [])) == 0: + logger.warning("分割API返回的数据为空,尝试使用后备方案") + + # 创建一个手动构建的文档结构 + fallback_data = { + 'code': 200, + 'message': '成功', + 'data': [ + { + 'name': file_obj.name, + 'content': [ + { + 'title': '文档内容', + 'content': '文件内容无法自动分割,请检查外部API。这是一个后备内容。' + } + ] + } + ] + } + logger.info("使用后备数据结构") + return fallback_data + + return result + except Exception as e: + logger.error(f"调用分割API失败: {str(e)}") + logger.error(traceback.format_exc()) + + # 创建一个后备响应 + fallback_response = { + 'code': 200, + 'message': '成功', + 'data': [] + } + + # 如果有文件,为每个文件创建一个基本文档结构 + if files: + fallback_response['data'] = [ + { + 'name': file.name, + 'content': [ + { + 'title': '文档内容', + 'content': '文件内容无法自动分割,请检查API连接。' + } + ] + } for file in files + ] + + logger.info("由于异常,返回后备响应") + return fallback_response + + @action(detail=True, methods=['post']) + def upload_document(self, request, pk=None): + """上传文档到知识库 - 支持多文件上传""" + try: + instance = self.get_object() + user = request.user + + # 使用统一的权限检查方法 + if not self.check_knowledge_base_permission(instance, user, 'edit'): + return Response({ + "code": 403, + "message": "没有编辑权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + # 记录请求内容,方便调试 + logger.info(f"请求内容: {request.data}") + logger.info(f"请求FILES: {request.FILES}") + + # 获取上传的文件,尝试多种可能的字段名 + files = [] + # 尝试'files'字段(多文件) + if 'files' in request.FILES: + files = request.FILES.getlist('files') + # 尝试'file'字段(多文件) + elif 'file' in request.FILES: + files = request.FILES.getlist('file') + # 尝试files[]格式(常见于前端FormData) + elif any(key.startswith('files[') for key in request.FILES): + files = [file for key, file in request.FILES.items() if key.startswith('files[')] + # 尝试file[]格式 + elif any(key.startswith('file[') for key in request.FILES): + files = [file for key, file in request.FILES.items() if key.startswith('file[')] + # 单个文件上传的情况 + elif len(request.FILES) > 0: + # 如果有任何文件,就全部使用 + files = list(request.FILES.values()) + + if not files: + return Response({ + "code": 400, + "message": "未找到上传文件,请确保表单字段名为'files'或'file'", + "data": { + "available_fields": list(request.FILES.keys()) + } + }, status=status.HTTP_400_BAD_REQUEST) + + logger.info(f"接收到 {len(files)} 个文件上传请求") + + # 保存所有处理后的文档 + saved_documents = [] + failed_documents = [] + + # 验证knowledge_base的external_id是否有效 + if not instance.external_id: + return Response({ + "code": 400, + "message": "知识库没有有效的external_id,请先创建知识库", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 先验证外部知识库是否存在 + try: + # 简单的验证请求 + verify_url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}' + verify_response = requests.get(verify_url, timeout=10) + if verify_response.status_code != 200: + logger.error(f"外部知识库不存在或无法访问: {instance.external_id}, 状态码: {verify_response.status_code}") + return Response({ + "code": 404, + "message": f"外部知识库不存在或无法访问: {instance.external_id}", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + + verify_data = verify_response.json() + if verify_data.get('code') != 200: + logger.error(f"验证外部知识库失败: {verify_data.get('message')}") + return Response({ + "code": verify_data.get('code', 500), + "message": f"验证外部知识库失败: {verify_data.get('message', '未知错误')}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + logger.info(f"外部知识库验证成功: {instance.external_id}") + except Exception as e: + logger.error(f"验证外部知识库时出错: {str(e)}") + return Response({ + "code": 500, + "message": f"验证外部知识库时出错: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 逐个处理每个文件 - 避免一次性传多个文件导致外部API处理失败 + for i, file in enumerate(files): + logger.info(f"处理第 {i+1} 个文件: {file.name}") + + # 创建只包含当前文件的列表,传递给分割API + current_file = [file] + + # 调用文档分割API + split_response = self._call_split_api_multiple(current_file) + if not split_response or split_response.get('code') != 200: + error_msg = f"文件 {file.name} 分割失败: {split_response.get('message', '未知错误') if split_response else '请求失败'}" + logger.error(error_msg) + failed_documents.append({ + "name": file.name, + "error": error_msg + }) + continue + + # 处理分割后的文档 + documents_data = split_response.get('data', []) + + # 如果没有文档数据,使用一个基本结构 + if not documents_data: + logger.warning(f"文件 {file.name} 未返回文档数据,创建基本文档结构") + documents_data = [{ + 'name': file.name, + 'content': [{ + 'title': '文档内容', + 'content': '文件内容无法自动分割,请检查文件格式。' + }] + }] + + # 遍历所有分割后的文档 + for doc in documents_data: + doc_name = doc.get('name', file.name) + doc_content = doc.get('content', []) + + logger.info(f"处理文档: {doc_name}, 包含 {len(doc_content)} 个段落") + + # 如果没有内容,添加一个默认段落 + if not doc_content: + doc_content = [{ + 'title': '文档内容', + 'content': '文件内容无法自动分割,请检查文件格式。' + }] + + # 准备文档数据结构 + doc_data = { + "name": doc_name, + "paragraphs": [] + } + + # 将所有段落添加到文档中 + for paragraph in doc_content: + doc_data["paragraphs"].append({ + "content": paragraph.get('content', ''), + "title": paragraph.get('title', ''), + "is_active": True, + "problem_list": [] + }) + + # 调用文档上传API + upload_response = self._call_upload_api(instance.external_id, doc_data) + + if upload_response and upload_response.get('code') == 200 and upload_response.get('data'): + # 上传成功,保存记录到数据库 + document_id = upload_response['data']['id'] + doc_record = KnowledgeBaseDocument.objects.create( + knowledge_base=instance, + document_id=document_id, + document_name=doc_name, + external_id=document_id + ) + + saved_documents.append({ + "id": str(doc_record.id), + "name": doc_record.document_name, + "external_id": doc_record.external_id + }) + + logger.info(f"文档 '{doc_name}' 上传成功,ID: {document_id}") + else: + # 上传失败,记录错误信息 + error_msg = upload_response.get('message', '未知错误') if upload_response else '上传API调用失败' + logger.error(f"文档 '{doc_name}' 上传失败: {error_msg}") + failed_documents.append({ + "name": doc_name, + "error": error_msg + }) + + # 返回结果 + if saved_documents: + return Response({ + "code": 200, + "message": f"文档上传完成,成功: {len(saved_documents)},失败: {len(failed_documents)}", + "data": { + "uploaded_count": len(saved_documents), + "failed_count": len(failed_documents), + "total_files": len(files), + "documents": saved_documents, + "failed_documents": failed_documents + } + }) + else: + return Response({ + "code": 400, + "message": f"所有文档上传失败", + "data": { + "uploaded_count": 0, + "failed_count": len(failed_documents), + "total_files": len(files), + "documents": [], + "failed_documents": failed_documents + } + }, status=status.HTTP_400_BAD_REQUEST) + + except Exception as e: + logger.error(f"文档上传失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"文档上传失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def _call_upload_api(self, external_id, doc_data): + """调用文档上传API""" + try: + url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document' + logger.info(f"调用文档上传API: {url}") + + # 记录请求数据,方便调试 + logger.info(f"上传文档数据: 文档名={doc_data.get('name')}, 段落数={len(doc_data.get('paragraphs', []))}") + + # 发送请求 + response = requests.post(url, json=doc_data) + + # 记录响应结果 + logger.info(f"上传API响应状态码: {response.status_code}") + + # 检查响应状态码 + if response.status_code != 200: + logger.error(f"上传API HTTP错误: {response.status_code}, 响应: {response.text}") + return { + 'code': response.status_code, + 'message': f"上传失败,HTTP状态码: {response.status_code}", + 'data': None + } + + # 解析响应JSON + result = response.json() + logger.info(f"上传API响应内容: {result}") + + # 检查业务状态码 + if result.get('code') != 200: + error_msg = result.get('message', '未知错误') + logger.error(f"上传API业务错误: {error_msg}") + return { + 'code': result.get('code', 500), + 'message': error_msg, + 'data': None + } + + return result + + except requests.exceptions.RequestException as e: + logger.error(f"调用上传API网络错误: {str(e)}") + return { + 'code': 500, + 'message': f"网络请求错误: {str(e)}", + 'data': None + } + except json.JSONDecodeError as e: + logger.error(f"解析API响应JSON失败: {str(e)}") + return { + 'code': 500, + 'message': f"解析响应数据失败: {str(e)}", + 'data': None + } + except Exception as e: + logger.error(f"调用上传API其他错误: {str(e)}") + return { + 'code': 500, + 'message': f"上传API调用失败: {str(e)}", + 'data': None + } + + def _call_delete_document_api(self, external_id, document_id): + """调用文档删除API""" + try: + url = f'{settings.API_BASE_URL}/api/dataset/{external_id}/document/{document_id}' + response = requests.delete(url) + return response.json() + except Exception as e: + logger.error(f"调用删除API失败: {str(e)}") + return None + + def _create_external_dataset(self, instance): + """创建外部知识库""" + try: + api_data = { + "name": instance.name, + "desc": instance.desc, + "type": "0", # 添加必要的type字段 + "meta": {}, # 添加必要的meta字段 + "documents": [] # 初始化为空列表 + } + + response = requests.post( + f'{settings.API_BASE_URL}/api/dataset', + json=api_data, + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + if response.status_code != 200: + raise ExternalAPIError(f"创建失败,状态码: {response.status_code}, 响应: {response.text}") + + api_response = response.json() + if not api_response.get('code') == 200: + raise ExternalAPIError(f"业务处理失败: {api_response.get('message', '未知错误')}") + + dataset_id = api_response.get('data', {}).get('id') + if not dataset_id: + raise ExternalAPIError("响应数据中缺少dataset id") + + return dataset_id + + except requests.exceptions.Timeout: + raise ExternalAPIError("请求超时,请稍后重试") + except requests.exceptions.RequestException as e: + raise ExternalAPIError(f"API请求失败: {str(e)}") + except Exception as e: + raise ExternalAPIError(f"创建外部知识库失败: {str(e)}") + + def _delete_external_dataset(self, external_id): + """删除外部知识库""" + try: + if not external_id: + raise ExternalAPIError("外部知识库ID不能为空") + + response = requests.delete( + f'{settings.API_BASE_URL}/api/dataset/{external_id}', + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}") + + # 检查响应状态码 + if response.status_code == 404: + logger.warning(f"外部知识库不存在: {external_id}") + return True # 如果知识库不存在,也视为删除成功 + elif response.status_code not in [200, 204]: + raise ExternalAPIError(f"删除失败,状态码: {response.status_code}, 响应: {response.text}") + + # 如果是 204 状态码,说明删除成功但无返回内容 + if response.status_code == 204: + logger.info(f"外部知识库删除成功: {external_id}") + return True + + # 如果是 200 状态码,检查响应内容 + try: + api_response = response.json() + if api_response.get('code') != 200: + raise ExternalAPIError(f"业务处理失败: {api_response.get('message', '未知错误')}") + logger.info(f"外部知识库删除成功: {external_id}") + return True + except ValueError: + # 如果无法解析 JSON,但状态码是 200,也认为成功 + logger.warning(f"外部知识库删除响应无法解析JSON,但状态码为200,视为成功: {external_id}") + return True + + except requests.exceptions.Timeout: + logger.error(f"删除外部知识库超时: {external_id}") + raise ExternalAPIError("请求超时,请稍后重试") + except requests.exceptions.RequestException as e: + logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}") + raise ExternalAPIError(f"API请求失败: {str(e)}") + except Exception as e: + logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}") + raise ExternalAPIError(f"删除外部知识库失败: {str(e)}") + + @action(detail=True, methods=['get']) + def documents(self, request, pk=None): + """获取知识库的文档列表""" + try: + instance = self.get_object() + user = request.user + + # 权限检查 + if not self.check_knowledge_base_permission(instance, user, 'read'): + return Response({ + "code": 403, + "message": "没有查看权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + # 检查external_id是否存在 + if not instance.external_id: + return Response({ + "code": 400, + "message": "知识库没有有效的external_id", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 调用外部API获取文档列表 + try: + url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}/document' + response = requests.get( + url, + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + if response.status_code != 200: + logger.error(f"获取文档列表API调用失败: {response.status_code}, {response.text}") + return Response({ + "code": 500, + "message": f"获取文档列表失败: HTTP {response.status_code}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + result = response.json() + if result.get('code') != 200: + logger.error(f"获取文档列表业务失败: {result.get('message')}") + return Response({ + "code": result.get('code', 500), + "message": result.get('message', '获取文档列表失败'), + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # 同步外部文档到本地数据库 + external_documents = result.get('data', []) + for doc in external_documents: + # 获取外部文档ID和名称 + external_id = doc.get('id') + doc_name = doc.get('name') + + if external_id and doc_name: + # 检查文档是否已存在 + kb_doc, created = KnowledgeBaseDocument.objects.update_or_create( + knowledge_base=instance, + external_id=external_id, + defaults={ + 'document_id': external_id, + 'document_name': doc_name, + 'status': 'active' if doc.get('is_active', True) else 'deleted' + } + ) + + if created: + logger.info(f"同步创建文档: {doc_name}, ID: {external_id}") + else: + logger.info(f"同步更新文档: {doc_name}, ID: {external_id}") + + # 获取最新的本地文档数据 + documents = KnowledgeBaseDocument.objects.filter( + knowledge_base=instance, + status='active' + ).order_by('-create_time') + + # 构建响应数据 + documents_data = [{ + "id": str(doc.id), + "document_id": doc.document_id, + "name": doc.document_name, + "external_id": doc.external_id, + "created_at": doc.create_time.strftime('%Y-%m-%d %H:%M:%S'), + # 添加外部API返回的额外信息 + "char_length": next((d.get('char_length', 0) for d in external_documents if d.get('id') == doc.external_id), 0), + "paragraph_count": next((d.get('paragraph_count', 0) for d in external_documents if d.get('id') == doc.external_id), 0), + "is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True) + } for doc in documents] + + return Response({ + "code": 200, + "message": "获取文档列表成功", + "data": documents_data + }) + + except requests.exceptions.RequestException as e: + logger.error(f"获取文档列表网络异常: {str(e)}") + return Response({ + "code": 500, + "message": f"获取文档列表失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + except Exception as e: + logger.error(f"获取文档列表失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"获取文档列表失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=True, methods=['get']) + def document_content(self, request, pk=None): + """获取文档内容 - 段落列表""" + try: + knowledge_base = self.get_object() + user = request.user + + # 权限检查 + if not self.check_knowledge_base_permission(knowledge_base, user, 'read'): + return Response({ + "code": 403, + "message": "没有查看权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + # 获取文档ID + document_id = request.query_params.get('document_id') + if not document_id: + return Response({ + "code": 400, + "message": "缺少document_id参数", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证文档存在 + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + document_id=document_id, + status='active' + ).first() + + if not document: + return Response({ + "code": 404, + "message": "文档不存在或已删除", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + + # 调用正确的外部API获取文档段落内容 + try: + url = f'{settings.API_BASE_URL}/api/dataset/{knowledge_base.external_id}/document/{document.external_id}/paragraph' + response = requests.get(url, timeout=30) + + if response.status_code != 200: + logger.error(f"获取文档段落内容失败: {response.status_code}, {response.text}") + return Response({ + "code": 500, + "message": f"获取文档段落内容失败,状态码: {response.status_code}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + api_response = response.json() + if api_response.get('code') != 200: + logger.error(f"获取文档段落内容业务失败: {api_response.get('message')}") + return Response({ + "code": api_response.get('code', 500), + "message": api_response.get('message', '获取文档段落内容失败'), + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + paragraphs = api_response.get('data', []) + + # 直接返回外部API的段落数据 + return Response({ + "code": 200, + "message": "获取文档内容成功", + "data": { + "document_id": document_id, + "name": document.document_name, + "paragraphs": paragraphs + } + }) + except Exception as e: + logger.error(f"获取文档段落内容API调用失败: {str(e)}") + return Response({ + "code": 500, + "message": f"获取文档内容失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + except Exception as e: + logger.error(f"获取文档内容失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"获取文档内容失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=True, methods=['delete']) + def delete_document(self, request, pk=None): + """删除知识库文档""" + try: + knowledge_base = self.get_object() + user = request.user + + # 权限检查 + if not self.check_knowledge_base_permission(knowledge_base, user, 'edit'): + return Response({ + "code": 403, + "message": "没有编辑权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + # 获取文档ID + document_id = request.query_params.get('document_id') + if not document_id: + return Response({ + "code": 400, + "message": "缺少document_id参数", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证文档存在 + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + document_id=document_id, + status='active' + ).first() + + if not document: + return Response({ + "code": 404, + "message": "文档不存在或已删除", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + + # 调用外部API删除文档 + try: + external_id = document.external_id + delete_result = self._call_delete_document_api(knowledge_base.external_id, external_id) + + # 无论外部API结果如何,都更新本地状态 + document.status = 'deleted' + document.save() + + if delete_result and delete_result.get('code') != 200: + logger.warning(f"外部API删除文档失败,但本地标记已更新: {delete_result.get('message')}") + + return Response({ + "code": 200, + "message": "文档删除成功", + "data": { + "document_id": document_id, + "name": document.document_name + } + }) + except Exception as e: + logger.error(f"调用删除文档API失败: {str(e)}") + # 即使外部API调用失败,也更新本地状态 + document.status = 'deleted' + document.save() + + return Response({ + "code": 200, + "message": "文档在系统中已标记为删除,但外部API调用失败", + "data": { + "document_id": document_id, + "name": document.document_name, + "error": str(e) + } + }) + + except Exception as e: + logger.error(f"删除文档失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": f"删除文档失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class PermissionViewSet(viewsets.ModelViewSet): + serializer_class = PermissionSerializer + permission_classes = [IsAuthenticated] + + def can_manage_knowledge_base(self, user, knowledge_base): + """检查用户是否是知识库的创建者""" + return str(knowledge_base.user_id) == str(user.id) + + def get_queryset(self): + """ + 获取权限申请列表: + 1. applicant_id 是当前用户 (看到自己发起的申请) + 2. approver_id 是当前用户 (看到自己需要审批的申请) + """ + user_id = str(self.request.user.id) + + # 构建查询条件:申请人是自己 或 审批人是自己 + query = Q(applicant_id=user_id) | Q(approver_id=user_id) + + return Permission.objects.filter(query).select_related( + 'knowledge_base', + 'applicant', + 'approver' + ) + + def list(self, request, *args, **kwargs): + """获取权限申请列表,包含详细信息""" + try: + queryset = self.get_queryset() + user_id = str(request.user.id) + + # 获取分页参数 + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 10)) + + # 计算总数 + total = queryset.count() + + # 手动分页 + start = (page - 1) * page_size + end = start + page_size + permissions = queryset[start:end] + + # 构建响应数据 + data = [] + for permission in permissions: + # 检查当前用户是否是申请人或审批人 + if user_id not in [str(permission.applicant_id), str(permission.approver_id)]: + continue + + # 构建响应数据 + permission_data = { + 'id': str(permission.id), + 'knowledge_base': { + 'id': str(permission.knowledge_base.id), + 'name': permission.knowledge_base.name, + 'type': permission.knowledge_base.type, + }, + 'applicant': { + 'id': str(permission.applicant.id), + 'username': permission.applicant.username, + 'name': permission.applicant.name, + 'department': permission.applicant.department, + }, + 'approver': { + 'id': str(permission.approver.id) if permission.approver else '', + 'username': permission.approver.username if permission.approver else '', + 'name': permission.approver.name if permission.approver else '', + 'department': permission.approver.department if permission.approver else '', + }, + 'permissions': permission.permissions, + 'status': permission.status, + 'created_at': permission.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None, + 'response_message': permission.response_message or '', + # 添加角色标识,用于前端展示 + 'role': 'applicant' if str(permission.applicant_id) == user_id else 'approver' + } + + data.append(permission_data) + + return Response({ + 'code': 200, + 'message': '获取权限申请列表成功', + 'data': { + 'total': len(data), # 使用过滤后的实际数量 + 'page': page, + 'page_size': page_size, + 'results': data + } + }) + + except Exception as e: + logger.error(f"获取权限申请列表失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取权限申请列表失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def perform_create(self, serializer): + """创建权限申请并发送通知给知识库创建者""" + # 获取知识库 + # 获取知识库 + knowledge_base = serializer.validated_data['knowledge_base'] + + # 检查是否是申请访问自己的知识库 + if str(knowledge_base.user_id) == str(self.request.user.id): + raise ValidationError({ + "code": 400, + "message": "您是此知识库的创建者,无需申请权限", + "data": None + }) + # 获取知识库创建者作为审批者 + approver = User.objects.get(id=knowledge_base.user_id) + + # 验证权限请求 + requested_permissions = serializer.validated_data.get('permissions', {}) + expires_at = serializer.validated_data.get('expires_at') + + if not any([requested_permissions.get('can_read'), + requested_permissions.get('can_edit'), + requested_permissions.get('can_delete')]): + raise ValidationError("至少需要申请一种权限(读/改/删)") + + if not expires_at: + raise ValidationError("请指定权限到期时间") + + # 检查是否已有未过期的权限申请 + existing_request = Permission.objects.filter( + knowledge_base=knowledge_base, + applicant=self.request.user, + status='pending' + ).first() + + if existing_request: + raise ValidationError("您已有一个待处理的权限申请") + + # 检查是否已有有效的权限 + existing_permission = Permission.objects.filter( + knowledge_base=knowledge_base, + applicant=self.request.user, + status='approved', + expires_at__gt=timezone.now() + ).first() + + if existing_permission: + raise ValidationError("您已有此知识库的访问权限") + + # 保存权限申请,设置审批者 + permission = serializer.save( + applicant=self.request.user, + status='pending', + approver=approver # 创建时就设置审批者 + ) + + # 获取权限类型字符串 + permission_types = [] + if requested_permissions.get('can_read'): + permission_types.append('读取') + if requested_permissions.get('can_edit'): + permission_types.append('编辑') + if requested_permissions.get('can_delete'): + permission_types.append('删除') + permission_str = '、'.join(permission_types) + + # 发送通知给知识库创建者 + owner = User.objects.get(id=knowledge_base.user_id) + self.send_notification( + user=owner, + title="新的权限申请", + content=f"用户 {self.request.user.name} 申请了知识库 '{knowledge_base.name}' 的{permission_str}权限", + notification_type="permission_request", + related_object_id=permission.id + ) + + def send_notification(self, user, title, content, notification_type, related_object_id): + """发送通知""" + try: + notification = Notification.objects.create( + sender=self.request.user, + receiver=user, + title=title, + content=content, + type=notification_type, + related_resource=related_object_id, + ) + + # 通过WebSocket发送实时通知 + channel_layer = get_channel_layer() + async_to_sync(channel_layer.group_send)( + f"notification_user_{user.id}", + { + "type": "notification", + "data": { + "id": str(notification.id), + "title": notification.title, + "content": notification.content, + "type": notification.type, + "created_at": notification.created_at.isoformat(), + "sender": { + "id": str(notification.sender.id), + "name": notification.sender.name + } + } + } + ) + except Exception as e: + logger.error(f"发送通知时发生错误: {str(e)}") + + @action(detail=True, methods=['post']) + def approve(self, request, pk=None): + try: + # 获取权限申请记录 + permission = self.get_object() + + # 只检查是否是知识库创建者 + if not self.can_manage_knowledge_base(request.user, permission.knowledge_base): + logger.warning(f"用户 {request.user.username} 尝试审批知识库 {permission.knowledge_base.name} 的权限申请,但不是创建者") + return Response({ + 'code': 403, + 'message': '只有知识库创建者可以审批此申请', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 获取审批意见 + response_message = request.data.get('response_message', '') + + with transaction.atomic(): + # 更新权限申请状态 + permission.status = 'approved' + permission.approver = request.user + permission.response_message = response_message + permission.save() + + # 检查是否已存在权限记录 + kb_permission = KBPermissionModel.objects.filter( + knowledge_base=permission.knowledge_base, + user=permission.applicant + ).first() + + if kb_permission: + # 更新现有权限 + kb_permission.can_read = permission.permissions.get('can_read', False) + kb_permission.can_edit = permission.permissions.get('can_edit', False) + kb_permission.can_delete = permission.permissions.get('can_delete', False) + kb_permission.granted_by = request.user + kb_permission.status = 'active' + kb_permission.expires_at = permission.expires_at + kb_permission.save() + logger.info(f"更新知识库权限记录: {kb_permission.id}") + else: + # 创建新的权限记录 + kb_permission = KBPermissionModel.objects.create( + knowledge_base=permission.knowledge_base, + user=permission.applicant, + can_read=permission.permissions.get('can_read', False), + can_edit=permission.permissions.get('can_edit', False), + can_delete=permission.permissions.get('can_delete', False), + granted_by=request.user, + status='active', + expires_at=permission.expires_at + ) + logger.info(f"创建新的知识库权限记录: {kb_permission.id}") + + # 发送通知给申请人 + self.send_notification( + user=permission.applicant, + title="权限申请已通过", + content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已通过", + notification_type="permission_approved", + related_object_id=permission.id + ) + + return Response({ + 'code': 200, + 'message': '权限申请已批准', + 'data': None + }) + + except Permission.DoesNotExist: + return Response({ + 'code': 404, + 'message': '权限申请不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + logger.error(f"处理权限申请失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'处理权限申请失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=True, methods=['post']) + def reject(self, request, pk=None): + """拒绝权限申请""" + permission = self.get_object() + + # 检查是否是知识库创建者 + if str(permission.knowledge_base.user_id) != str(request.user.id): + return Response({ + 'code': 403, + 'message': '只有知识库创建者可以审批此申请', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 检查申请是否已被处理 + if permission.status != 'pending': + return Response({ + 'code': 400, + 'message': '该申请已被处理', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证拒绝原因 + response_message = request.data.get('response_message') + if not response_message: + return Response({ + 'code': 400, + 'message': '请填写拒绝原因', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新权限状态 + permission.status = 'rejected' + permission.approver = request.user + permission.response_message = response_message + permission.save() + + # 发送通知给申请人 + self.send_notification( + user=permission.applicant, + title="权限申请已拒绝", + content=f"您对知识库 '{permission.knowledge_base.name}' 的权限申请已被拒绝\n" + f"拒绝原因:{response_message}", + notification_type="permission_rejected", + related_object_id=permission.id + ) + + return Response({ + 'code': 200, + 'message': '权限申请已拒绝', + 'data': PermissionSerializer(permission).data + }) + + @action(detail=True, methods=['post']) + def extend(self, request, pk=None): + """延长权限有效期""" + instance = self.get_object() + user = request.user + + # 检查是否有权限延长 + if not self.check_extend_permission(instance, user): + return Response({ + "code": 403, + "message": "您没有权限延长此权限", + "data": None + }, status=status.HTTP_403_FORBIDDEN) + + new_expires_at = request.data.get('expires_at') + if not new_expires_at: + return Response({ + "code": 400, + "message": "请设置新的过期时间", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + try: + with transaction.atomic(): + # 更新权限申请表的过期时间 + instance.expires_at = new_expires_at + instance.save() + + # 同步更新知识库权限表的过期时间 + kb_permission = KBPermissionModel.objects.get( + knowledge_base=instance.knowledge_base, + user=instance.applicant + ) + kb_permission.expires_at = new_expires_at + kb_permission.save() + + return Response({ + "code": 200, + "message": "权限有效期延长成功", + "data": PermissionSerializer(instance).data + }) + except Exception as e: + return Response({ + "code": 500, + "message": f"延长权限有效期失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def check_extend_permission(self, permission, user): + """检查是否有权限延长权限有效期""" + knowledge_base = permission.knowledge_base + + # 私人知识库只有拥有者能延长 + if knowledge_base.type == 'private': + return knowledge_base.owner == user + + # 组长知识库只有管理员能延长 + if knowledge_base.type == 'leader': + return user.role == 'admin' + + # 组员知识库可以由管理员或本部门组长延长 + if knowledge_base.type == 'member': + return ( + user.role == 'admin' or + (user.role == 'leader' and user.department == knowledge_base.department) + ) + + return False + + @action(detail=False, methods=['get']) + def user_permissions(self, request): + """获取指定用户的所有知识库权限""" + try: + # 获取用户名参数 + username = request.query_params.get('username') + if not username: + return Response({ + 'code': 400, + 'message': '请提供用户名', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取用户 + try: + target_user = User.objects.get(username=username) + except User.DoesNotExist: + return Response({ + 'code': 404, + 'message': f'用户 {username} 不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 获取该用户的所有权限记录 + permissions = KBPermissionModel.objects.filter( + user=target_user, + status='active' + ).select_related('knowledge_base', 'granted_by') + + # 构建响应数据 + permissions_data = [] + for perm in permissions: + perm_data = { + 'id': str(perm.id), + 'knowledge_base': { + 'id': str(perm.knowledge_base.id), + 'name': perm.knowledge_base.name, + 'type': perm.knowledge_base.type, + 'department': perm.knowledge_base.department, + 'group': perm.knowledge_base.group + }, + 'permissions': { + 'can_read': perm.can_read, + 'can_edit': perm.can_edit, + 'can_delete': perm.can_delete + }, + 'granted_by': { + 'id': str(perm.granted_by.id) if perm.granted_by else None, + 'username': perm.granted_by.username if perm.granted_by else None, + 'name': perm.granted_by.name if perm.granted_by else None + }, + 'created_at': perm.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None, + 'status': perm.status + } + permissions_data.append(perm_data) + + return Response({ + 'code': 200, + 'message': '获取用户权限成功', + 'data': { + 'user': { + 'id': str(target_user.id), + 'username': target_user.username, + 'name': target_user.name, + 'department': target_user.department, + 'role': target_user.role + }, + 'permissions': permissions_data + } + }) + + except Exception as e: + logger.error(f"获取用户权限失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取用户权限失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get']) + def all_permissions(self, request): + """管理员获取所有用户的知识库权限(不包括私有知识库)""" + try: + # 检查是否是管理员 + if request.user.role != 'admin': + return Response({ + 'code': 403, + 'message': '只有管理员可以查看所有权限', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 获取查询参数 + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 10)) + status_filter = request.query_params.get('status') + department = request.query_params.get('department') + kb_type = request.query_params.get('kb_type') + + # 构建基础查询 + queryset = KBPermissionModel.objects.filter( + ~Q(knowledge_base__type='private') + ).select_related( + 'user', + 'knowledge_base', + 'granted_by' + ) + + # 应用过滤条件 + if status_filter == 'active': + queryset = queryset.filter( + Q(expires_at__gt=timezone.now()) | Q(expires_at__isnull=True), + status='active' + ) + elif status_filter == 'expired': + queryset = queryset.filter( + Q(expires_at__lte=timezone.now()) | Q(status='inactive') + ) + + if department: + queryset = queryset.filter(user__department=department) + + if kb_type: + queryset = queryset.filter(knowledge_base__type=kb_type) + + # 按用户分组处理数据 + user_permissions = {} + for perm in queryset: + user_id = str(perm.user.id) + if user_id not in user_permissions: + user_permissions[user_id] = { + 'user_info': { + 'id': user_id, + 'username': perm.user.username, + 'name': getattr(perm.user, 'name', perm.user.username), + 'department': getattr(perm.user, 'department', None), + 'role': getattr(perm.user, 'role', None) + }, + 'permissions': [], + 'stats': { + 'total': 0, + 'by_type': { + 'admin': 0, + 'secret': 0, + 'leader': 0, + 'member': 0 + }, + 'by_permission': { + 'read_only': 0, + 'read_write': 0, + 'full_access': 0 + } + } + } + + # 添加权限信息 + perm_data = { + 'id': str(perm.id), + 'knowledge_base': { + 'id': str(perm.knowledge_base.id), + 'name': perm.knowledge_base.name, + 'type': perm.knowledge_base.type, + 'department': perm.knowledge_base.department, + 'group': perm.knowledge_base.group, + 'creator': { + 'id': str(perm.knowledge_base.user_id), + 'name': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'name', None), + 'username': getattr(User.objects.filter(id=perm.knowledge_base.user_id).first(), 'username', None) + } + }, + 'permissions': { + 'can_read': perm.can_read, + 'can_edit': perm.can_edit, + 'can_delete': perm.can_delete + }, + 'granted_by': { + 'id': str(perm.granted_by.id) if perm.granted_by else None, + 'username': perm.granted_by.username if perm.granted_by else None, + 'name': getattr(perm.granted_by, 'name', None) if perm.granted_by else None + }, + 'granted_at': perm.granted_at.strftime('%Y-%m-%d %H:%M:%S'), + 'expires_at': perm.expires_at.strftime('%Y-%m-%d %H:%M:%S') if perm.expires_at else None, + 'status': perm.status + } + + user_permissions[user_id]['permissions'].append(perm_data) + + # 更新统计信息 + stats = user_permissions[user_id]['stats'] + stats['total'] += 1 + stats['by_type'][perm.knowledge_base.type] += 1 + + # 统计权限级别 + if perm.can_delete: + stats['by_permission']['full_access'] += 1 + elif perm.can_edit: + stats['by_permission']['read_write'] += 1 + elif perm.can_read: + stats['by_permission']['read_only'] += 1 + + # 转换为列表并分页 + users_list = list(user_permissions.values()) + total = len(users_list) + start = (page - 1) * page_size + end = start + page_size + paginated_users = users_list[start:end] + + return Response({ + 'code': 200, + 'message': '获取权限列表成功', + 'data': { + 'total': total, + 'page': page, + 'page_size': page_size, + 'results': paginated_users + } + }) + + except Exception as e: + logger.error(f"获取所有权限失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'获取所有权限失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['post']) + def update_permission(self, request): + """管理员更新用户的知识库权限""" + try: + # 检查是否是管理员 + if request.user.role != 'admin': + return Response({ + 'code': 403, + 'message': '只有管理员可以直接修改权限', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 验证必要参数 + user_id = request.data.get('user_id') + knowledge_base_id = request.data.get('knowledge_base_id') + permissions = request.data.get('permissions') + expires_at_str = request.data.get('expires_at') + + if not all([user_id, knowledge_base_id, permissions]): + return Response({ + 'code': 400, + 'message': '缺少必要参数', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证权限参数格式 + required_permission_fields = ['can_read', 'can_edit', 'can_delete'] + if not all(field in permissions for field in required_permission_fields): + return Response({ + 'code': 400, + 'message': '权限参数格式错误,必须包含 can_read、can_edit、can_delete', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取用户和知识库 + try: + user = User.objects.get(id=user_id) + knowledge_base = KnowledgeBase.objects.get(id=knowledge_base_id) + except User.DoesNotExist: + return Response({ + 'code': 404, + 'message': f'用户ID {user_id} 不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + except KnowledgeBase.DoesNotExist: + return Response({ + 'code': 404, + 'message': f'知识库ID {knowledge_base_id} 不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + + # 检查知识库类型和用户角色的匹配 + if knowledge_base.type == 'private' and str(knowledge_base.user_id) != str(user.id): + return Response({ + 'code': 403, + 'message': '不能修改其他用户的私有知识库权限', + 'data': None + }, status=status.HTTP_403_FORBIDDEN) + + # 处理过期时间 + expires_at = None + if expires_at_str: + try: + # 将字符串转换为datetime对象 + expires_at = timezone.datetime.strptime( + expires_at_str, + '%Y-%m-%dT%H:%M:%SZ' + ) + # 确保时区感知 + expires_at = timezone.make_aware(expires_at) + + # 检查是否早于当前时间 + if expires_at <= timezone.now(): + return Response({ + 'code': 400, + 'message': '过期时间不能早于或等于当前时间', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + except ValueError: + return Response({ + 'code': 400, + 'message': '过期时间格式错误,应为 ISO 格式 (YYYY-MM-DDThh:mm:ssZ)', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 根据用户角色限制权限 + if user.role == 'member' and permissions.get('can_delete'): + return Response({ + 'code': 400, + 'message': '普通成员不能获得删除权限', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新或创建权限记录 + try: + with transaction.atomic(): + permission, created = KBPermissionModel.objects.update_or_create( + user=user, + knowledge_base=knowledge_base, + defaults={ + 'can_read': permissions.get('can_read', False), + 'can_edit': permissions.get('can_edit', False), + 'can_delete': permissions.get('can_delete', False), + 'granted_by': request.user, + 'status': 'active', + 'expires_at': expires_at + } + ) + + # 发送通知给用户 + self.send_notification( + user=user, + title="知识库权限更新", + content=f"管理员已{created and '授予' or '更新'}您对知识库 '{knowledge_base.name}' 的权限", + notification_type="permission_updated", + related_object_id=permission.id + ) + except IntegrityError as e: + return Response({ + 'code': 500, + 'message': f'数据库操作失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + return Response({ + 'code': 200, + 'message': f"{'创建' if created else '更新'}权限成功", + 'data': { + 'id': str(permission.id), + 'user': { + 'id': str(user.id), + 'username': user.username, + 'name': user.name, + 'department': user.department, + 'role': user.role + }, + 'knowledge_base': { + 'id': str(knowledge_base.id), + 'name': knowledge_base.name, + 'type': knowledge_base.type, + 'department': knowledge_base.department, + 'group': knowledge_base.group + }, + 'permissions': { + 'can_read': permission.can_read, + 'can_edit': permission.can_edit, + 'can_delete': permission.can_delete + }, + 'granted_by': { + 'id': str(request.user.id), + 'username': request.user.username, + 'name': request.user.name + }, + 'expires_at': permission.expires_at.strftime('%Y-%m-%d %H:%M:%S') if permission.expires_at else None, + 'created': created + } + }) + + except Exception as e: + logger.error(f"更新权限失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + 'code': 500, + 'message': f'更新权限失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class NotificationViewSet(viewsets.ModelViewSet): + """通知视图集""" + queryset = Notification.objects.all() + serializer_class = NotificationSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + """只返回用户自己的通知""" + return Notification.objects.filter(receiver=self.request.user) + + @action(detail=True, methods=['post']) + def mark_as_read(self, request, pk=None): + """标记通知为已读""" + notification = self.get_object() + notification.is_read = True + notification.save() + return Response({'status': 'marked as read'}) + + @action(detail=False, methods=['post']) + def mark_all_as_read(self, request): + """标记所有通知为已读""" + self.get_queryset().update(is_read=True) + return Response({'status': 'all marked as read'}) + + @action(detail=False, methods=['get']) + def unread_count(self, request): + """获取未读通知数量""" + count = self.get_queryset().filter(is_read=False).count() + return Response({'unread_count': count}) + + @action(detail=False, methods=['get']) + def latest(self, request): + """获取最新通知""" + notifications = self.get_queryset().filter( + is_read=False + ).order_by('-created_at')[:5] + serializer = self.get_serializer(notifications, many=True) + return Response(serializer.data) + + def perform_create(self, serializer): + """创建通知时自动设置发送者""" + serializer.save(sender=self.request.user) + + +@method_decorator(csrf_exempt, name='dispatch') +class LoginView(APIView): + """用户登录视图""" + authentication_classes = [] # 清空认证类 + permission_classes = [AllowAny] + + def post(self, request): + try: + username = request.data.get('username') + password = request.data.get('password') + + # 参数验证 + if not username or not password: + return Response({ + "code": 400, + "message": "请提供用户名和密码", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证用户 + user = authenticate(request, username=username, password=password) + + if user is not None: + # 获取或创建token + token, _ = Token.objects.get_or_create(user=user) + + return Response({ + "code": 200, + "message": "登录成功", + "data": { + "id": str(user.id), + "username": user.username, + "email": user.email, + "role": user.role, + "department": user.department, + "name": user.name, + "group": user.group, + "token": token.key + } + }) + else: + return Response({ + "code": 401, + "message": "用户名或密码错误", + "data": None + }, status=status.HTTP_401_UNAUTHORIZED) + + except Exception as e: + import traceback + logger.error(f"登录失败: {str(e)}") + logger.error(f"错误类型: {type(e)}") + logger.error(f"错误堆栈: {traceback.format_exc()}") + + return Response({ + "code": 500, + "message": "登录失败,请稍后重试", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@method_decorator(csrf_exempt, name='dispatch') +class RegisterView(APIView): + """用户注册视图""" + permission_classes = [AllowAny] + + def post(self, request): + try: + data = request.data + + # 检查必填字段 + required_fields = ['username', 'password', 'email', 'role', 'department', 'name'] + for field in required_fields: + if not data.get(field): + return Response({ + "code": 400, + "message": f"缺少必填字段: {field}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证角色 + valid_roles = ['admin', 'leader', 'member'] + roles_str = ', '.join(valid_roles) # 先构造角色字符串 + if data['role'] not in valid_roles: + return Response({ + "code": 400, + "message": f"无效的角色,必须是: {roles_str}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证部门是否存在 + if data['department'] not in settings.DEPARTMENT_GROUPS: + return Response({ + "code": 400, + "message": f"无效的部门,可选部门: {', '.join(settings.DEPARTMENT_GROUPS.keys())}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 如果是组员,验证小组 + if data['role'] == 'member': + if not data.get('group'): + return Response({ + "code": 400, + "message": "组员必须指定所属小组", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证小组是否存在且属于指定部门 + valid_groups = settings.DEPARTMENT_GROUPS.get(data['department'], []) + if data['group'] not in valid_groups: + return Response({ + "code": 400, + "message": f"无效的小组,{data['department']}的可选小组: {', '.join(valid_groups)}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查用户名是否已存在 + if User.objects.filter(username=data['username']).exists(): + return Response({ + "code": 400, + "message": "用户名已存在", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查邮箱是否已存在 + if User.objects.filter(email=data['email']).exists(): + return Response({ + "code": 400, + "message": "邮箱已被注册", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证密码强度 + if len(data['password']) < 8: + return Response({ + "code": 400, + "message": "密码长度必须至少为8位", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证邮箱格式 + try: + validate_email(data['email']) + except ValidationError: + return Response({ + "code": 400, + "message": "邮箱格式不正确", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建用户 + user = User.objects.create_user( + username=data['username'], + email=data['email'], + password=data['password'], + role=data['role'], + department=data['department'], + name=data['name'], + group=data.get('group') if data['role'] == 'member' else None, + is_staff=False, + is_superuser=False + ) + + # 生成认证令牌 + token, _ = Token.objects.get_or_create(user=user) + + return Response({ + "code": 200, + "message": "注册成功", + "data": { + "id": user.id, + "username": user.username, + "email": user.email, + "role": user.role, + "department": user.department, + "name": user.name, + "group": user.group, + "token": token.key, + "created_at": user.date_joined.strftime('%Y-%m-%d %H:%M:%S') + } + }, status=status.HTTP_201_CREATED) + + except Exception as e: + print(f"注册失败: {str(e)}") + print(f"错误类型: {type(e)}") + print(f"错误堆栈: {traceback.format_exc()}") + return Response({ + "code": 500, + "message": f"注册失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@method_decorator(csrf_exempt, name='dispatch') +class LogoutView(APIView): + """用户登出视图""" + permission_classes = [IsAuthenticated] + + def post(self, request): + try: + # 删除用户的token + request.user.auth_token.delete() + # 执行django的登出 + logout(request) + + return Response({ + "code": 200, + "message": "登出成功", + "data": None + }) + except Exception as e: + return Response({ + "code": 500, + "message": f"登出失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@api_view(['GET', 'PUT']) +@permission_classes([IsAuthenticated]) +def user_profile(request): + """获取或更新用户信息""" + if request.method == 'GET': + data = { + 'id': request.user.id, + 'username': request.user.username, + 'email': request.user.email, + 'role': request.user.role, + 'department': request.user.department, + 'phone': request.user.phone, + 'date_joined': request.user.date_joined + } + return Response(data) + + elif request.method == 'PUT': + user = request.user + # 只允许更新特定字段 + allowed_fields = ['email', 'phone', 'department'] + for field in allowed_fields: + if field in request.data: + setattr(user, field, request.data[field]) + user.save() + return Response({'message': '用户信息更新成功'}) + +@csrf_exempt +@api_view(['POST']) +@permission_classes([IsAuthenticated]) +def change_password(request): + """修改密码""" + try: + old_password = request.data.get('old_password') + new_password = request.data.get('new_password') + + # 验证参数 + if not old_password or not new_password: + return Response({ + "code": 400, + "message": "请提供旧密码和新密码", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证旧密码 + user = request.user + if not user.check_password(old_password): + return Response({ + "code": 400, + "message": "旧密码错误", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证新密码长度 + if len(new_password) < 8: + return Response({ + "code": 400, + "message": "新密码长度必须至少为8位", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 修改密码 + user.set_password(new_password) + user.save() + + # 更新token + user.auth_token.delete() + token, _ = Token.objects.get_or_create(user=user) + + return Response({ + "code": 200, + "message": "密码修改成功", + "data": { + "token": token.key + } + }) + + except Exception as e: + return Response({ + "code": 500, + "message": f"密码修改失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@api_view(['POST']) +@permission_classes([AllowAny]) +def user_register(request): + """用户注册""" + try: + data = request.data + + # 检查必填字段 + required_fields = ['username', 'password', 'email', 'role', 'department', 'name'] + for field in required_fields: + if not data.get(field): + return Response({ + 'error': f'缺少必填字段: {field}' + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证角色 + valid_roles = ['admin', 'leader', 'member'] + if data['role'] not in valid_roles: + return Response({ + 'error': f'无效的角色,必须是: {", ".join(valid_roles)}' + }, status=status.HTTP_400_BAD_REQUEST) + + # 如果是组员,必须指定小组 + if data['role'] == 'member' and not data.get('group'): + return Response({ + 'error': '组员必须指定所属小组' + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查用户名是否已存在 + if User.objects.filter(username=data['username']).exists(): + return Response({ + 'error': '用户名已存在' + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查邮箱是否已存在 + if User.objects.filter(email=data['email']).exists(): + return Response({ + 'error': '邮箱已被注册' + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证密码强度 + if len(data['password']) < 8: + return Response({ + 'error': '密码长度必须至少为8位' + }, status=status.HTTP_400_BAD_REQUEST) + + # 验证邮箱格式 + try: + validate_email(data['email']) + except ValidationError: + return Response({ + 'error': '邮箱格式不正确' + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建用户 + user = User.objects.create_user( + username=data['username'], + email=data['email'], + password=data['password'], + role=data['role'], + department=data['department'], + name=data['name'], + group=data.get('group') if data['role'] == 'member' else None, + is_staff=False, + is_superuser=False + ) + + # 生成认证令牌 + token, _ = Token.objects.get_or_create(user=user) + + return Response({ + 'message': '注册成功', + 'data': { + 'id': user.id, + 'username': user.username, + 'email': user.email, + 'role': user.role, + 'department': user.department, + 'name': user.name, + 'group': user.group, + 'token': token.key, + 'created_at': user.date_joined.strftime('%Y-%m-%d %H:%M:%S') + } + }, status=status.HTTP_201_CREATED) + + except Exception as e: + print(f"注册失败: {str(e)}") + print(f"错误类型: {type(e)}") + print(f"错误堆栈: {traceback.format_exc()}") + return Response({ + 'error': f'注册失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@csrf_exempt +@api_view(['POST']) +@permission_classes([IsAuthenticated]) +def verify_token(request): + """验证令牌有效性""" + try: + return Response({ + "code": 200, + "message": "令牌有效", + "data": { + "is_valid": True, + "user": { + "id": request.user.id, + "username": request.user.username, + "email": request.user.email, + "role": request.user.role, + "department": request.user.department, + "name": request.user.name, + "group": request.user.group + } + } + }) + except Exception as e: + return Response({ + "code": 500, + "message": f"验证失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def user_list(request): + """获取用户列表""" + user = request.user + if user.role == 'admin': + users = User.objects.all() + elif user.role == 'leader': + users = User.objects.filter(department=user.department) + else: + users = User.objects.filter(id=user.id) + + data = [{ + 'id': u.id, + 'username': u.username, + 'email': u.email, + 'role': u.role, + 'department': u.department, + 'is_active': u.is_active, + 'date_joined': u.date_joined + } for u in users] + + return Response(data) + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def user_detail(request, pk): + """获取用户详情""" + try: + # 尝试转换为 UUID + if not isinstance(pk, uuid.UUID): + try: + pk = uuid.UUID(pk) + except ValueError: + return Response({ + "code": 400, + "message": "无效的用户ID格式", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + user = get_object_or_404(User, pk=pk) + + return Response({ + "code": 200, + "message": "获取用户信息成功", + "data": { + "id": str(user.id), + "username": user.username, + "email": user.email, + "name": user.name, + "role": user.role, + "department": user.department, + "group": user.group + } + }) + except Exception as e: + return Response({ + "code": 500, + "message": f"获取用户信息失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@api_view(['PUT']) +@permission_classes([IsAdminUser]) +def user_update(request, pk): + """更新用户信息""" + try: + user = User.objects.get(pk=pk) + # 只允许更新特定字段 + allowed_fields = ['email', 'role', 'department', 'is_active', 'phone'] + for field in allowed_fields: + if field in request.data: + setattr(user, field, request.data[field]) + user.save() + return Response({'message': '用户信息更新成功'}) + except User.DoesNotExist: + return Response({'message': '用户不存在'}, status=404) + +@api_view(['DELETE']) +@permission_classes([IsAdminUser]) +def user_delete(request, pk): + """删除用户""" + try: + user = User.objects.get(pk=pk) + user.delete() + return Response({'message': '用户删除成功'}) + except User.DoesNotExist: + return Response({'message': '用户不存在'}, status=404) +