品牌和模板的字段

This commit is contained in:
wanjia 2025-05-30 11:49:59 +08:00
parent 0e62ef70ef
commit 10313f02b5
9 changed files with 322 additions and 156 deletions

View File

@ -0,0 +1,23 @@
# Generated by Django 5.2.1 on 2025-05-30 02:53
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('brands', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='campaign',
name='dataset_id',
field=models.CharField(blank=True, help_text='外部知识库系统中的ID', max_length=100, null=True, verbose_name='知识库ID'),
),
migrations.AlterField(
model_name='product',
name='dataset_id',
field=models.CharField(blank=True, help_text='外部知识库系统中的ID', max_length=100, null=True, verbose_name='知识库ID'),
),
]

View File

@ -56,7 +56,7 @@ class Product(models.Model):
collab_creators = models.IntegerField(default=0, verbose_name='合作创作者数') collab_creators = models.IntegerField(default=0, verbose_name='合作创作者数')
tiktok_shop = models.BooleanField(default=False, verbose_name='是否TikTok商店') tiktok_shop = models.BooleanField(default=False, verbose_name='是否TikTok商店')
dataset_id = models.CharField(max_length=100, verbose_name='知识库ID', dataset_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='知识库ID',
help_text='外部知识库系统中的ID') help_text='外部知识库系统中的ID')
external_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='外部ID', external_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='外部ID',
help_text='外部系统中的唯一标识') help_text='外部系统中的唯一标识')
@ -116,7 +116,7 @@ class Campaign(models.Model):
end_date = models.DateTimeField(blank=True, null=True, verbose_name='结束日期') end_date = models.DateTimeField(blank=True, null=True, verbose_name='结束日期')
# 知识库信息 # 知识库信息
dataset_id = models.CharField(max_length=100, verbose_name='知识库ID', dataset_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='知识库ID',
help_text='外部知识库系统中的ID') help_text='外部知识库系统中的ID')
external_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='外部ID', external_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='外部ID',
help_text='外部系统中的唯一标识') help_text='外部系统中的唯一标识')

View File

