TrainingPlatform_Flask/train_api.py
2025-06-03 16:14:17 +08:00

52 lines
1.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from config import BUCKET_NAME, LOCAL_DATASET_DIR
from ultralytics import YOLO
from tools import validate_and_update_yaml_fields, convert_xml_to_txt
from minio_tools import download_minio_folder
from dataset_action import update_database_status
import ray
from yolov5.train import really_training
@ray.remote(max_calls=1, num_gpus=0.2)
def training(training_id, minio_folder_name, coco_path, pre_model_path, images_path, labels_path, epochs, img_size, batch_size):
## 1、首先判断一下该数据集是否已经下载
if os.path.exists(images_path) and os.path.isdir(images_path):
print(f"数据集已存在。")
else:
print(f"数据集不存在,开始下载。")
if not os.path.exists(LOCAL_DATASET_DIR):
os.makedirs(LOCAL_DATASET_DIR)
is_load = download_minio_folder(BUCKET_NAME, minio_folder_name, LOCAL_DATASET_DIR)
if not is_load:
update_database_status(training_id, "初始化异常")
return False
## 2、更新yaml文件将其中的路径替换为实际训练路径
train_path = os.path.join(images_path, "train")
val_path = os.path.join(images_path, "val")
is_val, categories = validate_and_update_yaml_fields(coco_path, train_path, val_path) # 从这里顺便获取类别
if not is_val:
update_database_status(training_id, "初始化异常")
return False
## 3、检查数据集和图片是否完整将xml文件转为txt文件
is_convert = convert_xml_to_txt(categories, labels_path)
if not is_convert:
update_database_status(training_id, "初始化异常")
return False
## 4、开始训练
update_database_status(training_id, "训练中")
result = really_training(training_id, pre_model_path, coco_path, epochs, img_size, batch_size)
update_database_status(training_id, "训练完成")
return result
if __name__ == "__main__":
pass