from flask import Flask, request, jsonify, send_from_directory, abort from flask_cors import CORS import ray import os import zipfile import shutil from config import BUCKET_NAME, LOCAL_DATASET_DIR, LOCAL_PRE_MODEL_DIR from minio_tools import delete_temp_dir, download_minio_folder from concurrent.futures import ThreadPoolExecutor from train_api import training from tools import validate_and_update_yaml_fields from ray_manage import RayTaskManage ray.init(ignore_reinit_error=True) # 初始化Ray app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*", "supports_credentials": True}}) # 允许所有来源的请求并支持凭证 executor = ThreadPoolExecutor(max_workers=10) ray_manager = RayTaskManage() # 设置图片文件夹的路径 IMAGE_FOLDER = '/Users/liuzizhen/Projects/AIData/' # 替换为您的图片文件夹路径 app.config['UPLOAD_FOLDER'] = IMAGE_FOLDER @app.route('/train_detect', methods=['POST']) def train_model(): ## 1、从请求中获取参数 data = request.get_json() training_id = data.get('id') minio_folder_name = data.get('user') + '/' + data.get('datasetName') + '/' task_type = data.get('taskType') epochs = int(data.get('epoch')) batch_size = int(data.get('batchSize')) img_size = int(data.get('imageSize')) images_path = '' labels_path = '' coco_path = '' pre_model_path = os.path.join(LOCAL_PRE_MODEL_DIR, data.get('preModelName') + '.pt') ## 2、根据task_type来获取数据路径 if task_type == 'Detect': images_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Detection', 'images') labels_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Detection', 'labels') coco_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Detection', 'coco8.yaml') elif task_type == 'Segment': images_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Segment', 'images') labels_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Segment', 'labels') coco_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Segment', 'coco8.yaml') elif task_type == 'Classification': images_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Classification', 'images') labels_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Classification', 'lables') coco_path = os.path.join(LOCAL_DATASET_DIR, data.get('user'), data.get('datasetName'), 'Classification', 'coco8.yaml') else: return jsonify({'status': 'error', 'message': '任务类型错误'}) ## 3、交给ray去进行训练 future = training.remote(training_id, minio_folder_name, coco_path, pre_model_path, images_path, labels_path, epochs, img_size, batch_size) ray_manager.submit_training_task(training_id, future) ray_manager.tasks[training_id] = future return jsonify({'status': 'training started'}) @app.route('/delete_temp_dir', methods=['POST', 'OPTIONS']) def delete_dataset_dir(): if request.method == 'OPTIONS': return jsonify({'status': 'ok'}), 200 data = request.get_json() user = data.get('params').get('user') dataset_name = data.get('params').get('datasetName') is_deleted = delete_temp_dir(os.path.join(LOCAL_DATASET_DIR, user, dataset_name)) if is_deleted: return jsonify({'success': is_deleted, 'message': '删除成功'}), 200 else: return jsonify({'success': is_deleted, 'message': '删除失败'}), 400 @app.route('/stop_training_ray_task', methods=['GET']) def stop_training_ray_task(): data = request.get_json() # 获取传来的数据 training_id = int(data.get('training_id')) ray_manager.cancel_task(training_id) return jsonify({'success': True, 'message': f"任务{training_id}已停止训练"}) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000) # 监听所有可用的IP地址