194 lines
7.7 KiB
Python
194 lines
7.7 KiB
Python
|
from django.shortcuts import render
|
||
|
from django.views.decorators.csrf import csrf_exempt
|
||
|
from rest_framework.decorators import api_view
|
||
|
from django.http import JsonResponse
|
||
|
from rest_framework.response import Response
|
||
|
from .models import TrainingModel, TrainingEpochModel
|
||
|
import requests
|
||
|
from .config import FLASK_API
|
||
|
# Create your views here.
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
def create_training_task(request):
|
||
|
if request.method == 'POST':
|
||
|
user = request.POST.get('user')
|
||
|
project_name = request.POST.get('projectName')
|
||
|
pre_model_name = request.POST.get('preModelName')
|
||
|
dataset_name = request.POST.get('datasetName')
|
||
|
task_type = request.POST.get('taskType')
|
||
|
epoch = request.POST.get('epochs')
|
||
|
batch_size = request.POST.get('batchSize')
|
||
|
image_size = request.POST.get('imageSize')
|
||
|
id = -1
|
||
|
|
||
|
|
||
|
training_task = TrainingModel(
|
||
|
user=user,
|
||
|
project_name=project_name,
|
||
|
pre_model_name=pre_model_name,
|
||
|
dataset_name=dataset_name,
|
||
|
task_type=task_type,
|
||
|
epoch=epoch,
|
||
|
batch_size=batch_size,
|
||
|
image_size=image_size,
|
||
|
status="初始化",
|
||
|
model_size="0MB"
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
training_task.save()
|
||
|
id = training_task.id
|
||
|
success = True
|
||
|
except Exception as e:
|
||
|
success = False
|
||
|
error_message = "训练任务提交失败"
|
||
|
|
||
|
# 调用Flask创建训练任务
|
||
|
####### 记得做一个连通性测试,防止卡死
|
||
|
response = requests.post(f'{FLASK_API}/train_detect', json={
|
||
|
'id': id,
|
||
|
'user': user,
|
||
|
'projectName': project_name,
|
||
|
'preModelName': pre_model_name,
|
||
|
'datasetName': dataset_name,
|
||
|
'taskType': task_type,
|
||
|
'epoch': epoch,
|
||
|
'batchSize': batch_size,
|
||
|
'imageSize': image_size
|
||
|
})
|
||
|
|
||
|
|
||
|
if success:
|
||
|
return JsonResponse({'success': 'True', 'message': '训练任务提交成功'})
|
||
|
else:
|
||
|
return JsonResponse({'success': 'False', 'error': error_message}, status=500)
|
||
|
else:
|
||
|
return JsonResponse({'success': 'False', 'error': 'Invalid request method'}, status=405)
|
||
|
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
def get_project_training_tasks(request):
|
||
|
if request.method == 'GET':
|
||
|
user = request.GET.get('user')
|
||
|
project_name = request.GET.get('projectName')
|
||
|
tasks = TrainingModel.objects.filter(user=user, project_name=project_name)
|
||
|
|
||
|
return JsonResponse({'success': 'True', 'tasks': list(tasks.values())})
|
||
|
else:
|
||
|
return JsonResponse({'success': 'False', 'error': 'Invalid request method'}, status=405)
|
||
|
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
def delete_project_training_task(request):
|
||
|
if request.method == 'GET':
|
||
|
id = request.GET.get('id')
|
||
|
user = request.GET.get('user')
|
||
|
project_name = request.GET.get('projectName')
|
||
|
|
||
|
# 调用Flask停止当前的ray进程
|
||
|
####### 记得做一个连通性测试,防止卡死(也就是前台点击了没有反应)
|
||
|
response = requests.get(f'{FLASK_API}/stop_training_ray_task', json={
|
||
|
'training_id': id
|
||
|
})
|
||
|
|
||
|
# 删除与训练任务相关的所有训练记录
|
||
|
TrainingEpochModel.objects.filter(training_model__id=id).delete()
|
||
|
|
||
|
task_deleted, _ = TrainingModel.objects.filter(id=id, user=user, project_name=project_name).delete()
|
||
|
|
||
|
if task_deleted > 0:
|
||
|
return JsonResponse({'success': True, 'message': '删除成功'})
|
||
|
else:
|
||
|
return JsonResponse({'success': False, 'message': '删除失败'})
|
||
|
else:
|
||
|
return JsonResponse({'success': False, 'message': '请求异常'})
|
||
|
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
def update_training_status(request):
|
||
|
if request.method == 'GET':
|
||
|
training_id = request.GET.get('id')
|
||
|
# 假设您还想更新 status 字段
|
||
|
new_status = request.GET.get('status') # 从请求中获取新的状态
|
||
|
|
||
|
# 更新 TrainingModel 中的记录
|
||
|
updated_count = TrainingModel.objects.filter(id=training_id).update(status=new_status)
|
||
|
|
||
|
if updated_count > 0:
|
||
|
return JsonResponse({'success': True, 'message': '更新成功'})
|
||
|
else:
|
||
|
return JsonResponse({'success': False, 'message': '未找到记录或更新失败'})
|
||
|
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
@api_view(['POST'])
|
||
|
def insert_training_epoch_status(request):
|
||
|
id = request.data.get('training_id') # 获取 TrainingModel 的 ID
|
||
|
epoch_number = request.data.get('epoch_number')
|
||
|
map_50 = request.data.get('map_50')
|
||
|
map_95 = request.data.get('map_95')
|
||
|
precision = request.data.get('precision')
|
||
|
recall = request.data.get('recall')
|
||
|
|
||
|
# 查找对应的 TrainingModel 实例
|
||
|
try:
|
||
|
training_model = TrainingModel.objects.get(id=id)
|
||
|
|
||
|
# 创建新的 TrainingEpochModel 实例
|
||
|
new_epoch = TrainingEpochModel(
|
||
|
training_model=training_model, # 关联到 TrainingModel
|
||
|
epoch_number=epoch_number,
|
||
|
mAP50=map_50,
|
||
|
mAP95=map_95,
|
||
|
precision=precision,
|
||
|
recall=recall
|
||
|
)
|
||
|
|
||
|
# 保存到数据库
|
||
|
new_epoch.save()
|
||
|
return JsonResponse({'success': True, 'message': '训练轮次数据插入成功'})
|
||
|
|
||
|
except TrainingModel.DoesNotExist:
|
||
|
return JsonResponse({'success': False, 'message': '指定的 TrainingModel 不存在'}, status=404)
|
||
|
except Exception as e:
|
||
|
return JsonResponse({'success': False, 'message': str(e)}, status=500)
|
||
|
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
@api_view(['GET'])
|
||
|
def get_training_epoch_data(request):
|
||
|
training_id = request.GET.get('training_id') # 获取请求中的 training_id
|
||
|
|
||
|
if not training_id:
|
||
|
return JsonResponse({'success': False, 'message': '缺少 training_id'}, status=400)
|
||
|
|
||
|
try:
|
||
|
# 查找所有与 training_id 相关的 TrainingEpochModel 实例,并按 epoch_number 排序
|
||
|
epochs = TrainingEpochModel.objects.filter(training_model__id=training_id).order_by('epoch_number')
|
||
|
|
||
|
# 将查询结果转换为字典列表
|
||
|
epoch_data = list(epochs.values('id', 'epoch_number', 'mAP50', 'mAP95', 'precision', 'recall', 'create_time'))
|
||
|
print(epoch_data)
|
||
|
|
||
|
return JsonResponse({'success': True, 'epochs': epoch_data})
|
||
|
except Exception as e:
|
||
|
return JsonResponse({'success': False, 'message': str(e)}, status=500)
|
||
|
|
||
|
@csrf_exempt # 仅在开发时使用,生产环境中请使用更安全的方式
|
||
|
@api_view(['GET'])
|
||
|
def get_curr_epoch_data(request):
|
||
|
training_id = request.GET.get('training_id')
|
||
|
print(training_id)
|
||
|
|
||
|
# 获取当前 training_id 最新的 epoch 数据
|
||
|
latest_epoch_data = TrainingEpochModel.get_latest_epoch(training_id)
|
||
|
|
||
|
if latest_epoch_data:
|
||
|
# 将数据序列化为字典形式
|
||
|
data = {
|
||
|
'id': latest_epoch_data.id,
|
||
|
'epoch_number': latest_epoch_data.epoch_number,
|
||
|
'mAP50': latest_epoch_data.mAP50,
|
||
|
'mAP95': latest_epoch_data.mAP95,
|
||
|
'precision': latest_epoch_data.precision,
|
||
|
'recall': latest_epoch_data.recall,
|
||
|
'create_time': latest_epoch_data.create_time,
|
||
|
}
|
||
|
return Response({'success': True, 'data': data})
|
||
|
else:
|
||
|
return Response({'success': False, 'error': 'No data found for the given training_id.'}, status=404)
|
||
|
|