122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
# apps/chat/models.py
|
||
from django.db import models
|
||
from django.utils import timezone
|
||
import uuid
|
||
from itertools import count
|
||
from apps.accounts.models import User
|
||
from apps.knowledge_base.models import KnowledgeBase
|
||
|
||
class ChatHistory(models.Model):
|
||
"""聊天历史记录"""
|
||
ROLE_CHOICES = [
|
||
('user', '用户'),
|
||
('assistant', 'AI助手'),
|
||
('system', '系统')
|
||
]
|
||
|
||
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
||
# 保留与主知识库的关联
|
||
knowledge_base = models.ForeignKey(KnowledgeBase, on_delete=models.CASCADE)
|
||
# 用于标识知识库组合的对话
|
||
conversation_id = models.CharField(max_length=100, db_index=True)
|
||
# 对话标题
|
||
title = models.CharField(max_length=100, null=True, blank=True, default='New chat', help_text="对话标题")
|
||
parent_id = models.CharField(max_length=100, null=True, blank=True)
|
||
role = models.CharField(max_length=20, choices=ROLE_CHOICES)
|
||
content = models.TextField()
|
||
tokens = models.IntegerField(default=0, help_text="消息token数")
|
||
# 扩展metadata字段,用于存储知识库组合信息
|
||
metadata = models.JSONField(default=dict, blank=True, help_text="""
|
||
{
|
||
'model_id': 'xxx',
|
||
'dataset_id_list': ['id1', 'id2', ...],
|
||
'dataset_external_id_list': ['ext1', 'ext2', ...],
|
||
'primary_knowledge_base': 'id1'
|
||
}
|
||
""")
|
||
created_at = models.DateTimeField(auto_now_add=True)
|
||
is_deleted = models.BooleanField(default=False)
|
||
|
||
class Meta:
|
||
ordering = ['created_at']
|
||
indexes = [
|
||
models.Index(fields=['conversation_id', 'created_at']),
|
||
models.Index(fields=['user', 'created_at']),
|
||
# 添加新的索引以支持知识库组合查询
|
||
models.Index(fields=['conversation_id', 'is_deleted']),
|
||
]
|
||
|
||
def __str__(self):
|
||
return f"{self.user.username} - {self.knowledge_base.name} - {self.created_at}"
|
||
|
||
@classmethod
|
||
def get_conversation(cls, conversation_id):
|
||
"""获取完整对话历史"""
|
||
return cls.objects.filter(
|
||
conversation_id=conversation_id,
|
||
is_deleted=False
|
||
).order_by('created_at')
|
||
|
||
@classmethod
|
||
def get_conversations_by_knowledge_bases(cls, dataset_ids, user):
|
||
"""根据知识库组合获取对话历史"""
|
||
# 对知识库ID列表排序以确保一致性
|
||
sorted_kb_ids = sorted(dataset_ids)
|
||
conversation_id = str(uuid.uuid5(
|
||
uuid.NAMESPACE_DNS,
|
||
'-'.join(sorted_kb_ids)
|
||
))
|
||
|
||
return cls.objects.filter(
|
||
conversation_id=conversation_id,
|
||
user=user,
|
||
is_deleted=False
|
||
).order_by('created_at')
|
||
|
||
@classmethod
|
||
def get_knowledge_base_combinations(cls, user):
|
||
"""获取用户的所有知识库组合"""
|
||
return cls.objects.filter(
|
||
user=user,
|
||
is_deleted=False
|
||
).values('conversation_id').annotate(
|
||
last_message=max('created_at'),
|
||
message_count=count('id')
|
||
).values(
|
||
'conversation_id',
|
||
'last_message',
|
||
'message_count',
|
||
'metadata'
|
||
).order_by('-last_message')
|
||
|
||
def get_knowledge_bases(self):
|
||
"""获取此消息关联的所有知识库"""
|
||
if self.metadata and 'dataset_id_list' in self.metadata:
|
||
return KnowledgeBase.objects.filter(
|
||
id__in=self.metadata['dataset_id_list']
|
||
)
|
||
return KnowledgeBase.objects.filter(id=self.knowledge_base.id)
|
||
|
||
def soft_delete(self):
|
||
"""软删除消息"""
|
||
self.is_deleted = True
|
||
self.save()
|
||
|
||
def to_dict(self):
|
||
"""转换为字典格式"""
|
||
return {
|
||
'id': str(self.id),
|
||
'conversation_id': self.conversation_id,
|
||
'role': self.role,
|
||
'content': self.content,
|
||
'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'metadata': self.metadata,
|
||
'knowledge_bases': [
|
||
{
|
||
'id': str(kb.id),
|
||
'name': kb.name,
|
||
'type': kb.type
|
||
} for kb in self.get_knowledge_bases()
|
||
]
|
||
}
|