# server.py import hashlib import os import pickle import secrets from datetime import datetime import json import logging import torch from flask import Flask, jsonify, request, render_template from flask_socketio import SocketIO, emit, join_room, leave_room from config import get_default_config from mandatory_model_crypto import MandatoryModelEncryptor from model_upload_manager import get_upload_manager from task_manager import task_manager # 导入任务管理器 from global_data import gd from log import logger import time import traceback from mandatory_model_crypto import ModelEncryptionService, validate_models_before_task, verify_single_model_api # 配置日志 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler("public_server.log", encoding="utf-8") ] ) pub_logger = logging.getLogger("PublicServer") # Flask初始化 app = Flask(__name__, static_url_path='/static') CORS(app) # 处理Windows平台兼容性问题 import platform if platform.system() == 'Windows': socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading', allow_unsafe_werkzeug=True, # max_http_buffer_size=5 * 1024 * 1024) else: socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading', max_http_buffer_size=5 * 1024 * 1024) _initialized = False @app.before_request def initialize_once(): global _initialized if not _initialized: with app.app_context(): gd.set_value('task_manager', task_manager) logger.info("任务管理器初始化完成") _initialized = True # ======================= Flask路由 ======================= @app.route('/') def task_management(): """任务管理页面""" return render_template("task_management.html") @app.route('/video_player') def video_player(): """视频播放页面""" return render_template("flv2.html") # server.py - 修改 create_task 函数 @app.route('/api/tasks/create', methods=['POST']) def create_task(): """创建新任务 - 强制模型加密和密钥验证""" try: config = get_default_config() config['socketIO'] = socketio if not request.json: return jsonify({"status": "error", "message": "请求体不能为空"}), 400 data = request.json import json logger.info(f"请求参数: {json.dumps(data, indent=2, ensure_ascii=False)}") logger.info(f"收到创建任务请求: {data.get('taskname', '未命名')}") # 验证必须的参数 if 'rtmp_url' not in data: return jsonify({"status": "error", "message": "必须提供rtmp_url"}), 400 if 'models' not in data or not isinstance(data['models'], list): return jsonify({"status": "error", "message": "必须提供models列表"}), 400 # 检查模型数量 if len(data['models']) == 0: return jsonify({"status": "error", "message": "models列表不能为空"}), 400 # ================= 关键修改:创建任务前的模型验证 ================= logger.info("开始创建任务前的模型验证...") # 1. 检查所有模型是否都有加密密钥 for i, model_data in enumerate(data['models']): if 'encryption_key' not in model_data: return jsonify({ "status": "error", "message": f"模型 {i} 必须提供encryption_key" }), 400 # 2. 验证所有模型的密钥 task_config = { 'models': data['models'] } validation_result = validate_models_before_task(task_config) if not validation_result['success']: logger.error(f"模型验证失败: {validation_result.get('error', '未知错误')}") # 提供详细的验证结果 error_details = [] for result in validation_result.get('validation_results', []): if not result.get('key_valid', False): error_details.append(f"模型 {result['model_index']}: {result.get('error', '验证失败')}") error_message = validation_result.get('error', '模型验证失败') if error_details: error_message += f" | 详情: {', '.join(error_details)}" return jsonify({ "status": "error", "message": error_message, "data": validation_result }), 400 logger.info(f"模型验证通过: {validation_result['valid_models']}/{validation_result['total_models']} 个模型有效") # ================= 验证结束 ================= # 更新配置 config['rtmp']['url'] = data['rtmp_url'] if 'uavType' in data and data['uavType']: config['task']['uavType'] = data['uavType'] if 'algoInstancesName' in data and data['algoInstancesName']: config['task']['algoInstancesName'] = data['algoInstancesName'] if 'push_url' in data and data['push_url']: config['push']['url'] = data['push_url'] if 'taskname' in data: config['task']['taskname'] = data['taskname'] else: config['task']['taskname'] = f"task_{int(time.time())}" if 'AlgoId' in data: config['task']['aiid'] = data['AlgoId'] # 处理多模型配置 - 使用已验证的模型 config['models'] = [] for i, model_data in enumerate(data['models']): # 此时密钥已验证通过 encryption_key = model_data['encryption_key'] # 使用加密模型文件名(相对路径) model_path = model_data.get('path', f'model_{i}.enc') if not model_path.startswith('encrypted_models/'): model_filename = os.path.basename(model_path) if not model_filename.endswith('.enc'): model_filename += '.enc' model_path = f"{model_filename}" # 构建模型配置 model_config = { 'path': model_path, 'encryption_key': encryption_key, 'encrypted': True, 'tags': model_data.get('tags', {}), 'conf_thres': float(model_data.get('conf_thres', 0.25)), 'iou_thres': float(model_data.get('iou_thres', 0.45)), 'imgsz': max(128, min(1920, int(model_data.get('imgsz', 640)))), 'color': model_data.get('color'), 'line_width': int(model_data.get('line_width', 1)), 'device': model_data.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu'), 'half': model_data.get('half', True), 'enabled': model_data.get('enabled', True), # 注意:不再需要 download_url,模型从本地加载 } config['models'].append(model_config) logger.info(f"添加已验证的加密模型 {i}: {model_path}") # 在创建任务前清理已停止的任务,释放资源 logger.info("创建任务前清理已停止的任务...") cleaned_count = task_manager.cleanup_stopped_tasks() if cleaned_count > 0: logger.info(f"已清理 {cleaned_count} 个已停止的任务") # 检查系统资源 active_tasks = task_manager.get_active_tasks_count() max_tasks = task_manager.get_current_max_tasks() if active_tasks >= max_tasks: # 尝试强制清理一些资源 logger.warning(f"当前活动任务数 {active_tasks} 已达到限制 {max_tasks},尝试强制清理") # 清理所有非运行状态的任务 all_tasks = task_manager.get_all_tasks() force_cleaned = 0 for task in all_tasks: if task['status'] not in ['running', 'starting', 'creating']: if task_manager.cleanup_task(task['task_id']): force_cleaned += 1 if force_cleaned > 0: logger.info(f"强制清理了 {force_cleaned} 个非运行状态的任务") # 再次检查 active_tasks = task_manager.get_active_tasks_count() if active_tasks >= max_tasks: return jsonify({ "status": "error", "message": f"已达到最大并发任务数限制 ({max_tasks}),请等待其他任务完成或停止部分任务" }), 503 # 创建任务 logger.info(f"开始创建任务,包含 {len(config['models'])} 个已验证的加密模型...") try: task_id = task_manager.create_task(config, socketio) logger.info(f"任务创建成功,ID: {task_id}") except Exception as e: logger.error(f"任务创建失败: {str(e)}") return jsonify({ "status": "error", "message": f"任务创建失败: {str(e)}" }), 500 # 启动任务 logger.info(f"启动任务 {task_id}...") success = task_manager.start_task(task_id) if success: logger.info(f"任务启动成功: {task_id}") # 获取任务详细信息 task_status = task_manager.get_task_status(task_id) return jsonify({ "status": "success", "message": "任务创建并启动成功", "data": { "task_id": task_id, "models_count": len(config['models']), "encryption_required": True, "key_validated": True, "validation_result": validation_result, "task_info": task_status } }) else: logger.error(f"任务启动失败: {task_id}") # 如果启动失败,清理任务 task_manager.cleanup_task(task_id) return jsonify({ "status": "error", "message": "任务创建成功但启动失败", "task_id": task_id }), 500 except Exception as e: logger.error(f"创建任务失败: {str(e)}") logger.error(traceback.format_exc()) return jsonify({ "status": "error", "message": f"创建任务失败: {str(e)}" }), 500 @app.route('/api/system/resources', methods=['GET']) def get_system_resources(): """获取系统资源使用情况""" try: resources = gd.get_value('system_resources') max_tasks = gd.get_value('max_concurrent_tasks', 5) if not resources: # 实时获取资源 import psutil resources = { 'cpu_percent': psutil.cpu_percent(), 'memory_percent': psutil.virtual_memory().percent, 'memory_used': psutil.virtual_memory().used // (1024 * 1024), 'memory_total': psutil.virtual_memory().total // (1024 * 1024), 'timestamp': datetime.now().isoformat() } # 获取任务统计 active_tasks = task_manager.get_active_tasks_count() total_tasks = len(task_manager.tasks) return jsonify({ "status": "success", "data": { "resources": resources, "tasks": { "active": active_tasks, "total": total_tasks, "max_concurrent": max_tasks } } }) except Exception as e: logger.error(f"获取系统资源失败: {str(e)}") return jsonify({ "status": "error", "message": f"获取系统资源失败: {str(e)}" }), 500 @app.route('/api/models/encrypt', methods=['POST']) def encrypt_model(): """加密模型文件 - 强制加密""" try: config = get_default_config() data = request.json model_path = data.get('model_path') output_path = data.get('output_path') password = data.get('password') download_url = data.get('download_url') if not all([model_path, output_path, password, download_url]): return jsonify({"status": "error", "message": "缺少必要参数"}), 400 local_path = os.path.join(config['model_path'], model_path) output_path = os.path.join(config['model_path'], output_path) # 使用强制加密器 from mandatory_model_crypto import MandatoryModelEncryptor encryptor = MandatoryModelEncryptor() result = encryptor.encrypt_model(local_path, output_path, password, require_encryption=True) if result['success']: return jsonify({ "status": "success", "message": f"模型加密成功: {output_path}", "data": { "model_hash": result.get('model_hash'), "key_hash": result.get('key_hash'), "output_path": result.get('output_path') } }) else: return jsonify({ "status": "error", "message": f"模型加密失败: {result.get('error', '未知错误')}" }), 500 except Exception as e: logger.error(f"加密模型失败: {str(e)}") return jsonify({ "status": "error", "message": f"加密模型失败: {str(e)}" }), 500 @app.route('/api/models/verify_key', methods=['POST']) def verify_model_key(): """验证模型密钥""" try: data = request.json model_path = data.get('model_path') encryption_key = data.get('encryption_key') if not all([model_path, encryption_key]): return jsonify({"status": "error", "message": "缺少必要参数"}), 400 # 检查模型文件是否存在 full_path = os.path.join('encrypted_models', os.path.basename(model_path)) if not os.path.exists(full_path): return jsonify({"status": "error", "message": f"模型文件不存在: {full_path}"}), 400 # 使用强制加密器验证密钥 from mandatory_model_crypto import MandatoryModelEncryptor encryptor = MandatoryModelEncryptor() # 检查是否为正确加密的模型 if not encryptor.is_properly_encrypted(full_path): return jsonify({ "status": "error", "message": "模型文件未正确加密", "valid": False }), 400 # 验证密钥 verify_result = encryptor.decrypt_model(full_path, encryption_key, verify_key=True) if verify_result['success']: return jsonify({ "status": "success", "message": "密钥验证成功", "data": { "valid": True, "model_hash": verify_result.get('model_hash', '')[:16], "model_size": verify_result.get('original_size', 0) } }) else: return jsonify({ "status": "error", "message": f"密钥验证失败: {verify_result.get('error', '未知错误')}", "data": { "valid": False, "error": verify_result.get('error', '未知错误') } }), 400 except Exception as e: logger.error(f"验证密钥失败: {str(e)}") return jsonify({ "status": "error", "message": f"验证密钥失败: {str(e)}" }), 500 @app.route('/api/models/generate_key', methods=['POST']) def generate_secure_encryption_key(): """生成安全的加密密钥""" try: from mandatory_model_crypto import MandatoryModelEncryptor # 生成密钥 encryptor = MandatoryModelEncryptor() key_info = encryptor.generate_secure_key() return jsonify({ "status": "success", "message": "加密密钥生成成功", "data": { "key": key_info['key'], "key_hash": key_info['key_hash'], "short_hash": key_info['short_hash'], } }) except Exception as e: logger.error(f"生成加密密钥失败: {str(e)}") return jsonify({ "status": "error", "message": f"生成加密密钥失败: {str(e)}" }), 500 # server.py 中的 stop_detection 路由修改 @app.route('/api/tasks//stop', methods=['POST']) def stop_task(task_id): """停止指定任务""" try: logger.info(f"接收到停止任务请求: {task_id}") # 获取请求参数 data = request.json or {} force = data.get('force', False) timeout = data.get('timeout', 10) # 默认10秒超时 # 调用任务管理器停止任务 success = task_manager.stop_task(task_id, force=force) if success: logger.info(f"任务停止成功: {task_id}") return jsonify({ "status": "success", "message": f"任务停止成功: {task_id}", "task_id": task_id, "stopped": True }) else: logger.warning(f"停止任务失败: {task_id}") return jsonify({ "status": "error", "message": f"停止任务失败: {task_id}", "task_id": task_id, "stopped": False }), 500 except Exception as e: logger.error(f"停止任务异常: {str(e)}") logger.error(traceback.format_exc()) return jsonify({ "status": "error", "message": f"停止任务异常: {str(e)}", "task_id": task_id }), 500 @app.route('/api/tasks//status', methods=['GET']) def get_task_status(task_id): """获取任务状态(仅支持多模型)""" status = task_manager.get_task_status(task_id) if status: # 增强返回信息,包含模型详情 enhanced_status = { 'task_id': status['task_id'], 'status': status['status'], 'config': status['config'], 'models': status.get('models', []), # 直接返回模型列表 'stats': status['stats'], 'created_at': status['created_at'] } return jsonify({ "status": "success", "data": enhanced_status }) else: return jsonify({ "status": "error", "message": f"任务不存在: {task_id}" }), 404 @app.route('/api/tasks', methods=['GET']) def get_all_tasks(): """获取所有任务(仅支持多模型)""" tasks = task_manager.get_all_tasks() # 打印tasks内容 import json print("获取到的任务列表:") # print(json.dumps(tasks, indent=2, ensure_ascii=False)) # 增强任务信息 enhanced_tasks = [] for task in tasks: enhanced_task = { 'task_id': task['task_id'], 'status': task['status'], 'algoInstancesName': task['config']['algoInstancesName'], 'uavType': task['config']['uavType'], 'config': { 'rtmp_url': task['config']['rtmp_url'], 'taskname': task['config']['taskname'], 'push_url': task['config'].get('push_url', ''), 'enable_push': task['config'].get('enable_push', False) }, 'models': task.get('models', []), # 直接返回模型列表 'stats': task['stats'], 'performance': task.get('performance', { 'fps': 0, 'avg_process_time': 0, 'latency': 0, 'last_fps': 0 }), # 性能指标 'created_at': task['created_at'] } enhanced_tasks.append(enhanced_task) return jsonify({ "status": "success", "data": { "tasks": enhanced_tasks, "total": len(tasks), "active": task_manager.get_active_tasks_count(), "models_count": sum(len(t.get('models', [])) for t in enhanced_tasks) } }) @app.route('/api/tasks//cleanup', methods=['POST']) def cleanup_task(task_id): """清理任务资源""" success = task_manager.cleanup_task(task_id) if success: return jsonify({ "status": "success", "message": f"任务资源已清理: {task_id}" }) else: return jsonify({ "status": "error", "message": f"清理任务失败: {task_id}" }), 500 @app.route('/api/tasks/cleanup_all', methods=['POST']) def cleanup_all_tasks(): """清理所有任务""" cleaned_count = task_manager.cleanup_all_tasks() return jsonify({ "status": "success", "message": f"已清理 {cleaned_count} 个任务", "cleaned_count": cleaned_count }) @app.route('/api/tasks/cleanup_stopped', methods=['POST']) def cleanup_stopped_tasks(): """清理所有已停止的任务""" cleaned_count = task_manager.cleanup_stopped_tasks() return jsonify({ "status": "success", "message": f"已清理 {cleaned_count} 个已停止的任务", "cleaned_count": cleaned_count }) @app.route('/api/system/status', methods=['GET']) def get_system_status(): """获取系统状态""" try: import psutil system_info = { "cpu_percent": psutil.cpu_percent(), "memory_percent": psutil.virtual_memory().percent, "disk_percent": psutil.disk_usage('/').percent, "active_tasks": task_manager.get_active_tasks_count(), "total_tasks": len(task_manager.tasks), "max_concurrent_tasks": task_manager.get_current_max_tasks() } # GPU信息(如果可用) try: import GPUtil gpus = GPUtil.getGPUs() gpu_info = [] for gpu in gpus: gpu_info.append({ "id": gpu.id, "name": gpu.name, "load": gpu.load * 100, "memory_used": gpu.memoryUsed, "memory_total": gpu.memoryTotal, "temperature": gpu.temperature }) system_info["gpus"] = gpu_info except ImportError: # 如果没有安装GPUtil,尝试使用torch获取GPU信息 if torch.cuda.is_available(): gpu_info = [] for i in range(torch.cuda.device_count()): gpu_info.append({ "id": i, "name": torch.cuda.get_device_name(i), "memory_used": torch.cuda.memory_allocated(i) / 1024 ** 2, "memory_total": torch.cuda.get_device_properties(i).total_memory / 1024 ** 2 }) system_info["gpus"] = gpu_info return jsonify({ "status": "success", "data": system_info }) except Exception as e: logger.error(f"获取系统状态失败: {str(e)}") return jsonify({ "status": "error", "message": f"获取系统状态失败: {str(e)}" }), 500 # WebSocket事件 - 按任务ID发送 @socketio.on('connect') def handle_connect(): logger.info(f"Socket客户端已连接: {request.sid}") @socketio.on('disconnect') def handle_disconnect(): logger.info(f"Socket客户端断开: {request.sid}") @socketio.on('subscribe_task') def handle_subscribe_task(data): """订阅特定任务的WebSocket消息""" task_id = data.get('task_id') if task_id: # 这里可以记录客户端订阅关系 logger.info(f"客户端 {request.sid} 订阅任务: {task_id}") return {"status": "subscribed", "task_id": task_id} return {"status": "error", "message": "需要提供task_id"} @app.route('/api/system/resource_limits', methods=['GET', 'POST']) def manage_resource_limits(): """获取或设置资源限制""" if request.method == 'GET': # 获取当前资源限制 resource_monitor = gd.get_value('resource_monitor') if resource_monitor: return jsonify({ "status": "success", "data": resource_monitor.resource_limits }) else: config = get_default_config() return jsonify({ "status": "success", "message": "资源监控器未初始化", "data": config['resource_limits'] }) elif request.method == 'POST': # 更新资源限制 try: data = request.json resource_monitor = gd.get_value('resource_monitor', None) if resource_monitor and data: # 更新限制 resource_monitor.resource_limits.update(data) # 更新任务管理器中的限制 task_manager.resource_limits.update(data) logger.info(f"更新资源限制: {data}") return jsonify({ "status": "success", "message": "资源限制更新成功" }) else: return jsonify({ "status": "error", "message": "资源监控器未初始化或请求数据无效" }), 400 except Exception as e: logger.error(f"更新资源限制失败: {str(e)}") return jsonify({ "status": "error", "message": f"更新失败: {str(e)}" }), 500 @app.route('/api/tasks//models', methods=['GET']) def get_task_models(task_id): """获取任务中的模型配置""" try: task = task_manager.get_task_status(task_id) if not task: return jsonify({ "status": "error", "message": f"任务不存在: {task_id}" }), 404 # 获取任务中的模型配置 models_info = [] if task_id in task_manager.tasks: task_info = task_manager.tasks[task_id] config = task_info.get('config', {}) models_config = config.get('models', []) for i, model_config in enumerate(models_config): models_info.append({ 'id': i, 'name': os.path.basename(model_config.get('path', '')).split('.')[0], 'path': model_config.get('path'), 'conf_thres': model_config.get('conf_thres'), 'tags': model_config.get('tags', {}), 'color': model_config.get('color'), 'enabled': model_config.get('enabled', True) }) return jsonify({ "status": "success", "data": { "task_id": task_id, "models": models_info } }) except Exception as e: logger.error(f"获取任务模型失败: {str(e)}") return jsonify({ "status": "error", "message": f"获取失败: {str(e)}" }), 500 # server.py 添加以下路由 @app.route('/api/tasks//stream/status', methods=['GET']) def get_task_stream_status(task_id): """获取任务推流状态""" try: from task_stream_manager import task_stream_manager # 获取任务状态 task_status = task_manager.get_task_status(task_id) if not task_status: return jsonify({ "status": "error", "message": f"任务不存在: {task_id}" }), 404 # 获取推流信息 stream_info = task_stream_manager.get_all_task_streams_info().get(task_id, {}) # 合并信息 result = { "task_id": task_id, "task_status": task_status['status'], "stream_enabled": task_status['config'].get('enable_push', False), "stream_info": stream_info, "push_url": task_status['config'].get('push_url', '') } return jsonify({ "status": "success", "data": result }) except Exception as e: logger.error(f"获取任务推流状态失败: {str(e)}") return jsonify({ "status": "error", "message": f"获取失败: {str(e)}" }), 500 @app.route('/api/system/streams/info', methods=['GET']) def get_all_streams_info(): """获取所有任务推流信息""" try: from task_stream_manager import task_stream_manager streams_info = task_stream_manager.get_all_task_streams_info() return jsonify({ "status": "success", "data": { "total_streams": len(streams_info), "active_streams": sum(1 for info in streams_info.values() if info.get('running', False)), "streams": streams_info } }) except Exception as e: logger.error(f"获取所有推流信息失败: {str(e)}") return jsonify({ "status": "error", "message": f"获取失败: {str(e)}" }), 500 # 在适当位置初始化上传管理器 upload_manager = None def init_upload_manager(config): """初始化上传管理器""" global upload_manager if not upload_manager: upload_manager = get_upload_manager(config['upload']) return upload_manager # 添加上传相关路由 @app.route('/api/models/upload/start', methods=['POST']) def start_model_upload(): """开始模型文件上传""" try: data = request.json if not data: return jsonify({"status": "error", "message": "请求数据不能为空"}), 400 filename = data.get('filename') total_size = data.get('total_size') encryption_key = data.get('encryption_key') # 可选,但建议提供 if not all([filename, total_size]): return jsonify({"status": "error", "message": "缺少必要参数"}), 400 # 初始化上传管理器 config = get_default_config() upload_mgr = init_upload_manager(config) # 获取分片大小配置 chunk_size = config['upload'].get('chunk_size', 5 * 1024 * 1024) # 创建上传会话 result = upload_mgr.create_upload_session( filename=filename, total_size=total_size, chunk_size=chunk_size, encryption_key=encryption_key ) if result['success']: return jsonify({ "status": "success", "message": "上传会话创建成功", "data": { "session_id": result['session_id'], "total_chunks": result['total_chunks'], "chunk_size": chunk_size } }) else: return jsonify({ "status": "error", "message": f"创建上传会话失败: {result.get('error', '未知错误')}" }), 500 except Exception as e: logger.error(f"开始上传失败: {str(e)}") return jsonify({ "status": "error", "message": f"开始上传失败: {str(e)}" }), 500 @app.route('/api/models/upload/chunk', methods=['POST']) def upload_model_chunk(): """上传模型文件分片""" try: # 获取表单数据 session_id = request.form.get('session_id') chunk_index = int(request.form.get('chunk_index', 0)) if not session_id: return jsonify({"status": "error", "message": "缺少session_id"}), 400 # 获取文件数据 if 'chunk' not in request.files: return jsonify({"status": "error", "message": "未找到文件分片"}), 400 chunk_file = request.files['chunk'] chunk_data = chunk_file.read() # 获取上传管理器 config = get_default_config() upload_mgr = init_upload_manager(config) # 上传分片 result = upload_mgr.upload_chunk(session_id, chunk_index, chunk_data) if result['success']: return jsonify({ "status": "success", "message": "分片上传成功", "data": { "progress": result['progress'], "received_chunks": result['received_chunks'], "total_chunks": result['total_chunks'], "chunk_index": chunk_index } }) else: return jsonify({ "status": "error", "message": f"分片上传失败: {result.get('error', '未知错误')}" }), 500 except Exception as e: logger.error(f"上传分片失败: {str(e)}") return jsonify({ "status": "error", "message": f"上传分片失败: {str(e)}" }), 500 @app.route('/api/models/upload/status/', methods=['GET']) def get_upload_status(session_id): """获取上传状态""" try: config = get_default_config() upload_mgr = init_upload_manager(config) result = upload_mgr.get_upload_status(session_id) if result['success']: return jsonify({ "status": "success", "data": result['data'] }) else: return jsonify({ "status": "error", "message": result.get('error', '获取状态失败') }), 404 except Exception as e: logger.error(f"获取上传状态失败: {str(e)}") return jsonify({ "status": "error", "message": f"获取上传状态失败: {str(e)}" }), 500 @app.route('/api/models/upload/cancel/', methods=['POST']) def cancel_upload(session_id): """取消上传""" try: # 实现取消逻辑 config = get_default_config() upload_mgr = init_upload_manager(config) # 这里需要在上传管理器中添加取消功能 # upload_mgr.cancel_upload(session_id) return jsonify({ "status": "success", "message": f"上传已取消: {session_id}" }) except Exception as e: logger.error(f"取消上传失败: {str(e)}") return jsonify({ "status": "error", "message": f"取消上传失败: {str(e)}" }), 500 @app.route('/api/models/list', methods=['GET']) def list_encrypted_models(): """列出所有已加密的模型文件""" try: config = get_default_config() encrypted_dir = config['upload']['encrypted_models_dir'] if not os.path.exists(encrypted_dir): return jsonify({ "status": "success", "data": { "models": [], "total": 0 } }) models = [] for filename in os.listdir(encrypted_dir): if filename.endswith('.enc'): filepath = os.path.join(encrypted_dir, filename) stats = os.stat(filepath) models.append({ 'filename': filename, 'path': filepath, 'size': stats.st_size, 'modified': stats.st_mtime, 'encrypted': True }) return jsonify({ "status": "success", "data": { "models": models, "total": len(models), "encrypted_dir": encrypted_dir } }) except Exception as e: logger.error(f"列出模型失败: {str(e)}") return jsonify({ "status": "error", "message": f"列出模型失败: {str(e)}" }), 500 # 在 server.py 中添加上传页面路由 @app.route('/model_upload') def model_upload_page(): """模型上传页面""" return render_template("model_upload.html") # server.py - 添加以下路由 from mandatory_model_crypto import ModelEncryptionService, validate_models_before_task, verify_single_model_api @app.route('/api/models/process/start', methods=['POST']) def start_model_processing(): """开始模型处理流程:生成密钥 -> 上传模型""" try: data = request.json # 选项1:客户端提供自己的密钥 client_key = data.get('encryption_key') # 选项2:服务器生成密钥 generate_new = data.get('generate_key', False) response_data = {} if generate_new and not client_key: # 服务器生成新密钥 key_info = ModelEncryptionService.generate_secure_key() response_data['key_info'] = { 'key': key_info['key'], 'key_hash': key_info['key_hash'], 'short_hash': key_info['short_hash'], 'generated_by': 'server' } elif client_key: # 验证客户端提供的密钥 key_valid, key_msg = ModelEncryptionService.validate_key_strength(client_key) if not key_valid: return jsonify({ "status": "error", "message": f"密钥强度不足: {key_msg}" }), 400 response_data['key_info'] = { 'key': client_key, 'key_hash': hashlib.sha256(client_key.encode()).hexdigest(), 'short_hash': hashlib.sha256(client_key.encode()).hexdigest()[:16], 'generated_by': 'client' } else: return jsonify({ "status": "error", "message": "请提供加密密钥或选择生成新密钥" }), 400 # 生成上传令牌 upload_token = secrets.token_urlsafe(32) response_data['upload_token'] = upload_token response_data['token_expires'] = time.time() + 3600 # 1小时有效期 # 存储上传会话(简化版,生产环境应使用数据库) upload_sessions = gd.get_or_create_dict('upload_sessions') upload_sessions[upload_token] = { 'key_info': response_data['key_info'], 'created_at': time.time(), 'status': 'pending' } return jsonify({ "status": "success", "message": "模型处理流程已启动", "data": response_data }) except Exception as e: logger.error(f"启动模型处理流程失败: {str(e)}") return jsonify({ "status": "error", "message": f"启动流程失败: {str(e)}" }), 500 @app.route('/api/models/process/verify_key', methods=['POST']) def verify_encryption_key(): """验证加密密钥""" try: data = request.json encryption_key = data.get('encryption_key') model_path = data.get('model_path') # 可选,如果有具体模型 if not encryption_key: return jsonify({ "status": "error", "message": "请提供加密密钥" }), 400 # 验证密钥强度 key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key) if not key_valid: return jsonify({ "status": "error", "message": key_msg, "data": { "valid": False, "strength": "weak" } }), 400 response_data = { "valid": True, "strength": "strong", "key_hash": hashlib.sha256(encryption_key.encode()).hexdigest()[:16] } # 如果有具体模型,尝试解密验证 if model_path: verify_result = verify_single_model_api(model_path, encryption_key) response_data['model_verification'] = verify_result return jsonify({ "status": "success", "message": "密钥验证成功", "data": response_data }) except Exception as e: logger.error(f"验证密钥失败: {str(e)}") return jsonify({ "status": "error", "message": f"验证失败: {str(e)}" }), 500 @app.route('/api/models/process/validate_task', methods=['POST']) def validate_task_models(): """创建任务前的模型验证""" try: data = request.json if not data: return jsonify({ "status": "error", "message": "请求数据不能为空" }), 400 # 提取任务配置 task_config = { 'models': data.get('models', []) } # 验证所有模型 validation_result = validate_models_before_task(task_config) if validation_result['success']: return jsonify({ "status": "success", "message": f"模型验证通过 ({validation_result['valid_models']}/{validation_result['total_models']})", "data": validation_result }) else: return jsonify({ "status": "error", "message": f"模型验证失败: {validation_result.get('error', '未知错误')}", "data": validation_result }), 400 except Exception as e: logger.error(f"验证任务模型失败: {str(e)}") return jsonify({ "status": "error", "message": f"验证失败: {str(e)}" }), 500 @app.route('/api/models/encrypted/list_available', methods=['GET']) def list_available_encrypted_models(): """列出可用的加密模型(用于任务创建选择)""" try: config = get_default_config() encrypted_dir = config['upload']['encrypted_models_dir'] if not os.path.exists(encrypted_dir): return jsonify({ "status": "success", "data": { "models": [], "total": 0 } }) models = [] for filename in os.listdir(encrypted_dir): if filename.endswith('.enc'): filepath = os.path.join(encrypted_dir, filename) stats = os.stat(filepath) # 尝试读取模型基本信息(不验证密钥) try: with open(filepath, 'rb') as f: encrypted_data = pickle.load(f) model_info = { 'filename': filename, 'path': f"encrypted_models/{filename}", # 相对路径 'size': stats.st_size, 'modified': stats.st_mtime, 'encrypted': True, 'model_hash': encrypted_data.get('model_hash', '')[:16] if isinstance(encrypted_data, dict) else '', 'version': encrypted_data.get('version', 'unknown') if isinstance(encrypted_data, dict) else '', 'original_size': encrypted_data.get('original_size', 0) if isinstance(encrypted_data, dict) else 0 } models.append(model_info) except Exception as e: logger.warning(f"读取模型信息失败 {filename}: {str(e)}") models.append({ 'filename': filename, 'path': f"encrypted_models/{filename}", 'size': stats.st_size, 'modified': stats.st_mtime, 'encrypted': True, 'error': '无法读取模型信息' }) # 按修改时间排序 models.sort(key=lambda x: x['modified'], reverse=True) return jsonify({ "status": "success", "data": { "models": models, "total": len(models), "directory": encrypted_dir } }) except Exception as e: logger.error(f"列出可用模型失败: {str(e)}") return jsonify({ "status": "error", "message": f"列出模型失败: {str(e)}" }), 500 @app.route('/api/models/process/test_decrypt', methods=['POST']) def test_model_decryption(): """测试模型解密(不实际加载YOLO模型)""" try: data = request.json model_path = data.get('model_path') encryption_key = data.get('encryption_key') if not all([model_path, encryption_key]): return jsonify({ "status": "error", "message": "请提供模型路径和加密密钥" }), 400 # 构建完整路径 config = get_default_config() encrypted_dir = config['upload']['encrypted_models_dir'] if not os.path.isabs(model_path): model_filename = os.path.basename(model_path) full_model_path = os.path.join(encrypted_dir, model_filename) else: full_model_path = model_path # 检查文件是否存在 if not os.path.exists(full_model_path): return jsonify({ "status": "error", "message": f"模型文件不存在: {full_model_path}" }), 404 # 测试解密 from mandatory_model_crypto import MandatoryModelValidator validator = MandatoryModelValidator() start_time = time.time() decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key) elapsed_time = time.time() - start_time if decrypt_result['success']: return jsonify({ "status": "success", "message": "解密测试成功", "data": { 'success': True, 'model_hash': decrypt_result.get('model_hash', '')[:16], 'model_size': decrypt_result.get('original_size', 0), 'decryption_time': elapsed_time, 'file_path': full_model_path } }) else: return jsonify({ "status": "error", "message": f"解密测试失败: {decrypt_result.get('error', '未知错误')}", "data": { 'success': False, 'error': decrypt_result.get('error', '未知错误'), 'decryption_time': elapsed_time } }), 400 except Exception as e: logger.error(f"测试解密失败: {str(e)}") return jsonify({ "status": "error", "message": f"测试失败: {str(e)}" }), 500 # server.py - 添加任务创建页面路由 @app.route('/task_create') def task_create_page(): """任务创建页面""" return render_template("task_create.html") # server.py - 添加资源检查接口 @app.route('/api/system/check_resources', methods=['GET']) def check_system_resources(): """检查系统资源是否足够创建新任务""" try: import psutil import torch # 获取系统资源 cpu_percent = psutil.cpu_percent(interval=0.1) memory_info = psutil.virtual_memory() memory_percent = memory_info.percent # GPU信息 gpu_info = [] gpu_available = torch.cuda.is_available() if gpu_available: for i in range(torch.cuda.device_count()): gpu_memory_used = torch.cuda.memory_allocated(i) / 1024 ** 2 # MB gpu_memory_total = torch.cuda.get_device_properties(i).total_memory / 1024 ** 2 # MB gpu_memory_percent = (gpu_memory_used / gpu_memory_total) * 100 if gpu_memory_total > 0 else 0 gpu_info.append({ 'id': i, 'name': torch.cuda.get_device_name(i), 'memory_used': gpu_memory_used, 'memory_total': gpu_memory_total, 'memory_percent': gpu_memory_percent }) # 获取当前任务信息 active_tasks = task_manager.get_active_tasks_count() max_tasks = task_manager.get_current_max_tasks() # 资源阈值 cpu_threshold = 80 # CPU使用率阈值 memory_threshold = 85 # 内存使用率阈值 gpu_memory_threshold = 90 # GPU内存使用率阈值 # 检查资源状态 resources_ok = True warnings = [] if cpu_percent > cpu_threshold: resources_ok = False warnings.append(f"CPU使用率过高: {cpu_percent:.1f}% > {cpu_threshold}%") if memory_percent > memory_threshold: resources_ok = False warnings.append(f"内存使用率过高: {memory_percent:.1f}% > {memory_threshold}%") if gpu_available and gpu_info: max_gpu_memory = max(gpu['memory_percent'] for gpu in gpu_info) if max_gpu_memory > gpu_memory_threshold: warnings.append(f"GPU内存使用率过高: {max_gpu_memory:.1f}% > {gpu_memory_threshold}%") # GPU内存高不是致命错误,只是警告 if active_tasks >= max_tasks: resources_ok = False warnings.append(f"任务数达到上限: {active_tasks}/{max_tasks}") return jsonify({ "status": "success", "data": { "resources_available": resources_ok, "active_tasks": active_tasks, "max_tasks": max_tasks, "slots_available": max(0, max_tasks - active_tasks), "cpu_percent": cpu_percent, "memory_percent": memory_percent, "memory_used": memory_info.used / 1024 ** 2, # MB "memory_total": memory_info.total / 1024 ** 2, # MB "gpu_available": gpu_available, "gpu_info": gpu_info, "warnings": warnings, "thresholds": { "cpu": cpu_threshold, "memory": memory_threshold, "gpu_memory": gpu_memory_threshold }, "timestamp": time.time() } }) except Exception as e: logger.error(f"检查系统资源失败: {str(e)}") return jsonify({ "status": "error", "message": f"检查资源失败: {str(e)}" }), 500 # 初始化函数,可在主程序中调用 def init_app(): """初始化应用程序""" with app.app_context(): gd.set_value('task_manager', task_manager) logger.info("任务管理器初始化完成") return app # ======================= 公共服务器功能 ======================= # 内部系统状态 internal_system_status = { "is_connected": False, "devices": {} } # 连接管理 connected_clients = {} # WebSocket事件 - 公共服务器 @socketio.on('connect', namespace='/ws/internal') def handle_internal_connect(): """内部系统连接""" client_id = "internal_system" connected_clients[client_id] = request.sid internal_system_status["is_connected"] = True pub_logger.info(f"Internal system connected: {request.sid}") @socketio.on('disconnect', namespace='/ws/internal') def handle_internal_disconnect(): """内部系统断开连接""" client_id = "internal_system" if client_id in connected_clients: del connected_clients[client_id] internal_system_status["is_connected"] = False pub_logger.info("Internal system disconnected") @socketio.on('message', namespace='/ws/internal') def handle_internal_message(data): """处理内部系统消息""" try: message = json.loads(data) if message.get("type") == "status_update": # 更新设备状态 device_id = message.get("device_id") if device_id: internal_system_status["devices"][device_id] = message # 广播状态更新给客户端 socketio.emit('message', json.dumps(message), namespace='/ws/client') pub_logger.info(f"Status updated for device {device_id}: {message}") elif message.get("type") == "command_response": # 转发命令响应给客户端 socketio.emit('message', json.dumps(message), namespace='/ws/client') pub_logger.info(f"Command response: {message}") except json.JSONDecodeError: pub_logger.error(f"Invalid JSON from internal system: {data}") @socketio.on('connect', namespace='/ws/client') def handle_client_connect(): """客户端连接""" client_id = request.args.get('client_id', request.sid) connected_clients[client_id] = request.sid pub_logger.info(f"Client connected: {client_id} ({request.sid})") # 发送当前状态给新连接的客户端 for device_id, status in internal_system_status["devices"].items(): socketio.emit('message', json.dumps(status), namespace='/ws/client', room=request.sid) @socketio.on('disconnect', namespace='/ws/client') def handle_client_disconnect(): """客户端断开连接""" # 找到对应的客户端ID client_id = None for id, sid in connected_clients.items(): if sid == request.sid: client_id = id break if client_id: del connected_clients[client_id] pub_logger.info(f"Client disconnected: {client_id}") @socketio.on('message', namespace='/ws/client') def handle_client_message(data): """处理客户端消息""" try: command = json.loads(data) # 转发命令给内部系统 if internal_system_status["is_connected"] and "internal_system" in connected_clients: socketio.emit('message', data, namespace='/ws/internal', room=connected_clients["internal_system"]) pub_logger.info(f"Forwarded command from client: {command}") else: # 内部系统未连接,返回错误 error_response = { "type": "command_response", "status": "error", "message": "Internal system not connected", "command": command } socketio.emit('message', json.dumps(error_response), namespace='/ws/client', room=request.sid) except json.JSONDecodeError: pub_logger.error(f"Invalid JSON from client: {data}") # API路由 - 公共服务器 @app.route('/api/devices') def get_devices(): """获取设备列表""" return jsonify({ "code": 200, "msg": "查询成功", "data": list(internal_system_status["devices"].values()) }) @app.route('/api/command', methods=['POST']) def send_command(): """发送命令""" if not internal_system_status["is_connected"]: return jsonify({ "code": 503, "msg": "内部系统未连接", "data": None }) try: command = request.json # 发送命令给内部系统 if "internal_system" in connected_clients: socketio.emit('message', json.dumps(command), namespace='/ws/internal', room=connected_clients["internal_system"]) pub_logger.info(f"Command sent: {command}") return jsonify({ "code": 200, "msg": "命令已发送", "data": command }) else: return jsonify({ "code": 503, "msg": "内部系统未连接", "data": None }) except Exception as e: pub_logger.error(f"Error sending command: {str(e)}") return jsonify({ "code": 500, "msg": f"发送命令失败: {str(e)}", "data": None })