daren/apps/brands/views.py
2025-05-29 16:11:38 +08:00

845 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from django.shortcuts import render, get_object_or_404
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated
from apps.user.authentication import CustomTokenAuthentication
import logging
from django.db.models import Q
from .models import Brand, Product, Campaign, BrandChatSession
from .serializers import (
BrandSerializer,
ProductSerializer,
CampaignSerializer,
BrandChatSessionSerializer,
BrandDetailSerializer
)
from .services.status_polling_service import polling_service
from .services.offer_status_service import OfferStatusService
logger = logging.getLogger(__name__)
def api_response(code=200, message="成功", data=None, headers=None):
"""统一API响应格式"""
return Response({
'code': code,
'message': message,
'data': data
}, headers=headers)
class BrandViewSet(viewsets.ModelViewSet):
"""品牌API视图集"""
queryset = Brand.objects.all()
serializer_class = BrandSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_serializer_class(self):
if self.action == 'retrieve':
return BrandDetailSerializer
return BrandSerializer
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return api_response(data=serializer.data, headers=headers)
return api_response(code=400, message="创建失败", data=serializer.errors)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
@action(detail=True, methods=['get'])
def products(self, request, pk=None):
"""获取品牌下的所有产品"""
brand = self.get_object()
products = Product.objects.filter(brand=brand, is_active=True)
serializer = ProductSerializer(products, many=True)
return api_response(data=serializer.data)
@action(detail=True, methods=['get'])
def campaigns(self, request, pk=None):
"""获取品牌下的所有活动"""
brand = self.get_object()
campaigns = Campaign.objects.filter(brand=brand, is_active=True)
serializer = CampaignSerializer(campaigns, many=True)
return api_response(data=serializer.data)
@action(detail=True, methods=['get'])
def dataset_ids(self, request, pk=None):
"""获取品牌的所有知识库ID"""
brand = self.get_object()
return api_response(data={'dataset_id_list': brand.dataset_id_list})
@action(detail=False, methods=['get'])
def search(self, request):
"""关键字搜索品牌"""
keyword = request.query_params.get('keyword', '')
if not keyword:
return api_response(code=400, message="缺少关键字参数", data=None)
queryset = self.get_queryset().filter(
Q(name__icontains=keyword) |
Q(description__icontains=keyword) |
Q(category__icontains=keyword) |
Q(source__icontains=keyword)
)
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
class ProductViewSet(viewsets.ModelViewSet):
"""产品API视图集"""
queryset = Product.objects.filter(is_active=True)
serializer_class = ProductSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return api_response(data=serializer.data, headers=headers)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
def perform_create(self, serializer):
# 创建产品时自动更新品牌的dataset_id_list
product = serializer.save()
brand = product.brand
# 确保dataset_id添加到品牌的dataset_id_list中
if product.dataset_id and product.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(product.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_update(self, serializer):
# 获取原始产品信息
old_product = self.get_object()
old_dataset_id = old_product.dataset_id
# 保存更新后的产品
product = serializer.save()
brand = product.brand
# 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id
if old_dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(old_dataset_id)
if product.dataset_id and product.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(product.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_destroy(self, instance):
# 软删除产品并从品牌的dataset_id_list中移除对应的ID
instance.is_active = False
instance.save()
brand = instance.brand
if instance.dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(instance.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
@action(detail=False, methods=['get'])
def search(self, request):
"""关键字搜索产品"""
keyword = request.query_params.get('keyword', '')
if not keyword:
return api_response(code=400, message="缺少关键字参数", data=None)
queryset = self.get_queryset().filter(
Q(name__icontains=keyword) |
Q(description__icontains=keyword) |
Q(pid__icontains=keyword) |
Q(brand__name__icontains=keyword)
)
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
class CampaignViewSet(viewsets.ModelViewSet):
"""活动API视图集"""
queryset = Campaign.objects.filter(is_active=True)
serializer_class = CampaignSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_permissions(self):
"""根据不同的操作设置不同的权限"""
if self.action in ['stop_polling', 'active_pollings', 'token_info']:
# 这些操作不需要身份验证
return []
return super().get_permissions()
@action(detail=False, methods=['get'], url_path='token-info')
def token_info(self, request):
"""获取当前用户的token信息和WebSocket URL示例"""
# 检查用户是否已认证
if not request.user.is_authenticated:
return api_response(code=401, message="未授权,请先登录", data=None)
# 获取当前用户的token
from apps.user.models import UserToken
token = None
user_token = UserToken.objects.filter(user=request.user).first()
if user_token:
token = user_token.token
# 如果没有token返回错误
if not token:
return api_response(code=404, message="未找到有效的token请重新登录", data=None)
# 构建示例WebSocket URL
base_url = request.get_host()
ws_protocol = 'wss' if request.is_secure() else 'ws'
# 构建示例URL
ws_examples = {
"活动状态WebSocket": f"{ws_protocol}://{base_url}/ws/campaigns/1/status/?token={token}",
"活动产品状态WebSocket": f"{ws_protocol}://{base_url}/ws/campaigns/1/products/123/status/?token={token}",
}
# 构建响应
data = {
"user_id": request.user.id,
"email": request.user.email,
"token": token,
"token_expired_at": user_token.expired_at.strftime('%Y-%m-%d %H:%M:%S') if hasattr(user_token, 'expired_at') else None,
"websocket_examples": ws_examples
}
return api_response(data=data)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return api_response(data=serializer.data, headers=headers)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
def perform_create(self, serializer):
# 创建活动时自动更新品牌的dataset_id_list
campaign = serializer.save()
brand = campaign.brand
# 确保dataset_id添加到品牌的dataset_id_list中
if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(campaign.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_update(self, serializer):
# 获取原始活动信息
old_campaign = self.get_object()
old_dataset_id = old_campaign.dataset_id
# 保存更新后的活动
campaign = serializer.save()
brand = campaign.brand
# 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id
if old_dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(old_dataset_id)
if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(campaign.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_destroy(self, instance):
# 软删除活动并从品牌的dataset_id_list中移除对应的ID
instance.is_active = False
instance.save()
brand = instance.brand
if instance.dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(instance.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
@action(detail=True, methods=['post'])
def add_product(self, request, pk=None):
"""将产品添加到活动中"""
campaign = self.get_object()
product_id = request.data.get('product_id')
if not product_id:
return api_response(code=400, message="缺少产品ID", data=None)
try:
product = Product.objects.get(id=product_id, is_active=True)
campaign.link_product.add(product)
return api_response(message="产品添加成功", data=None)
except Product.DoesNotExist:
return api_response(code=404, message="产品不存在", data=None)
except Exception as e:
return api_response(code=500, message=f"添加产品失败: {str(e)}", data=None)
@action(detail=True, methods=['post'])
def remove_product(self, request, pk=None):
"""从活动中移除产品"""
campaign = self.get_object()
product_id = request.data.get('product_id')
if not product_id:
return api_response(code=400, message="缺少产品ID", data=None)
try:
product = Product.objects.get(id=product_id)
campaign.link_product.remove(product)
return api_response(message="产品移除成功", data=None)
except Product.DoesNotExist:
return api_response(code=404, message="产品不存在", data=None)
except Exception as e:
return api_response(code=500, message=f"移除产品失败: {str(e)}", data=None)
@action(detail=True, methods=['get'])
def creator_list(self, request, pk=None):
"""获取活动关联的达人列表"""
campaign = self.get_object()
from apps.daren_detail.models import CreatorCampaign, CreatorProfile
try:
# 获取活动关联的所有产品
products = campaign.link_product.all()
# 如果没有关联产品,使用活动本身作为产品
if not products.exists():
products = [campaign]
all_creator_list = []
# 遍历每个产品,获取相关达人
for product in products:
# 查询与活动关联的所有达人关联记录
creator_campaigns = CreatorCampaign.objects.filter(
campaign_id=campaign.id
).select_related('creator')
for cc in creator_campaigns:
creator = cc.creator
# 构建响应数据
creator_data = {
"name": creator.name,
"category": creator.category,
"followers": f"{int(creator.followers / 1000)}k" if creator.followers else "0",
"GMV Achieved": f"${creator.gmv}k" if creator.gmv else "$0",
"Views Achieved": f"{int(creator.avg_video_views / 1000)}k" if creator.avg_video_views else "0",
"Pricing": f"${creator.pricing}" if creator.pricing else "$0",
"Status": cc.status
}
all_creator_list.append(creator_data)
# 启动状态轮询
try:
# 构建达人-产品对
creator_product_pairs = []
for product in products:
product_id = product.id
for cc in creator_campaigns:
creator_id = cc.creator_id
creator_product_pairs.append((creator_id, product_id))
# 启动轮询
if creator_product_pairs:
polling_service.start_polling(
campaign_id=campaign.id,
creator_product_pairs=creator_product_pairs,
interval=30 # 每30秒轮询一次
)
except Exception as e:
logger.error(f"启动状态轮询时出错: {str(e)}")
# 构建活动基本信息
campaign_info = {
"name": campaign.name,
"description": campaign.description,
"image_url": campaign.image_url,
"service": campaign.service,
"creator_type": campaign.creator_type,
"creator_level": campaign.creator_level,
"creator_category": campaign.creator_category,
"creators_count": len(all_creator_list),
"gmv": campaign.gmv,
"followers": campaign.followers,
"views": campaign.views,
"budget": campaign.budget,
"start_date": campaign.start_date.strftime('%Y-%m-%d') if campaign.start_date else None,
"end_date": campaign.end_date.strftime('%Y-%m-%d') if campaign.end_date else None,
"status": campaign.status
}
return api_response(data={
"campaign": campaign_info,
"creators": all_creator_list
})
except Exception as e:
logger.error(f"获取活动达人列表失败: {str(e)}")
return api_response(code=500, message=f"获取活动达人列表失败: {str(e)}", data=None)
@action(detail=True, methods=['post'])
def update_creator_status(self, request, pk=None):
"""手动更新达人状态"""
campaign = self.get_object()
from apps.daren_detail.models import CreatorCampaign
from .services.offer_status_service import OfferStatusService
# 获取传入的达人ID和产品ID
creator_id = request.data.get('creator_id')
product_id = request.data.get('product_id')
if not creator_id:
return api_response(code=400, message="缺少必要参数: creator_id", data=None)
try:
# 查询达人与活动的关联
creator_campaign = CreatorCampaign.objects.get(
campaign_id=campaign.id,
creator_id=creator_id
)
# 如果没有提供产品ID则获取活动的第一个关联产品或使用活动ID
if not product_id:
if campaign.link_product.exists():
product = campaign.link_product.first()
product_id = str(product.id)
else:
product_id = str(campaign.id)
# 获取最新状态
status = OfferStatusService.fetch_status(creator_id, product_id)
if status:
# 更新状态
creator_campaign.status = status
creator_campaign.save()
# 获取所有达人的最新数据
creator_list = OfferStatusService.get_campaign_creator_data(campaign.id)
# 发送WebSocket更新传递产品ID
OfferStatusService.send_status_update(campaign.id, creator_id, status, product_id)
return api_response(message="状态已更新", data=creator_list)
else:
return api_response(code=500, message="获取状态失败", data=None)
except CreatorCampaign.DoesNotExist:
return api_response(code=404, message="找不到达人与活动的关联", data=None)
except Exception as e:
logger.error(f"更新达人状态时出错: {str(e)}")
return api_response(code=500, message=f"更新状态失败: {str(e)}", data=None)
@action(detail=True, methods=['get'])
def product_creators(self, request, pk=None):
"""根据活动ID和产品ID获取达人列表"""
campaign = self.get_object()
product_id = request.query_params.get('product_id')
try:
# 获取与活动关联的所有达人
from apps.daren_detail.models import CreatorCampaign, CreatorProfile
# 查询与活动关联的所有达人关联记录
creator_campaigns = CreatorCampaign.objects.filter(
campaign_id=campaign.id
).select_related('creator')
# 如果指定了产品ID返回单产品的达人数据
if product_id:
# 获取产品信息
product = get_object_or_404(Product, id=product_id)
creator_list = []
# 构建达人-产品对,用于获取状态
creator_product_pairs = []
for cc in creator_campaigns:
creator = cc.creator
creator_id = creator.id
# 添加到达人-产品对列表
creator_product_pairs.append((creator_id, product_id))
# 获取状态
status = OfferStatusService.fetch_status(creator_id, product_id)
# 如果无法获取状态,则使用数据库中的状态
if not status:
status = cc.status
else:
# 更新数据库中的状态
cc.status = status
cc.save(update_fields=['status', 'update_time'])
# 构建响应数据
creator_data = {
"creator_name": creator.name,
"category": creator.category,
"followers": f"{int(creator.followers / 1000)}k" if creator.followers else "0",
"gmv_achieved": f"${creator.gmv}k" if creator.gmv else "$0",
"views_achieved": f"{int(creator.avg_video_views / 1000)}k" if creator.avg_video_views else "0",
"pricing": f"${creator.pricing}" if creator.pricing else "$0",
"status": status
}
creator_list.append(creator_data)
# 启动轮询服务
if creator_product_pairs:
try:
polling_service.start_polling(
campaign_id=campaign.id,
creator_product_pairs=creator_product_pairs,
interval=30 # 每30秒轮询一次
)
except Exception as e:
logger.error(f"启动状态轮询时出错: {str(e)}")
# 构建单产品响应
response_data = {
"campaign_id": str(campaign.id),
"product_id": str(product.id),
"product_name": product.name,
"creators": creator_list
}
return api_response(data=response_data)
# 如果没有指定产品ID返回所有产品的达人数据
else:
# 获取活动关联的所有产品
products = campaign.link_product.all()
# 如果没有关联产品,使用活动本身作为产品
if not products.exists():
products = []
# 构建达人-产品对使用活动ID作为产品ID
creator_product_pairs = []
fallback_product_id = str(campaign.id)
for cc in creator_campaigns:
creator_id = cc.creator_id
creator_product_pairs.append((creator_id, fallback_product_id))
# 获取所有达人数据
creator_list = []
for cc in creator_campaigns:
creator = cc.creator
creator_id = creator.id
# 获取状态
status = OfferStatusService.fetch_status(creator_id, fallback_product_id)
# 如果无法获取状态,则使用数据库中的状态
if not status:
status = cc.status
else:
# 更新数据库中的状态
cc.status = status
cc.save(update_fields=['status', 'update_time'])
# 构建响应数据
creator_data = {
"creator_name": creator.name,
"category": creator.category,
"followers": f"{int(creator.followers / 1000)}k" if creator.followers else "0",
"gmv_achieved": f"${creator.gmv}k" if creator.gmv else "$0",
"views_achieved": f"{int(creator.avg_video_views / 1000)}k" if creator.avg_video_views else "0",
"pricing": f"${creator.pricing}" if creator.pricing else "$0",
"status": status
}
creator_list.append(creator_data)
# 启动轮询服务
if creator_product_pairs:
try:
polling_service.start_polling(
campaign_id=campaign.id,
creator_product_pairs=creator_product_pairs,
interval=30 # 每30秒轮询一次
)
except Exception as e:
logger.error(f"启动状态轮询时出错: {str(e)}")
# 构建单产品响应(使用活动作为产品)
response_data = {
"campaign_id": str(campaign.id),
"product_id": fallback_product_id,
"product_name": campaign.name,
"creators": creator_list
}
return api_response(data=response_data)
# 如果有关联产品,返回所有产品的达人数据
products_data = []
all_creator_product_pairs = []
for product in products:
product_id = str(product.id)
creator_list = []
for cc in creator_campaigns:
creator = cc.creator
creator_id = creator.id
# 添加到达人-产品对列表
all_creator_product_pairs.append((creator_id, product_id))
# 获取状态
status = OfferStatusService.fetch_status(creator_id, product_id)
# 如果无法获取状态,则使用数据库中的状态
if not status:
status = cc.status
else:
# 更新数据库中的状态
cc.status = status
cc.save(update_fields=['status', 'update_time'])
# 构建响应数据
creator_data = {
"creator_name": creator.name,
"category": creator.category,
"followers": f"{int(creator.followers / 1000)}k" if creator.followers else "0",
"gmv_achieved": f"${creator.gmv}k" if creator.gmv else "$0",
"views_achieved": f"{int(creator.avg_video_views / 1000)}k" if creator.avg_video_views else "0",
"pricing": f"${creator.pricing}" if creator.pricing else "$0",
"status": status
}
creator_list.append(creator_data)
# 构建产品数据
product_data = {
"product_id": product_id,
"product_name": product.name,
"creators": creator_list
}
products_data.append(product_data)
# 启动轮询服务
if all_creator_product_pairs:
try:
polling_service.start_polling(
campaign_id=campaign.id,
creator_product_pairs=all_creator_product_pairs,
interval=30 # 每30秒轮询一次
)
except Exception as e:
logger.error(f"启动状态轮询时出错: {str(e)}")
# 构建多产品响应
response_data = {
"campaign_id": str(campaign.id),
"products": products_data
}
return api_response(data=response_data)
except Exception as e:
logger.error(f"获取活动产品达人列表时出错: {str(e)}")
return api_response(code=500, message=f"获取活动产品达人列表失败: {str(e)}", data=None)
@action(detail=False, methods=['post'], url_path='stop-polling')
def stop_polling(self, request):
"""停止指定活动或所有活动的状态轮询"""
campaign_id = request.data.get('campaign_id')
if campaign_id:
# 停止指定活动的轮询
result = polling_service.stop_polling(campaign_id)
if result:
return api_response(message=f"已停止活动 {campaign_id} 的状态轮询")
else:
return api_response(code=404, message=f"未找到活动 {campaign_id} 的轮询任务")
else:
# 停止所有轮询
count = polling_service.stop_all()
return api_response(message=f"已停止 {count} 个活动的状态轮询")
@action(detail=False, methods=['get'], url_path='active-pollings')
def active_pollings(self, request):
"""获取当前正在运行的所有轮询任务信息"""
active_pollings = polling_service.get_active_pollings()
return api_response(data=active_pollings)
@action(detail=True, methods=['get'], url_path='websocket-url')
def get_websocket_url(self, request, pk=None):
"""获取带认证的WebSocket连接URL"""
campaign = self.get_object()
product_id = request.query_params.get('product_id')
# 获取当前用户的token
from apps.user.models import UserToken
token = None
if request.user.is_authenticated:
user_token = UserToken.objects.filter(user=request.user).first()
if user_token:
token = user_token.token
# 如果没有token返回错误
if not token:
return api_response(code=401, message="未授权,请先登录", data=None)
# 构建基础URL
base_url = request.get_host()
ws_protocol = 'wss' if request.is_secure() else 'ws'
# 根据是否有产品ID构建不同的WebSocket URL
if product_id:
ws_url = f"{ws_protocol}://{base_url}/ws/campaigns/{campaign.id}/products/{product_id}/status/?token={token}"
else:
ws_url = f"{ws_protocol}://{base_url}/ws/campaigns/{campaign.id}/status/?token={token}"
return api_response(data={"websocket_url": ws_url})
@action(detail=False, methods=['get'])
def search(self, request):
"""关键字搜索活动"""
keyword = request.query_params.get('keyword', '')
if not keyword:
return api_response(code=400, message="缺少关键字参数", data=None)
queryset = self.get_queryset().filter(
Q(name__icontains=keyword) |
Q(description__icontains=keyword) |
Q(service__icontains=keyword) |
Q(creator_type__icontains=keyword) |
Q(creator_level__icontains=keyword) |
Q(creator_category__icontains=keyword) |
Q(brand__name__icontains=keyword) |
Q(status__icontains=keyword)
)
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
class BrandChatSessionViewSet(viewsets.ModelViewSet):
"""品牌聊天会话API视图集"""
queryset = BrandChatSession.objects.filter(is_active=True)
serializer_class = BrandChatSessionSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return api_response(data=serializer.data, headers=headers)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
def perform_create(self, serializer):
# 创建聊天会话时,可以设置使用特定品牌下的所有知识库
chat_session = serializer.save()
# 如果没有提供dataset_id_list则使用品牌的dataset_id_list
if not chat_session.dataset_id_list:
brand = chat_session.brand
chat_session.dataset_id_list = brand.dataset_id_list
chat_session.save(update_fields=['dataset_id_list', 'updated_at'])