operations_project/apps/brands/views.py

318 lines
13 KiB
Python
Raw Normal View History

2025-05-13 11:58:17 +08:00
from django.shortcuts import render, get_object_or_404
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.response import Response
from .models import Brand, Product, Campaign, BrandChatSession
from .serializers import (
BrandSerializer,
ProductSerializer,
CampaignSerializer,
BrandChatSessionSerializer,
BrandDetailSerializer
)
def api_response(code=200, message="成功", data=None):
"""统一API响应格式"""
return Response({
'code': code,
'message': message,
'data': data
})
class BrandViewSet(viewsets.ModelViewSet):
"""品牌API视图集"""
queryset = Brand.objects.all()
serializer_class = BrandSerializer
def get_serializer_class(self):
if self.action == 'retrieve':
return BrandDetailSerializer
return BrandSerializer
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
@action(detail=True, methods=['get'])
def products(self, request, pk=None):
"""获取品牌下的所有产品"""
brand = self.get_object()
products = Product.objects.filter(brand=brand, is_active=True)
serializer = ProductSerializer(products, many=True)
return api_response(data=serializer.data)
@action(detail=True, methods=['get'])
def campaigns(self, request, pk=None):
"""获取品牌下的所有活动"""
brand = self.get_object()
campaigns = Campaign.objects.filter(brand=brand, is_active=True)
serializer = CampaignSerializer(campaigns, many=True)
return api_response(data=serializer.data)
@action(detail=True, methods=['get'])
def dataset_ids(self, request, pk=None):
"""获取品牌的所有知识库ID"""
brand = self.get_object()
return api_response(data={'dataset_id_list': brand.dataset_id_list})
class ProductViewSet(viewsets.ModelViewSet):
"""产品API视图集"""
queryset = Product.objects.filter(is_active=True)
serializer_class = ProductSerializer
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
def perform_create(self, serializer):
# 创建产品时自动更新品牌的dataset_id_list
product = serializer.save()
brand = product.brand
# 确保dataset_id添加到品牌的dataset_id_list中
if product.dataset_id and product.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(product.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_update(self, serializer):
# 获取原始产品信息
old_product = self.get_object()
old_dataset_id = old_product.dataset_id
# 保存更新后的产品
product = serializer.save()
brand = product.brand
# 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id
if old_dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(old_dataset_id)
if product.dataset_id and product.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(product.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_destroy(self, instance):
# 软删除产品并从品牌的dataset_id_list中移除对应的ID
instance.is_active = False
instance.save()
brand = instance.brand
if instance.dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(instance.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
class CampaignViewSet(viewsets.ModelViewSet):
"""活动API视图集"""
queryset = Campaign.objects.filter(is_active=True)
serializer_class = CampaignSerializer
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
def perform_create(self, serializer):
# 创建活动时自动更新品牌的dataset_id_list
campaign = serializer.save()
brand = campaign.brand
# 确保dataset_id添加到品牌的dataset_id_list中
if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(campaign.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_update(self, serializer):
# 获取原始活动信息
old_campaign = self.get_object()
old_dataset_id = old_campaign.dataset_id
# 保存更新后的活动
campaign = serializer.save()
brand = campaign.brand
# 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id
if old_dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(old_dataset_id)
if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(campaign.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
def perform_destroy(self, instance):
# 软删除活动并从品牌的dataset_id_list中移除对应的ID
instance.is_active = False
instance.save()
brand = instance.brand
if instance.dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(instance.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at'])
@action(detail=True, methods=['post'])
def add_product(self, request, pk=None):
"""将产品添加到活动中"""
campaign = self.get_object()
product_id = request.data.get('product_id')
if not product_id:
return api_response(code=400, message="缺少产品ID", data=None)
try:
product = Product.objects.get(id=product_id, is_active=True)
campaign.link_product.add(product)
return api_response(message="产品添加成功", data=None)
except Product.DoesNotExist:
return api_response(code=404, message="产品不存在", data=None)
except Exception as e:
return api_response(code=500, message=f"添加产品失败: {str(e)}", data=None)
@action(detail=True, methods=['post'])
def remove_product(self, request, pk=None):
"""从活动中移除产品"""
campaign = self.get_object()
product_id = request.data.get('product_id')
if not product_id:
return api_response(code=400, message="缺少产品ID", data=None)
try:
product = Product.objects.get(id=product_id)
campaign.link_product.remove(product)
return api_response(message="产品移除成功", data=None)
except Product.DoesNotExist:
return api_response(code=404, message="产品不存在", data=None)
except Exception as e:
return api_response(code=500, message=f"移除产品失败: {str(e)}", data=None)
class BrandChatSessionViewSet(viewsets.ModelViewSet):
"""品牌聊天会话API视图集"""
queryset = BrandChatSession.objects.filter(is_active=True)
serializer_class = BrandChatSessionSerializer
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="创建失败", data=serializer.errors)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return api_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return api_response(data=serializer.data)
return api_response(code=400, message="更新失败", data=serializer.errors)
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return api_response(message="删除成功", data=None)
def perform_create(self, serializer):
# 创建聊天会话时,可以设置使用特定品牌下的所有知识库
chat_session = serializer.save()
# 如果没有提供dataset_id_list则使用品牌的dataset_id_list
if not chat_session.dataset_id_list:
brand = chat_session.brand
chat_session.dataset_id_list = brand.dataset_id_list
chat_session.save(update_fields=['dataset_id_list', 'updated_at'])