88 lines
4.0 KiB
Python
88 lines
4.0 KiB
Python
from flask import Flask, request, jsonify, send_from_directory, abort
|
|
from flask_cors import CORS
|
|
import ray
|
|
import os
|
|
import zipfile
|
|
import shutil
|
|
from config import BUCKET_NAME, LOCAL_DATASET_DIR, LOCAL_PRE_MODEL_DIR
|
|
from minio_tools import delete_temp_dir, download_minio_folder
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from train_api import training
|
|
from tools import validate_and_update_yaml_fields
|
|
from ray_manage import RayTaskManage
|
|
ray.init(ignore_reinit_error=True) # 初始化Ray
|
|
|
|
app = Flask(__name__)
|
|
CORS(app, resources={r"/*": {"origins": "*", "supports_credentials": True}}) # 允许所有来源的请求并支持凭证
|
|
|
|
executor = ThreadPoolExecutor(max_workers=10)
|
|
ray_manager = RayTaskManage()
|
|
|
|
# 设置图片文件夹的路径
|
|
IMAGE_FOLDER = '/Users/liuzizhen/Projects/AIData/' # 替换为您的图片文件夹路径
|
|
app.config['UPLOAD_FOLDER'] = IMAGE_FOLDER
|
|
|
|
@app.route('/train_detect', methods=['POST'])
|
|
def train_model():
|
|
## 1、从请求中获取参数
|
|
data = request.get_json()
|
|
training_id = data.get('id')
|
|
minio_folder_name = data.get('user') + '/' + data.get('datasetName') + '/'
|
|
task_type = data.get('taskType')
|
|
epochs = int(data.get('epoch'))
|
|
batch_size = int(data.get('batchSize'))
|
|
img_size = int(data.get('imageSize'))
|
|
images_path = ''
|
|
labels_path = ''
|
|
coco_path = ''
|
|
pre_model_path = os.path.join(LOCAL_PRE_MODEL_DIR, data.get('preModelName') + '.pt')
|
|
|
|
|
|
## 2、根据task_type来获取数据路径
|
|
if task_type == 'Detect':
|
|
images_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Detection', 'images')
|
|
labels_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Detection', 'labels')
|
|
coco_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Detection', 'coco8.yaml')
|
|
elif task_type == 'Segment':
|
|
images_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Segment', 'images')
|
|
labels_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Segment', 'labels')
|
|
coco_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Segment', 'coco8.yaml')
|
|
elif task_type == 'Classification':
|
|
images_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Classification', 'images')
|
|
labels_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Classification', 'lables')
|
|
coco_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Classification', 'coco8.yaml')
|
|
else:
|
|
return jsonify({'status': 'error', 'message': '任务类型错误'})
|
|
|
|
|
|
## 3、交给ray去进行训练
|
|
future = training.remote(training_id, minio_folder_name, coco_path, pre_model_path, images_path, labels_path, epochs, img_size, batch_size)
|
|
ray_manager.submit_training_task(training_id, future)
|
|
ray_manager.tasks[training_id] = future
|
|
return jsonify({'status': 'training started'})
|
|
|
|
@app.route('/delete_temp_dir', methods=['POST', 'OPTIONS'])
|
|
def delete_dataset_dir():
|
|
if request.method == 'OPTIONS':
|
|
return jsonify({'status': 'ok'}), 200
|
|
|
|
data = request.get_json()
|
|
user = data.get('params').get('user')
|
|
dataset_name = data.get('params').get('datasetName')
|
|
is_deleted = delete_temp_dir(os.path.join(LOCAL_DATASET_DIR, user, dataset_name))
|
|
if is_deleted:
|
|
return jsonify({'success': is_deleted, 'message': '删除成功'}), 200
|
|
else:
|
|
return jsonify({'success': is_deleted, 'message': '删除失败'}), 400
|
|
|
|
@app.route('/stop_training_ray_task', methods=['GET'])
|
|
def stop_training_ray_task():
|
|
data = request.get_json() # 获取传来的数据
|
|
training_id = int(data.get('training_id'))
|
|
ray_manager.cancel_task(training_id)
|
|
return jsonify({'success': True, 'message': f"任务{training_id}已停止训练"})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host='0.0.0.0', port=5000) # 监听所有可用的IP地址
|