TrainingPlatform_Django/training/views.py

194 lines
7.7 KiB
Python
Raw Permalink Normal View History

2025-06-03 16:10:11 +08:00
from django.shortcuts import render
from django.views.decorators.csrf import csrf_exempt
from rest_framework.decorators import api_view
from django.http import JsonResponse
from rest_framework.response import Response
from .models import TrainingModel, TrainingEpochModel
import requests
from .config import FLASK_API
# Create your views here.
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
def create_training_task(request):
if request.method == 'POST':
user = request.POST.get('user')
project_name = request.POST.get('projectName')
pre_model_name = request.POST.get('preModelName')
dataset_name = request.POST.get('datasetName')
task_type = request.POST.get('taskType')
epoch = request.POST.get('epochs')
batch_size = request.POST.get('batchSize')
image_size = request.POST.get('imageSize')
id = -1
training_task = TrainingModel(
user=user,
project_name=project_name,
pre_model_name=pre_model_name,
dataset_name=dataset_name,
task_type=task_type,
epoch=epoch,
batch_size=batch_size,
image_size=image_size,
status="初始化",
model_size="0MB"
)
try:
training_task.save()
id = training_task.id
success = True
except Exception as e:
success = False
error_message = "训练任务提交失败"
# 调用Flask创建训练任务
####### 记得做一个连通性测试,防止卡死
response = requests.post(f'{FLASK_API}/train_detect', json={
'id': id,
'user': user,
'projectName': project_name,
'preModelName': pre_model_name,
'datasetName': dataset_name,
'taskType': task_type,
'epoch': epoch,
'batchSize': batch_size,
'imageSize': image_size
})
if success:
return JsonResponse({'success': 'True', 'message': '训练任务提交成功'})
else:
return JsonResponse({'success': 'False', 'error': error_message}, status=500)
else:
return JsonResponse({'success': 'False', 'error': 'Invalid request method'}, status=405)
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
def get_project_training_tasks(request):
if request.method == 'GET':
user = request.GET.get('user')
project_name = request.GET.get('projectName')
tasks = TrainingModel.objects.filter(user=user, project_name=project_name)
return JsonResponse({'success': 'True', 'tasks': list(tasks.values())})
else:
return JsonResponse({'success': 'False', 'error': 'Invalid request method'}, status=405)
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
def delete_project_training_task(request):
if request.method == 'GET':
id = request.GET.get('id')
user = request.GET.get('user')
project_name = request.GET.get('projectName')
# 调用Flask停止当前的ray进程
####### 记得做一个连通性测试,防止卡死(也就是前台点击了没有反应)
response = requests.get(f'{FLASK_API}/stop_training_ray_task', json={
'training_id': id
})
# 删除与训练任务相关的所有训练记录
TrainingEpochModel.objects.filter(training_model__id=id).delete()
task_deleted, _ = TrainingModel.objects.filter(id=id, user=user, project_name=project_name).delete()
if task_deleted > 0:
return JsonResponse({'success': True, 'message': '删除成功'})
else:
return JsonResponse({'success': False, 'message': '删除失败'})
else:
return JsonResponse({'success': False, 'message': '请求异常'})
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
def update_training_status(request):
if request.method == 'GET':
training_id = request.GET.get('id')
# 假设您还想更新 status 字段
new_status = request.GET.get('status') # 从请求中获取新的状态
# 更新 TrainingModel 中的记录
updated_count = TrainingModel.objects.filter(id=training_id).update(status=new_status)
if updated_count > 0:
return JsonResponse({'success': True, 'message': '更新成功'})
else:
return JsonResponse({'success': False, 'message': '未找到记录或更新失败'})
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
@api_view(['POST'])
def insert_training_epoch_status(request):
id = request.data.get('training_id') # 获取 TrainingModel 的 ID
epoch_number = request.data.get('epoch_number')
map_50 = request.data.get('map_50')
map_95 = request.data.get('map_95')
precision = request.data.get('precision')
recall = request.data.get('recall')
# 查找对应的 TrainingModel 实例
try:
training_model = TrainingModel.objects.get(id=id)
# 创建新的 TrainingEpochModel 实例
new_epoch = TrainingEpochModel(
training_model=training_model, # 关联到 TrainingModel
epoch_number=epoch_number,
mAP50=map_50,
mAP95=map_95,
precision=precision,
recall=recall
)
# 保存到数据库
new_epoch.save()
return JsonResponse({'success': True, 'message': '训练轮次数据插入成功'})
except TrainingModel.DoesNotExist:
return JsonResponse({'success': False, 'message': '指定的 TrainingModel 不存在'}, status=404)
except Exception as e:
return JsonResponse({'success': False, 'message': str(e)}, status=500)
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
@api_view(['GET'])
def get_training_epoch_data(request):
training_id = request.GET.get('training_id') # 获取请求中的 training_id
if not training_id:
return JsonResponse({'success': False, 'message': '缺少 training_id'}, status=400)
try:
# 查找所有与 training_id 相关的 TrainingEpochModel 实例,并按 epoch_number 排序
epochs = TrainingEpochModel.objects.filter(training_model__id=training_id).order_by('epoch_number')
# 将查询结果转换为字典列表
epoch_data = list(epochs.values('id', 'epoch_number', 'mAP50', 'mAP95', 'precision', 'recall', 'create_time'))
print(epoch_data)
return JsonResponse({'success': True, 'epochs': epoch_data})
except Exception as e:
return JsonResponse({'success': False, 'message': str(e)}, status=500)
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
@api_view(['GET'])
def get_curr_epoch_data(request):
training_id = request.GET.get('training_id')
print(training_id)
# 获取当前 training_id 最新的 epoch 数据
latest_epoch_data = TrainingEpochModel.get_latest_epoch(training_id)
if latest_epoch_data:
# 将数据序列化为字典形式
data = {
'id': latest_epoch_data.id,
'epoch_number': latest_epoch_data.epoch_number,
'mAP50': latest_epoch_data.mAP50,
'mAP95': latest_epoch_data.mAP95,
'precision': latest_epoch_data.precision,
'recall': latest_epoch_data.recall,
'create_time': latest_epoch_data.create_time,
}
return Response({'success': True, 'data': data})
else:
return Response({'success': False, 'error': 'No data found for the given training_id.'}, status=404)