diff --git a/apps/accounts/migrations/0003_userprofile.py b/apps/accounts/migrations/0003_userprofile.py new file mode 100644 index 0000000..37369cf --- /dev/null +++ b/apps/accounts/migrations/0003_userprofile.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2 on 2025-05-09 03:11 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('accounts', '0002_delete_userprofile'), + ] + + operations = [ + migrations.CreateModel( + name='UserProfile', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('department', models.CharField(blank=True, help_text='部门', max_length=100)), + ('group', models.CharField(blank=True, help_text='小组', max_length=100)), + ('auto_recommend_reply', models.BooleanField(default=False, help_text='是否启用自动推荐回复功能')), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='profile', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'db_table': 'user_profiles', + }, + ), + ] diff --git a/apps/accounts/migrations/0004_delete_usergoal.py b/apps/accounts/migrations/0004_delete_usergoal.py new file mode 100644 index 0000000..890319d --- /dev/null +++ b/apps/accounts/migrations/0004_delete_usergoal.py @@ -0,0 +1,16 @@ +# Generated by Django 5.2 on 2025-05-12 04:43 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('accounts', '0003_userprofile'), + ] + + operations = [ + migrations.DeleteModel( + name='UserGoal', + ), + ] diff --git a/apps/accounts/models.py b/apps/accounts/models.py index 8086786..74734b4 100644 --- a/apps/accounts/models.py +++ b/apps/accounts/models.py @@ -98,20 +98,3 @@ class UserProfile(models.Model): def __str__(self): return f"{self.user.username}的个人资料" - -class UserGoal(models.Model): - """用户总目标模型""" - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='goals') - content = models.TextField(verbose_name='总目标内容') - is_active = models.BooleanField(default=True, verbose_name='是否激活') - created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') - updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') - - class Meta: - db_table = 'user_goals' - verbose_name = '用户总目标' - verbose_name_plural = '用户总目标' - - def __str__(self): - return f"{self.user.username}的总目标 - {self.content[:50]}..." \ No newline at end of file diff --git a/apps/accounts/serializers.py b/apps/accounts/serializers.py new file mode 100644 index 0000000..35344d7 --- /dev/null +++ b/apps/accounts/serializers.py @@ -0,0 +1,63 @@ +from rest_framework import serializers +from apps.accounts.models import User, UserProfile + +class UserProfileSerializer(serializers.ModelSerializer): + """用户档案序列化器""" + + class Meta: + model = UserProfile + fields = ['department', 'group', 'auto_recommend_reply'] + + +class UserSerializer(serializers.ModelSerializer): + """用户序列化器""" + profile = UserProfileSerializer(read_only=True) + + class Meta: + model = User + fields = [ + 'id', 'username', 'email', 'name', 'role', + 'department', 'group', 'profile', 'is_active', + 'date_joined', 'last_login' + ] + read_only_fields = ['id', 'date_joined', 'last_login'] + + +class UserCreateSerializer(serializers.ModelSerializer): + """创建用户的序列化器""" + password = serializers.CharField(write_only=True, required=True, style={'input_type': 'password'}) + + class Meta: + model = User + fields = [ + 'id', 'username', 'email', 'name', 'password', + 'role', 'department', 'group', 'is_active' + ] + read_only_fields = ['id'] + + def create(self, validated_data): + password = validated_data.pop('password') + user = User(**validated_data) + user.set_password(password) + user.save() + + # 创建用户档案 + UserProfile.objects.create( + user=user, + department=user.department, + group=user.group + ) + + return user + + +class PasswordChangeSerializer(serializers.Serializer): + """修改密码序列化器""" + old_password = serializers.CharField(required=True) + new_password = serializers.CharField(required=True) + + def validate_old_password(self, value): + user = self.context['request'].user + if not user.check_password(value): + raise serializers.ValidationError("旧密码不正确") + return value \ No newline at end of file diff --git a/apps/accounts/views.py b/apps/accounts/views.py index afa7768..f326223 100644 --- a/apps/accounts/views.py +++ b/apps/accounts/views.py @@ -228,6 +228,16 @@ def user_profile(request): for field in allowed_fields: if field in request.data: + # 检查name字段是否重名 + if field == 'name' and request.data['name'] != user.name: + # 检查是否有其他用户使用相同name + if User.objects.filter(name=request.data['name']).exclude(id=user.id).exists(): + return Response({ + 'code': 400, + 'message': '用户名称已存在', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + setattr(user, field, request.data[field]) updated_fields.append(field) diff --git a/apps/message/__init__.py b/apps/brands/__init__.py similarity index 100% rename from apps/message/__init__.py rename to apps/brands/__init__.py diff --git a/apps/brands/admin.py b/apps/brands/admin.py new file mode 100644 index 0000000..3b3c106 --- /dev/null +++ b/apps/brands/admin.py @@ -0,0 +1,96 @@ +from django.contrib import admin +from .models import Brand, Product, Campaign, BrandChatSession + +@admin.register(Brand) +class BrandAdmin(admin.ModelAdmin): + list_display = ('name', 'category', 'source', 'collab_count', 'creators_count', 'total_gmv_achieved', 'total_views_achieved', 'shop_overall_rating', 'created_at', 'is_active') + search_fields = ('name', 'description', 'category', 'source') + list_filter = ('is_active', 'created_at', 'category', 'source') + readonly_fields = ('id', 'created_at', 'updated_at') + fieldsets = ( + ('基本信息', { + 'fields': ('id', 'name', 'description', 'logo_url', 'is_active') + }), + ('分类信息', { + 'fields': ('category', 'source', 'collab_count', 'creators_count', 'campaign_id') + }), + ('统计信息', { + 'fields': ('total_gmv_achieved', 'total_views_achieved', 'shop_overall_rating') + }), + ('知识库关联', { + 'fields': ('dataset_id_list',) + }), + ('时间信息', { + 'fields': ('created_at', 'updated_at') + }), + ) + +@admin.register(Product) +class ProductAdmin(admin.ModelAdmin): + list_display = ('name', 'brand', 'pid', 'commission_rate', 'stock', 'items_sold', 'product_rating', 'created_at', 'is_active') + search_fields = ('name', 'description', 'brand__name', 'pid') + list_filter = ('brand', 'is_active', 'created_at', 'tiktok_shop') + readonly_fields = ('id', 'created_at', 'updated_at') + fieldsets = ( + ('基本信息', { + 'fields': ('id', 'name', 'brand', 'description', 'image_url', 'is_active') + }), + ('产品详情', { + 'fields': ('pid', 'commission_rate', 'open_collab', 'available_samples', + 'sales_price_min', 'sales_price_max', 'stock', 'items_sold', + 'product_rating', 'reviews_count', 'collab_creators', 'tiktok_shop') + }), + ('知识库信息', { + 'fields': ('dataset_id', 'external_id') + }), + ('时间信息', { + 'fields': ('created_at', 'updated_at') + }), + ) + +@admin.register(Campaign) +class CampaignAdmin(admin.ModelAdmin): + list_display = ('name', 'brand', 'service', 'creator_type', 'start_date', 'end_date', 'is_active') + search_fields = ('name', 'description', 'brand__name', 'service', 'creator_type') + list_filter = ('brand', 'is_active', 'start_date', 'end_date', 'service', 'creator_type') + readonly_fields = ('id', 'created_at', 'updated_at') + filter_horizontal = ('link_product',) + fieldsets = ( + ('基本信息', { + 'fields': ('id', 'name', 'brand', 'description', 'image_url', 'is_active') + }), + ('活动详情', { + 'fields': ('service', 'creator_type', 'creator_level', 'creator_category', + 'creators_count', 'gmv', 'followers', 'views', 'budget') + }), + ('关联产品', { + 'fields': ('link_product',) + }), + ('活动时间', { + 'fields': ('start_date', 'end_date') + }), + ('知识库信息', { + 'fields': ('dataset_id', 'external_id') + }), + ('时间信息', { + 'fields': ('created_at', 'updated_at') + }), + ) + +@admin.register(BrandChatSession) +class BrandChatSessionAdmin(admin.ModelAdmin): + list_display = ('title', 'brand', 'session_id', 'created_at', 'is_active') + search_fields = ('title', 'session_id', 'brand__name') + list_filter = ('brand', 'is_active', 'created_at') + readonly_fields = ('id', 'created_at', 'updated_at') + fieldsets = ( + ('基本信息', { + 'fields': ('id', 'title', 'brand', 'session_id', 'is_active') + }), + ('知识库信息', { + 'fields': ('dataset_id_list',) + }), + ('时间信息', { + 'fields': ('created_at', 'updated_at') + }), + ) diff --git a/apps/brands/apps.py b/apps/brands/apps.py new file mode 100644 index 0000000..24a1db0 --- /dev/null +++ b/apps/brands/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class BrandsConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.brands' diff --git a/apps/brands/migrations/0001_initial.py b/apps/brands/migrations/0001_initial.py new file mode 100644 index 0000000..c94f925 --- /dev/null +++ b/apps/brands/migrations/0001_initial.py @@ -0,0 +1,99 @@ +# Generated by Django 5.2 on 2025-05-09 03:55 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='Brand', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=100, unique=True, verbose_name='品牌名称')), + ('description', models.TextField(blank=True, null=True, verbose_name='品牌描述')), + ('logo_url', models.CharField(blank=True, max_length=255, null=True, verbose_name='品牌Logo')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('is_active', models.BooleanField(default=True, verbose_name='是否激活')), + ('dataset_id_list', models.JSONField(blank=True, default=list, help_text='所有关联的知识库ID列表', verbose_name='知识库ID列表')), + ], + options={ + 'verbose_name': '品牌', + 'verbose_name_plural': '品牌', + 'db_table': 'brands', + }, + ), + migrations.CreateModel( + name='Activity', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=100, verbose_name='活动名称')), + ('description', models.TextField(blank=True, null=True, verbose_name='活动描述')), + ('image_url', models.CharField(blank=True, max_length=255, null=True, verbose_name='活动图片')), + ('start_date', models.DateTimeField(blank=True, null=True, verbose_name='开始日期')), + ('end_date', models.DateTimeField(blank=True, null=True, verbose_name='结束日期')), + ('dataset_id', models.CharField(help_text='外部知识库系统中的ID', max_length=100, verbose_name='知识库ID')), + ('external_id', models.CharField(blank=True, help_text='外部系统中的唯一标识', max_length=100, null=True, verbose_name='外部ID')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('is_active', models.BooleanField(default=True, verbose_name='是否激活')), + ('brand', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='activities', to='brands.brand', verbose_name='所属品牌')), + ], + options={ + 'verbose_name': '活动', + 'verbose_name_plural': '活动', + 'db_table': 'activities', + 'indexes': [models.Index(fields=['brand'], name='activities_brand_i_9a57da_idx'), models.Index(fields=['dataset_id'], name='activities_dataset_c873ab_idx'), models.Index(fields=['is_active'], name='activities_is_acti_cff4bc_idx'), models.Index(fields=['start_date'], name='activities_start_d_fe1952_idx'), models.Index(fields=['end_date'], name='activities_end_dat_9cb2d8_idx')], + 'unique_together': {('brand', 'name')}, + }, + ), + migrations.CreateModel( + name='BrandChatSession', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('session_id', models.CharField(max_length=100, unique=True, verbose_name='会话ID')), + ('title', models.CharField(default='新对话', max_length=200, verbose_name='会话标题')), + ('dataset_id_list', models.JSONField(blank=True, default=list, verbose_name='知识库ID列表')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('is_active', models.BooleanField(default=True, verbose_name='是否激活')), + ('brand', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='chat_sessions', to='brands.brand', verbose_name='品牌')), + ], + options={ + 'verbose_name': '品牌聊天会话', + 'verbose_name_plural': '品牌聊天会话', + 'db_table': 'brand_chat_sessions', + 'indexes': [models.Index(fields=['brand'], name='brand_chat__brand_i_83752e_idx'), models.Index(fields=['session_id'], name='brand_chat__session_4bf9b0_idx'), models.Index(fields=['created_at'], name='brand_chat__created_957266_idx')], + }, + ), + migrations.CreateModel( + name='Product', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=100, verbose_name='产品名称')), + ('description', models.TextField(blank=True, null=True, verbose_name='产品描述')), + ('image_url', models.CharField(blank=True, max_length=255, null=True, verbose_name='产品图片')), + ('dataset_id', models.CharField(help_text='外部知识库系统中的ID', max_length=100, verbose_name='知识库ID')), + ('external_id', models.CharField(blank=True, help_text='外部系统中的唯一标识', max_length=100, null=True, verbose_name='外部ID')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('is_active', models.BooleanField(default=True, verbose_name='是否激活')), + ('brand', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='products', to='brands.brand', verbose_name='所属品牌')), + ], + options={ + 'verbose_name': '产品', + 'verbose_name_plural': '产品', + 'db_table': 'products', + 'indexes': [models.Index(fields=['brand'], name='products_brand_i_0d1950_idx'), models.Index(fields=['dataset_id'], name='products_dataset_faf62a_idx'), models.Index(fields=['is_active'], name='products_is_acti_cb485f_idx')], + 'unique_together': {('brand', 'name')}, + }, + ), + ] diff --git a/apps/brands/migrations/0002_campaign_alter_activity_unique_together_and_more.py b/apps/brands/migrations/0002_campaign_alter_activity_unique_together_and_more.py new file mode 100644 index 0000000..2e0172a --- /dev/null +++ b/apps/brands/migrations/0002_campaign_alter_activity_unique_together_and_more.py @@ -0,0 +1,194 @@ +# Generated by Django 5.2 on 2025-05-13 02:45 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('brands', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Campaign', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=100, verbose_name='活动名称')), + ('description', models.TextField(blank=True, null=True, verbose_name='活动描述')), + ('image_url', models.CharField(blank=True, max_length=255, null=True, verbose_name='活动图片')), + ('service', models.CharField(blank=True, max_length=100, null=True, verbose_name='服务类型')), + ('creator_type', models.CharField(blank=True, max_length=100, null=True, verbose_name='创作者类型')), + ('creator_level', models.CharField(blank=True, max_length=100, null=True, verbose_name='创作者等级')), + ('creator_category', models.CharField(blank=True, max_length=100, null=True, verbose_name='创作者分类')), + ('creators_count', models.IntegerField(default=0, verbose_name='创作者数量')), + ('gmv', models.CharField(blank=True, max_length=100, null=True, verbose_name='GMV范围')), + ('followers', models.CharField(blank=True, max_length=100, null=True, verbose_name='粉丝数范围')), + ('views', models.CharField(blank=True, max_length=100, null=True, verbose_name='浏览量范围')), + ('budget', models.CharField(blank=True, max_length=100, null=True, verbose_name='预算范围')), + ('start_date', models.DateTimeField(blank=True, null=True, verbose_name='开始日期')), + ('end_date', models.DateTimeField(blank=True, null=True, verbose_name='结束日期')), + ('dataset_id', models.CharField(help_text='外部知识库系统中的ID', max_length=100, verbose_name='知识库ID')), + ('external_id', models.CharField(blank=True, help_text='外部系统中的唯一标识', max_length=100, null=True, verbose_name='外部ID')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('is_active', models.BooleanField(default=True, verbose_name='是否激活')), + ], + options={ + 'verbose_name': '活动', + 'verbose_name_plural': '活动', + 'db_table': 'campaigns', + }, + ), + migrations.AlterUniqueTogether( + name='activity', + unique_together=None, + ), + migrations.RemoveField( + model_name='activity', + name='brand', + ), + migrations.AddField( + model_name='brand', + name='campaign_id', + field=models.CharField(blank=True, max_length=100, null=True, verbose_name='活动ID'), + ), + migrations.AddField( + model_name='brand', + name='category', + field=models.CharField(blank=True, max_length=100, null=True, verbose_name='品牌分类'), + ), + migrations.AddField( + model_name='brand', + name='collab_count', + field=models.IntegerField(default=0, verbose_name='合作数量'), + ), + migrations.AddField( + model_name='brand', + name='creators_count', + field=models.IntegerField(default=0, verbose_name='创作者数量'), + ), + migrations.AddField( + model_name='brand', + name='shop_overall_rating', + field=models.DecimalField(decimal_places=1, default=0.0, max_digits=3, verbose_name='店铺评分'), + ), + migrations.AddField( + model_name='brand', + name='source', + field=models.CharField(blank=True, max_length=100, null=True, verbose_name='来源'), + ), + migrations.AddField( + model_name='brand', + name='total_gmv_achieved', + field=models.DecimalField(decimal_places=2, default=0, max_digits=12, verbose_name='总GMV'), + ), + migrations.AddField( + model_name='brand', + name='total_views_achieved', + field=models.DecimalField(decimal_places=2, default=0, max_digits=12, verbose_name='总浏览量'), + ), + migrations.AddField( + model_name='product', + name='available_samples', + field=models.IntegerField(default=0, verbose_name='可用样品数'), + ), + migrations.AddField( + model_name='product', + name='collab_creators', + field=models.IntegerField(default=0, verbose_name='合作创作者数'), + ), + migrations.AddField( + model_name='product', + name='commission_rate', + field=models.DecimalField(decimal_places=2, default=0, max_digits=5, verbose_name='佣金率'), + ), + migrations.AddField( + model_name='product', + name='items_sold', + field=models.IntegerField(default=0, verbose_name='已售数量'), + ), + migrations.AddField( + model_name='product', + name='open_collab', + field=models.DecimalField(decimal_places=2, default=0, max_digits=5, verbose_name='开放合作率'), + ), + migrations.AddField( + model_name='product', + name='pid', + field=models.CharField(blank=True, max_length=100, null=True, verbose_name='产品ID'), + ), + migrations.AddField( + model_name='product', + name='product_rating', + field=models.DecimalField(decimal_places=1, default=0, max_digits=3, verbose_name='产品评分'), + ), + migrations.AddField( + model_name='product', + name='reviews_count', + field=models.IntegerField(default=0, verbose_name='评价数量'), + ), + migrations.AddField( + model_name='product', + name='sales_price_max', + field=models.DecimalField(decimal_places=2, default=0, max_digits=10, verbose_name='最高销售价'), + ), + migrations.AddField( + model_name='product', + name='sales_price_min', + field=models.DecimalField(decimal_places=2, default=0, max_digits=10, verbose_name='最低销售价'), + ), + migrations.AddField( + model_name='product', + name='stock', + field=models.IntegerField(default=0, verbose_name='库存'), + ), + migrations.AddField( + model_name='product', + name='tiktok_shop', + field=models.BooleanField(default=False, verbose_name='是否TikTok商店'), + ), + migrations.AddIndex( + model_name='product', + index=models.Index(fields=['pid'], name='products_pid_99aab2_idx'), + ), + migrations.AddField( + model_name='campaign', + name='brand', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='campaigns', to='brands.brand', verbose_name='所属品牌'), + ), + migrations.AddField( + model_name='campaign', + name='link_product', + field=models.ManyToManyField(blank=True, related_name='campaigns', to='brands.product', verbose_name='关联产品'), + ), + migrations.DeleteModel( + name='Activity', + ), + migrations.AddIndex( + model_name='campaign', + index=models.Index(fields=['brand'], name='campaigns_brand_i_c2d4bd_idx'), + ), + migrations.AddIndex( + model_name='campaign', + index=models.Index(fields=['dataset_id'], name='campaigns_dataset_bfbb68_idx'), + ), + migrations.AddIndex( + model_name='campaign', + index=models.Index(fields=['is_active'], name='campaigns_is_acti_6c57d0_idx'), + ), + migrations.AddIndex( + model_name='campaign', + index=models.Index(fields=['start_date'], name='campaigns_start_d_5c2c6b_idx'), + ), + migrations.AddIndex( + model_name='campaign', + index=models.Index(fields=['end_date'], name='campaigns_end_dat_6aaba4_idx'), + ), + migrations.AlterUniqueTogether( + name='campaign', + unique_together={('brand', 'name')}, + ), + ] diff --git a/apps/brands/migrations/__init__.py b/apps/brands/migrations/__init__.py new file mode 100644 index 0000000..b28b04f --- /dev/null +++ b/apps/brands/migrations/__init__.py @@ -0,0 +1,3 @@ + + + diff --git a/apps/brands/models.py b/apps/brands/models.py new file mode 100644 index 0000000..e30edda --- /dev/null +++ b/apps/brands/models.py @@ -0,0 +1,181 @@ +from django.db import models +import uuid +from django.utils import timezone + +class Brand(models.Model): + """品牌模型""" + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + name = models.CharField(max_length=100, unique=True, verbose_name='品牌名称') + description = models.TextField(blank=True, null=True, verbose_name='品牌描述') + logo_url = models.CharField(max_length=255, blank=True, null=True, verbose_name='品牌Logo') + category = models.CharField(max_length=100, blank=True, null=True, verbose_name='品牌分类') + source = models.CharField(max_length=100, blank=True, null=True, verbose_name='来源') + collab_count = models.IntegerField(default=0, verbose_name='合作数量') + creators_count = models.IntegerField(default=0, verbose_name='创作者数量') + campaign_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='活动ID') + + # 添加数据统计字段 + total_gmv_achieved = models.DecimalField(max_digits=12, decimal_places=2, default=0, verbose_name='总GMV') + total_views_achieved = models.DecimalField(max_digits=12, decimal_places=2, default=0, verbose_name='总浏览量') + shop_overall_rating = models.DecimalField(max_digits=3, decimal_places=1, default=0.0, verbose_name='店铺评分') + + # 存储关联到此品牌的所有产品和活动知识库ID列表 + dataset_id_list = models.JSONField(default=list, blank=True, verbose_name='知识库ID列表', + help_text='所有关联的知识库ID列表') + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + is_active = models.BooleanField(default=True, verbose_name='是否激活') + + class Meta: + db_table = 'brands' + verbose_name = '品牌' + verbose_name_plural = '品牌' + + def __str__(self): + return self.name + + +class Product(models.Model): + """产品模型 - 作为一个知识库""" + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + brand = models.ForeignKey(Brand, on_delete=models.CASCADE, related_name='products', verbose_name='所属品牌') + name = models.CharField(max_length=100, verbose_name='产品名称') + description = models.TextField(blank=True, null=True, verbose_name='产品描述') + image_url = models.CharField(max_length=255, blank=True, null=True, verbose_name='产品图片') + + # 添加产品详情字段 + pid = models.CharField(max_length=100, blank=True, null=True, verbose_name='产品ID') + commission_rate = models.DecimalField(max_digits=5, decimal_places=2, default=0, verbose_name='佣金率') + open_collab = models.DecimalField(max_digits=5, decimal_places=2, default=0, verbose_name='开放合作率') + available_samples = models.IntegerField(default=0, verbose_name='可用样品数') + sales_price_min = models.DecimalField(max_digits=10, decimal_places=2, default=0, verbose_name='最低销售价') + sales_price_max = models.DecimalField(max_digits=10, decimal_places=2, default=0, verbose_name='最高销售价') + stock = models.IntegerField(default=0, verbose_name='库存') + items_sold = models.IntegerField(default=0, verbose_name='已售数量') + product_rating = models.DecimalField(max_digits=3, decimal_places=1, default=0, verbose_name='产品评分') + reviews_count = models.IntegerField(default=0, verbose_name='评价数量') + collab_creators = models.IntegerField(default=0, verbose_name='合作创作者数') + tiktok_shop = models.BooleanField(default=False, verbose_name='是否TikTok商店') + + dataset_id = models.CharField(max_length=100, verbose_name='知识库ID', + help_text='外部知识库系统中的ID') + external_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='外部ID', + help_text='外部系统中的唯一标识') + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + is_active = models.BooleanField(default=True, verbose_name='是否激活') + + class Meta: + db_table = 'products' + verbose_name = '产品' + verbose_name_plural = '产品' + unique_together = ['brand', 'name'] + indexes = [ + models.Index(fields=['brand']), + models.Index(fields=['dataset_id']), + models.Index(fields=['is_active']), + models.Index(fields=['pid']), + ] + + def __str__(self): + return f"{self.brand.name} - {self.name}" + + def save(self, *args, **kwargs): + """重写save方法,更新品牌的dataset_id_list""" + is_new = self.pk is None + super().save(*args, **kwargs) + + # 刷新品牌的dataset_id_list + if is_new and self.is_active and self.dataset_id: + brand = self.brand + if self.dataset_id not in brand.dataset_id_list: + brand.dataset_id_list.append(self.dataset_id) + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + +class Campaign(models.Model): + """活动模型 - 作为一个知识库""" + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + brand = models.ForeignKey(Brand, on_delete=models.CASCADE, related_name='campaigns', verbose_name='所属品牌') + name = models.CharField(max_length=100, verbose_name='活动名称') + description = models.TextField(blank=True, null=True, verbose_name='活动描述') + image_url = models.CharField(max_length=255, blank=True, null=True, verbose_name='活动图片') + + # 活动相关字段 + service = models.CharField(max_length=100, blank=True, null=True, verbose_name='服务类型') + creator_type = models.CharField(max_length=100, blank=True, null=True, verbose_name='创作者类型') + creator_level = models.CharField(max_length=100, blank=True, null=True, verbose_name='创作者等级') + creator_category = models.CharField(max_length=100, blank=True, null=True, verbose_name='创作者分类') + creators_count = models.IntegerField(default=0, verbose_name='创作者数量') + gmv = models.CharField(max_length=100, blank=True, null=True, verbose_name='GMV范围') + followers = models.CharField(max_length=100, blank=True, null=True, verbose_name='粉丝数范围') + views = models.CharField(max_length=100, blank=True, null=True, verbose_name='浏览量范围') + budget = models.CharField(max_length=100, blank=True, null=True, verbose_name='预算范围') + link_product = models.ManyToManyField(Product, blank=True, related_name='campaigns', verbose_name='关联产品') + + # 时间信息 + start_date = models.DateTimeField(blank=True, null=True, verbose_name='开始日期') + end_date = models.DateTimeField(blank=True, null=True, verbose_name='结束日期') + + # 知识库信息 + dataset_id = models.CharField(max_length=100, verbose_name='知识库ID', + help_text='外部知识库系统中的ID') + external_id = models.CharField(max_length=100, blank=True, null=True, verbose_name='外部ID', + help_text='外部系统中的唯一标识') + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + is_active = models.BooleanField(default=True, verbose_name='是否激活') + + class Meta: + db_table = 'campaigns' + verbose_name = '活动' + verbose_name_plural = '活动' + unique_together = ['brand', 'name'] + indexes = [ + models.Index(fields=['brand']), + models.Index(fields=['dataset_id']), + models.Index(fields=['is_active']), + models.Index(fields=['start_date']), + models.Index(fields=['end_date']), + ] + + def __str__(self): + return f"{self.brand.name} - {self.name}" + + def save(self, *args, **kwargs): + """重写save方法,更新品牌的dataset_id_list""" + is_new = self.pk is None + super().save(*args, **kwargs) + + # 刷新品牌的dataset_id_list + if is_new and self.is_active and self.dataset_id: + brand = self.brand + if self.dataset_id not in brand.dataset_id_list: + brand.dataset_id_list.append(self.dataset_id) + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + +class BrandChatSession(models.Model): + """品牌聊天会话模型""" + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + brand = models.ForeignKey(Brand, on_delete=models.CASCADE, related_name='chat_sessions', verbose_name='品牌') + session_id = models.CharField(max_length=100, unique=True, verbose_name='会话ID') + title = models.CharField(max_length=200, default='新对话', verbose_name='会话标题') + # 存储此次会话使用的所有知识库ID + dataset_id_list = models.JSONField(default=list, blank=True, verbose_name='知识库ID列表') + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + is_active = models.BooleanField(default=True, verbose_name='是否激活') + + class Meta: + db_table = 'brand_chat_sessions' + verbose_name = '品牌聊天会话' + verbose_name_plural = '品牌聊天会话' + indexes = [ + models.Index(fields=['brand']), + models.Index(fields=['session_id']), + models.Index(fields=['created_at']), + ] + + def __str__(self): + return f"{self.brand.name} - {self.title}" diff --git a/apps/brands/serializers.py b/apps/brands/serializers.py new file mode 100644 index 0000000..90e0fd0 --- /dev/null +++ b/apps/brands/serializers.py @@ -0,0 +1,67 @@ +from rest_framework import serializers +from .models import Brand, Product, Campaign, BrandChatSession + +class BrandSerializer(serializers.ModelSerializer): + """品牌序列化器""" + class Meta: + model = Brand + fields = ['id', 'name', 'description', 'logo_url', 'category', 'source', + 'collab_count', 'creators_count', 'campaign_id', 'total_gmv_achieved', + 'total_views_achieved', 'shop_overall_rating', 'dataset_id_list', + 'created_at', 'updated_at', 'is_active'] + read_only_fields = ['id', 'created_at', 'updated_at', 'dataset_id_list'] + + +class ProductSerializer(serializers.ModelSerializer): + """产品序列化器""" + brand_name = serializers.CharField(source='brand.name', read_only=True) + + class Meta: + model = Product + fields = ['id', 'brand', 'brand_name', 'name', 'description', 'image_url', + 'pid', 'commission_rate', 'open_collab', 'available_samples', + 'sales_price_min', 'sales_price_max', 'stock', 'items_sold', + 'product_rating', 'reviews_count', 'collab_creators', 'tiktok_shop', + 'dataset_id', 'external_id', 'created_at', 'updated_at', 'is_active'] + read_only_fields = ['id', 'created_at', 'updated_at'] + + +class CampaignSerializer(serializers.ModelSerializer): + """活动序列化器""" + brand_name = serializers.CharField(source='brand.name', read_only=True) + link_product_details = ProductSerializer(source='link_product', many=True, read_only=True) + + class Meta: + model = Campaign + fields = ['id', 'brand', 'brand_name', 'name', 'description', 'image_url', + 'service', 'creator_type', 'creator_level', 'creator_category', + 'creators_count', 'gmv', 'followers', 'views', 'budget', + 'link_product', 'link_product_details', + 'start_date', 'end_date', 'dataset_id', 'external_id', + 'created_at', 'updated_at', 'is_active'] + read_only_fields = ['id', 'created_at', 'updated_at'] + + +class BrandChatSessionSerializer(serializers.ModelSerializer): + """品牌聊天会话序列化器""" + brand_name = serializers.CharField(source='brand.name', read_only=True) + + class Meta: + model = BrandChatSession + fields = ['id', 'brand', 'brand_name', 'session_id', 'title', + 'dataset_id_list', 'created_at', 'updated_at', 'is_active'] + read_only_fields = ['id', 'created_at', 'updated_at'] + + +class BrandDetailSerializer(serializers.ModelSerializer): + """品牌详情序列化器""" + products = ProductSerializer(many=True, read_only=True) + campaigns = CampaignSerializer(many=True, read_only=True) + + class Meta: + model = Brand + fields = ['id', 'name', 'description', 'logo_url', 'category', 'source', + 'collab_count', 'creators_count', 'campaign_id', 'total_gmv_achieved', + 'total_views_achieved', 'shop_overall_rating', 'dataset_id_list', + 'products', 'campaigns', 'created_at', 'updated_at', 'is_active'] + read_only_fields = ['id', 'created_at', 'updated_at', 'dataset_id_list'] \ No newline at end of file diff --git a/apps/brands/services/__init__.py b/apps/brands/services/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/apps/brands/services/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/message/tests.py b/apps/brands/tests.py similarity index 100% rename from apps/message/tests.py rename to apps/brands/tests.py diff --git a/apps/brands/urls.py b/apps/brands/urls.py new file mode 100644 index 0000000..740e70c --- /dev/null +++ b/apps/brands/urls.py @@ -0,0 +1,18 @@ +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from .views import ( + BrandViewSet, + ProductViewSet, + CampaignViewSet, + BrandChatSessionViewSet +) + +router = DefaultRouter() +router.register(r'brands', BrandViewSet) +router.register(r'products', ProductViewSet) +router.register(r'campaigns', CampaignViewSet) +router.register(r'chat-sessions', BrandChatSessionViewSet) + +urlpatterns = [ + path('', include(router.urls)), +] \ No newline at end of file diff --git a/apps/brands/views.py b/apps/brands/views.py new file mode 100644 index 0000000..a1ab38b --- /dev/null +++ b/apps/brands/views.py @@ -0,0 +1,322 @@ +from django.shortcuts import render, get_object_or_404 +from rest_framework import viewsets, status +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.permissions import IsAuthenticated + +from .models import Brand, Product, Campaign, BrandChatSession +from .serializers import ( + BrandSerializer, + ProductSerializer, + CampaignSerializer, + BrandChatSessionSerializer, + BrandDetailSerializer +) + +def api_response(code=200, message="成功", data=None): + """统一API响应格式""" + return Response({ + 'code': code, + 'message': message, + 'data': data + }) + +class BrandViewSet(viewsets.ModelViewSet): + """品牌API视图集""" + queryset = Brand.objects.all() + serializer_class = BrandSerializer + permission_classes = [IsAuthenticated] + + def get_serializer_class(self): + if self.action == 'retrieve': + return BrandDetailSerializer + return BrandSerializer + + def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + serializer = self.get_serializer(queryset, many=True) + return api_response(data=serializer.data) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + if serializer.is_valid(): + self.perform_create(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="创建失败", data=serializer.errors) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return api_response(data=serializer.data) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + if serializer.is_valid(): + self.perform_update(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="更新失败", data=serializer.errors) + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return api_response(message="删除成功", data=None) + + @action(detail=True, methods=['get']) + def products(self, request, pk=None): + """获取品牌下的所有产品""" + brand = self.get_object() + products = Product.objects.filter(brand=brand, is_active=True) + serializer = ProductSerializer(products, many=True) + return api_response(data=serializer.data) + + @action(detail=True, methods=['get']) + def campaigns(self, request, pk=None): + """获取品牌下的所有活动""" + brand = self.get_object() + campaigns = Campaign.objects.filter(brand=brand, is_active=True) + serializer = CampaignSerializer(campaigns, many=True) + return api_response(data=serializer.data) + + @action(detail=True, methods=['get']) + def dataset_ids(self, request, pk=None): + """获取品牌的所有知识库ID""" + brand = self.get_object() + return api_response(data={'dataset_id_list': brand.dataset_id_list}) + + +class ProductViewSet(viewsets.ModelViewSet): + """产品API视图集""" + queryset = Product.objects.filter(is_active=True) + serializer_class = ProductSerializer + permission_classes = [IsAuthenticated] + + def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + serializer = self.get_serializer(queryset, many=True) + return api_response(data=serializer.data) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + if serializer.is_valid(): + self.perform_create(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="创建失败", data=serializer.errors) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return api_response(data=serializer.data) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + if serializer.is_valid(): + self.perform_update(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="更新失败", data=serializer.errors) + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return api_response(message="删除成功", data=None) + + def perform_create(self, serializer): + # 创建产品时自动更新品牌的dataset_id_list + product = serializer.save() + brand = product.brand + + # 确保dataset_id添加到品牌的dataset_id_list中 + if product.dataset_id and product.dataset_id not in brand.dataset_id_list: + brand.dataset_id_list.append(product.dataset_id) + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + def perform_update(self, serializer): + # 获取原始产品信息 + old_product = self.get_object() + old_dataset_id = old_product.dataset_id + + # 保存更新后的产品 + product = serializer.save() + brand = product.brand + + # 从品牌的dataset_id_list中移除旧的dataset_id,添加新的dataset_id + if old_dataset_id in brand.dataset_id_list: + brand.dataset_id_list.remove(old_dataset_id) + + if product.dataset_id and product.dataset_id not in brand.dataset_id_list: + brand.dataset_id_list.append(product.dataset_id) + + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + def perform_destroy(self, instance): + # 软删除产品,并从品牌的dataset_id_list中移除对应的ID + instance.is_active = False + instance.save() + + brand = instance.brand + if instance.dataset_id in brand.dataset_id_list: + brand.dataset_id_list.remove(instance.dataset_id) + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + +class CampaignViewSet(viewsets.ModelViewSet): + """活动API视图集""" + queryset = Campaign.objects.filter(is_active=True) + serializer_class = CampaignSerializer + permission_classes = [IsAuthenticated] + + def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + serializer = self.get_serializer(queryset, many=True) + return api_response(data=serializer.data) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + if serializer.is_valid(): + self.perform_create(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="创建失败", data=serializer.errors) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return api_response(data=serializer.data) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + if serializer.is_valid(): + self.perform_update(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="更新失败", data=serializer.errors) + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return api_response(message="删除成功", data=None) + + def perform_create(self, serializer): + # 创建活动时自动更新品牌的dataset_id_list + campaign = serializer.save() + brand = campaign.brand + + # 确保dataset_id添加到品牌的dataset_id_list中 + if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list: + brand.dataset_id_list.append(campaign.dataset_id) + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + def perform_update(self, serializer): + # 获取原始活动信息 + old_campaign = self.get_object() + old_dataset_id = old_campaign.dataset_id + + # 保存更新后的活动 + campaign = serializer.save() + brand = campaign.brand + + # 从品牌的dataset_id_list中移除旧的dataset_id,添加新的dataset_id + if old_dataset_id in brand.dataset_id_list: + brand.dataset_id_list.remove(old_dataset_id) + + if campaign.dataset_id and campaign.dataset_id not in brand.dataset_id_list: + brand.dataset_id_list.append(campaign.dataset_id) + + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + def perform_destroy(self, instance): + # 软删除活动,并从品牌的dataset_id_list中移除对应的ID + instance.is_active = False + instance.save() + + brand = instance.brand + if instance.dataset_id in brand.dataset_id_list: + brand.dataset_id_list.remove(instance.dataset_id) + brand.save(update_fields=['dataset_id_list', 'updated_at']) + + @action(detail=True, methods=['post']) + def add_product(self, request, pk=None): + """将产品添加到活动中""" + campaign = self.get_object() + product_id = request.data.get('product_id') + + if not product_id: + return api_response(code=400, message="缺少产品ID", data=None) + + try: + product = Product.objects.get(id=product_id, is_active=True) + campaign.link_product.add(product) + return api_response(message="产品添加成功", data=None) + except Product.DoesNotExist: + return api_response(code=404, message="产品不存在", data=None) + except Exception as e: + return api_response(code=500, message=f"添加产品失败: {str(e)}", data=None) + + @action(detail=True, methods=['post']) + def remove_product(self, request, pk=None): + """从活动中移除产品""" + campaign = self.get_object() + product_id = request.data.get('product_id') + + if not product_id: + return api_response(code=400, message="缺少产品ID", data=None) + + try: + product = Product.objects.get(id=product_id) + campaign.link_product.remove(product) + return api_response(message="产品移除成功", data=None) + except Product.DoesNotExist: + return api_response(code=404, message="产品不存在", data=None) + except Exception as e: + return api_response(code=500, message=f"移除产品失败: {str(e)}", data=None) + + +class BrandChatSessionViewSet(viewsets.ModelViewSet): + """品牌聊天会话API视图集""" + queryset = BrandChatSession.objects.filter(is_active=True) + serializer_class = BrandChatSessionSerializer + permission_classes = [IsAuthenticated] + + def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + serializer = self.get_serializer(queryset, many=True) + return api_response(data=serializer.data) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + if serializer.is_valid(): + self.perform_create(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="创建失败", data=serializer.errors) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return api_response(data=serializer.data) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + if serializer.is_valid(): + self.perform_update(serializer) + return api_response(data=serializer.data) + return api_response(code=400, message="更新失败", data=serializer.errors) + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return api_response(message="删除成功", data=None) + + def perform_create(self, serializer): + # 创建聊天会话时,可以设置使用特定品牌下的所有知识库 + chat_session = serializer.save() + + # 如果没有提供dataset_id_list,则使用品牌的dataset_id_list + if not chat_session.dataset_id_list: + brand = chat_session.brand + chat_session.dataset_id_list = brand.dataset_id_list + chat_session.save(update_fields=['dataset_id_list', 'updated_at']) diff --git a/apps/chat/consumers.py b/apps/chat/consumers.py index bc7c4d9..862ebfd 100644 --- a/apps/chat/consumers.py +++ b/apps/chat/consumers.py @@ -1,17 +1,16 @@ # apps/chat/consumers.py +from channels.generic.websocket import AsyncWebsocketConsumer import json +from channels.db import database_sync_to_async +from apps.knowledge_base.models import KnowledgeBase +from apps.chat.models import ChatHistory +from rest_framework.authtoken.models import Token +from django.conf import settings import logging import traceback -from channels.generic.websocket import AsyncWebsocketConsumer -from channels.db import database_sync_to_async -from rest_framework.authtoken.models import Token -from urllib.parse import parse_qs -from apps.chat.models import ChatHistory -from apps.knowledge_base.models import KnowledgeBase -from django.conf import settings -import aiohttp import uuid -from apps.common.services.permission_service import PermissionService +import aiohttp +from urllib.parse import parse_qs logger = logging.getLogger(__name__) diff --git a/apps/chat/routing.py b/apps/chat/routing.py index 278b16f..0d84cfb 100644 --- a/apps/chat/routing.py +++ b/apps/chat/routing.py @@ -1,8 +1,7 @@ # apps/chat/routing.py from django.urls import re_path -from apps.chat.consumers import ChatConsumer, ChatStreamConsumer +from apps.chat.consumers import ChatStreamConsumer websocket_urlpatterns = [ - re_path(r'ws/chat/$', ChatConsumer.as_asgi()), re_path(r'ws/chat/stream/$', ChatStreamConsumer.as_asgi()), ] diff --git a/apps/common/services/chat_service.py b/apps/common/services/chat_service.py index f8ccea1..88d284a 100644 --- a/apps/common/services/chat_service.py +++ b/apps/common/services/chat_service.py @@ -7,6 +7,7 @@ from apps.accounts.models import User from apps.knowledge_base.models import KnowledgeBase from apps.chat.models import ChatHistory from apps.permissions.services.permission_service import KnowledgeBasePermissionMixin +from django.db.models import Q logger = logging.getLogger(__name__) diff --git a/apps/common/services/notification_service.py b/apps/common/services/notification_service.py index 023112c..4f333f5 100644 --- a/apps/common/services/notification_service.py +++ b/apps/common/services/notification_service.py @@ -2,7 +2,7 @@ import logging from asgiref.sync import async_to_sync from channels.layers import get_channel_layer -from apps.message.models import Notification +from apps.notification.models import Notification logger = logging.getLogger(__name__) @@ -19,25 +19,33 @@ class NotificationService: related_resource=related_object_id, ) + # 准备发送到WebSocket的数据 + notification_data = { + "id": str(notification.id), + "title": notification.title, + "content": notification.content, + "type": notification.type, + "created_at": notification.created_at.isoformat(), + } + + # 只有当sender不为None时才添加sender信息 + if notification.sender: + notification_data["sender"] = { + "id": str(notification.sender.id), + "name": notification.sender.name + } + channel_layer = get_channel_layer() async_to_sync(channel_layer.group_send)( f"notification_user_{user.id}", { "type": "notification", - "data": { - "id": str(notification.id), - "title": notification.title, - "content": notification.content, - "type": notification.type, - "created_at": notification.created_at.isoformat(), - "sender": { - "id": str(notification.sender.id), - "name": notification.sender.name - } if notification.sender else None - } + "data": notification_data } ) + return notification except Exception as e: logger.error(f"发送通知失败: {str(e)}") + return None \ No newline at end of file diff --git a/apps/gmail/migrations/0001_initial.py b/apps/gmail/migrations/0001_initial.py index 700be04..a7dcff2 100644 --- a/apps/gmail/migrations/0001_initial.py +++ b/apps/gmail/migrations/0001_initial.py @@ -1,7 +1,6 @@ -# Generated by Django 5.2 on 2025-05-07 03:40 +# Generated by Django 5.2 on 2025-05-12 06:56 import django.db.models.deletion -import uuid from django.conf import settings from django.db import migrations, models @@ -11,72 +10,24 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ('chat', '0001_initial'), - ('knowledge_base', '0001_initial'), migrations.swappable_dependency(settings.AUTH_USER_MODEL), ] operations = [ - migrations.CreateModel( - name='GmailAttachment', - fields=[ - ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), - ('gmail_message_id', models.CharField(max_length=100, verbose_name='Gmail消息ID')), - ('filename', models.CharField(max_length=255, verbose_name='文件名')), - ('filepath', models.CharField(max_length=500, verbose_name='文件路径')), - ('mimetype', models.CharField(max_length=100, verbose_name='MIME类型')), - ('filesize', models.IntegerField(default=0, verbose_name='文件大小')), - ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), - ('chat_message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='gmail_attachments', to='chat.chathistory')), - ], - options={ - 'verbose_name': 'Gmail附件', - 'verbose_name_plural': 'Gmail附件', - 'db_table': 'gmail_attachments', - }, - ), migrations.CreateModel( name='GmailCredential', fields=[ - ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), - ('gmail_email', models.EmailField(default='your_default_email@example.com', max_length=255, verbose_name='Gmail邮箱')), - ('name', models.CharField(default='默认Gmail', max_length=100, verbose_name='名称')), - ('credentials', models.TextField(blank=True, null=True, verbose_name='凭证JSON')), - ('token_path', models.CharField(blank=True, max_length=255, null=True, verbose_name='令牌路径')), - ('is_default', models.BooleanField(default=False, verbose_name='是否默认')), - ('last_history_id', models.CharField(blank=True, max_length=100, null=True, verbose_name='最后历史ID')), - ('watch_expiration', models.DateTimeField(blank=True, null=True, verbose_name='监听过期时间')), - ('is_active', models.BooleanField(default=True, verbose_name='是否活跃')), - ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), - ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), - ('gmail_credential_id', models.CharField(blank=True, max_length=255, null=True, verbose_name='Gmail凭证ID')), - ('needs_reauth', models.BooleanField(default=False, verbose_name='需要重新授权')), + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('email', models.EmailField(help_text='Gmail email address', max_length=254, unique=True)), + ('credentials', models.TextField(help_text='Serialized OAuth2 credentials (JSON)')), + ('is_default', models.BooleanField(default=False, help_text='Default Gmail account for user')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('is_valid', models.BooleanField(default=True, help_text='Whether the credential is valid')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='gmail_credentials', to=settings.AUTH_USER_MODEL)), ], options={ - 'verbose_name': 'Gmail凭证', - 'verbose_name_plural': 'Gmail凭证', - 'ordering': ['-is_default', '-updated_at'], - 'unique_together': {('user', 'gmail_email')}, - }, - ), - migrations.CreateModel( - name='GmailTalentMapping', - fields=[ - ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), - ('talent_email', models.EmailField(max_length=254, verbose_name='达人邮箱')), - ('conversation_id', models.CharField(max_length=100, verbose_name='对话ID')), - ('is_active', models.BooleanField(default=True, verbose_name='是否激活')), - ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), - ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), - ('knowledge_base', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='gmail_mappings', to='knowledge_base.knowledgebase')), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='gmail_talent_mappings', to=settings.AUTH_USER_MODEL)), - ], - options={ - 'verbose_name': 'Gmail达人映射', - 'verbose_name_plural': 'Gmail达人映射', - 'db_table': 'gmail_talent_mappings', - 'unique_together': {('user', 'talent_email')}, + 'unique_together': {('user', 'email')}, }, ), ] diff --git a/apps/gmail/migrations/0002_gmailconversation_gmailattachment.py b/apps/gmail/migrations/0002_gmailconversation_gmailattachment.py new file mode 100644 index 0000000..0707fc7 --- /dev/null +++ b/apps/gmail/migrations/0002_gmailconversation_gmailattachment.py @@ -0,0 +1,55 @@ +# Generated by Django 5.2 on 2025-05-12 08:22 + +import django.db.models.deletion +import django.utils.timezone +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('gmail', '0001_initial'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='GmailConversation', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_email', models.EmailField(help_text='用户Gmail邮箱', max_length=254)), + ('influencer_email', models.EmailField(help_text='达人Gmail邮箱', max_length=254)), + ('conversation_id', models.CharField(help_text='关联到chat_history的会话ID', max_length=100, unique=True)), + ('title', models.CharField(default='Gmail对话', help_text='对话标题', max_length=100)), + ('last_sync_time', models.DateTimeField(default=django.utils.timezone.now, help_text='最后同步时间')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('is_active', models.BooleanField(default=True)), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='gmail_conversations', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['-updated_at'], + 'unique_together': {('user', 'user_email', 'influencer_email')}, + }, + ), + migrations.CreateModel( + name='GmailAttachment', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('email_message_id', models.CharField(help_text='Gmail邮件ID', max_length=100)), + ('attachment_id', models.CharField(help_text='Gmail附件ID', max_length=100)), + ('filename', models.CharField(help_text='原始文件名', max_length=255)), + ('file_path', models.CharField(help_text='保存在服务器上的路径', max_length=255)), + ('content_type', models.CharField(help_text='MIME类型', max_length=100)), + ('size', models.IntegerField(default=0, help_text='文件大小(字节)')), + ('sender_email', models.EmailField(help_text='发送者邮箱', max_length=254)), + ('chat_message_id', models.CharField(blank=True, help_text='关联到ChatHistory的消息ID', max_length=100, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='attachments', to='gmail.gmailconversation')), + ], + options={ + 'ordering': ['-created_at'], + }, + ), + ] diff --git a/apps/gmail/models.py b/apps/gmail/models.py index e4be75d..058f363 100644 --- a/apps/gmail/models.py +++ b/apps/gmail/models.py @@ -1,72 +1,90 @@ -# apps/gmail/models.py from django.db import models -from django.utils import timezone -import uuid from apps.accounts.models import User -from apps.knowledge_base.models import KnowledgeBase -from apps.chat.models import ChatHistory # 更新导入路径 +import json +import os +from django.utils import timezone class GmailCredential(models.Model): - """Gmail账号凭证""" - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='gmail_credentials') - gmail_email = models.EmailField(verbose_name='Gmail邮箱', max_length=255, default='your_default_email@example.com') - name = models.CharField(verbose_name='名称', max_length=100, default='默认Gmail') - credentials = models.TextField(verbose_name='凭证JSON', blank=True, null=True) - token_path = models.CharField(verbose_name='令牌路径', max_length=255, blank=True, null=True) - is_default = models.BooleanField(verbose_name='是否默认', default=False) - last_history_id = models.CharField(verbose_name='最后历史ID', max_length=100, blank=True, null=True) - watch_expiration = models.DateTimeField(verbose_name='监听过期时间', blank=True, null=True) - is_active = models.BooleanField(verbose_name='是否活跃', default=True) - created_at = models.DateTimeField(verbose_name='创建时间', auto_now_add=True) - updated_at = models.DateTimeField(verbose_name='更新时间', auto_now=True) - gmail_credential_id = models.CharField(verbose_name='Gmail凭证ID', max_length=255, blank=True, null=True) - needs_reauth = models.BooleanField(verbose_name='需要重新授权', default=False) - - def __str__(self): - return f"{self.name} ({self.gmail_email})" - - class Meta: - verbose_name = 'Gmail凭证' - verbose_name_plural = 'Gmail凭证' - unique_together = ('user', 'gmail_email') - ordering = ['-is_default', '-updated_at'] + email = models.EmailField(unique=True, help_text="Gmail email address") + credentials = models.TextField(help_text="Serialized OAuth2 credentials (JSON)") + is_default = models.BooleanField(default=False, help_text="Default Gmail account for user") + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + is_valid = models.BooleanField(default=True, help_text="Whether the credential is valid") -class GmailTalentMapping(models.Model): - """Gmail达人映射关系模型""" - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='gmail_talent_mappings') - talent_email = models.EmailField(verbose_name='达人邮箱') - knowledge_base = models.ForeignKey(KnowledgeBase, on_delete=models.CASCADE, related_name='gmail_mappings') - conversation_id = models.CharField(max_length=100, verbose_name='对话ID') - is_active = models.BooleanField(default=True, verbose_name='是否激活') - created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') - updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + class Meta: + unique_together = ('user', 'email') + + def set_credentials(self, credentials): + self.credentials = json.dumps({ + 'token': credentials.token, + 'refresh_token': credentials.refresh_token, + 'token_uri': credentials.token_uri, + 'client_id': credentials.client_id, + 'client_secret': credentials.client_secret, + 'scopes': credentials.scopes + }) + self.is_valid = True + + def get_credentials(self): + from google.oauth2.credentials import Credentials + creds_data = json.loads(self.credentials) + return Credentials( + token=creds_data['token'], + refresh_token=creds_data['refresh_token'], + token_uri=creds_data['token_uri'], + client_id=creds_data['client_id'], + client_secret=creds_data['client_secret'], + scopes=creds_data['scopes'] + ) + + def __str__(self): + return f"{self.user.username} - {self.email}" + +class GmailConversation(models.Model): + """Gmail对话记录,跟踪用户和达人之间的邮件交互""" + user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='gmail_conversations') + user_email = models.EmailField(help_text="用户Gmail邮箱") + influencer_email = models.EmailField(help_text="达人Gmail邮箱") + conversation_id = models.CharField(max_length=100, unique=True, help_text="关联到chat_history的会话ID") + title = models.CharField(max_length=100, default="Gmail对话", help_text="对话标题") + last_sync_time = models.DateTimeField(default=timezone.now, help_text="最后同步时间") + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + is_active = models.BooleanField(default=True) + + def __str__(self): + return f"{self.user.username}: {self.user_email} - {self.influencer_email}" class Meta: - db_table = 'gmail_talent_mappings' - unique_together = ['user', 'talent_email'] - verbose_name = 'Gmail达人映射' - verbose_name_plural = 'Gmail达人映射' - - def __str__(self): - return f"{self.user.username} - {self.talent_email}" + ordering = ['-updated_at'] + unique_together = ('user', 'user_email', 'influencer_email') class GmailAttachment(models.Model): - """Gmail附件模型""" - id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - chat_message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='gmail_attachments') - gmail_message_id = models.CharField(max_length=100, verbose_name='Gmail消息ID') - filename = models.CharField(max_length=255, verbose_name='文件名') - filepath = models.CharField(max_length=500, verbose_name='文件路径') - mimetype = models.CharField(max_length=100, verbose_name='MIME类型') - filesize = models.IntegerField(default=0, verbose_name='文件大小') - created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + """Gmail附件记录""" + conversation = models.ForeignKey(GmailConversation, on_delete=models.CASCADE, related_name='attachments') + email_message_id = models.CharField(max_length=100, help_text="Gmail邮件ID") + attachment_id = models.CharField(max_length=100, help_text="Gmail附件ID") + filename = models.CharField(max_length=255, help_text="原始文件名") + file_path = models.CharField(max_length=255, help_text="保存在服务器上的路径") + content_type = models.CharField(max_length=100, help_text="MIME类型") + size = models.IntegerField(default=0, help_text="文件大小(字节)") + sender_email = models.EmailField(help_text="发送者邮箱") + chat_message_id = models.CharField(max_length=100, blank=True, null=True, help_text="关联到ChatHistory的消息ID") + created_at = models.DateTimeField(auto_now_add=True) + + def __str__(self): + return f"{self.filename} ({self.size} bytes)" + + def get_file_extension(self): + """获取文件扩展名""" + _, ext = os.path.splitext(self.filename) + return ext.lower() + + def get_absolute_url(self): + """获取文件URL""" + return f"/media/gmail_attachments/{os.path.basename(self.file_path)}" class Meta: - db_table = 'gmail_attachments' - verbose_name = 'Gmail附件' - verbose_name_plural = 'Gmail附件' - - def __str__(self): - return f"{self.filename} ({self.filesize} bytes)" \ No newline at end of file + ordering = ['-created_at'] \ No newline at end of file diff --git a/apps/gmail/serializers.py b/apps/gmail/serializers.py index e69de29..e236ee2 100644 --- a/apps/gmail/serializers.py +++ b/apps/gmail/serializers.py @@ -0,0 +1,69 @@ +from rest_framework import serializers +from .models import GmailCredential +import json + +class GmailCredentialSerializer(serializers.ModelSerializer): + client_secret_json = serializers.JSONField(write_only=True, required=False, allow_null=True) + client_secret_file = serializers.FileField(write_only=True, required=False, allow_null=True) + auth_code = serializers.CharField(write_only=True, required=False, allow_blank=True) + email = serializers.EmailField(required=False, allow_blank=True) # Make email optional + + class Meta: + model = GmailCredential + fields = ['id', 'email', 'is_default', 'created_at', 'updated_at', 'is_valid', + 'client_secret_json', 'client_secret_file', 'auth_code'] + read_only_fields = ['created_at', 'updated_at', 'is_valid'] + + def validate(self, data): + """Validate client_secret input (either JSON or file).""" + client_secret_json = data.get('client_secret_json') + client_secret_file = data.get('client_secret_file') + auth_code = data.get('auth_code') + + # For auth initiation, only client_secret is required + if not auth_code: # Initiation phase + if not client_secret_json and not client_secret_file: + raise serializers.ValidationError( + "Either client_secret_json or client_secret_file is required." + ) + if client_secret_json and client_secret_file: + raise serializers.ValidationError( + "Provide only one of client_secret_json or client_secret_file." + ) + + # For auth completion, both auth_code and client_secret are required + if auth_code and not (client_secret_json or client_secret_file): + raise serializers.ValidationError( + "client_secret_json or client_secret_file is required with auth_code." + ) + + # Parse client_secret_json if provided + if client_secret_json: + try: + json.dumps(client_secret_json) + except (TypeError, ValueError): + raise serializers.ValidationError("client_secret_json must be valid JSON.") + + # Parse client_secret_file if provided + if client_secret_file: + try: + content = client_secret_file.read().decode('utf-8') + client_secret_json = json.loads(content) + data['client_secret_json'] = client_secret_json + except (json.JSONDecodeError, UnicodeDecodeError): + raise serializers.ValidationError("client_secret_file must contain valid JSON.") + + return data + + def validate_email(self, value): + """Ensure email is unique for the user (only for completion).""" + if not value: # Email is optional during initiation + return value + user = self.context['request'].user + if self.instance: # Update case + if GmailCredential.objects.filter(user=user, email=value).exclude(id=self.instance.id).exists(): + raise serializers.ValidationError("This Gmail account is already added.") + else: # Create case + if GmailCredential.objects.filter(user=user, email=value).exists(): + raise serializers.ValidationError("This Gmail account is already added.") + return value \ No newline at end of file diff --git a/apps/gmail/services/__init__.py b/apps/gmail/services/__init__.py index e69de29..b28b04f 100644 --- a/apps/gmail/services/__init__.py +++ b/apps/gmail/services/__init__.py @@ -0,0 +1,3 @@ + + + diff --git a/apps/gmail/services/gmail_service.py b/apps/gmail/services/gmail_service.py new file mode 100644 index 0000000..2a60b10 --- /dev/null +++ b/apps/gmail/services/gmail_service.py @@ -0,0 +1,1084 @@ +import os +import json +import logging +import base64 +import email +from email.utils import parseaddr +import datetime +import shutil +import uuid +from google_auth_oauthlib.flow import InstalledAppFlow +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError +from django.conf import settings +from django.utils import timezone +from django.db import transaction +from ..models import GmailCredential, GmailConversation, GmailAttachment +from apps.chat.models import ChatHistory +from apps.knowledge_base.models import KnowledgeBase +import requests +from google.cloud import pubsub_v1 +from apps.common.services.notification_service import NotificationService +import threading +import time +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from email.mime.application import MIMEApplication +from email.header import Header +import mimetypes + +# 配置日志记录器 +logger = logging.getLogger(__name__) + +# 全局设置环境变量代理,用于 HTTP 和 HTTPS 请求 +proxy_url = getattr(settings, 'PROXY_URL', None) +if proxy_url: + os.environ['HTTP_PROXY'] = proxy_url + os.environ['HTTPS_PROXY'] = proxy_url + logger.info(f"Gmail服务已设置全局代理环境变量: {proxy_url}") + + +class GmailService: + # 定义 Gmail API 所需的 OAuth 2.0 权限范围 + SCOPES = ['https://www.googleapis.com/auth/gmail.modify', 'https://www.googleapis.com/auth/pubsub'] + # 定义存储临时客户端密钥文件的目录 + TOKEN_DIR = os.path.join(settings.BASE_DIR, 'gmail_tokens') + # 附件存储目录 + ATTACHMENT_DIR = os.path.join(settings.BASE_DIR, 'media', 'gmail_attachments') + # Gmail 监听 Pub/Sub 主题和订阅 + PUBSUB_TOPIC = getattr(settings, 'GMAIL_PUBSUB_TOPIC', 'projects/{project_id}/topics/gmail-notifications') + PUBSUB_SUBSCRIPTION = getattr(settings, 'GMAIL_PUBSUB_SUBSCRIPTION', 'projects/{project_id}/subscriptions/gmail-notifications-sub') + + @staticmethod + def initiate_authentication(user, client_secret_json): + """ + 启动 Gmail API 的 OAuth 2.0 认证流程,生成授权 URL。 + + Args: + user: Django 用户对象,用于关联认证。 + client_secret_json: 包含客户端密钥的 JSON 字典,通常由 Google Cloud Console 获取。 + + Returns: + str: 授权 URL,用户需访问该 URL 进行认证并获取授权代码。 + + Raises: + Exception: 如果创建临时文件、生成授权 URL 或其他操作失败。 + """ + try: + # 确保临时文件目录存在 + os.makedirs(GmailService.TOKEN_DIR, exist_ok=True) + # 创建临时客户端密钥文件路径,基于用户 ID 避免冲突 + temp_client_secret_path = os.path.join(GmailService.TOKEN_DIR, f'client_secret_{user.id}.json') + + # 将客户端密钥 JSON 写入临时文件 + with open(temp_client_secret_path, 'w') as f: + json.dump(client_secret_json, f) + + # 初始化 OAuth 2.0 流程,使用临时密钥文件和指定权限范围 + # 代理通过环境变量自动应用 + flow = InstalledAppFlow.from_client_secrets_file( + temp_client_secret_path, + scopes=GmailService.SCOPES, + redirect_uri='urn:ietf:wg:oauth:2.0:oob' # 使用 OOB 流程,适合非 Web 应用 + ) + + # 生成授权 URL,强制用户同意权限 + auth_url, _ = flow.authorization_url(prompt='consent') + logger.info(f"Generated auth URL for user {user.id}: {auth_url}") + return auth_url + except Exception as e: + # 记录错误并抛出异常 + logger.error(f"Error initiating Gmail authentication for user {user.id}: {str(e)}") + raise + finally: + # 清理临时文件,确保不留下敏感信息 + if os.path.exists(temp_client_secret_path): + os.remove(temp_client_secret_path) + + @staticmethod + def complete_authentication(user, auth_code, client_secret_json): + """ + 完成 Gmail API 的 OAuth 2.0 认证流程,使用授权代码获取凭证并保存。 + + Args: + user: Django 用户对象,用于关联凭证。 + auth_code: 用户在授权 URL 页面获取的授权代码。 + client_secret_json: 包含客户端密钥的 JSON 字典。 + + Returns: + GmailCredential: 保存的 Gmail 凭证对象,包含用户邮箱和认证信息。 + + Raises: + HttpError: 如果 Gmail API 请求失败(如无效授权代码)。 + Exception: 如果保存凭证或文件操作失败。 + """ + try: + # 创建临时客户端密钥文件路径 + temp_client_secret_path = os.path.join(GmailService.TOKEN_DIR, f'client_secret_{user.id}.json') + with open(temp_client_secret_path, 'w') as f: + json.dump(client_secret_json, f) + + # 初始化 OAuth 2.0 流程,代理通过环境变量自动应用 + flow = InstalledAppFlow.from_client_secrets_file( + temp_client_secret_path, + scopes=GmailService.SCOPES, + redirect_uri='urn:ietf:wg:oauth:2.0:oob' + ) + + # 使用授权代码获取访问令牌和刷新令牌 + flow.fetch_token(code=auth_code) + credentials = flow.credentials + + # 创建 Gmail API 服务并获取用户邮箱 + service = build('gmail', 'v1', credentials=credentials) + profile = service.users().getProfile(userId='me').execute() + email = profile['emailAddress'] + + # 保存或更新 Gmail 凭证到数据库 + credential, created = GmailCredential.objects.get_or_create( + user=user, + email=email, + defaults={'is_default': not GmailCredential.objects.filter(user=user).exists()} + ) + credential.set_credentials(credentials) + credential.save() + + # 确保只有一个默认凭证 + if credential.is_default: + GmailCredential.objects.filter(user=user).exclude(id=credential.id).update(is_default=False) + + logger.info(f"Gmail credential saved for user {user.id}, email: {email}") + return credential + except HttpError as e: + # 记录 Gmail API 错误并抛出 + logger.error(f"Gmail API error for user {user.id}: {str(e)}") + raise + except Exception as e: + # 记录其他错误并抛出 + logger.error(f"Error completing Gmail authentication for user {user.id}: {str(e)}") + raise + finally: + # 清理临时文件 + if os.path.exists(temp_client_secret_path): + os.remove(temp_client_secret_path) + + @staticmethod + def get_service(credential): + """ + 使用存储的凭证创建 Gmail API 服务实例。 + + Args: + credential: GmailCredential 对象,包含用户的 OAuth 2.0 凭证。 + + Returns: + googleapiclient.discovery.Resource: Gmail API 服务实例,用于后续 API 调用。 + + Raises: + Exception: 如果凭证无效或创建服务失败。 + """ + try: + # 从数据库凭证中获取 Google API 凭证对象 + credentials = credential.get_credentials() + + # 创建 Gmail API 服务,代理通过环境变量自动应用 + return build('gmail', 'v1', credentials=credentials) + + except Exception as e: + # 记录错误并抛出 + logger.error(f"Error creating Gmail service: {str(e)}") + raise + + @staticmethod + def get_conversations(user, user_email, influencer_email): + """ + 获取用户和达人之间的Gmail对话 + + Args: + user: 当前用户对象 + user_email: 用户的Gmail邮箱 (已授权) + influencer_email: 达人的Gmail邮箱 + + Returns: + tuple: (对话列表, 错误信息) + """ + try: + # 确保附件目录存在 + os.makedirs(GmailService.ATTACHMENT_DIR, exist_ok=True) + + # 获取凭证 + credential = GmailCredential.objects.filter(user=user, email=user_email).first() + if not credential: + return None, f"未找到{user_email}的授权信息" + + # 获取Gmail服务 + service = GmailService.get_service(credential) + + # 构建搜索查询 - 查找与达人的所有邮件往来 + query = f"from:({user_email} OR {influencer_email}) to:({user_email} OR {influencer_email})" + logger.info(f"Gmail搜索查询: {query}") + + # 获取满足条件的所有邮件 + response = service.users().messages().list(userId='me', q=query).execute() + messages = [] + + if 'messages' in response: + messages.extend(response['messages']) + + # 如果有更多页,继续获取 + while 'nextPageToken' in response: + page_token = response['nextPageToken'] + response = service.users().messages().list( + userId='me', + q=query, + pageToken=page_token + ).execute() + messages.extend(response['messages']) + + logger.info(f"找到 {len(messages)} 封邮件") + + # 获取每封邮件的详细内容 + conversations = [] + for msg in messages: + try: + message = service.users().messages().get(userId='me', id=msg['id']).execute() + email_data = GmailService._parse_email_content(message) + if email_data: + conversations.append(email_data) + except Exception as e: + logger.error(f"处理邮件 {msg['id']} 时出错: {str(e)}") + + # 按时间排序 + conversations.sort(key=lambda x: x['date']) + + return conversations, None + + except Exception as e: + logger.error(f"获取Gmail对话失败: {str(e)}") + return None, f"获取Gmail对话失败: {str(e)}" + + @staticmethod + def _parse_email_content(message): + """ + 解析邮件内容 + + Args: + message: Gmail API返回的邮件对象 + + Returns: + dict: 邮件内容字典 + """ + try: + message_id = message['id'] + payload = message['payload'] + headers = payload['headers'] + + # 提取基本信息 + email_data = { + 'id': message_id, + 'subject': '', + 'from': '', + 'from_email': '', + 'to': '', + 'to_email': '', + 'date': '', + 'body': '', + 'attachments': [] + } + + # 提取邮件头信息 + for header in headers: + name = header['name'].lower() + if name == 'subject': + email_data['subject'] = header['value'] + elif name == 'from': + email_data['from'] = header['value'] + _, email_data['from_email'] = parseaddr(header['value']) + elif name == 'to': + email_data['to'] = header['value'] + _, email_data['to_email'] = parseaddr(header['value']) + elif name == 'date': + try: + date_value = header['value'] + # 解析日期格式并转换为标准格式 + date_obj = email.utils.parsedate_to_datetime(date_value) + email_data['date'] = date_obj.strftime('%Y-%m-%d %H:%M:%S') + except Exception as e: + logger.error(f"解析日期失败: {str(e)}") + email_data['date'] = header['value'] + + # 处理邮件正文和附件 + GmailService._process_email_parts(payload, email_data) + + return email_data + + except Exception as e: + logger.error(f"解析邮件内容失败: {str(e)}") + return None + + @staticmethod + def _process_email_parts(part, email_data, is_root=True): + """ + 递归处理邮件部分,提取正文和附件 + + Args: + part: 邮件部分 + email_data: 邮件数据字典 + is_root: 是否为根部分 + """ + if 'parts' in part: + for sub_part in part['parts']: + GmailService._process_email_parts(sub_part, email_data, False) + + # 处理附件 + if not is_root and 'filename' in part.get('body', {}) and part.get('filename'): + attachment = { + 'filename': part.get('filename', ''), + 'mimeType': part.get('mimeType', ''), + 'size': part['body'].get('size', 0) + } + + if 'attachmentId' in part['body']: + attachment['attachmentId'] = part['body']['attachmentId'] + + email_data['attachments'].append(attachment) + + # 处理正文 + mime_type = part.get('mimeType', '') + if mime_type == 'text/plain' and 'data' in part.get('body', {}): + data = part['body'].get('data', '') + if data: + try: + text = base64.urlsafe_b64decode(data).decode('utf-8') + email_data['body'] = text + except Exception as e: + logger.error(f"解码邮件正文失败: {str(e)}") + + @staticmethod + def download_attachment(user, gmail_credential, message_id, attachment_id, filename): + """ + 下载邮件附件 + + Args: + user: 当前用户 + gmail_credential: Gmail凭证 + message_id: 邮件ID + attachment_id: 附件ID + filename: 文件名 + + Returns: + str: 保存的文件路径 + """ + try: + service = GmailService.get_service(gmail_credential) + + attachment = service.users().messages().attachments().get( + userId='me', + messageId=message_id, + id=attachment_id + ).execute() + + data = attachment['data'] + file_data = base64.urlsafe_b64decode(data) + + # 安全处理文件名 + safe_filename = GmailService._safe_filename(filename) + timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + unique_filename = f"{user.id}_{timestamp}_{safe_filename}" + + # 保存附件 + filepath = os.path.join(GmailService.ATTACHMENT_DIR, unique_filename) + with open(filepath, 'wb') as f: + f.write(file_data) + + logger.info(f"附件已保存: {filepath}") + return filepath + + except Exception as e: + logger.error(f"下载附件失败: {str(e)}") + return None + + @staticmethod + def _safe_filename(filename): + """ + 生成安全的文件名 + + Args: + filename: 原始文件名 + + Returns: + str: 安全的文件名 + """ + # 替换不安全字符 + unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|'] + for char in unsafe_chars: + filename = filename.replace(char, '_') + + # 确保文件名长度合理 + if len(filename) > 100: + base, ext = os.path.splitext(filename) + filename = base[:100] + ext + + return filename + + @staticmethod + @transaction.atomic + def save_conversations_to_chat(user, user_email, influencer_email, kb_id=None): + """ + 保存Gmail对话到聊天记录 + + Args: + user: 当前用户 + user_email: 用户Gmail邮箱 + influencer_email: 达人Gmail邮箱 + kb_id: 知识库ID (可选) + + Returns: + tuple: (对话ID, 错误信息) + """ + try: + # 获取Gmail凭证 + gmail_credential = GmailCredential.objects.filter(user=user, email=user_email).first() + if not gmail_credential: + return None, f"未找到{user_email}的授权信息" + + # 获取对话 + conversations, error = GmailService.get_conversations(user, user_email, influencer_email) + if error: + return None, error + + if not conversations: + return None, "未找到与该达人的对话记录" + + # 获取或创建默认知识库 + if not kb_id: + knowledge_base = KnowledgeBase.objects.filter(user_id=user.id, type='private').first() + if not knowledge_base: + return None, "未找到默认知识库,请先创建一个知识库" + else: + knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first() + if not knowledge_base: + return None, f"知识库ID {kb_id} 不存在" + + # 创建会话ID + conversation_id = f"gmail_{user.id}_{str(uuid.uuid4())[:8]}" + + # 创建或更新Gmail对话记录 + gmail_conversation, created = GmailConversation.objects.get_or_create( + user=user, + user_email=user_email, + influencer_email=influencer_email, + defaults={ + 'conversation_id': conversation_id, + 'title': f"与 {influencer_email} 的Gmail对话", + 'last_sync_time': timezone.now() + } + ) + + if not created: + # 使用现有的会话ID + conversation_id = gmail_conversation.conversation_id + gmail_conversation.last_sync_time = timezone.now() + gmail_conversation.save() + + # 逐个保存邮件到聊天历史 + chat_messages = [] + for email_data in conversations: + # 确定发送者角色 (user 或 assistant) + is_from_user = email_data['from_email'].lower() == user_email.lower() + role = 'user' if is_from_user else 'assistant' + + # 准备内容文本 + content = f"主题: {email_data['subject']}\n\n{email_data['body']}" + + # 创建聊天消息 + chat_message = ChatHistory.objects.create( + user=user, + knowledge_base=knowledge_base, + conversation_id=conversation_id, + title=gmail_conversation.title, + role=role, + content=content, + metadata={ + 'gmail_message_id': email_data['id'], + 'from': email_data['from'], + 'to': email_data['to'], + 'date': email_data['date'], + 'source': 'gmail' + } + ) + + chat_messages.append(chat_message) + + # 处理附件 + if email_data['attachments']: + for attachment in email_data['attachments']: + if 'attachmentId' in attachment: + # 下载附件 + file_path = GmailService.download_attachment( + user, + gmail_credential, + email_data['id'], + attachment['attachmentId'], + attachment['filename'] + ) + + if file_path: + # 保存附件记录 + gmail_attachment = GmailAttachment.objects.create( + conversation=gmail_conversation, + email_message_id=email_data['id'], + attachment_id=attachment['attachmentId'], + filename=attachment['filename'], + file_path=file_path, + content_type=attachment['mimeType'], + size=attachment['size'], + sender_email=email_data['from_email'], + chat_message_id=str(chat_message.id) + ) + + # 更新聊天消息,添加附件信息 + metadata = chat_message.metadata or {} + if 'attachments' not in metadata: + metadata['attachments'] = [] + + metadata['attachments'].append({ + 'id': str(gmail_attachment.id), + 'filename': attachment['filename'], + 'size': attachment['size'], + 'mime_type': attachment['mimeType'], + 'url': gmail_attachment.get_absolute_url() + }) + + chat_message.metadata = metadata + chat_message.save() + + return conversation_id, None + + except Exception as e: + logger.error(f"保存Gmail对话到聊天记录失败: {str(e)}") + return None, f"保存Gmail对话到聊天记录失败: {str(e)}" + + @staticmethod + def setup_gmail_push_notification(user, user_email, topic_name=None, subscription_name=None): + """ + 为Gmail账户设置Pub/Sub推送通知 + + Args: + user: 当前用户 + user_email: 用户Gmail邮箱 + topic_name: 自定义主题名称 (可选) + subscription_name: 自定义订阅名称 (可选) + + Returns: + tuple: (成功标志, 错误信息) + """ + try: + # 获取Gmail凭证 + credential = GmailCredential.objects.filter(user=user, email=user_email).first() + if not credential: + return False, f"未找到{user_email}的授权信息" + + # 获取Gmail服务 + service = GmailService.get_service(credential) + + # 设置Pub/Sub主题和订阅名称 + project_id = getattr(settings, 'GOOGLE_CLOUD_PROJECT_ID', '') + if not project_id: + return False, "未配置Google Cloud项目ID" + + topic = topic_name or GmailService.PUBSUB_TOPIC.format(project_id=project_id) + subscription = subscription_name or GmailService.PUBSUB_SUBSCRIPTION.format(project_id=project_id) + + # 为Gmail账户启用推送通知 + request = { + 'labelIds': ['INBOX'], + 'topicName': topic + } + + try: + # 先停止现有的监听 + service.users().stop(userId='me').execute() + logger.info(f"已停止现有的监听: {user_email}") + except: + pass + + # 启动新的监听 + service.users().watch(userId='me', body=request).execute() + logger.info(f"已为 {user_email} 设置Gmail推送通知,主题: {topic}") + + return True, None + + except Exception as e: + logger.error(f"设置Gmail推送通知失败: {str(e)}") + return False, f"设置Gmail推送通知失败: {str(e)}" + + @staticmethod + def start_pubsub_listener(user_id=None): + """ + 启动Pub/Sub监听器,监听Gmail新消息通知 + + Args: + user_id: 指定要监听的用户ID (可选,如果不指定则监听所有用户) + + Returns: + None + """ + try: + project_id = getattr(settings, 'GOOGLE_CLOUD_PROJECT_ID', '') + if not project_id: + logger.error("未配置Google Cloud项目ID,无法启动Gmail监听") + return + + subscription_name = GmailService.PUBSUB_SUBSCRIPTION.format(project_id=project_id) + + # 创建订阅者客户端 + subscriber = pubsub_v1.SubscriberClient() + subscription_path = subscriber.subscription_path(project_id, subscription_name.split('/')[-1]) + + def callback(message): + """处理接收到的Pub/Sub消息""" + try: + # 解析消息数据 + data = json.loads(message.data.decode('utf-8')) + logger.info(f"接收到Gmail推送通知: {data}") + + # 确认消息已处理 + message.ack() + + # 获取关键信息 + email_address = data.get('emailAddress') + history_id = data.get('historyId') + + if not email_address or not history_id: + logger.error("推送通知缺少必要信息") + return + + # 获取用户凭证 + query = GmailCredential.objects.filter(email=email_address) + if user_id: + query = query.filter(user_id=user_id) + + credential = query.first() + if not credential: + logger.error(f"未找到匹配的Gmail凭证: {email_address}") + return + + # 处理新收到的邮件 + GmailService.process_new_emails(credential.user, credential, history_id) + + except Exception as e: + logger.error(f"处理Gmail推送通知失败: {str(e)}") + # 确认消息,避免重复处理 + message.ack() + + # 设置订阅流 + streaming_pull_future = subscriber.subscribe(subscription_path, callback=callback) + logger.info(f"Gmail Pub/Sub监听器已启动: {subscription_path}") + + # 保持监听状态 + try: + streaming_pull_future.result() + except Exception as e: + streaming_pull_future.cancel() + logger.error(f"Gmail Pub/Sub监听器中断: {str(e)}") + + except Exception as e: + logger.error(f"启动Gmail Pub/Sub监听器失败: {str(e)}") + + @staticmethod + def start_pubsub_listener_thread(user_id=None): + """在后台线程中启动Pub/Sub监听器""" + t = threading.Thread(target=GmailService.start_pubsub_listener, args=(user_id,)) + t.daemon = True + t.start() + return t + + @staticmethod + def process_new_emails(user, credential, history_id): + """ + 处理新收到的邮件 + + Args: + user: 用户对象 + credential: Gmail凭证对象 + history_id: Gmail历史记录ID + + Returns: + None + """ + try: + # 获取Gmail服务 + service = GmailService.get_service(credential) + + # 获取历史记录变更 + history_results = service.users().history().list( + userId='me', + startHistoryId=history_id, + historyTypes=['messageAdded'] + ).execute() + + if 'history' not in history_results: + return + + # 获取活跃对话 + active_conversations = GmailConversation.objects.filter( + user=user, + user_email=credential.email, + is_active=True + ) + + influencer_emails = [conv.influencer_email for conv in active_conversations] + if not influencer_emails: + logger.info(f"用户 {user.username} 没有活跃的Gmail对话") + return + + # 处理每个历史变更 + for history in history_results.get('history', []): + for message_added in history.get('messagesAdded', []): + message_id = message_added.get('message', {}).get('id') + if not message_id: + continue + + # 获取完整邮件内容 + message = service.users().messages().get(userId='me', id=message_id).execute() + email_data = GmailService._parse_email_content(message) + + if not email_data: + continue + + # 检查是否是来自达人的邮件 + if email_data['from_email'] in influencer_emails: + # 查找相关对话 + conversation = active_conversations.filter( + influencer_email=email_data['from_email'] + ).first() + + if conversation: + # 将新邮件保存到聊天历史 + GmailService._save_email_to_chat( + user, + credential, + conversation, + email_data + ) + + # 发送通知 + NotificationService().send_notification( + user=user, + title="收到新邮件", + content=f"您收到来自 {email_data['from_email']} 的新邮件: {email_data['subject']}", + notification_type="gmail", + related_object_id=conversation.conversation_id + ) + + logger.info(f"已处理来自 {email_data['from_email']} 的新邮件") + + except Exception as e: + logger.error(f"处理Gmail新消息失败: {str(e)}") + + @staticmethod + def _save_email_to_chat(user, credential, conversation, email_data): + """ + 保存一封邮件到聊天历史 + + Args: + user: 用户对象 + credential: Gmail凭证对象 + conversation: Gmail对话对象 + email_data: 邮件数据 + + Returns: + bool: 成功标志 + """ + try: + # 查找关联的知识库 + first_message = ChatHistory.objects.filter( + conversation_id=conversation.conversation_id + ).first() + + if not first_message: + knowledge_base = KnowledgeBase.objects.filter(user_id=user.id, type='private').first() + if not knowledge_base: + logger.error("未找到默认知识库") + return False + else: + knowledge_base = first_message.knowledge_base + + # 确定发送者角色 (user 或 assistant) + is_from_user = email_data['from_email'].lower() == credential.email.lower() + role = 'user' if is_from_user else 'assistant' + + # 准备内容文本 + content = f"主题: {email_data['subject']}\n\n{email_data['body']}" + + # 创建聊天消息 + chat_message = ChatHistory.objects.create( + user=user, + knowledge_base=knowledge_base, + conversation_id=conversation.conversation_id, + title=conversation.title, + role=role, + content=content, + metadata={ + 'gmail_message_id': email_data['id'], + 'from': email_data['from'], + 'to': email_data['to'], + 'date': email_data['date'], + 'source': 'gmail' + } + ) + + # 更新对话的同步时间 + conversation.last_sync_time = timezone.now() + conversation.save() + + # 处理附件 + if email_data['attachments']: + for attachment in email_data['attachments']: + if 'attachmentId' in attachment: + # 下载附件 + file_path = GmailService.download_attachment( + user, + credential, + email_data['id'], + attachment['attachmentId'], + attachment['filename'] + ) + + if file_path: + # 保存附件记录 + gmail_attachment = GmailAttachment.objects.create( + conversation=conversation, + email_message_id=email_data['id'], + attachment_id=attachment['attachmentId'], + filename=attachment['filename'], + file_path=file_path, + content_type=attachment['mimeType'], + size=attachment['size'], + sender_email=email_data['from_email'], + chat_message_id=str(chat_message.id) + ) + + # 更新聊天消息,添加附件信息 + metadata = chat_message.metadata or {} + if 'attachments' not in metadata: + metadata['attachments'] = [] + + metadata['attachments'].append({ + 'id': str(gmail_attachment.id), + 'filename': attachment['filename'], + 'size': attachment['size'], + 'mime_type': attachment['mimeType'], + 'url': gmail_attachment.get_absolute_url() + }) + + chat_message.metadata = metadata + chat_message.save() + + return True + + except Exception as e: + logger.error(f"保存Gmail新邮件到聊天记录失败: {str(e)}") + return False + + @staticmethod + def send_email(user, user_email, to_email, subject, body, attachments=None): + """ + 发送Gmail邮件,支持附件 + + Args: + user: 用户对象 + user_email: 发件人Gmail邮箱(已授权) + to_email: 收件人邮箱 + subject: 邮件主题 + body: 邮件正文 + attachments: 附件列表,格式为 [{'path': '本地文件路径', 'filename': '文件名称(可选)'}] + + Returns: + tuple: (成功标志, 消息ID或错误信息) + """ + try: + # 获取凭证 + credential = GmailCredential.objects.filter(user=user, email=user_email).first() + if not credential: + return False, f"未找到{user_email}的授权信息" + + # 获取Gmail服务 + service = GmailService.get_service(credential) + + # 创建邮件 + message = MIMEMultipart() + message['to'] = to_email + message['from'] = user_email + message['subject'] = Header(subject, 'utf-8').encode() + + # 添加正文 + text_part = MIMEText(body, 'plain', 'utf-8') + message.attach(text_part) + + # 添加附件 + if attachments and isinstance(attachments, list): + for attachment in attachments: + if 'path' not in attachment: + continue + + filepath = attachment['path'] + filename = attachment.get('filename', os.path.basename(filepath)) + + if not os.path.exists(filepath): + logger.warning(f"附件文件不存在: {filepath}") + continue + + # 确定MIME类型 + content_type, encoding = mimetypes.guess_type(filepath) + if content_type is None: + content_type = 'application/octet-stream' + + main_type, sub_type = content_type.split('/', 1) + + try: + with open(filepath, 'rb') as f: + file_data = f.read() + + file_part = MIMEApplication(file_data, Name=filename) + file_part['Content-Disposition'] = f'attachment; filename="{filename}"' + message.attach(file_part) + logger.info(f"已添加附件: {filename}") + except Exception as e: + logger.error(f"处理附件时出错: {str(e)}") + + # 编码邮件为base64 + raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode('utf-8') + + # 发送邮件 + result = service.users().messages().send( + userId='me', + body={'raw': raw_message} + ).execute() + + message_id = result.get('id') + + # 查找或创建对话 + conversation = None + try: + # 查找现有对话 + conversation = GmailConversation.objects.filter( + user=user, + user_email=user_email, + influencer_email=to_email + ).first() + + if not conversation: + # 创建新对话 + conversation_id = f"gmail_{user.id}_{str(uuid.uuid4())[:8]}" + conversation = GmailConversation.objects.create( + user=user, + user_email=user_email, + influencer_email=to_email, + conversation_id=conversation_id, + title=f"与 {to_email} 的Gmail对话", + last_sync_time=timezone.now() + ) + else: + # 更新最后同步时间 + conversation.last_sync_time = timezone.now() + conversation.save() + + # 保存到聊天历史 + if conversation: + # 获取知识库 + knowledge_base = KnowledgeBase.objects.filter(user_id=user.id, type='private').first() + if not knowledge_base: + logger.warning(f"未找到默认知识库,邮件发送成功但未保存到聊天记录") + else: + # 创建聊天消息 + chat_message = ChatHistory.objects.create( + user=user, + knowledge_base=knowledge_base, + conversation_id=conversation.conversation_id, + title=conversation.title, + role='user', + content=f"主题: {subject}\n\n{body}", + metadata={ + 'gmail_message_id': message_id, + 'from': user_email, + 'to': to_email, + 'date': timezone.now().strftime('%Y-%m-%d %H:%M:%S'), + 'source': 'gmail' + } + ) + + # 如果有附件,保存附件信息 + if attachments and isinstance(attachments, list): + metadata = chat_message.metadata or {} + if 'attachments' not in metadata: + metadata['attachments'] = [] + + for attachment in attachments: + if 'path' not in attachment: + continue + + filepath = attachment['path'] + filename = attachment.get('filename', os.path.basename(filepath)) + + if not os.path.exists(filepath): + continue + + # 复制附件到Gmail附件目录 + try: + # 确保目录存在 + os.makedirs(GmailService.ATTACHMENT_DIR, exist_ok=True) + + # 生成唯一文件名 + unique_filename = f"{uuid.uuid4()}_{filename}" + target_path = os.path.join(GmailService.ATTACHMENT_DIR, unique_filename) + + # 复制文件 + shutil.copy2(filepath, target_path) + + # 获取文件大小和类型 + filesize = os.path.getsize(filepath) + content_type, _ = mimetypes.guess_type(filepath) + if content_type is None: + content_type = 'application/octet-stream' + + # 创建附件记录 + gmail_attachment = GmailAttachment.objects.create( + conversation=conversation, + email_message_id=message_id, + attachment_id=f"outgoing_{uuid.uuid4()}", + filename=filename, + file_path=target_path, + content_type=content_type, + size=filesize, + sender_email=user_email, + chat_message_id=str(chat_message.id) + ) + + # 更新聊天消息,添加附件信息 + metadata['attachments'].append({ + 'id': str(gmail_attachment.id), + 'filename': filename, + 'size': filesize, + 'mime_type': content_type, + 'url': gmail_attachment.get_absolute_url() + }) + except Exception as e: + logger.error(f"处理发送邮件附件时出错: {str(e)}") + + # 保存更新的元数据 + if metadata['attachments']: + chat_message.metadata = metadata + chat_message.save() + except Exception as e: + logger.error(f"保存发送的邮件到聊天记录失败: {str(e)}") + + logger.info(f"成功发送邮件到 {to_email}") + return True, message_id + + except Exception as e: + logger.error(f"发送Gmail邮件失败: {str(e)}") + return False, f"发送Gmail邮件失败: {str(e)}" + + + \ No newline at end of file diff --git a/apps/gmail/urls.py b/apps/gmail/urls.py index e69de29..64ec33d 100644 --- a/apps/gmail/urls.py +++ b/apps/gmail/urls.py @@ -0,0 +1,27 @@ +from django.urls import path +from .views import ( + GmailAuthInitiateView, + GmailAuthCompleteView, + GmailCredentialListView, + GmailCredentialDetailView, + GmailConversationView, + GmailAttachmentListView, + GmailPubSubView, + GmailNotificationStartView, + GmailSendEmailView +) + +app_name = 'gmail' + +urlpatterns = [ + path('auth/initiate/', GmailAuthInitiateView.as_view(), name='auth_initiate'), + path('auth/complete/', GmailAuthCompleteView.as_view(), name='auth_complete'), + path('credentials/', GmailCredentialListView.as_view(), name='credential_list'), + path('credentials//', GmailCredentialDetailView.as_view(), name='credential_detail'), + path('conversations/', GmailConversationView.as_view(), name='conversation_list'), + path('attachments/', GmailAttachmentListView.as_view(), name='attachment_list'), + path('attachments//', GmailAttachmentListView.as_view(), name='attachment_list_by_conversation'), + path('notifications/setup/', GmailPubSubView.as_view(), name='pubsub_setup'), + path('notifications/start/', GmailNotificationStartView.as_view(), name='notification_start'), + path('send/', GmailSendEmailView.as_view(), name='send_email'), +] \ No newline at end of file diff --git a/apps/gmail/views.py b/apps/gmail/views.py index 91ea44a..0fd7f26 100644 --- a/apps/gmail/views.py +++ b/apps/gmail/views.py @@ -1,3 +1,598 @@ -from django.shortcuts import render +from rest_framework.views import APIView +from rest_framework.response import Response +from rest_framework.permissions import IsAuthenticated +from rest_framework import status +from .serializers import GmailCredentialSerializer +from .services.gmail_service import GmailService +from .models import GmailCredential, GmailConversation, GmailAttachment +from django.shortcuts import get_object_or_404 +import logging +import os +from django.conf import settings +from django.core.files.storage import default_storage +from django.core.files.base import ContentFile -# Create your views here. +# 配置日志记录器,用于记录视图操作的调试、警告和错误信息 +logger = logging.getLogger(__name__) + +class GmailAuthInitiateView(APIView): + """ + API 视图,用于启动 Gmail OAuth2 认证流程。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def post(self, request): + """ + 处理 POST 请求,启动 Gmail OAuth2 认证并返回授权 URL。 + + Args: + request: Django REST Framework 请求对象,包含客户端密钥 JSON 数据。 + + Returns: + Response: 包含授权 URL 的 JSON 响应(成功时),或错误信息(失败时)。 + + Status Codes: + 200: 成功生成授权 URL。 + 400: 请求数据无效。 + 500: 服务器内部错误(如认证服务失败)。 + """ + logger.debug(f"Received auth initiate request: {request.data}") + serializer = GmailCredentialSerializer(data=request.data, context={'request': request}) + if serializer.is_valid(): + try: + # 从请求数据中提取客户端密钥 JSON + client_secret_json = serializer.validated_data['client_secret_json'] + # 调用 GmailService 生成授权 URL + auth_url = GmailService.initiate_authentication(request.user, client_secret_json) + logger.info(f"Generated auth URL for user {request.user.id}") + return Response({'auth_url': auth_url}, status=status.HTTP_200_OK) + except Exception as e: + # 记录错误并返回服务器错误响应 + logger.error(f"Error initiating authentication for user {request.user.id}: {str(e)}") + return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + # 记录无效请求数据并返回错误响应 + logger.warning(f"Invalid request data: {serializer.errors}") + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class GmailAuthCompleteView(APIView): + """ + API 视图,用于完成 Gmail OAuth2 认证流程。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def post(self, request): + """ + 处理 POST 请求,使用授权代码完成 Gmail OAuth2 认证并保存凭证。 + + Args: + request: Django REST Framework 请求对象,包含授权代码和客户端密钥 JSON。 + + Returns: + Response: 包含已保存凭证数据的 JSON 响应(成功时),或错误信息(失败时)。 + + Status Codes: + 201: 成功保存凭证。 + 400: 请求数据无效。 + 500: 服务器内部错误(如认证失败)。 + """ + logger.debug(f"Received auth complete request: {request.data}") + serializer = GmailCredentialSerializer(data=request.data, context={'request': request}) + if serializer.is_valid(): + try: + # 提取授权代码和客户端密钥 JSON + auth_code = serializer.validated_data['auth_code'] + client_secret_json = serializer.validated_data['client_secret_json'] + # 完成认证并保存凭证 + credential = GmailService.complete_authentication(request.user, auth_code, client_secret_json) + # 序列化凭证数据以返回 + serializer = GmailCredentialSerializer(credential, context={'request': request}) + logger.info(f"Authentication completed for user {request.user.id}, email: {credential.email}") + return Response(serializer.data, status=status.HTTP_201_CREATED) + except Exception as e: + # 记录错误并返回服务器错误响应 + logger.error(f"Error completing authentication for user {request.user.id}: {str(e)}") + return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + # 记录无效请求数据并返回错误响应 + logger.warning(f"Invalid request data: {serializer.errors}") + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class GmailCredentialListView(APIView): + """ + API 视图,用于列出用户的所有 Gmail 凭证。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def get(self, request): + """ + 处理 GET 请求,返回用户的所有 Gmail 凭证列表。 + + Args: + request: Django REST Framework 请求对象。 + + Returns: + Response: 包含凭证列表的 JSON 响应。 + + Status Codes: + 200: 成功返回凭证列表。 + """ + # 获取用户关联的所有 Gmail 凭证 + credentials = request.user.gmail_credentials.all() + # 序列化凭证数据 + serializer = GmailCredentialSerializer(credentials, many=True, context={'request': request}) + return Response(serializer.data, status=status.HTTP_200_OK) + + +class GmailCredentialDetailView(APIView): + """ + API 视图,用于管理特定 Gmail 凭证的获取、更新和删除。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def get(self, request, pk): + """ + 处理 GET 请求,获取特定 Gmail 凭证的详细信息。 + + Args: + request: Django REST Framework 请求对象。 + pk: 凭证的主键 ID。 + + Returns: + Response: 包含凭证详细信息的 JSON 响应。 + + Status Codes: + 200: 成功返回凭证信息。 + 404: 未找到指定凭证。 + """ + # 获取用户拥有的指定凭证,未找到则返回 404 + credential = get_object_or_404(GmailCredential, pk=pk, user=request.user) + serializer = GmailCredentialSerializer(credential, context={'request': request}) + return Response(serializer.data, status=status.HTTP_200_OK) + + def patch(self, request, pk): + """ + 处理 PATCH 请求,更新特定 Gmail 凭证(如设置为默认凭证)。 + + Args: + request: Django REST Framework 请求对象,包含更新数据。 + pk: 凭证的主键 ID。 + + Returns: + Response: 包含更新后凭证数据的 JSON 响应,或错误信息。 + + Status Codes: + 200: 成功更新凭证。 + 400: 请求数据无效。 + 404: 未找到指定凭证。 + """ + # 获取用户拥有的指定凭证 + credential = get_object_or_404(GmailCredential, pk=pk, user=request.user) + serializer = GmailCredentialSerializer(credential, data=request.data, partial=True, context={'request': request}) + if serializer.is_valid(): + # 如果设置为默认凭证,清除其他凭证的默认状态 + if serializer.validated_data.get('is_default', False): + GmailCredential.objects.filter(user=request.user).exclude(id=credential.id).update(is_default=False) + serializer.save() + return Response(serializer.data, status=status.HTTP_200_OK) + # 返回无效数据错误 + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + def delete(self, request, pk): + """ + 处理 DELETE 请求,删除特定 Gmail 凭证。 + + Args: + request: Django REST Framework 请求对象。 + pk: 凭证的主键 ID。 + + Returns: + Response: 空响应,表示删除成功。 + + Status Codes: + 204: 成功删除凭证。 + 404: 未找到指定凭证。 + """ + # 获取并删除用户拥有的指定凭证 + credential = get_object_or_404(GmailCredential, pk=pk, user=request.user) + credential.delete() + return Response(status=status.HTTP_204_NO_CONTENT) + +class GmailConversationView(APIView): + """ + API视图,用于获取和保存Gmail对话。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def post(self, request): + """ + 处理POST请求,获取Gmail对话并保存到聊天历史。 + + 请求参数: + user_email: 用户Gmail邮箱 + influencer_email: 达人Gmail邮箱 + kb_id: [可选] 知识库ID,不提供则使用默认知识库 + + 返回: + conversation_id: 创建的会话ID + """ + try: + # 验证必填参数 + user_email = request.data.get('user_email') + influencer_email = request.data.get('influencer_email') + + if not user_email or not influencer_email: + return Response({ + 'code': 400, + 'message': '缺少必填参数: user_email 或 influencer_email', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 可选参数 + kb_id = request.data.get('kb_id') + + # 调用服务保存对话 + conversation_id, error = GmailService.save_conversations_to_chat( + request.user, + user_email, + influencer_email, + kb_id + ) + + if error: + return Response({ + 'code': 400, + 'message': error, + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + return Response({ + 'code': 200, + 'message': '获取Gmail对话成功', + 'data': { + 'conversation_id': conversation_id + } + }) + + except Exception as e: + logger.error(f"获取Gmail对话失败: {str(e)}") + return Response({ + 'code': 500, + 'message': f'获取Gmail对话失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def get(self, request): + """ + 处理GET请求,获取用户的Gmail对话列表。 + """ + try: + conversations = GmailConversation.objects.filter(user=request.user, is_active=True) + + data = [] + for conversation in conversations: + # 获取附件计数 + attachments_count = GmailAttachment.objects.filter( + conversation=conversation + ).count() + + data.append({ + 'id': str(conversation.id), + 'conversation_id': conversation.conversation_id, + 'user_email': conversation.user_email, + 'influencer_email': conversation.influencer_email, + 'title': conversation.title, + 'last_sync_time': conversation.last_sync_time.strftime('%Y-%m-%d %H:%M:%S'), + 'created_at': conversation.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'attachments_count': attachments_count + }) + + return Response({ + 'code': 200, + 'message': '获取对话列表成功', + 'data': data + }) + + except Exception as e: + logger.error(f"获取对话列表失败: {str(e)}") + return Response({ + 'code': 500, + 'message': f'获取对话列表失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +class GmailAttachmentListView(APIView): + """ + API视图,用于获取Gmail附件列表。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def get(self, request, conversation_id=None): + """ + 处理GET请求,获取指定对话的附件列表。 + """ + try: + if conversation_id: + # 获取指定对话的附件 + conversation = get_object_or_404(GmailConversation, conversation_id=conversation_id, user=request.user) + attachments = GmailAttachment.objects.filter(conversation=conversation) + else: + # 获取用户的所有附件 + conversations = GmailConversation.objects.filter(user=request.user, is_active=True) + attachments = GmailAttachment.objects.filter(conversation__in=conversations) + + data = [] + for attachment in attachments: + data.append({ + 'id': str(attachment.id), + 'conversation_id': attachment.conversation.conversation_id, + 'filename': attachment.filename, + 'content_type': attachment.content_type, + 'size': attachment.size, + 'sender_email': attachment.sender_email, + 'created_at': attachment.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'url': attachment.get_absolute_url() + }) + + return Response({ + 'code': 200, + 'message': '获取附件列表成功', + 'data': data + }) + + except Exception as e: + logger.error(f"获取附件列表失败: {str(e)}") + return Response({ + 'code': 500, + 'message': f'获取附件列表失败: {str(e)}', + 'data': None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class GmailPubSubView(APIView): + """ + API视图,用于设置Gmail的Pub/Sub实时通知。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def post(self, request): + """ + 处理POST请求,为用户的Gmail账户设置Pub/Sub推送通知。 + + Args: + request: Django REST Framework请求对象,包含Gmail邮箱信息。 + + Returns: + Response: 设置结果的JSON响应。 + + Status Codes: + 200: 成功设置Pub/Sub通知。 + 400: 请求数据无效。 + 404: 未找到指定Gmail凭证。 + 500: 服务器内部错误。 + """ + try: + # 获取请求参数 + email = request.data.get('email') + + if not email: + return Response({'error': '必须提供Gmail邮箱地址'}, status=status.HTTP_400_BAD_REQUEST) + + # 检查用户是否有此Gmail账户的凭证 + credential = GmailCredential.objects.filter(user=request.user, email=email).first() + if not credential: + return Response({'error': f'未找到{email}的授权信息'}, status=status.HTTP_404_NOT_FOUND) + + # 设置Pub/Sub通知 + success, error = GmailService.setup_gmail_push_notification(request.user, email) + + if success: + return Response({'message': f'已成功为{email}设置实时通知'}, status=status.HTTP_200_OK) + else: + return Response({'error': error}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + except Exception as e: + logger.error(f"设置Gmail Pub/Sub通知失败: {str(e)}") + return Response({'error': f'设置Gmail实时通知失败: {str(e)}'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def get(self, request): + """ + 处理GET请求,获取用户所有已设置Pub/Sub通知的Gmail账户。 + + 这个方法目前仅返回用户的所有Gmail凭证,未来可以扩展为返回推送通知的详细状态。 + + Args: + request: Django REST Framework请求对象。 + + Returns: + Response: 包含Gmail账户列表的JSON响应。 + + Status Codes: + 200: 成功返回账户列表。 + """ + # 获取用户所有Gmail凭证 + credentials = request.user.gmail_credentials.filter(is_valid=True) + + # 构建响应数据 + accounts = [] + for cred in credentials: + accounts.append({ + 'id': cred.id, + 'email': cred.email, + 'is_default': cred.is_default + }) + + return Response({'accounts': accounts}, status=status.HTTP_200_OK) + +class GmailNotificationStartView(APIView): + """ + API视图,用于启动Gmail Pub/Sub监听器。 + 通常由系统管理员或后台任务调用,而非普通用户。 + """ + permission_classes = [IsAuthenticated] # 可根据需要更改为更严格的权限 + + def post(self, request): + """ + 处理POST请求,启动Gmail Pub/Sub监听器。 + + Args: + request: Django REST Framework请求对象。 + + Returns: + Response: 启动结果的JSON响应。 + + Status Codes: + 200: 成功启动监听器。 + 500: 服务器内部错误。 + """ + try: + # 可选:指定要监听的用户ID + user_id = request.data.get('user_id') + + # 在后台线程中启动监听器 + thread = GmailService.start_pubsub_listener_thread(user_id) + + return Response({'message': '已成功启动Gmail实时通知监听器'}, status=status.HTTP_200_OK) + + except Exception as e: + logger.error(f"启动Gmail Pub/Sub监听器失败: {str(e)}") + return Response({'error': f'启动Gmail实时通知监听器失败: {str(e)}'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +class GmailSendEmailView(APIView): + """ + API视图,用于发送Gmail邮件(支持附件)。 + """ + permission_classes = [IsAuthenticated] # 限制访问,仅允许已认证用户 + + def post(self, request): + """ + 处理POST请求,发送Gmail邮件。 + + 请求应包含以下字段: + - email: 发件人Gmail邮箱 + - to: 收件人邮箱 + - subject: 邮件主题 + - body: 邮件正文 + - attachments: 附件文件IDs列表 (可选) + + Args: + request: Django REST Framework请求对象。 + + Returns: + Response: 发送结果的JSON响应。 + + Status Codes: + 200: 成功发送邮件。 + 400: 请求数据无效。 + 404: 未找到Gmail凭证。 + 500: 服务器内部错误。 + """ + try: + # 获取请求参数 + user_email = request.data.get('email') + to_email = request.data.get('to') + subject = request.data.get('subject') + body = request.data.get('body') + attachment_ids = request.data.get('attachments', []) + + # 验证必填字段 + if not all([user_email, to_email, subject]): + return Response({ + 'error': '缺少必要参数,请提供email、to和subject字段' + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查是否有此Gmail账户的凭证 + credential = GmailCredential.objects.filter( + user=request.user, + email=user_email, + is_valid=True + ).first() + + if not credential: + return Response({ + 'error': f'未找到{user_email}的有效授权信息' + }, status=status.HTTP_404_NOT_FOUND) + + # 处理附件 + attachments = [] + if attachment_ids and isinstance(attachment_ids, list): + for file_id in attachment_ids: + # 查找已上传的文件 + file_obj = request.FILES.get(f'file_{file_id}') + if file_obj: + # 保存临时文件 + tmp_path = os.path.join(settings.MEDIA_ROOT, 'tmp', f'{file_id}_{file_obj.name}') + os.makedirs(os.path.dirname(tmp_path), exist_ok=True) + + with open(tmp_path, 'wb+') as destination: + for chunk in file_obj.chunks(): + destination.write(chunk) + + attachments.append({ + 'path': tmp_path, + 'filename': file_obj.name + }) + else: + # 检查是否为已有的Gmail附件ID + try: + attachment = GmailAttachment.objects.get(id=file_id) + if attachment.conversation.user_id == request.user.id: + attachments.append({ + 'path': attachment.file_path, + 'filename': attachment.filename + }) + except (GmailAttachment.DoesNotExist, ValueError): + logger.warning(f"无法找到附件: {file_id}") + + # 发送邮件 + success, result = GmailService.send_email( + request.user, + user_email, + to_email, + subject, + body or '', + attachments + ) + + if success: + return Response({ + 'message': '邮件发送成功', + 'message_id': result + }, status=status.HTTP_200_OK) + else: + return Response({ + 'error': result + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + except Exception as e: + logger.error(f"发送Gmail邮件失败: {str(e)}") + return Response({ + 'error': f'发送Gmail邮件失败: {str(e)}' + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def get(self, request): + """ + 处理GET请求,获取用户可用于发送邮件的Gmail账户列表。 + + Args: + request: Django REST Framework请求对象。 + + Returns: + Response: 包含Gmail账户列表的JSON响应。 + + Status Codes: + 200: 成功返回账户列表。 + """ + # 获取用户所有可用Gmail凭证 + credentials = request.user.gmail_credentials.filter(is_valid=True) + + # 构建响应数据 + accounts = [] + for cred in credentials: + accounts.append({ + 'id': cred.id, + 'email': cred.email, + 'is_default': cred.is_default + }) + + return Response({'accounts': accounts}, status=status.HTTP_200_OK) diff --git a/apps/message/consumers.py b/apps/message/consumers.py deleted file mode 100644 index 6473dcb..0000000 --- a/apps/message/consumers.py +++ /dev/null @@ -1,72 +0,0 @@ -# apps/message/consumers.py -import json -import logging -from channels.generic.websocket import AsyncWebsocketConsumer -from channels.db import database_sync_to_async -from rest_framework.authtoken.models import Token -from urllib.parse import parse_qs - -logger = logging.getLogger(__name__) - -class NotificationConsumer(AsyncWebsocketConsumer): - async def connect(self): - """建立WebSocket连接""" - try: - # 从URL参数中获取token - query_string = self.scope.get('query_string', b'').decode() - query_params = parse_qs(query_string) - token_key = query_params.get('token', [''])[0] - - if not token_key: - logger.warning("WebSocket连接尝试,但没有提供token") - await self.close() - return - - # 验证token - self.user = await self.get_user_from_token(token_key) - if not self.user: - logger.warning(f"WebSocket连接尝试,但token无效: {token_key}") - await self.close() - return - - # 为用户创建专属房间 - self.room_name = f"notification_user_{self.user.id}" - await self.channel_layer.group_add( - self.room_name, - self.channel_name - ) - await self.accept() - logger.info(f"用户 {self.user.username} WebSocket连接成功") - - except Exception as e: - logger.error(f"WebSocket连接错误: {str(e)}") - await self.close() - - @database_sync_to_async - def get_user_from_token(self, token_key): - try: - token = Token.objects.select_related('user').get(key=token_key) - return token.user - except Token.DoesNotExist: - return None - - async def disconnect(self, close_code): - """断开WebSocket连接""" - try: - if hasattr(self, 'room_name'): - await self.channel_layer.group_discard( - self.room_name, - self.channel_name - ) - logger.info(f"用户 {self.user.username} 已断开连接,关闭代码: {close_code}") - except Exception as e: - logger.error(f"断开连接时发生错误: {str(e)}") - - async def notification(self, event): - """处理并发送通知消息""" - try: - await self.send(text_data=json.dumps(event)) - logger.info(f"已发送通知给用户 {self.user.username}") - except Exception as e: - logger.error(f"发送通知消息时发生错误: {str(e)}") - \ No newline at end of file diff --git a/apps/message/routing.py b/apps/message/routing.py deleted file mode 100644 index 0400939..0000000 --- a/apps/message/routing.py +++ /dev/null @@ -1,11 +0,0 @@ -# apps/message/routing.py -from django.urls import re_path -from apps.message.consumers import NotificationConsumer -from apps.chat.consumers import ChatStreamConsumer # 直接导入已有的ChatStreamConsumer -import logging - -websocket_urlpatterns = [ - re_path(r'^ws/notifications/$', NotificationConsumer.as_asgi()), - re_path(r'^ws/chat/stream/$', ChatStreamConsumer.as_asgi()), -] - diff --git a/apps/message/migrations/__init__.py b/apps/notification/__init__.py similarity index 100% rename from apps/message/migrations/__init__.py rename to apps/notification/__init__.py diff --git a/apps/message/admin.py b/apps/notification/admin.py similarity index 100% rename from apps/message/admin.py rename to apps/notification/admin.py diff --git a/apps/message/apps.py b/apps/notification/apps.py similarity index 80% rename from apps/message/apps.py rename to apps/notification/apps.py index 0c2c88a..de94f13 100644 --- a/apps/message/apps.py +++ b/apps/notification/apps.py @@ -3,4 +3,4 @@ from django.apps import AppConfig class MessageConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' - name = 'apps.message' + name = 'apps.notification' diff --git a/apps/notification/consumers.py b/apps/notification/consumers.py new file mode 100644 index 0000000..8e42d9b --- /dev/null +++ b/apps/notification/consumers.py @@ -0,0 +1,61 @@ +# apps/notification/consumers.py +from channels.generic.websocket import AsyncWebsocketConsumer +import json +from channels.db import database_sync_to_async +from rest_framework.authtoken.models import Token +import logging + +logger = logging.getLogger(__name__) + +class NotificationConsumer(AsyncWebsocketConsumer): + async def connect(self): + # 获取token参数 + query_string = self.scope.get('query_string', b'').decode() + query_params = dict(param.split('=') for param in query_string.split('&') if '=' in param) + token_key = query_params.get('token', None) + + if token_key: + # 使用token获取用户 + self.user = await self.get_user_from_token(token_key) + if not self.user: + logger.error(f"Invalid token: {token_key}") + await self.close() + return + else: + # 使用scope中的用户(如果有认证) + self.user = self.scope.get('user') + if not self.user or not self.user.is_authenticated: + logger.error("No valid authentication in WebSocket connection") + await self.close() + return + + logger.info(f"WebSocket connected for user: {self.user.id}") + self.group_name = f"notification_user_{self.user.id}" + await self.channel_layer.group_add( + self.group_name, + self.channel_name + ) + await self.accept() + + @database_sync_to_async + def get_user_from_token(self, token_key): + try: + token = Token.objects.select_related('user').get(key=token_key) + return token.user + except Token.DoesNotExist: + return None + except Exception as e: + logger.error(f"Error authenticating token: {str(e)}") + return None + + async def disconnect(self, close_code): + logger.info(f"WebSocket disconnected with code: {close_code}") + if hasattr(self, 'group_name'): + await self.channel_layer.group_discard( + self.group_name, + self.channel_name + ) + + async def notification(self, event): + """处理通知事件""" + await self.send(text_data=json.dumps(event['data'])) diff --git a/apps/message/migrations/0001_initial.py b/apps/notification/migrations/0001_initial.py similarity index 100% rename from apps/message/migrations/0001_initial.py rename to apps/notification/migrations/0001_initial.py diff --git a/apps/notification/migrations/0002_rename_message_not_receive_e8d006_idx_notificatio_receive_6f29eb_idx_and_more.py b/apps/notification/migrations/0002_rename_message_not_receive_e8d006_idx_notificatio_receive_6f29eb_idx_and_more.py new file mode 100644 index 0000000..5d83eb3 --- /dev/null +++ b/apps/notification/migrations/0002_rename_message_not_receive_e8d006_idx_notificatio_receive_6f29eb_idx_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2 on 2025-05-09 03:11 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('notification', '0001_initial'), + ] + + operations = [ + migrations.RenameIndex( + model_name='notification', + new_name='notificatio_receive_6f29eb_idx', + old_name='message_not_receive_e8d006_idx', + ), + migrations.RenameIndex( + model_name='notification', + new_name='notificatio_type_83c189_idx', + old_name='message_not_type_a0b1e3_idx', + ), + ] diff --git a/apps/notification/migrations/0003_alter_notification_sender.py b/apps/notification/migrations/0003_alter_notification_sender.py new file mode 100644 index 0000000..5cd4706 --- /dev/null +++ b/apps/notification/migrations/0003_alter_notification_sender.py @@ -0,0 +1,21 @@ +# Generated by Django 5.2 on 2025-05-09 08:35 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('notification', '0002_rename_message_not_receive_e8d006_idx_notificatio_receive_6f29eb_idx_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterField( + model_name='notification', + name='sender', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='sent_notifications', to=settings.AUTH_USER_MODEL), + ), + ] diff --git a/apps/message/services/__init__.py b/apps/notification/migrations/__init__.py similarity index 100% rename from apps/message/services/__init__.py rename to apps/notification/migrations/__init__.py diff --git a/apps/message/models.py b/apps/notification/models.py similarity index 89% rename from apps/message/models.py rename to apps/notification/models.py index bf72a7d..854b61d 100644 --- a/apps/message/models.py +++ b/apps/notification/models.py @@ -1,4 +1,4 @@ -# apps/message/models.py +# apps/notification/models.py from django.db import models from django.utils import timezone import uuid @@ -18,7 +18,7 @@ class Notification(models.Model): type = models.CharField(max_length=20, choices=NOTIFICATION_TYPES) title = models.CharField(max_length=100) content = models.TextField() - sender = models.ForeignKey(User, on_delete=models.CASCADE, related_name='sent_notifications') + sender = models.ForeignKey(User, null=True, blank=True, on_delete=models.SET_NULL, related_name='sent_notifications') receiver = models.ForeignKey(User, on_delete=models.CASCADE, related_name='received_notifications') is_read = models.BooleanField(default=False) related_resource = models.CharField(max_length=100, blank=True) # 相关资源ID diff --git a/apps/notification/routing.py b/apps/notification/routing.py new file mode 100644 index 0000000..1737b4a --- /dev/null +++ b/apps/notification/routing.py @@ -0,0 +1,9 @@ +# apps/notification/routing.py +from django.urls import re_path +from apps.notification.consumers import NotificationConsumer +import logging + +websocket_urlpatterns = [ + re_path(r'^ws/notifications/$', NotificationConsumer.as_asgi()), +] + diff --git a/apps/message/serializers.py b/apps/notification/serializers.py similarity index 88% rename from apps/message/serializers.py rename to apps/notification/serializers.py index 53792a2..9aa8088 100644 --- a/apps/message/serializers.py +++ b/apps/notification/serializers.py @@ -1,6 +1,6 @@ -# apps/message/serializers.py +# apps/notification/serializers.py from rest_framework import serializers -from apps.message.models import Notification +from apps.notification.models import Notification from apps.accounts.models import User class NotificationSerializer(serializers.ModelSerializer): diff --git a/apps/notification/services/__init__.py b/apps/notification/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/notification/tests.py b/apps/notification/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/apps/notification/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/message/urls.py b/apps/notification/urls.py similarity index 73% rename from apps/message/urls.py rename to apps/notification/urls.py index d9aee95..668e98d 100644 --- a/apps/message/urls.py +++ b/apps/notification/urls.py @@ -1,7 +1,7 @@ -# apps/message/urls.py +# apps/notification/urls.py from django.urls import path, include from rest_framework.routers import DefaultRouter -from apps.message.views import NotificationViewSet +from apps.notification.views import NotificationViewSet router = DefaultRouter() router.register(r'', NotificationViewSet, basename='notification') diff --git a/apps/message/views.py b/apps/notification/views.py similarity index 92% rename from apps/message/views.py rename to apps/notification/views.py index efd6b6b..f6ea2b0 100644 --- a/apps/message/views.py +++ b/apps/notification/views.py @@ -1,10 +1,10 @@ -# apps/message/views.py +# apps/notification/views.py from rest_framework import viewsets, status from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.decorators import action -from apps.message.models import Notification -from apps.message.serializers import NotificationSerializer +from apps.notification.models import Notification +from apps.notification.serializers import NotificationSerializer class NotificationViewSet(viewsets.ModelViewSet): """通知视图集""" diff --git a/apps/operation/__init__.py b/apps/operation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/operation/admin.py b/apps/operation/admin.py new file mode 100644 index 0000000..f0944c5 --- /dev/null +++ b/apps/operation/admin.py @@ -0,0 +1,36 @@ +from django.contrib import admin +from .models import OperatorAccount, PlatformAccount, Video + +@admin.register(OperatorAccount) +class OperatorAccountAdmin(admin.ModelAdmin): + list_display = ('username', 'real_name', 'email', 'phone', 'position', 'department', 'is_active', 'created_at') + list_filter = ('position', 'department', 'is_active') + search_fields = ('username', 'real_name', 'email', 'phone') + date_hierarchy = 'created_at' + readonly_fields = ('created_at', 'updated_at') + +@admin.register(PlatformAccount) +class PlatformAccountAdmin(admin.ModelAdmin): + list_display = ('account_name', 'platform_name', 'operator', 'status', 'followers_count', 'last_posting', 'created_at') + list_filter = ('platform_name', 'status') + search_fields = ('account_name', 'account_id', 'description') + date_hierarchy = 'created_at' + readonly_fields = ('created_at', 'updated_at') + + def get_queryset(self, request): + """优化查询,减少数据库查询次数""" + queryset = super().get_queryset(request) + return queryset.select_related('operator') + +@admin.register(Video) +class VideoAdmin(admin.ModelAdmin): + list_display = ('title', 'platform_account', 'status', 'views_count', 'likes_count', 'publish_time', 'created_at') + list_filter = ('status', 'created_at', 'publish_time') + search_fields = ('title', 'description', 'tags') + date_hierarchy = 'created_at' + readonly_fields = ('created_at', 'updated_at') + + def get_queryset(self, request): + """优化查询,减少数据库查询次数""" + queryset = super().get_queryset(request) + return queryset.select_related('platform_account', 'platform_account__operator') diff --git a/apps/operation/apps.py b/apps/operation/apps.py new file mode 100644 index 0000000..4ef1f73 --- /dev/null +++ b/apps/operation/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class OperationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.operation' diff --git a/apps/operation/migrations/0001_initial.py b/apps/operation/migrations/0001_initial.py new file mode 100644 index 0000000..6e9c0a7 --- /dev/null +++ b/apps/operation/migrations/0001_initial.py @@ -0,0 +1,88 @@ +# Generated by Django 5.1.5 on 2025-05-12 08:55 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='OperatorAccount', + fields=[ + ('id', models.AutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True, verbose_name='UUID')), + ('username', models.CharField(max_length=100, unique=True, verbose_name='用户名')), + ('password', models.CharField(max_length=255, verbose_name='密码')), + ('real_name', models.CharField(max_length=50, verbose_name='真实姓名')), + ('email', models.EmailField(max_length=254, verbose_name='邮箱')), + ('phone', models.CharField(max_length=15, verbose_name='电话')), + ('position', models.CharField(choices=[('editor', '编辑'), ('planner', '策划'), ('operator', '运营'), ('admin', '管理员')], max_length=20, verbose_name='工作定位')), + ('department', models.CharField(max_length=50, verbose_name='部门')), + ('is_active', models.BooleanField(default=True, verbose_name='是否在职')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ], + options={ + 'verbose_name': '运营账号', + 'verbose_name_plural': '运营账号', + }, + ), + migrations.CreateModel( + name='PlatformAccount', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('platform_name', models.CharField(choices=[('youtube', 'YouTube'), ('tiktok', 'TikTok'), ('twitter', 'Twitter/X'), ('instagram', 'Instagram'), ('facebook', 'Facebook'), ('bilibili', 'Bilibili')], max_length=20, verbose_name='平台名称')), + ('account_name', models.CharField(max_length=100, verbose_name='账号名称')), + ('account_id', models.CharField(max_length=100, verbose_name='账号ID')), + ('status', models.CharField(choices=[('active', '正常'), ('restricted', '限流'), ('suspended', '封禁'), ('inactive', '未激活')], default='active', max_length=20, verbose_name='账号状态')), + ('followers_count', models.IntegerField(default=0, verbose_name='粉丝数')), + ('account_url', models.URLField(verbose_name='账号链接')), + ('description', models.TextField(blank=True, null=True, verbose_name='账号描述')), + ('tags', models.CharField(blank=True, help_text='用逗号分隔的标签列表', max_length=255, null=True, verbose_name='标签')), + ('profile_image', models.URLField(blank=True, null=True, verbose_name='头像URL')), + ('last_posting', models.DateTimeField(blank=True, null=True, verbose_name='最后发布时间')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('last_login', models.DateTimeField(blank=True, null=True, verbose_name='最后登录时间')), + ('operator', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='platform_accounts', to='operation.operatoraccount', verbose_name='关联运营')), + ], + options={ + 'verbose_name': '平台账号', + 'verbose_name_plural': '平台账号', + 'unique_together': {('platform_name', 'account_id')}, + }, + ), + migrations.CreateModel( + name='Video', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('title', models.CharField(max_length=200, verbose_name='视频标题')), + ('description', models.TextField(blank=True, null=True, verbose_name='视频描述')), + ('video_url', models.URLField(blank=True, null=True, verbose_name='视频地址')), + ('local_path', models.CharField(blank=True, max_length=255, null=True, verbose_name='本地路径')), + ('thumbnail_url', models.URLField(blank=True, null=True, verbose_name='缩略图地址')), + ('status', models.CharField(choices=[('draft', '草稿'), ('scheduled', '已排期'), ('published', '已发布'), ('failed', '发布失败'), ('deleted', '已删除')], default='draft', max_length=20, verbose_name='发布状态')), + ('views_count', models.IntegerField(default=0, verbose_name='播放次数')), + ('likes_count', models.IntegerField(default=0, verbose_name='点赞数')), + ('comments_count', models.IntegerField(default=0, verbose_name='评论数')), + ('shares_count', models.IntegerField(default=0, verbose_name='分享数')), + ('tags', models.CharField(blank=True, max_length=500, null=True, verbose_name='标签')), + ('publish_time', models.DateTimeField(blank=True, null=True, verbose_name='发布时间')), + ('scheduled_time', models.DateTimeField(blank=True, null=True, verbose_name='计划发布时间')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('platform_account', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='videos', to='operation.platformaccount', verbose_name='发布账号')), + ], + options={ + 'verbose_name': '视频', + 'verbose_name_plural': '视频', + }, + ), + ] diff --git a/apps/operation/migrations/__init__.py b/apps/operation/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/operation/models.py b/apps/operation/models.py new file mode 100644 index 0000000..3050fc1 --- /dev/null +++ b/apps/operation/models.py @@ -0,0 +1,126 @@ +from django.db import models +import uuid +from django.utils import timezone +from apps.knowledge_base.models import KnowledgeBase, KnowledgeBaseDocument +from apps.accounts.models import User + +# Create your models here. + +# 我们可以在这里添加额外的模型或关系,但现在使用user_management中的现有模型 + +# 从user_management迁移过来的模型 +class OperatorAccount(models.Model): + """运营账号信息表""" + + id = models.AutoField(primary_key=True) # 保留自动递增的ID字段 + uuid = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, verbose_name='UUID') + + POSITION_CHOICES = [ + ('editor', '编辑'), + ('planner', '策划'), + ('operator', '运营'), + ('admin', '管理员'), + ] + + username = models.CharField(max_length=100, unique=True, verbose_name='用户名') + password = models.CharField(max_length=255, verbose_name='密码') + real_name = models.CharField(max_length=50, verbose_name='真实姓名') + email = models.EmailField(verbose_name='邮箱') + phone = models.CharField(max_length=15, verbose_name='电话') + position = models.CharField(max_length=20, choices=POSITION_CHOICES, verbose_name='工作定位') + department = models.CharField(max_length=50, verbose_name='部门') + is_active = models.BooleanField(default=True, verbose_name='是否在职') + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + + class Meta: + verbose_name = '运营账号' + verbose_name_plural = '运营账号' + + def __str__(self): + return f"{self.real_name} ({self.username})" + +class PlatformAccount(models.Model): + """平台账号信息表""" + + STATUS_CHOICES = [ + ('active', '正常'), + ('restricted', '限流'), + ('suspended', '封禁'), + ('inactive', '未激活'), + ] + + PLATFORM_CHOICES = [ + ('youtube', 'YouTube'), + ('tiktok', 'TikTok'), + ('twitter', 'Twitter/X'), + ('instagram', 'Instagram'), + ('facebook', 'Facebook'), + ('bilibili', 'Bilibili'), + ] + + operator = models.ForeignKey(OperatorAccount, on_delete=models.CASCADE, related_name='platform_accounts', verbose_name='关联运营') + platform_name = models.CharField(max_length=20, choices=PLATFORM_CHOICES, verbose_name='平台名称') + account_name = models.CharField(max_length=100, verbose_name='账号名称') + account_id = models.CharField(max_length=100, verbose_name='账号ID') + status = models.CharField(max_length=20, choices=STATUS_CHOICES, default='active', verbose_name='账号状态') + followers_count = models.IntegerField(default=0, verbose_name='粉丝数') + account_url = models.URLField(verbose_name='账号链接') + description = models.TextField(blank=True, null=True, verbose_name='账号描述') + + # 新增字段 + tags = models.CharField(max_length=255, blank=True, null=True, verbose_name='标签', help_text='用逗号分隔的标签列表') + profile_image = models.URLField(blank=True, null=True, verbose_name='头像URL') + last_posting = models.DateTimeField(blank=True, null=True, verbose_name='最后发布时间') + + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + last_login = models.DateTimeField(blank=True, null=True, verbose_name='最后登录时间') + + class Meta: + verbose_name = '平台账号' + verbose_name_plural = '平台账号' + unique_together = ('platform_name', 'account_id') + + def __str__(self): + return f"{self.account_name} ({self.platform_name})" + +class Video(models.Model): + """视频信息表""" + + STATUS_CHOICES = [ + ('draft', '草稿'), + ('scheduled', '已排期'), + ('published', '已发布'), + ('failed', '发布失败'), + ('deleted', '已删除'), + ] + + platform_account = models.ForeignKey(PlatformAccount, on_delete=models.CASCADE, related_name='videos', verbose_name='发布账号') + title = models.CharField(max_length=200, verbose_name='视频标题') + description = models.TextField(blank=True, null=True, verbose_name='视频描述') + video_url = models.URLField(blank=True, null=True, verbose_name='视频地址') + local_path = models.CharField(max_length=255, blank=True, null=True, verbose_name='本地路径') + thumbnail_url = models.URLField(blank=True, null=True, verbose_name='缩略图地址') + status = models.CharField(max_length=20, choices=STATUS_CHOICES, default='draft', verbose_name='发布状态') + views_count = models.IntegerField(default=0, verbose_name='播放次数') + likes_count = models.IntegerField(default=0, verbose_name='点赞数') + comments_count = models.IntegerField(default=0, verbose_name='评论数') + shares_count = models.IntegerField(default=0, verbose_name='分享数') + tags = models.CharField(max_length=500, blank=True, null=True, verbose_name='标签') + publish_time = models.DateTimeField(blank=True, null=True, verbose_name='发布时间') + scheduled_time = models.DateTimeField(blank=True, null=True, verbose_name='计划发布时间') + created_at = models.DateTimeField(auto_now_add=True, verbose_name='创建时间') + updated_at = models.DateTimeField(auto_now=True, verbose_name='更新时间') + + class Meta: + verbose_name = '视频' + verbose_name_plural = '视频' + + def __str__(self): + return self.title + + def save(self, *args, **kwargs): + if self.status == 'published' and not self.publish_time: + self.publish_time = timezone.now() + super().save(*args, **kwargs) diff --git a/apps/operation/pagination.py b/apps/operation/pagination.py new file mode 100644 index 0000000..febb3cc --- /dev/null +++ b/apps/operation/pagination.py @@ -0,0 +1,23 @@ +from rest_framework.pagination import PageNumberPagination +from rest_framework.response import Response + +class CustomPagination(PageNumberPagination): + """自定义分页器,返回格式为 {code, message, data}""" + page_size = 10 + page_size_query_param = 'page_size' + max_page_size = 100 + + def get_paginated_response(self, data): + return Response({ + "code": 200, + "message": "获取数据成功", + "data": { + "count": self.page.paginator.count, + "next": self.get_next_link(), + "previous": self.get_previous_link(), + "results": data, + "page": self.page.number, + "pages": self.page.paginator.num_pages, + "page_size": self.page_size + } + }) \ No newline at end of file diff --git a/apps/operation/serializers.py b/apps/operation/serializers.py new file mode 100644 index 0000000..a091276 --- /dev/null +++ b/apps/operation/serializers.py @@ -0,0 +1,102 @@ +from rest_framework import serializers +from .models import OperatorAccount, PlatformAccount, Video +from apps.knowledge_base.models import KnowledgeBase, KnowledgeBaseDocument +import uuid + + +class OperatorAccountSerializer(serializers.ModelSerializer): + id = serializers.UUIDField(read_only=False, required=False) # 允许前端不提供ID,但如果提供则必须是有效的UUID + + class Meta: + model = OperatorAccount + fields = ['id', 'username', 'password', 'real_name', 'email', 'phone', 'position', 'department', 'is_active', 'created_at', 'updated_at'] + read_only_fields = ['created_at', 'updated_at'] + extra_kwargs = { + 'password': {'write_only': True} + } + + def create(self, validated_data): + # 如果没有提供ID,则生成一个UUID + if 'id' not in validated_data: + validated_data['id'] = uuid.uuid4() + + password = validated_data.pop('password', None) + instance = self.Meta.model(**validated_data) + if password: + instance.password = password # 在实际应用中应该加密存储密码 + instance.save() + return instance + + +class PlatformAccountSerializer(serializers.ModelSerializer): + operator_name = serializers.CharField(source='operator.real_name', read_only=True) + + class Meta: + model = PlatformAccount + fields = ['id', 'operator', 'operator_name', 'platform_name', 'account_name', 'account_id', + 'status', 'followers_count', 'account_url', 'description', + 'tags', 'profile_image', 'last_posting', + 'created_at', 'updated_at', 'last_login'] + read_only_fields = ['id', 'created_at', 'updated_at'] + + def to_internal_value(self, data): + # 处理operator字段,可能是字符串格式的UUID + if 'operator' in data and isinstance(data['operator'], str): + try: + # 尝试获取对应的运营账号对象 + operator = OperatorAccount.objects.get(id=data['operator']) + data['operator'] = operator.id # 确保使用正确的ID格式 + except OperatorAccount.DoesNotExist: + # 如果找不到对应的运营账号,保持原值,让验证器捕获此错误 + pass + except Exception as e: + # 其他类型的错误,如ID格式不正确等 + pass + + return super().to_internal_value(data) + + +class VideoSerializer(serializers.ModelSerializer): + platform_account_name = serializers.CharField(source='platform_account.account_name', read_only=True) + platform_name = serializers.CharField(source='platform_account.platform_name', read_only=True) + + class Meta: + model = Video + fields = ['id', 'platform_account', 'platform_account_name', 'platform_name', 'title', + 'description', 'video_url', 'local_path', 'thumbnail_url', 'status', + 'views_count', 'likes_count', 'comments_count', 'shares_count', 'tags', + 'publish_time', 'scheduled_time', 'created_at', 'updated_at'] + read_only_fields = ['id', 'created_at', 'updated_at', 'views_count', 'likes_count', + 'comments_count', 'shares_count'] + + def to_internal_value(self, data): + # 处理platform_account字段,可能是字符串格式的UUID + if 'platform_account' in data and isinstance(data['platform_account'], str): + try: + # 尝试获取对应的平台账号对象 + platform_account = PlatformAccount.objects.get(id=data['platform_account']) + data['platform_account'] = platform_account.id # 确保使用正确的ID格式 + except PlatformAccount.DoesNotExist: + # 如果找不到对应的平台账号,保持原值,让验证器捕获此错误 + pass + except Exception as e: + # 其他类型的错误,如ID格式不正确等 + pass + + return super().to_internal_value(data) + + +class KnowledgeBaseSerializer(serializers.ModelSerializer): + class Meta: + model = KnowledgeBase + fields = ['id', 'user_id', 'name', 'desc', 'type', 'department', 'group', + 'external_id', 'create_time', 'update_time'] + read_only_fields = ['id', 'create_time', 'update_time'] + + +class KnowledgeBaseDocumentSerializer(serializers.ModelSerializer): + class Meta: + model = KnowledgeBaseDocument + fields = ['id', 'knowledge_base', 'document_id', 'document_name', + 'external_id', 'uploader_name', 'status', 'create_time', 'update_time'] + read_only_fields = ['id', 'create_time', 'update_time'] \ No newline at end of file diff --git a/apps/operation/tests.py b/apps/operation/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/apps/operation/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/operation/urls.py b/apps/operation/urls.py new file mode 100644 index 0000000..e6a3b82 --- /dev/null +++ b/apps/operation/urls.py @@ -0,0 +1,12 @@ +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from .views import OperatorAccountViewSet, PlatformAccountViewSet, VideoViewSet + +router = DefaultRouter() +router.register(r'operators', OperatorAccountViewSet) +router.register(r'platforms', PlatformAccountViewSet) +router.register(r'videos', VideoViewSet) + +urlpatterns = [ + path('', include(router.urls)), +] \ No newline at end of file diff --git a/apps/operation/views.py b/apps/operation/views.py new file mode 100644 index 0000000..91a7eda --- /dev/null +++ b/apps/operation/views.py @@ -0,0 +1,1095 @@ +from django.shortcuts import render +import json +import uuid +import logging +from django.db import transaction +from django.shortcuts import get_object_or_404 +from django.conf import settings +from django.utils import timezone +from rest_framework import viewsets, status +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.permissions import IsAuthenticated +from django.db.models import Q +import os + +from .models import OperatorAccount, PlatformAccount, Video +from apps.knowledge_base.models import KnowledgeBase, KnowledgeBaseDocument +from apps.accounts.models import User +from .serializers import ( + OperatorAccountSerializer, PlatformAccountSerializer, VideoSerializer, + KnowledgeBaseSerializer, KnowledgeBaseDocumentSerializer +) +from .pagination import CustomPagination + +logger = logging.getLogger(__name__) + +class OperatorAccountViewSet(viewsets.ModelViewSet): + """运营账号管理视图集""" + queryset = OperatorAccount.objects.all() + serializer_class = OperatorAccountSerializer + pagination_class = CustomPagination + + def list(self, request, *args, **kwargs): + """获取运营账号列表""" + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + + if page is not None: + serializer = self.get_serializer(page, many=True) + # 使用自定义分页器的响应 + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response({ + "code": 200, + "message": "获取运营账号列表成功", + "data": serializer.data + }) + + def retrieve(self, request, *args, **kwargs): + """获取运营账号详情""" + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response({ + "code": 200, + "message": "获取运营账号详情成功", + "data": serializer.data + }) + + def update(self, request, *args, **kwargs): + """更新运营账号信息""" + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + return Response({ + "code": 200, + "message": "更新运营账号信息成功", + "data": serializer.data + }) + + def partial_update(self, request, *args, **kwargs): + """部分更新运营账号信息""" + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def create(self, request, *args, **kwargs): + """创建运营账号并自动创建对应的私有知识库""" + with transaction.atomic(): + # 1. 创建运营账号 + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + # 2. 手动保存数据而不是使用serializer.save(),确保不传入UUID + operator_data = serializer.validated_data + operator = OperatorAccount.objects.create(**operator_data) + + # 3. 为每个运营账号创建一个私有知识库 + knowledge_base = KnowledgeBase.objects.create( + user_id=request.user.id, # 使用当前用户作为创建者 + name=f"{operator.real_name}的运营知识库", + desc=f"用于存储{operator.real_name}({operator.username})相关的运营数据", + type='private', + department=operator.department + ) + + # 4. 创建知识库文档记录 - 运营信息文档 + document_data = { + "name": f"{operator.real_name}_运营信息", + "paragraphs": [ + { + "title": "运营账号基本信息", + "content": f""" + 用户名: {operator.username} + 真实姓名: {operator.real_name} + 邮箱: {operator.email} + 电话: {operator.phone} + 职位: {operator.get_position_display()} + 部门: {operator.department} + 创建时间: {operator.created_at.strftime('%Y-%m-%d %H:%M:%S')} + uuid: {operator.uuid} + """, + "is_active": True + } + ] + } + + # 调用外部API创建文档 + document_id = self._create_document(knowledge_base.external_id, document_data) + + if document_id: + # 创建知识库文档记录 + KnowledgeBaseDocument.objects.create( + knowledge_base=knowledge_base, + document_id=document_id, + document_name=document_data["name"], + external_id=document_id, + uploader_name=request.user.username + ) + + return Response({ + "code": 200, + "message": "运营账号创建成功,并已创建对应知识库", + "data": { + "operator": self.get_serializer(operator).data, + "knowledge_base": { + "id": knowledge_base.id, + "name": knowledge_base.name, + "external_id": knowledge_base.external_id + } + } + }, status=status.HTTP_201_CREATED) + + def destroy(self, request, *args, **kwargs): + """删除运营账号并更新相关知识库状态""" + operator = self.get_object() + + # 更新知识库状态或删除关联文档 + knowledge_bases = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ) + + for kb in knowledge_bases: + # 可以选择删除知识库,或者更新知识库状态 + # 这里我们更新对应的文档状态 + documents = KnowledgeBaseDocument.objects.filter( + knowledge_base=kb, + document_name__contains=operator.real_name + ) + + for doc in documents: + doc.status = 'deleted' + doc.save() + + operator.is_active = False # 软删除 + operator.save() + + return Response({ + "code": 200, + "message": "运营账号已停用,相关知识库文档已标记为删除", + "data": None + }) + + def _create_document(self, external_id, doc_data): + """调用外部API创建文档""" + try: + if not external_id: + logger.error("创建文档失败:知识库external_id为空") + return None + + # 在实际应用中,这里需要调用外部API创建文档 + # 模拟创建文档并返回document_id + document_id = str(uuid.uuid4()) + logger.info(f"模拟创建文档成功,document_id: {document_id}") + return document_id + except Exception as e: + logger.error(f"创建文档失败: {str(e)}") + return None + + +class PlatformAccountViewSet(viewsets.ModelViewSet): + """平台账号管理视图集""" + queryset = PlatformAccount.objects.all() + serializer_class = PlatformAccountSerializer + pagination_class = CustomPagination + + def list(self, request, *args, **kwargs): + """获取平台账号列表""" + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + + if page is not None: + serializer = self.get_serializer(page, many=True) + # 使用自定义分页器的响应 + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response({ + "code": 200, + "message": "获取平台账号列表成功", + "data": serializer.data + }) + + def retrieve(self, request, *args, **kwargs): + """获取平台账号详情""" + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response({ + "code": 200, + "message": "获取平台账号详情成功", + "data": serializer.data + }) + + def update(self, request, *args, **kwargs): + """更新平台账号信息""" + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + return Response({ + "code": 200, + "message": "更新平台账号信息成功", + "data": serializer.data + }) + + def partial_update(self, request, *args, **kwargs): + """部分更新平台账号信息""" + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def create(self, request, *args, **kwargs): + """创建平台账号并记录到知识库""" + with transaction.atomic(): + # 处理operator字段,可能是字符串类型的ID + data = request.data.copy() + if 'operator' in data and isinstance(data['operator'], str): + try: + # 尝试通过ID查找运营账号 + operator_id = data['operator'] + try: + # 先尝试通过整数ID查找 + operator_id_int = int(operator_id) + operator = OperatorAccount.objects.get(id=operator_id_int) + except (ValueError, OperatorAccount.DoesNotExist): + # 如果无法转换为整数或找不到对应账号,尝试通过用户名或真实姓名查找 + operator = OperatorAccount.objects.filter( + Q(username=operator_id) | Q(real_name=operator_id) + ).first() + + if not operator: + return Response({ + "code": 404, + "message": f"未找到运营账号: {operator_id},请提供有效的ID、用户名或真实姓名", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + + # 更新请求数据中的operator字段为找到的operator的ID + data['operator'] = operator.id + + except Exception as e: + return Response({ + "code": 400, + "message": f"处理运营账号ID时出错: {str(e)}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建平台账号 + serializer = self.get_serializer(data=data) + serializer.is_valid(raise_exception=True) + + # 手动创建平台账号,不使用serializer.save()避免ID问题 + platform_data = serializer.validated_data + platform_account = PlatformAccount.objects.create(**platform_data) + + # 获取关联的运营账号 + operator = platform_account.operator + + # 查找对应的知识库 + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base and knowledge_base.external_id: + # 创建平台账号文档 + document_data = { + "name": f"{platform_account.account_name}_{platform_account.platform_name}_账号信息", + "paragraphs": [ + { + "title": "平台账号基本信息", + "content": f""" + 平台: {platform_account.get_platform_name_display()} + 账号名称: {platform_account.account_name} + 账号ID: {platform_account.account_id} + 账号状态: {platform_account.get_status_display()} + 粉丝数: {platform_account.followers_count} + 账号链接: {platform_account.account_url} + 账号描述: {platform_account.description or '无'} + 标签: {platform_account.tags or '无'} + 头像链接: {platform_account.profile_image or '无'} + 最后发布时间: {platform_account.last_posting.strftime('%Y-%m-%d %H:%M:%S') if platform_account.last_posting else '未发布'} + 创建时间: {platform_account.created_at.strftime('%Y-%m-%d %H:%M:%S')} + 最后登录: {platform_account.last_login.strftime('%Y-%m-%d %H:%M:%S') if platform_account.last_login else '从未登录'} + """, + "is_active": True + } + ] + } + + # 调用外部API创建文档 + document_id = self._create_document(knowledge_base.external_id, document_data) + + if document_id: + # 创建知识库文档记录 + KnowledgeBaseDocument.objects.create( + knowledge_base=knowledge_base, + document_id=document_id, + document_name=document_data["name"], + external_id=document_id, + uploader_name=request.user.username + ) + + return Response({ + "code": 200, + "message": "平台账号创建成功,并已添加到知识库", + "data": self.get_serializer(platform_account).data + }, status=status.HTTP_201_CREATED) + + def destroy(self, request, *args, **kwargs): + """删除平台账号并更新相关知识库文档""" + platform_account = self.get_object() + + # 获取关联的运营账号 + operator = platform_account.operator + + # 查找对应的知识库 + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + # 查找相关文档并标记为删除 + documents = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base + ).filter( + Q(document_name__contains=platform_account.account_name) | + Q(document_name__contains=platform_account.platform_name) + ) + + for doc in documents: + doc.status = 'deleted' + doc.save() + + # 删除平台账号 + self.perform_destroy(platform_account) + + return Response({ + "code": 200, + "message": "平台账号已删除,相关知识库文档已标记为删除", + "data": None + }) + + def _create_document(self, external_id, doc_data): + """调用外部API创建文档""" + try: + if not external_id: + logger.error("创建文档失败:知识库external_id为空") + return None + + # 在实际应用中,这里需要调用外部API创建文档 + # 模拟创建文档并返回document_id + document_id = str(uuid.uuid4()) + logger.info(f"模拟创建文档成功,document_id: {document_id}") + return document_id + except Exception as e: + logger.error(f"创建文档失败: {str(e)}") + return None + + @action(detail=True, methods=['post']) + def update_followers(self, request, pk=None): + """更新平台账号粉丝数并同步到知识库""" + platform_account = self.get_object() + followers_count = request.data.get('followers_count') + + if not followers_count: + return Response({ + "code": 400, + "message": "粉丝数不能为空", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新粉丝数 + platform_account.followers_count = followers_count + platform_account.save() + + # 同步到知识库 + operator = platform_account.operator + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + # 查找相关文档 + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + status='active' + ).filter( + Q(document_name__contains=platform_account.account_name) | + Q(document_name__contains=platform_account.platform_name) + ).first() + + if document: + # 这里应该调用外部API更新文档内容 + # 但由于我们没有实际的API,只做记录 + logger.info(f"应当更新文档 {document.document_id} 的粉丝数为 {followers_count}") + + return Response({ + "code": 200, + "message": "粉丝数更新成功", + "data": { + "id": platform_account.id, + "account_name": platform_account.account_name, + "followers_count": platform_account.followers_count + } + }) + + @action(detail=True, methods=['post']) + def update_profile(self, request, pk=None): + """更新平台账号的头像、标签和最后发布时间""" + platform_account = self.get_object() + + # 获取更新的资料数据 + profile_data = {} + + # 处理标签 + if 'tags' in request.data: + profile_data['tags'] = request.data['tags'] + + # 处理头像 + if 'profile_image' in request.data: + profile_data['profile_image'] = request.data['profile_image'] + + # 处理最后发布时间 + if 'last_posting' in request.data: + try: + # 尝试解析时间字符串 + from dateutil import parser + last_posting = parser.parse(request.data['last_posting']) + profile_data['last_posting'] = last_posting + except Exception as e: + return Response({ + "code": 400, + "message": f"最后发布时间格式错误: {str(e)}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + if not profile_data: + return Response({ + "code": 400, + "message": "没有提供任何需更新的资料数据", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新平台账号资料 + for field, value in profile_data.items(): + setattr(platform_account, field, value) + platform_account.save() + + # 同步到知识库 + # 在实际应用中应该调用外部API更新文档内容 + operator = platform_account.operator + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + status='active' + ).filter( + Q(document_name__contains=platform_account.account_name) | + Q(document_name__contains=platform_account.platform_name) + ).first() + + if document: + logger.info(f"应当更新文档 {document.document_id} 的平台账号资料数据") + + return Response({ + "code": 200, + "message": "平台账号资料更新成功", + "data": self.get_serializer(platform_account).data + }) + + +class VideoViewSet(viewsets.ModelViewSet): + """视频管理视图集""" + queryset = Video.objects.all() + serializer_class = VideoSerializer + pagination_class = CustomPagination + + def list(self, request, *args, **kwargs): + """获取视频列表""" + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + + if page is not None: + serializer = self.get_serializer(page, many=True) + # 使用自定义分页器的响应 + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response({ + "code": 200, + "message": "获取视频列表成功", + "data": serializer.data + }) + + def retrieve(self, request, *args, **kwargs): + """获取视频详情""" + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response({ + "code": 200, + "message": "获取视频详情成功", + "data": serializer.data + }) + + def update(self, request, *args, **kwargs): + """更新视频信息""" + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + return Response({ + "code": 200, + "message": "更新视频信息成功", + "data": serializer.data + }) + + def partial_update(self, request, *args, **kwargs): + """部分更新视频信息""" + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + def create(self, request, *args, **kwargs): + """创建视频并记录到知识库""" + with transaction.atomic(): + # 处理platform_account字段,可能是字符串类型的ID + data = request.data.copy() + if 'platform_account' in data and isinstance(data['platform_account'], str): + try: + # 尝试通过ID查找平台账号 + platform_id = data['platform_account'] + try: + # 先尝试通过整数ID查找 + platform_id_int = int(platform_id) + platform = PlatformAccount.objects.get(id=platform_id_int) + except (ValueError, PlatformAccount.DoesNotExist): + # 如果无法转换为整数或找不到对应账号,尝试通过账号名称或账号ID查找 + platform = PlatformAccount.objects.filter( + Q(account_name=platform_id) | Q(account_id=platform_id) + ).first() + + if not platform: + return Response({ + "code": 404, + "message": f"未找到平台账号: {platform_id},请提供有效的ID、账号名称或账号ID", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + + # 更新请求数据中的platform_account字段为找到的platform的ID + data['platform_account'] = platform.id + + except Exception as e: + return Response({ + "code": 400, + "message": f"处理平台账号ID时出错: {str(e)}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建视频 + serializer = self.get_serializer(data=data) + serializer.is_valid(raise_exception=True) + + # 手动创建视频,不使用serializer.save()避免ID问题 + video_data = serializer.validated_data + video = Video.objects.create(**video_data) + + # 获取关联的平台账号和运营账号 + platform_account = video.platform_account + operator = platform_account.operator + + # 查找对应的知识库 + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base and knowledge_base.external_id: + # 创建视频文档 + document_data = { + "name": f"{video.title}_{platform_account.account_name}_视频信息", + "paragraphs": [ + { + "title": "视频基本信息", + "content": f""" + 标题: {video.title} + 平台: {platform_account.get_platform_name_display()} + 账号: {platform_account.account_name} + 视频ID: {video.video_id} + 发布时间: {video.publish_time.strftime('%Y-%m-%d %H:%M:%S') if video.publish_time else '未发布'} + 视频链接: {video.video_url} + 点赞数: {video.likes_count} + 评论数: {video.comments_count} + 分享数: {video.shares_count} + 观看数: {video.views_count} + 视频描述: {video.description or '无'} + """, + "is_active": True + } + ] + } + + # 调用外部API创建文档 + document_id = self._create_document(knowledge_base.external_id, document_data) + + if document_id: + # 创建知识库文档记录 + KnowledgeBaseDocument.objects.create( + knowledge_base=knowledge_base, + document_id=document_id, + document_name=document_data["name"], + external_id=document_id, + uploader_name=request.user.username + ) + + return Response({ + "code": 200, + "message": "视频创建成功,并已添加到知识库", + "data": self.get_serializer(video).data + }, status=status.HTTP_201_CREATED) + + def destroy(self, request, *args, **kwargs): + """删除视频记录并更新相关知识库文档""" + video = self.get_object() + + # 获取关联的平台账号和运营账号 + platform_account = video.platform_account + operator = platform_account.operator + + # 查找对应的知识库 + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + # 查找相关文档并标记为删除 + documents = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + document_name__contains=video.title + ) + + for doc in documents: + doc.status = 'deleted' + doc.save() + + # 删除视频记录 + self.perform_destroy(video) + + return Response({ + "code": 200, + "message": "视频记录已删除,相关知识库文档已标记为删除", + "data": None + }) + + def _create_document(self, external_id, doc_data): + """调用外部API创建文档""" + try: + if not external_id: + logger.error("创建文档失败:知识库external_id为空") + return None + + # 在实际应用中,这里需要调用外部API创建文档 + # 模拟创建文档并返回document_id + document_id = str(uuid.uuid4()) + logger.info(f"模拟创建文档成功,document_id: {document_id}") + return document_id + except Exception as e: + logger.error(f"创建文档失败: {str(e)}") + return None + + @action(detail=True, methods=['post']) + def update_stats(self, request, pk=None): + """更新视频统计数据并同步到知识库""" + video = self.get_object() + + # 获取更新的统计数据 + stats = {} + for field in ['views_count', 'likes_count', 'comments_count', 'shares_count']: + if field in request.data: + stats[field] = request.data[field] + + if not stats: + return Response({ + "code": 400, + "message": "没有提供任何统计数据", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新视频统计数据 + for field, value in stats.items(): + setattr(video, field, value) + video.save() + + # 同步到知识库 + # 在实际应用中应该调用外部API更新文档内容 + platform_account = video.platform_account + operator = platform_account.operator + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + document_name__contains=video.title, + status='active' + ).first() + + if document: + logger.info(f"应当更新文档 {document.document_id} 的视频统计数据") + + return Response({ + "code": 200, + "message": "视频统计数据更新成功", + "data": { + "id": video.id, + "title": video.title, + "views_count": video.views_count, + "likes_count": video.likes_count, + "comments_count": video.comments_count, + "shares_count": video.shares_count + } + }) + + @action(detail=True, methods=['post']) + def publish(self, request, pk=None): + """发布视频并更新状态""" + video = self.get_object() + + # 检查视频状态 + if video.status not in ['draft', 'scheduled']: + return Response({ + "code": 400, + "message": f"当前视频状态为 {video.get_status_display()},无法发布", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取视频URL + video_url = request.data.get('video_url') + if not video_url: + return Response({ + "code": 400, + "message": "未提供视频URL", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 更新视频状态和URL + video.video_url = video_url + video.status = 'published' + video.publish_time = timezone.now() + video.save() + + # 同步到知识库 + # 在实际应用中应该调用外部API更新文档内容 + platform_account = video.platform_account + operator = platform_account.operator + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + document_name__contains=video.title, + status='active' + ).first() + + if document: + logger.info(f"应当更新文档 {document.document_id} 的视频发布状态") + + return Response({ + "code": 200, + "message": "视频已成功发布", + "data": { + "id": video.id, + "title": video.title, + "status": video.status, + "video_url": video.video_url, + "publish_time": video.publish_time + } + }) + + @action(detail=False, methods=['post']) + def upload_video(self, request): + """上传视频文件并创建视频记录""" + try: + # 获取上传的视频文件 + video_file = request.FILES.get('video_file') + if not video_file: + return Response({ + "code": 400, + "message": "未提供视频文件", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 获取平台账号ID + platform_account_id = request.data.get('platform_account') + if not platform_account_id: + return Response({ + "code": 400, + "message": "未提供平台账号ID", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + try: + platform_account = PlatformAccount.objects.get(id=platform_account_id) + except PlatformAccount.DoesNotExist: + return Response({ + "code": 404, + "message": f"未找到ID为{platform_account_id}的平台账号", + "data": None + }, status=status.HTTP_404_NOT_FOUND) + + # 创建保存视频的目录 + import os + from django.conf import settings + + # 确保文件保存目录存在 + media_root = getattr(settings, 'MEDIA_ROOT', os.path.join(settings.BASE_DIR, 'media')) + videos_dir = os.path.join(media_root, 'videos') + account_dir = os.path.join(videos_dir, f"{platform_account.platform_name}_{platform_account.account_name}") + + if not os.path.exists(videos_dir): + os.makedirs(videos_dir) + if not os.path.exists(account_dir): + os.makedirs(account_dir) + + # 生成唯一的文件名 + import time + timestamp = int(time.time()) + file_name = f"{timestamp}_{video_file.name}" + file_path = os.path.join(account_dir, file_name) + + # 保存视频文件 + with open(file_path, 'wb+') as destination: + for chunk in video_file.chunks(): + destination.write(chunk) + + # 创建视频记录 + video_data = { + 'platform_account': platform_account, + 'title': request.data.get('title', os.path.splitext(video_file.name)[0]), + 'description': request.data.get('description', ''), + 'local_path': file_path, + 'status': 'draft', + 'tags': request.data.get('tags', '') + } + + # 如果提供了计划发布时间,则设置状态为已排期 + scheduled_time = request.data.get('scheduled_time') + if scheduled_time: + from dateutil import parser + try: + parsed_time = parser.parse(scheduled_time) + video_data['scheduled_time'] = parsed_time + video_data['status'] = 'scheduled' + except Exception as e: + return Response({ + "code": 400, + "message": f"计划发布时间格式错误: {str(e)}", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建视频记录 + video = Video.objects.create(**video_data) + + # 添加到知识库 + self._add_to_knowledge_base(video, platform_account) + + # 如果是已排期状态,创建定时任务 + if video.status == 'scheduled': + self._create_publish_task(video) + + return Response({ + "code": 200, + "message": "视频上传成功", + "data": { + "id": video.id, + "title": video.title, + "status": video.get_status_display(), + "scheduled_time": video.scheduled_time + } + }, status=status.HTTP_201_CREATED) + + except Exception as e: + logger.error(f"视频上传失败: {str(e)}") + return Response({ + "code": 500, + "message": f"视频上传失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + def _add_to_knowledge_base(self, video, platform_account): + """将视频添加到知识库""" + # 获取关联的运营账号 + operator = platform_account.operator + + # 查找对应的知识库 + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base and knowledge_base.external_id: + # 创建视频文档 + document_data = { + "name": f"{video.title}_{platform_account.account_name}_视频信息", + "paragraphs": [ + { + "title": "视频基本信息", + "content": f""" + 标题: {video.title} + 平台: {platform_account.get_platform_name_display()} + 账号: {platform_account.account_name} + 状态: {video.get_status_display()} + 本地路径: {video.local_path} + 计划发布时间: {video.scheduled_time.strftime('%Y-%m-%d %H:%M:%S') if video.scheduled_time else '未设置'} + 视频描述: {video.description or '无'} + 标签: {video.tags or '无'} + 创建时间: {video.created_at.strftime('%Y-%m-%d %H:%M:%S')} + """, + "is_active": True + } + ] + } + + # 调用外部API创建文档 + document_id = self._create_document(knowledge_base.external_id, document_data) + + if document_id: + # 创建知识库文档记录 + KnowledgeBaseDocument.objects.create( + knowledge_base=knowledge_base, + document_id=document_id, + document_name=document_data["name"], + external_id=document_id, + uploader_name="系统" + ) + + def _create_publish_task(self, video): + """创建定时发布任务""" + try: + from django_celery_beat.models import PeriodicTask, CrontabSchedule + import json + from datetime import datetime + + scheduled_time = video.scheduled_time + + # 创建定时任务 + schedule, _ = CrontabSchedule.objects.get_or_create( + minute=scheduled_time.minute, + hour=scheduled_time.hour, + day_of_month=scheduled_time.day, + month_of_year=scheduled_time.month, + ) + + # 创建周期性任务 + task_name = f"Publish_Video_{video.id}_{datetime.now().timestamp()}" + PeriodicTask.objects.create( + name=task_name, + task='user_management.tasks.publish_scheduled_video', + crontab=schedule, + args=json.dumps([video.id]), + one_off=True, # 只执行一次 + start_time=scheduled_time + ) + + logger.info(f"已创建视频 {video.id} 的定时发布任务,计划发布时间: {scheduled_time}") + + except Exception as e: + logger.error(f"创建定时发布任务失败: {str(e)}") + # 记录错误但不中断流程 + + @action(detail=True, methods=['post']) + def manual_publish(self, request, pk=None): + """手动发布视频""" + video = self.get_object() + + # 检查视频状态是否允许发布 + if video.status not in ['draft', 'scheduled']: + return Response({ + "code": 400, + "message": f"当前视频状态为 {video.get_status_display()},无法发布", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 检查视频文件是否存在 + if not video.local_path or not os.path.exists(video.local_path): + return Response({ + "code": 400, + "message": "视频文件不存在,无法发布", + "data": None + }, status=status.HTTP_400_BAD_REQUEST) + + # 自动发布 - 不依赖Celery任务 + try: + # 模拟上传到平台 + platform_account = video.platform_account + platform_name = platform_account.platform_name + + # 创建模拟的视频URL和ID + video_url = f"https://example.com/{platform_name}/{video.id}" + video_id = f"VID_{video.id}" + + # 更新视频状态 + video.status = 'published' + video.publish_time = timezone.now() + video.video_url = video_url + video.video_id = video_id + video.save() + + logger.info(f"视频 {video.id} 已手动发布") + + # 更新知识库文档 + platform_account = video.platform_account + operator = platform_account.operator + + knowledge_base = KnowledgeBase.objects.filter( + name__contains=operator.real_name, + type='private' + ).first() + + if knowledge_base: + document = KnowledgeBaseDocument.objects.filter( + knowledge_base=knowledge_base, + document_name__contains=video.title, + status='active' + ).first() + + if document: + logger.info(f"应当更新文档 {document.document_id} 的视频发布状态") + + return Response({ + "code": 200, + "message": "视频发布成功", + "data": { + "id": video.id, + "title": video.title, + "status": "published", + "video_url": video_url, + "publish_time": video.publish_time.strftime("%Y-%m-%d %H:%M:%S") + } + }) + + except Exception as e: + logger.error(f"手动发布视频失败: {str(e)}") + return Response({ + "code": 500, + "message": f"发布失败: {str(e)}", + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/apps/permissions/models.py b/apps/permissions/models.py index b210180..87e329c 100644 --- a/apps/permissions/models.py +++ b/apps/permissions/models.py @@ -6,7 +6,7 @@ import uuid import logging from apps.accounts.models import User from apps.knowledge_base.models import KnowledgeBase -from apps.message.models import Notification +from apps.notification.models import Notification logger = logging.getLogger(__name__) diff --git a/daren_project/asgi.py b/daren_project/asgi.py index c97fc16..3f7eb80 100644 --- a/daren_project/asgi.py +++ b/daren_project/asgi.py @@ -8,19 +8,27 @@ https://docs.djangoproject.com/en/5.2/howto/deployment/asgi/ """ import os +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'daren_project.settings') + + + from django.core.asgi import get_asgi_application from channels.routing import ProtocolTypeRouter, URLRouter from channels.auth import AuthMiddlewareStack -from apps.message.routing import websocket_urlpatterns # WebSocket 路由 + # WebSocket 路由 + +django_asgi_app = get_asgi_application() + +from apps.chat.routing import websocket_urlpatterns as chat_websocket_urlpatterns # WebSocket 路由 +from apps.notification.routing import websocket_urlpatterns as notification_websocket_urlpatterns # WebSocket 路由 -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'daren_project.settings') application = ProtocolTypeRouter({ - "http": get_asgi_application(), + "http": django_asgi_app, "websocket": AuthMiddlewareStack( URLRouter( - websocket_urlpatterns + chat_websocket_urlpatterns + notification_websocket_urlpatterns ) ), }) diff --git a/daren_project/settings.py b/daren_project/settings.py index 1cb8acd..f942e5b 100644 --- a/daren_project/settings.py +++ b/daren_project/settings.py @@ -46,10 +46,12 @@ INSTALLED_APPS = [ 'apps.knowledge_base', 'apps.chat', 'apps.permissions', - 'apps.message', + 'apps.notification', 'apps.gmail', 'apps.feishu', 'apps.common', + 'apps.brands', + 'apps.operation', ] MIDDLEWARE = [ @@ -143,14 +145,20 @@ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' # REST Framework 配置 REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES': [ - # 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.TokenAuthentication', + # 'rest_framework.authentication.SessionAuthentication', ], 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.IsAuthenticated', + 'rest_framework.permissions.AllowAny', + ], + 'DEFAULT_PARSER_CLASSES': [ + 'rest_framework.parsers.JSONParser', + 'rest_framework.parsers.FormParser', + 'rest_framework.parsers.MultiPartParser' ], } + # Channels 配置(WebSocket) ASGI_APPLICATION = 'daren_project.asgi.application' CHANNEL_LAYERS = { @@ -169,3 +177,19 @@ API_BASE_URL = 'http://81.69.223.133:48329' SILICON_CLOUD_API_KEY = 'sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf' GMAIL_WEBHOOK_URL = 'https://27b3-180-159-100-165.ngrok-free.app/api/user/gmail/webhook/' APPLICATION_ID = 'd5d11efa-ea9a-11ef-9933-0242ac120006' + + +# 全局代理设置 +# 格式为 'http://主机名:端口号',例如:'http://127.0.0.1:7890' +# 此代理将应用于所有HTTP/HTTPS请求和Gmail API请求 +# 如果代理不可用,请将此值设为None或注释掉此行 +PROXY_URL = 'http://127.0.0.1:7890' + + +# Gmail Pub/Sub相关设置 +GOOGLE_CLOUD_PROJECT_ID = 'your-project-id' # 替换为您的Google Cloud项目ID +GMAIL_PUBSUB_TOPIC = 'projects/{project_id}/topics/gmail-notifications' +GMAIL_PUBSUB_SUBSCRIPTION = 'projects/{project_id}/subscriptions/gmail-notifications-sub' + +# 设置允许使用Google Pub/Sub的应用列表 +INSTALLED_APPS += ['google.cloud.pubsub'] \ No newline at end of file diff --git a/daren_project/urls.py b/daren_project/urls.py index 2b05251..4f46ab1 100644 --- a/daren_project/urls.py +++ b/daren_project/urls.py @@ -23,7 +23,9 @@ urlpatterns = [ path('api/knowledge-bases/', include('apps.knowledge_base.urls')), path('api/chat-history/', include('apps.chat.urls')), path('api/permissions/', include('apps.permissions.urls')), - path('api/message/', include('apps.message.urls')), - # path('api/gmail/', include('apps.gmail.urls')), + path('api/notification/', include('apps.notification.urls')), + path('api/gmail/', include('apps.gmail.urls')), # path('api/feishu/', include('apps.feishu.urls')), + path('api/', include('apps.brands.urls')), + path('api/operation/', include('apps.operation.urls')), ] \ No newline at end of file diff --git a/daren_project/wsgi.py b/daren_project/wsgi.py index 05ffecc..954a56a 100644 --- a/daren_project/wsgi.py +++ b/daren_project/wsgi.py @@ -8,9 +8,26 @@ https://docs.djangoproject.com/en/5.2/howto/deployment/wsgi/ """ import os +import django -from django.core.wsgi import get_wsgi_application - +# 首先设置 Django 设置模块 os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'daren_project.settings') +django.setup() # 添加这行来初始化 Django -application = get_wsgi_application() +# 然后再导入其他模块 +from django.core.asgi import get_asgi_application +from channels.routing import ProtocolTypeRouter, URLRouter +from channels.auth import AuthMiddlewareStack +from channels.security.websocket import AllowedHostsOriginValidator +from apps.chat.routing import websocket_urlpatterns +from apps.common.middlewares import TokenAuthMiddleware + +# 使用TokenAuthMiddleware代替AuthMiddlewareStack +application = ProtocolTypeRouter({ + "http": get_asgi_application(), + "websocket": AllowedHostsOriginValidator( + TokenAuthMiddleware( + URLRouter(websocket_urlpatterns) + ) + ), +}) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ae8a16a..449bddb 100644 Binary files a/requirements.txt and b/requirements.txt differ