from django.shortcuts import render, get_object_or_404 from rest_framework import viewsets, status from rest_framework.decorators import action from rest_framework.response import Response from .models import Brand, Product, Campaign, BrandChatSession from .serializers import ( BrandSerializer, ProductSerializer, CampaignSerializer, BrandChatSessionSerializer, BrandDetailSerializer ) def api_response(code=200, message="成功", data=None): """统一API响应格式""" return Response({ 'code': code, 'message': message, 'data': data }) class BrandViewSet(viewsets.ModelViewSet): """品牌API视图集""" queryset = Brand.objects.all() serializer_class = BrandSerializer def get_serializer_class(self): if self.action == 'retrieve': return BrandDetailSerializer return BrandSerializer def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) serializer = self.get_serializer(queryset, many=True) return api_response(data=serializer.data) def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): self.perform_create(serializer) return api_response(data=serializer.data) return api_response(code=400, message="创建失败", data=serializer.errors) def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return api_response(data=serializer.data) def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) if serializer.is_valid(): self.perform_update(serializer) return api_response(data=serializer.data) return api_response(code=400, message="更新失败", data=serializer.errors) def destroy(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) return api_response(message="删除成功", data=None) @action(detail=True, methods=['get']) def products(self, request, pk=None): """获取品牌下的所有产品""" brand = self.get_object() products = Product.objects.filter(brand=brand, is_active=True) serializer = ProductSerializer(products, many=True) return api_response(data=serializer.data) @action(detail=True, methods=['get']) def campaigns(self, request, pk=None): """获取品牌下的所有活动""" brand = self.get_object() campaigns = Campaign.objects.filter(brand=brand, is_active=True) serializer = CampaignSerializer(campaigns, many=True) return api_response(data=serializer.data) @action(detail=True, methods=['get']) def dataset_ids(self, request, pk=None): """获取品牌的所有知识库ID""" brand = self.get_object() return api_response(data={'dataset_id_list': brand.dataset_id_list}) class ProductViewSet(viewsets.ModelViewSet): """产品API视图集""" queryset = Product.objects.filter(is_active=True) serializer_class = ProductSerializer def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) serializer = self.get_serializer(queryset, many=True) return api_response(data=serializer.data) def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): self.perform_create(serializer) return api_response(data=serializer.data) return api_response(code=400, message="创建失败", data=serializer.errors) def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return api_response(data=serializer.data) def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) if serializer.is_valid(): self.perform_update(serializer) return api_response(data=serializer.data) return api_response(code=400, message="更新失败", data=serializer.errors) def destroy(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) return api_response(message="删除成功", data=None) def perform_create(self, serializer): # 创建产品时自动更新品牌的dataset_id_list product = serializer.save() brand = product.brand # 确保dataset_id添加到品牌的dataset_id_list中 if product.dataset_id and product.dataset_id not in brand.dataset_id_list: brand.dataset_id_list.append(product.dataset_id) brand.save(update_fields=['dataset_id_list', 'updated_at']) def perform_update(self, serializer): # 获取原始产品信息 old_product = self.get_object() old_dataset_id = old_product.dataset_id # 保存更新后的产品 product = serializer.save() brand = product.brand # 从品牌的dataset_id_list中移除旧的dataset_id,添加新的dataset_id if old_dataset_id in brand.dataset_id_list: brand.dataset_id_list.remove(old_dataset_id) if product.dataset_id and product.dataset_id not in brand.dataset_id_list: brand.dataset_id_list.append(product.dataset_id) brand.save(update_fields=['dataset_id_list', 'updated_at']) def perform_destroy(self, instance): # 软删除产品,并从品牌的dataset_id_list中移除对应的ID instance.is_active = False instance.save() brand = instance.brand if instance.dataset_id in brand.dataset_id_list: brand.dataset_id_list.remove(instance.dataset_id) brand.save(update_fields=['dataset_id_list', 'updated_at']) class CampaignViewSet(viewsets.ModelViewSet): """活动API视图集""" queryset = Campaign.objects.filter(is_active=True) serializer_class = CampaignSerializer def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) serializer = self.get_serializer(queryset, many=True) return api_response(data=serializer.data) def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): self.perform_create(serializer) return api_response(data=serializer.data) return api_response(code=400, message="创建失败", data=serializer.errors) def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return api_response(data=serializer.data) def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) if serializer.is_valid(): self.perform_update(serializer) return api_response(data=serializer.data) return api_response(code=400, message="更新失败", data=serializer.errors) def destroy(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) return api_response(message="删除成功", data=None) def perform_create(self, serializer): # 创建活动时自动更新品牌的dataset_id_list campaign = serializer.save() brand = campaign.brand # 确保dataset_id添加到品牌的dataset_id_list中 if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list: brand.dataset_id_list.append(campaign.dataset_id) brand.save(update_fields=['dataset_id_list', 'updated_at']) def perform_update(self, serializer): # 获取原始活动信息 old_campaign = self.get_object() old_dataset_id = old_campaign.dataset_id # 保存更新后的活动 campaign = serializer.save() brand = campaign.brand # 从品牌的dataset_id_list中移除旧的dataset_id,添加新的dataset_id if old_dataset_id in brand.dataset_id_list: brand.dataset_id_list.remove(old_dataset_id) if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list: brand.dataset_id_list.append(campaign.dataset_id) brand.save(update_fields=['dataset_id_list', 'updated_at']) def perform_destroy(self, instance): # 软删除活动,并从品牌的dataset_id_list中移除对应的ID instance.is_active = False instance.save() brand = instance.brand if instance.dataset_id in brand.dataset_id_list: brand.dataset_id_list.remove(instance.dataset_id) brand.save(update_fields=['dataset_id_list', 'updated_at']) @action(detail=True, methods=['post']) def add_product(self, request, pk=None): """将产品添加到活动中""" campaign = self.get_object() product_id = request.data.get('product_id') if not product_id: return api_response(code=400, message="缺少产品ID", data=None) try: product = Product.objects.get(id=product_id, is_active=True) campaign.link_product.add(product) return api_response(message="产品添加成功", data=None) except Product.DoesNotExist: return api_response(code=404, message="产品不存在", data=None) except Exception as e: return api_response(code=500, message=f"添加产品失败: {str(e)}", data=None) @action(detail=True, methods=['post']) def remove_product(self, request, pk=None): """从活动中移除产品""" campaign = self.get_object() product_id = request.data.get('product_id') if not product_id: return api_response(code=400, message="缺少产品ID", data=None) try: product = Product.objects.get(id=product_id) campaign.link_product.remove(product) return api_response(message="产品移除成功", data=None) except Product.DoesNotExist: return api_response(code=404, message="产品不存在", data=None) except Exception as e: return api_response(code=500, message=f"移除产品失败: {str(e)}", data=None) class BrandChatSessionViewSet(viewsets.ModelViewSet): """品牌聊天会话API视图集""" queryset = BrandChatSession.objects.filter(is_active=True) serializer_class = BrandChatSessionSerializer def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) serializer = self.get_serializer(queryset, many=True) return api_response(data=serializer.data) def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): self.perform_create(serializer) return api_response(data=serializer.data) return api_response(code=400, message="创建失败", data=serializer.errors) def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return api_response(data=serializer.data) def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) if serializer.is_valid(): self.perform_update(serializer) return api_response(data=serializer.data) return api_response(code=400, message="更新失败", data=serializer.errors) def destroy(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) return api_response(message="删除成功", data=None) def perform_create(self, serializer): # 创建聊天会话时,可以设置使用特定品牌下的所有知识库 chat_session = serializer.save() # 如果没有提供dataset_id_list,则使用品牌的dataset_id_list if not chat_session.dataset_id_list: brand = chat_session.brand chat_session.dataset_id_list = brand.dataset_id_list chat_session.save(update_fields=['dataset_id_list', 'updated_at'])