@ -42,7 +42,16 @@ class BrandViewSet(viewsets.ModelViewSet):
return BrandSerializer return BrandSerializer
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) data = request.data.copy()
# 处理source字段将前端的value值转换为后端存储
if 'source' in data and data['source']:
# 前端可能传递的是对象或直接是值
if isinstance(data['source'], dict) and 'value' in data['source']:
data['source'] = data['source']['value']
# 否则认为直接传递的是值
serializer = self.get_serializer(data=data)
if serializer.is_valid(): if serializer.is_valid():
self.perform_create(serializer) self.perform_create(serializer)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
@ -62,7 +71,17 @@ class BrandViewSet(viewsets.ModelViewSet):
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
data = request.data.copy()
# 处理source字段将前端的value值转换为后端存储
if 'source' in data and data['source']:
# 前端可能传递的是对象或直接是值
if isinstance(data['source'], dict) and 'value' in data['source']:
data['source'] = data['source']['value']
# 否则认为直接传递的是值
serializer = self.get_serializer(instance, data=data, partial=partial)
if serializer.is_valid(): if serializer.is_valid():
self.perform_update(serializer) self.perform_update(serializer)
return api_response(data=serializer.data) return api_response(data=serializer.data)
@ -112,6 +131,29 @@ class BrandViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(queryset, many=True) serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data) return api_response(data=serializer.data)
@action(detail=False, methods=['get'])
def source_options(self, request):
"""获取品牌来源选项列表"""
source_options = [
{
'value': 'tks_official',
'name': 'TKS Official',
},
{
'value': 'third_party_agency',
'name': 'Third-party Agency',
},
{
'value': 'offline_event',
'name': 'Offline Event',
},
{
'value': 'social_media',
'name': 'Social Media',
},
]
return api_response(data=source_options)
class ProductViewSet(viewsets.ModelViewSet): class ProductViewSet(viewsets.ModelViewSet):
"""产品API视图集""" """产品API视图集"""
@ -172,7 +214,7 @@ class ProductViewSet(viewsets.ModelViewSet):
brand = product.brand brand = product.brand
# 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id # 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id
if old_dataset_id in brand.dataset_id_list: if old_dataset_id and old_dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(old_dataset_id) brand.dataset_id_list.remove(old_dataset_id)
if product.dataset_id and product.dataset_id not in brand.dataset_id_list: if product.dataset_id and product.dataset_id not in brand.dataset_id_list:
@ -267,7 +309,39 @@ class CampaignViewSet(viewsets.ModelViewSet):
return api_response(data=serializer.data) return api_response(data=serializer.data)
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) data = request.data.copy()
# 处理budget字段将[min, max]数组转为字符串格式
if 'budget' in data and isinstance(data['budget'], list) and len(data['budget']) == 2:
data['budget'] = f"{data['budget'][0]}-{data['budget'][1]}"
# 处理followers字段将[min, max]数组转为字符串格式
if 'followers' in data and isinstance(data['followers'], list) and len(data['followers']) == 2:
data['followers'] = f"{data['followers'][0]}-{data['followers'][1]}"
# 处理views字段将[min, max]数组转为字符串格式
if 'views' in data and isinstance(data['views'], list) and len(data['views']) == 2:
data['views'] = f"{data['views'][0]}-{data['views'][1]}"
# 处理品牌ID字段从brand_id转为brand
if 'brand_id' in data and data['brand_id']:
data['brand'] = data.pop('brand_id')
# 将creator_count转换为creators_count
if 'creator_count' in data:
data['creators_count'] = data.pop('creator_count')
# 处理service字段将前端的对象转换为值
if 'service' in data and data['service']:
if isinstance(data['service'], dict) and 'value' in data['service']:
data['service'] = data['service']['value']
# 处理creator_type字段将前端的对象转换为值
if 'creator_type' in data and data['creator_type']:
if isinstance(data['creator_type'], dict) and 'value' in data['creator_type']:
data['creator_type'] = data['creator_type']['value']
serializer = self.get_serializer(data=data)
if serializer.is_valid(): if serializer.is_valid():
self.perform_create(serializer) self.perform_create(serializer)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
@ -282,7 +356,40 @@ class CampaignViewSet(viewsets.ModelViewSet):
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
data = request.data.copy()
# 处理budget字段将[min, max]数组转为字符串格式
if 'budget' in data and isinstance(data['budget'], list) and len(data['budget']) == 2:
data['budget'] = f"{data['budget'][0]}-{data['budget'][1]}"
# 处理followers字段将[min, max]数组转为字符串格式
if 'followers' in data and isinstance(data['followers'], list) and len(data['followers']) == 2:
data['followers'] = f"{data['followers'][0]}-{data['followers'][1]}"
# 处理views字段将[min, max]数组转为字符串格式
if 'views' in data and isinstance(data['views'], list) and len(data['views']) == 2:
data['views'] = f"{data['views'][0]}-{data['views'][1]}"
# 处理品牌ID字段从brand_id转为brand
if 'brand_id' in data and data['brand_id']:
data['brand'] = data.pop('brand_id')
# 将creator_count转换为creators_count
if 'creator_count' in data:
data['creators_count'] = data.pop('creator_count')
# 处理service字段将前端的对象转换为值
if 'service' in data and data['service']:
if isinstance(data['service'], dict) and 'value' in data['service']:
data['service'] = data['service']['value']
# 处理creator_type字段将前端的对象转换为值
if 'creator_type' in data and data['creator_type']:
if isinstance(data['creator_type'], dict) and 'value' in data['creator_type']:
data['creator_type'] = data['creator_type']['value']
serializer = self.get_serializer(instance, data=data, partial=partial)
if serializer.is_valid(): if serializer.is_valid():
self.perform_update(serializer) self.perform_update(serializer)
return api_response(data=serializer.data) return api_response(data=serializer.data)
@ -294,11 +401,27 @@ class CampaignViewSet(viewsets.ModelViewSet):
return api_response(message="删除成功", data=None) return api_response(message="删除成功", data=None)
def perform_create(self, serializer): def perform_create(self, serializer):
# 创建活动时自动更新品牌的dataset_id_list # 保存获取link_product字段以便后续添加产品关联
link_product_ids = self.request.data.get('link_product', [])
if isinstance(link_product_ids, str):
link_product_ids = [link_product_ids]
# 创建活动
campaign = serializer.save() campaign = serializer.save()
brand = campaign.brand
# 确保dataset_id添加到品牌的dataset_id_list中 # 处理产品关联
if link_product_ids:
for product_id in link_product_ids:
try:
product = Product.objects.get(id=product_id)
campaign.link_product.add(product)
except Product.DoesNotExist:
logger.warning(f"产品ID {product_id} 不存在")
except Exception as e:
logger.error(f"添加产品关联时出错: {str(e)}")
# 更新品牌的dataset_id_list
brand = campaign.brand
if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list: if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list:
brand.dataset_id_list.append(campaign.dataset_id) brand.dataset_id_list.append(campaign.dataset_id)
brand.save(update_fields=['dataset_id_list', 'updated_at']) brand.save(update_fields=['dataset_id_list', 'updated_at'])
@ -313,7 +436,7 @@ class CampaignViewSet(viewsets.ModelViewSet):
brand = campaign.brand brand = campaign.brand
# 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id # 从品牌的dataset_id_list中移除旧的dataset_id添加新的dataset_id
if old_dataset_id in brand.dataset_id_list: if old_dataset_id and old_dataset_id in brand.dataset_id_list:
brand.dataset_id_list.remove(old_dataset_id) brand.dataset_id_list.remove(old_dataset_id)
if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list: if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list:
@ -793,6 +916,48 @@ class CampaignViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(queryset, many=True) serializer = self.get_serializer(queryset, many=True)
return api_response(data=serializer.data) return api_response(data=serializer.data)
@action(detail=False, methods=['get'])
def service_options(self, request):
"""获取活动服务类型选项列表"""
service_options = [
{
'value': 'short_video_paid',
'name': '达人短视频(付费)',
},
{
'value': 'short_video_affiliate',
'name': '达人短视频(纯佣)',
},
{
'value': 'live_stream_brand_hosted',
'name': '直播(代播)',
},
{
'value': 'live_stream_influencer_hosted',
'name': '直播(达播)',
},
{
'value': 'short_video_material_only',
'name': '纯素材短视频',
},
]
return api_response(data=service_options)
@action(detail=False, methods=['get'])
def creator_type_options(self, request):
"""获取创作者类型选项列表"""
creator_type_options = [
{
'value': 'product_promotion',
'name': '带货类达人',
},
{
'value': 'exposure_focused',
'name': '曝光类达人',
},
]
return api_response(data=creator_type_options)
class BrandChatSessionViewSet(viewsets.ModelViewSet): class BrandChatSessionViewSet(viewsets.ModelViewSet):
"""品牌聊天会话API视图集""" """品牌聊天会话API视图集"""

