52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
import os
|
||
from config import BUCKET_NAME, LOCAL_DATASET_DIR
|
||
from ultralytics import YOLO
|
||
from tools import validate_and_update_yaml_fields, convert_xml_to_txt
|
||
from minio_tools import download_minio_folder
|
||
from dataset_action import update_database_status
|
||
import ray
|
||
from yolov5.train import really_training
|
||
|
||
|
||
|
||
@ray.remote(max_calls=1, num_gpus=0.2)
|
||
def training(training_id, minio_folder_name, coco_path, pre_model_path, images_path, labels_path, epochs, img_size, batch_size):
|
||
|
||
## 1、首先判断一下该数据集是否已经下载
|
||
if os.path.exists(images_path) and os.path.isdir(images_path):
|
||
print(f"数据集已存在。")
|
||
else:
|
||
print(f"数据集不存在,开始下载。")
|
||
if not os.path.exists(LOCAL_DATASET_DIR):
|
||
os.makedirs(LOCAL_DATASET_DIR)
|
||
is_load = download_minio_folder(BUCKET_NAME, minio_folder_name, LOCAL_DATASET_DIR)
|
||
if not is_load:
|
||
update_database_status(training_id, "初始化异常")
|
||
return False
|
||
|
||
## 2、更新yaml文件,将其中的路径替换为实际训练路径
|
||
train_path = os.path.join(images_path, "train")
|
||
val_path = os.path.join(images_path, "val")
|
||
is_val, categories = validate_and_update_yaml_fields(coco_path, train_path, val_path) # 从这里顺便获取类别
|
||
if not is_val:
|
||
update_database_status(training_id, "初始化异常")
|
||
return False
|
||
|
||
## 3、检查数据集和图片是否完整(将xml文件转为txt文件)
|
||
is_convert = convert_xml_to_txt(categories, labels_path)
|
||
if not is_convert:
|
||
update_database_status(training_id, "初始化异常")
|
||
return False
|
||
|
||
## 4、开始训练
|
||
update_database_status(training_id, "训练中")
|
||
result = really_training(training_id, pre_model_path, coco_path, epochs, img_size, batch_size)
|
||
update_database_status(training_id, "训练完成")
|
||
|
||
return result
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pass
|