TrainingPlatform_Flask/train_api.py

52 lines
1.9 KiB
Python
Raw Normal View History

2025-06-03 16:14:17 +08:00
import os
from config import BUCKET_NAME, LOCAL_DATASET_DIR
from ultralytics import YOLO
from tools import validate_and_update_yaml_fields, convert_xml_to_txt
from minio_tools import download_minio_folder
from dataset_action import update_database_status
import ray
from yolov5.train import really_training
@ray.remote(max_calls=1, num_gpus=0.2)
def training(training_id, minio_folder_name, coco_path, pre_model_path, images_path, labels_path, epochs, img_size, batch_size):
## 1、首先判断一下该数据集是否已经下载
if os.path.exists(images_path) and os.path.isdir(images_path):
print(f"数据集已存在。")
else:
print(f"数据集不存在,开始下载。")
if not os.path.exists(LOCAL_DATASET_DIR):
os.makedirs(LOCAL_DATASET_DIR)
is_load = download_minio_folder(BUCKET_NAME, minio_folder_name, LOCAL_DATASET_DIR)
if not is_load:
update_database_status(training_id, "初始化异常")
return False
## 2、更新yaml文件将其中的路径替换为实际训练路径
train_path = os.path.join(images_path, "train")
val_path = os.path.join(images_path, "val")
is_val, categories = validate_and_update_yaml_fields(coco_path, train_path, val_path) # 从这里顺便获取类别
if not is_val:
update_database_status(training_id, "初始化异常")
return False
## 3、检查数据集和图片是否完整将xml文件转为txt文件
is_convert = convert_xml_to_txt(categories, labels_path)
if not is_convert:
update_database_status(training_id, "初始化异常")
return False
## 4、开始训练
update_database_status(training_id, "训练中")
result = really_training(training_id, pre_model_path, coco_path, epochs, img_size, batch_size)
update_database_status(training_id, "训练完成")
return result
if __name__ == "__main__":
pass