View File

@ -9,8 +9,6 @@ class TemplateFilter(django_filters.FilterSet):
platform = django_filters.CharFilter(field_name='platform') platform = django_filters.CharFilter(field_name='platform')
collaboration_type = django_filters.CharFilter(field_name='collaboration_type') collaboration_type = django_filters.CharFilter(field_name='collaboration_type')
service = django_filters.CharFilter(field_name='service') service = django_filters.CharFilter(field_name='service')
category = django_filters.NumberFilter(field_name='category__id')
category_name = django_filters.CharFilter(field_name='category__name', lookup_expr='icontains')
created_by = django_filters.NumberFilter(field_name='created_by__id') created_by = django_filters.NumberFilter(field_name='created_by__id')
is_public = django_filters.BooleanFilter(field_name='is_public') is_public = django_filters.BooleanFilter(field_name='is_public')
created_after = django_filters.DateTimeFilter(field_name='created_at', lookup_expr='gte') created_after = django_filters.DateTimeFilter(field_name='created_at', lookup_expr='gte')
@ -20,7 +18,7 @@ class TemplateFilter(django_filters.FilterSet):
model = Template model = Template
fields = [ fields = [
'title', 'content', 'mission', 'platform', 'title', 'content', 'mission', 'platform',
'collaboration_type', 'service', 'category', 'collaboration_type', 'service',
'category_name', 'created_by', 'is_public', 'created_by', 'is_public',
'created_after', 'created_before' 'created_after', 'created_before'
] ]

View File

@ -0,0 +1,25 @@
# Generated by Django 5.2.1 on 2025-05-30 03:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('template', '0001_initial'),
]
operations = [
migrations.RemoveField(
model_name='template',
name='category',
),
migrations.AlterField(
model_name='template',
name='mission',
field=models.CharField(choices=[('all', '全部'), ('initial_contact', '初步建联'), ('negotiation', '砍价邮件'), ('script', '脚本邮件'), ('follow_up', '合作追踪')], default='initial_contact', max_length=50, verbose_name='任务类型'),
),
migrations.DeleteModel(
name='TemplateCategory',
),
]

View File

@ -1,28 +1,14 @@
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
class TemplateCategory(models.Model):
"""模板分类模型"""
name = models.CharField(_('分类名称'), max_length=100)
description = models.TextField(_('分类描述'), blank=True, null=True)
created_at = models.DateTimeField(_('创建时间'), auto_now_add=True)
updated_at = models.DateTimeField(_('更新时间'), auto_now=True)
class Meta:
verbose_name = _('模板分类')
verbose_name_plural = _('模板分类')
def __str__(self):
return self.name
class Template(models.Model): class Template(models.Model):
"""模板模型""" """模板模型"""
MISSION_CHOICES = [ MISSION_CHOICES = [
('initial_contact', '初步联系'), ('all', '全部'),
('follow_up', '跟进'), ('initial_contact', '初步建联'),
('negotiation', '谈判'), ('negotiation', '砍价邮件'),
('closing', '成交'), ('script', '脚本邮件'),
('other', '其他'), ('follow_up', '合作追踪'),
] ]
PLATFORM_CHOICES = [ PLATFORM_CHOICES = [
@ -53,7 +39,6 @@ class Template(models.Model):
title = models.CharField(_('模板标题'), max_length=200) title = models.CharField(_('模板标题'), max_length=200)
content = models.TextField(_('模板内容')) content = models.TextField(_('模板内容'))
preview = models.TextField(_('内容预览'), blank=True, null=True) preview = models.TextField(_('内容预览'), blank=True, null=True)
category = models.ForeignKey(TemplateCategory, on_delete=models.CASCADE, related_name='templates', verbose_name=_('模板分类'))
mission = models.CharField(_('任务类型'), max_length=50, choices=MISSION_CHOICES, default='initial_contact') mission = models.CharField(_('任务类型'), max_length=50, choices=MISSION_CHOICES, default='initial_contact')
platform = models.CharField(_('平台'), max_length=50, choices=PLATFORM_CHOICES, default='tiktok') platform = models.CharField(_('平台'), max_length=50, choices=PLATFORM_CHOICES, default='tiktok')
collaboration_type = models.CharField(_('合作模式'), max_length=50, choices=COLLABORATION_CHOICES, default='paid_promotion') collaboration_type = models.CharField(_('合作模式'), max_length=50, choices=COLLABORATION_CHOICES, default='paid_promotion')

View File

@ -1,16 +1,8 @@
from rest_framework import serializers from rest_framework import serializers
from .models import Template, TemplateCategory from .models import Template
class TemplateCategorySerializer(serializers.ModelSerializer):
"""模板分类序列化器"""
class Meta:
model = TemplateCategory
fields = ['id', 'name', 'description', 'created_at', 'updated_at']
read_only_fields = ['created_at', 'updated_at']
class TemplateListSerializer(serializers.ModelSerializer): class TemplateListSerializer(serializers.ModelSerializer):
"""模板列表序列化器(简化版)""" """模板列表序列化器(简化版)"""
category_name = serializers.CharField(source='category.name', read_only=True)
mission_display = serializers.CharField(source='get_mission_display', read_only=True) mission_display = serializers.CharField(source='get_mission_display', read_only=True)
platform_display = serializers.CharField(source='get_platform_display', read_only=True) platform_display = serializers.CharField(source='get_platform_display', read_only=True)
collaboration_type_display = serializers.CharField(source='get_collaboration_type_display', read_only=True) collaboration_type_display = serializers.CharField(source='get_collaboration_type_display', read_only=True)
@ -19,7 +11,7 @@ class TemplateListSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Template model = Template
fields = [ fields = [
'id', 'title', 'preview', 'category_name', 'id', 'title', 'preview',
'mission', 'mission_display', 'mission', 'mission_display',
'platform', 'platform_display', 'platform', 'platform_display',
'service', 'service_display', 'service', 'service_display',
@ -30,8 +22,6 @@ class TemplateListSerializer(serializers.ModelSerializer):
class TemplateDetailSerializer(serializers.ModelSerializer): class TemplateDetailSerializer(serializers.ModelSerializer):
"""模板详情序列化器""" """模板详情序列化器"""
category = TemplateCategorySerializer(read_only=True)
category_id = serializers.IntegerField(write_only=True, required=False)
mission_display = serializers.CharField(source='get_mission_display', read_only=True) mission_display = serializers.CharField(source='get_mission_display', read_only=True)
platform_display = serializers.CharField(source='get_platform_display', read_only=True) platform_display = serializers.CharField(source='get_platform_display', read_only=True)
collaboration_type_display = serializers.CharField(source='get_collaboration_type_display', read_only=True) collaboration_type_display = serializers.CharField(source='get_collaboration_type_display', read_only=True)
@ -40,8 +30,7 @@ class TemplateDetailSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Template model = Template
fields = [ fields = [
'id', 'title', 'content', 'preview', 'id', 'title', 'content', 'preview',
'category', 'category_id',
'mission', 'mission_display', 'mission', 'mission_display',
'platform', 'platform_display', 'platform', 'platform_display',
'service', 'service_display', 'service', 'service_display',
@ -49,29 +38,14 @@ class TemplateDetailSerializer(serializers.ModelSerializer):
'is_public', 'created_at', 'updated_at' 'is_public', 'created_at', 'updated_at'
] ]
read_only_fields = ['created_at', 'updated_at', 'preview'] read_only_fields = ['created_at', 'updated_at', 'preview']
def create(self, validated_data):
"""创建模板"""
# 处理category_id字段
category_id = validated_data.pop('category_id', None)
if category_id:
try:
category = TemplateCategory.objects.get(id=category_id)
validated_data['category'] = category
except TemplateCategory.DoesNotExist:
# 如果分类不存在,创建一个默认分类
category = TemplateCategory.objects.create(name="默认分类")
validated_data['category'] = category
return super().create(validated_data)
class TemplateCreateUpdateSerializer(serializers.ModelSerializer): class TemplateCreateUpdateSerializer(serializers.ModelSerializer):
"""模板创建和更新序列化器""" """模板创建和更新序列化器"""
class Meta: class Meta:
model = Template model = Template
fields = [ fields = [
'id', 'title', 'content', 'id', 'title', 'content',
'category', 'mission', 'platform', 'mission', 'platform',
'service', 'collaboration_type', 'service', 'collaboration_type',
'is_public' 'is_public'
] ]
@ -79,12 +53,6 @@ class TemplateCreateUpdateSerializer(serializers.ModelSerializer):
def validate(self, data): def validate(self, data):
"""验证数据,处理测试期间可能缺失的字段""" """验证数据,处理测试期间可能缺失的字段"""
# 处理category字段确保有有效的分类
if 'category' not in data:
# 获取或创建默认分类
category, created = TemplateCategory.objects.get_or_create(name="默认分类")
data['category'] = category
# 确保其他必填字段有默认值 # 确保其他必填字段有默认值
if 'mission' not in data: if 'mission' not in data:
data['mission'] = 'initial_contact' data['mission'] = 'initial_contact'

View File

@ -1,9 +1,8 @@
from django.urls import path, include from django.urls import path, include
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from .views import TemplateViewSet, TemplateCategoryViewSet from .views import TemplateViewSet
router = DefaultRouter() router = DefaultRouter()
router.register(r'categories', TemplateCategoryViewSet)
router.register(r'', TemplateViewSet) router.register(r'', TemplateViewSet)
urlpatterns = [ urlpatterns = [

View File

@ -3,12 +3,11 @@ from rest_framework import viewsets, permissions, status, filters
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from .models import Template, TemplateCategory from .models import Template
from .serializers import ( from .serializers import (
TemplateListSerializer, TemplateListSerializer,
TemplateDetailSerializer, TemplateDetailSerializer,
TemplateCreateUpdateSerializer, TemplateCreateUpdateSerializer
TemplateCategorySerializer
) )
from .filters import TemplateFilter from .filters import TemplateFilter
from .utils import ApiResponse from .utils import ApiResponse
@ -17,82 +16,6 @@ from rest_framework.permissions import IsAuthenticated
from apps.user.authentication import CustomTokenAuthentication from apps.user.authentication import CustomTokenAuthentication
# Create your views here. # Create your views here.
class TemplateCategoryViewSet(viewsets.ModelViewSet):
"""
模板分类视图集
提供模板分类的增删改查功能
"""
queryset = TemplateCategory.objects.all()
serializer_class = TemplateCategorySerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
pagination_class = StandardResultsSetPagination
def list(self, request, *args, **kwargs):
"""获取所有模板分类"""
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return ApiResponse.success(
data=serializer.data,
message="获取模板分类列表成功"
)
def retrieve(self, request, *args, **kwargs):
"""获取单个模板分类详情"""
instance = self.get_object()
serializer = self.get_serializer(instance)
return ApiResponse.success(
data=serializer.data,
message="获取模板分类详情成功"
)
def create(self, request, *args, **kwargs):
"""创建模板分类"""
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
self.perform_create(serializer)
return ApiResponse.success(
data=serializer.data,
message="模板分类创建成功",
code=status.HTTP_201_CREATED
)
return ApiResponse.error(
message="模板分类创建失败",
data=serializer.errors,
code=status.HTTP_400_BAD_REQUEST
)
def update(self, request, *args, **kwargs):
"""更新模板分类"""
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
if serializer.is_valid():
self.perform_update(serializer)
return ApiResponse.success(
data=serializer.data,
message="模板分类更新成功"
)
return ApiResponse.error(
message="模板分类更新失败",
data=serializer.errors,
code=status.HTTP_400_BAD_REQUEST
)
def destroy(self, request, *args, **kwargs):
"""删除模板分类"""
instance = self.get_object()
self.perform_destroy(instance)
return ApiResponse.success(
data=None,
message="模板分类删除成功"
)
class TemplateViewSet(viewsets.ModelViewSet): class TemplateViewSet(viewsets.ModelViewSet):
""" """
模板视图集 模板视图集
@ -108,6 +31,7 @@ class TemplateViewSet(viewsets.ModelViewSet):
ordering_fields = ['created_at', 'updated_at', 'title'] ordering_fields = ['created_at', 'updated_at', 'title']
ordering = ['-created_at'] ordering = ['-created_at']
pagination_class = StandardResultsSetPagination pagination_class = StandardResultsSetPagination
http_method_names = ['get', 'post', 'put', 'patch', 'delete', 'head', 'options']
def get_queryset(self): def get_queryset(self):
""" """
@ -124,8 +48,14 @@ class TemplateViewSet(viewsets.ModelViewSet):
return TemplateDetailSerializer return TemplateDetailSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
"""获取所有模板""" """获取所有模板支持通过query参数mission进行筛选"""
queryset = self.filter_queryset(self.get_queryset()) # 从query_params中获取mission参数
mission = request.query_params.get('mission', None)
if mission and mission != 'all':
queryset = self.filter_queryset(self.get_queryset().filter(mission=mission))
else:
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
if page is not None: if page is not None:
serializer = self.get_serializer(page, many=True) serializer = self.get_serializer(page, many=True)
@ -225,7 +155,11 @@ class TemplateViewSet(viewsets.ModelViewSet):
code=status.HTTP_400_BAD_REQUEST code=status.HTTP_400_BAD_REQUEST
) )
queryset = self.get_queryset().filter(mission=mission) if mission == 'all':
queryset = self.get_queryset()
else:
queryset = self.get_queryset().filter(mission=mission)
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
if page is not None: if page is not None:
serializer = self.get_serializer(page, many=True) serializer = self.get_serializer(page, many=True)
@ -298,3 +232,72 @@ class TemplateViewSet(viewsets.ModelViewSet):
data=serializer.data, data=serializer.data,
message="按服务类型获取模板成功" message="按服务类型获取模板成功"
) )
@action(detail=False, methods=['get'])
def mission_options(self, request):
"""获取任务类型选项列表"""
mission_options = [
{'value': 'all', 'name': '全部'},
{'value': 'initial_contact', 'name': '初步建联'},
{'value': 'negotiation', 'name': '砍价邮件'},
{'value': 'script', 'name': '脚本邮件'},
{'value': 'follow_up', 'name': '合作追踪'},
]
return ApiResponse.success(
data=mission_options,
message="获取任务类型选项列表成功"
)
@action(detail=False, methods=['post'])
def search_templates(self, request):
"""
通过POST请求查询模板
请求体参数:
- mission: 任务类型
- page: 页码可选默认1
- page_size: 每页数量可选默认10
"""
# 获取查询参数
mission = request.data.get('mission', None)
page = request.data.get('page', 1)
page_size = request.data.get('page_size', 10)
try:
page = int(page)
page_size = int(page_size)
except (TypeError, ValueError):
return ApiResponse.error(
message="页码和每页数量必须是整数",
code=status.HTTP_400_BAD_REQUEST
)
# 构建查询集
if mission and mission != 'all':
queryset = self.filter_queryset(self.get_queryset().filter(mission=mission))
else:
queryset = self.filter_queryset(self.get_queryset())
# 设置分页
paginator = self.pagination_class()
paginator.page_size = page_size
paginated_queryset = paginator.paginate_queryset(queryset, request)
if paginated_queryset is not None:
serializer = self.get_serializer(paginated_queryset, many=True)
response_data = {
'count': paginator.page.paginator.count,
'next': paginator.get_next_link(),
'previous': paginator.get_previous_link(),
'results': serializer.data
}
return ApiResponse.success(
data=response_data,
message="查询模板成功"
)
serializer = self.get_serializer(queryset, many=True)
return ApiResponse.success(
data=serializer.data,
message="查询模板成功"
)