Yolov/server.py

1636 lines
55 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# server.py
import hashlib
import os
import pickle
import secrets
from datetime import datetime
import json
import logging
import torch
from flask import Flask, jsonify, request, render_template
from flask_socketio import SocketIO, emit, join_room, leave_room
from config import get_default_config
from mandatory_model_crypto import MandatoryModelEncryptor
from model_upload_manager import get_upload_manager
from task_manager import task_manager # 导入任务管理器
from global_data import gd
from log import logger
import time
import traceback
from mandatory_model_crypto import ModelEncryptionService, validate_models_before_task, verify_single_model_api
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler("public_server.log", encoding="utf-8")
]
)
pub_logger = logging.getLogger("PublicServer")
# Flask初始化
app = Flask(__name__, static_url_path='/static')
CORS(app)
# 处理Windows平台兼容性问题
import platform
if platform.system() == 'Windows':
socketio = SocketIO(app,
cors_allowed_origins="*",
async_mode='threading',
allow_unsafe_werkzeug=True, #
max_http_buffer_size=5 * 1024 * 1024)
else:
socketio = SocketIO(app,
cors_allowed_origins="*",
async_mode='threading',
max_http_buffer_size=5 * 1024 * 1024)
_initialized = False
@app.before_request
def initialize_once():
global _initialized
if not _initialized:
with app.app_context():
gd.set_value('task_manager', task_manager)
logger.info("任务管理器初始化完成")
_initialized = True
# ======================= Flask路由 =======================
@app.route('/')
def task_management():
"""任务管理页面"""
return render_template("task_management.html")
@app.route('/video_player')
def video_player():
"""视频播放页面"""
return render_template("flv2.html")
# server.py - 修改 create_task 函数
@app.route('/api/tasks/create', methods=['POST'])
def create_task():
"""创建新任务 - 强制模型加密和密钥验证"""
try:
config = get_default_config()
config['socketIO'] = socketio
if not request.json:
return jsonify({"status": "error", "message": "请求体不能为空"}), 400
data = request.json
import json
logger.info(f"请求参数: {json.dumps(data, indent=2, ensure_ascii=False)}")
logger.info(f"收到创建任务请求: {data.get('taskname', '未命名')}")
# 验证必须的参数
if 'rtmp_url' not in data:
return jsonify({"status": "error", "message": "必须提供rtmp_url"}), 400
if 'models' not in data or not isinstance(data['models'], list):
return jsonify({"status": "error", "message": "必须提供models列表"}), 400
# 检查模型数量
if len(data['models']) == 0:
return jsonify({"status": "error", "message": "models列表不能为空"}), 400
# ================= 关键修改:创建任务前的模型验证 =================
logger.info("开始创建任务前的模型验证...")
# 1. 检查所有模型是否都有加密密钥
for i, model_data in enumerate(data['models']):
if 'encryption_key' not in model_data:
return jsonify({
"status": "error",
"message": f"模型 {i} 必须提供encryption_key"
}), 400
# 2. 验证所有模型的密钥
task_config = {
'models': data['models']
}
validation_result = validate_models_before_task(task_config)
if not validation_result['success']:
logger.error(f"模型验证失败: {validation_result.get('error', '未知错误')}")
# 提供详细的验证结果
error_details = []
for result in validation_result.get('validation_results', []):
if not result.get('key_valid', False):
error_details.append(f"模型 {result['model_index']}: {result.get('error', '验证失败')}")
error_message = validation_result.get('error', '模型验证失败')
if error_details:
error_message += f" | 详情: {', '.join(error_details)}"
return jsonify({
"status": "error",
"message": error_message,
"data": validation_result
}), 400
logger.info(f"模型验证通过: {validation_result['valid_models']}/{validation_result['total_models']} 个模型有效")
# ================= 验证结束 =================
# 更新配置
config['rtmp']['url'] = data['rtmp_url']
if 'uavType' in data and data['uavType']:
config['task']['uavType'] = data['uavType']
if 'algoInstancesName' in data and data['algoInstancesName']:
config['task']['algoInstancesName'] = data['algoInstancesName']
if 'push_url' in data and data['push_url']:
config['push']['url'] = data['push_url']
if 'taskname' in data:
config['task']['taskname'] = data['taskname']
else:
config['task']['taskname'] = f"task_{int(time.time())}"
if 'AlgoId' in data:
config['task']['aiid'] = data['AlgoId']
# 处理多模型配置 - 使用已验证的模型
config['models'] = []
for i, model_data in enumerate(data['models']):
# 此时密钥已验证通过
encryption_key = model_data['encryption_key']
# 使用加密模型文件名(相对路径)
model_path = model_data.get('path', f'model_{i}.enc')
if not model_path.startswith('encrypted_models/'):
model_filename = os.path.basename(model_path)
if not model_filename.endswith('.enc'):
model_filename += '.enc'
model_path = f"{model_filename}"
# 构建模型配置
model_config = {
'path': model_path,
'encryption_key': encryption_key,
'encrypted': True,
'tags': model_data.get('tags', {}),
'conf_thres': float(model_data.get('conf_thres', 0.25)),
'iou_thres': float(model_data.get('iou_thres', 0.45)),
'imgsz': max(128, min(1920, int(model_data.get('imgsz', 640)))),
'color': model_data.get('color'),
'line_width': int(model_data.get('line_width', 1)),
'device': model_data.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu'),
'half': model_data.get('half', True),
'enabled': model_data.get('enabled', True),
# 注意:不再需要 download_url模型从本地加载
}
config['models'].append(model_config)
logger.info(f"添加已验证的加密模型 {i}: {model_path}")
# 在创建任务前清理已停止的任务,释放资源
logger.info("创建任务前清理已停止的任务...")
cleaned_count = task_manager.cleanup_stopped_tasks()
if cleaned_count > 0:
logger.info(f"已清理 {cleaned_count} 个已停止的任务")
# 检查系统资源
active_tasks = task_manager.get_active_tasks_count()
max_tasks = task_manager.get_current_max_tasks()
if active_tasks >= max_tasks:
# 尝试强制清理一些资源
logger.warning(f"当前活动任务数 {active_tasks} 已达到限制 {max_tasks},尝试强制清理")
# 清理所有非运行状态的任务
all_tasks = task_manager.get_all_tasks()
force_cleaned = 0
for task in all_tasks:
if task['status'] not in ['running', 'starting', 'creating']:
if task_manager.cleanup_task(task['task_id']):
force_cleaned += 1
if force_cleaned > 0:
logger.info(f"强制清理了 {force_cleaned} 个非运行状态的任务")
# 再次检查
active_tasks = task_manager.get_active_tasks_count()
if active_tasks >= max_tasks:
return jsonify({
"status": "error",
"message": f"已达到最大并发任务数限制 ({max_tasks}),请等待其他任务完成或停止部分任务"
}), 503
# 创建任务
logger.info(f"开始创建任务,包含 {len(config['models'])} 个已验证的加密模型...")
try:
task_id = task_manager.create_task(config, socketio)
logger.info(f"任务创建成功ID: {task_id}")
except Exception as e:
logger.error(f"任务创建失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"任务创建失败: {str(e)}"
}), 500
# 启动任务
logger.info(f"启动任务 {task_id}...")
success = task_manager.start_task(task_id)
if success:
logger.info(f"任务启动成功: {task_id}")
# 获取任务详细信息
task_status = task_manager.get_task_status(task_id)
return jsonify({
"status": "success",
"message": "任务创建并启动成功",
"data": {
"task_id": task_id,
"models_count": len(config['models']),
"encryption_required": True,
"key_validated": True,
"validation_result": validation_result,
"task_info": task_status
}
})
else:
logger.error(f"任务启动失败: {task_id}")
# 如果启动失败,清理任务
task_manager.cleanup_task(task_id)
return jsonify({
"status": "error",
"message": "任务创建成功但启动失败",
"task_id": task_id
}), 500
except Exception as e:
logger.error(f"创建任务失败: {str(e)}")
logger.error(traceback.format_exc())
return jsonify({
"status": "error",
"message": f"创建任务失败: {str(e)}"
}), 500
@app.route('/api/system/resources', methods=['GET'])
def get_system_resources():
"""获取系统资源使用情况"""
try:
resources = gd.get_value('system_resources')
max_tasks = gd.get_value('max_concurrent_tasks', 5)
if not resources:
# 实时获取资源
import psutil
resources = {
'cpu_percent': psutil.cpu_percent(),
'memory_percent': psutil.virtual_memory().percent,
'memory_used': psutil.virtual_memory().used // (1024 * 1024),
'memory_total': psutil.virtual_memory().total // (1024 * 1024),
'timestamp': datetime.now().isoformat()
}
# 获取任务统计
active_tasks = task_manager.get_active_tasks_count()
total_tasks = len(task_manager.tasks)
return jsonify({
"status": "success",
"data": {
"resources": resources,
"tasks": {
"active": active_tasks,
"total": total_tasks,
"max_concurrent": max_tasks
}
}
})
except Exception as e:
logger.error(f"获取系统资源失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取系统资源失败: {str(e)}"
}), 500
@app.route('/api/models/encrypt', methods=['POST'])
def encrypt_model():
"""加密模型文件 - 强制加密"""
try:
config = get_default_config()
data = request.json
model_path = data.get('model_path')
output_path = data.get('output_path')
password = data.get('password')
download_url = data.get('download_url')
if not all([model_path, output_path, password, download_url]):
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
local_path = os.path.join(config['model_path'], model_path)
output_path = os.path.join(config['model_path'], output_path)
# 使用强制加密器
from mandatory_model_crypto import MandatoryModelEncryptor
encryptor = MandatoryModelEncryptor()
result = encryptor.encrypt_model(local_path, output_path, password, require_encryption=True)
if result['success']:
return jsonify({
"status": "success",
"message": f"模型加密成功: {output_path}",
"data": {
"model_hash": result.get('model_hash'),
"key_hash": result.get('key_hash'),
"output_path": result.get('output_path')
}
})
else:
return jsonify({
"status": "error",
"message": f"模型加密失败: {result.get('error', '未知错误')}"
}), 500
except Exception as e:
logger.error(f"加密模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"加密模型失败: {str(e)}"
}), 500
@app.route('/api/models/verify_key', methods=['POST'])
def verify_model_key():
"""验证模型密钥"""
try:
data = request.json
model_path = data.get('model_path')
encryption_key = data.get('encryption_key')
if not all([model_path, encryption_key]):
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
# 检查模型文件是否存在
full_path = os.path.join('encrypted_models', os.path.basename(model_path))
if not os.path.exists(full_path):
return jsonify({"status": "error", "message": f"模型文件不存在: {full_path}"}), 400
# 使用强制加密器验证密钥
from mandatory_model_crypto import MandatoryModelEncryptor
encryptor = MandatoryModelEncryptor()
# 检查是否为正确加密的模型
if not encryptor.is_properly_encrypted(full_path):
return jsonify({
"status": "error",
"message": "模型文件未正确加密",
"valid": False
}), 400
# 验证密钥
verify_result = encryptor.decrypt_model(full_path, encryption_key, verify_key=True)
if verify_result['success']:
return jsonify({
"status": "success",
"message": "密钥验证成功",
"data": {
"valid": True,
"model_hash": verify_result.get('model_hash', '')[:16],
"model_size": verify_result.get('original_size', 0)
}
})
else:
return jsonify({
"status": "error",
"message": f"密钥验证失败: {verify_result.get('error', '未知错误')}",
"data": {
"valid": False,
"error": verify_result.get('error', '未知错误')
}
}), 400
except Exception as e:
logger.error(f"验证密钥失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"验证密钥失败: {str(e)}"
}), 500
@app.route('/api/models/generate_key', methods=['POST'])
def generate_secure_encryption_key():
"""生成安全的加密密钥"""
try:
from mandatory_model_crypto import MandatoryModelEncryptor
# 生成密钥
encryptor = MandatoryModelEncryptor()
key_info = encryptor.generate_secure_key()
return jsonify({
"status": "success",
"message": "加密密钥生成成功",
"data": {
"key": key_info['key'],
"key_hash": key_info['key_hash'],
"short_hash": key_info['short_hash'],
}
})
except Exception as e:
logger.error(f"生成加密密钥失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"生成加密密钥失败: {str(e)}"
}), 500
# server.py 中的 stop_detection 路由修改
@app.route('/api/tasks/<task_id>/stop', methods=['POST'])
def stop_task(task_id):
"""停止指定任务"""
try:
logger.info(f"接收到停止任务请求: {task_id}")
# 获取请求参数
data = request.json or {}
force = data.get('force', False)
timeout = data.get('timeout', 10) # 默认10秒超时
# 调用任务管理器停止任务
success = task_manager.stop_task(task_id, force=force)
if success:
logger.info(f"任务停止成功: {task_id}")
return jsonify({
"status": "success",
"message": f"任务停止成功: {task_id}",
"task_id": task_id,
"stopped": True
})
else:
logger.warning(f"停止任务失败: {task_id}")
return jsonify({
"status": "error",
"message": f"停止任务失败: {task_id}",
"task_id": task_id,
"stopped": False
}), 500
except Exception as e:
logger.error(f"停止任务异常: {str(e)}")
logger.error(traceback.format_exc())
return jsonify({
"status": "error",
"message": f"停止任务异常: {str(e)}",
"task_id": task_id
}), 500
@app.route('/api/tasks/<task_id>/status', methods=['GET'])
def get_task_status(task_id):
"""获取任务状态(仅支持多模型)"""
status = task_manager.get_task_status(task_id)
if status:
# 增强返回信息,包含模型详情
enhanced_status = {
'task_id': status['task_id'],
'status': status['status'],
'config': status['config'],
'models': status.get('models', []), # 直接返回模型列表
'stats': status['stats'],
'created_at': status['created_at']
}
return jsonify({
"status": "success",
"data": enhanced_status
})
else:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
@app.route('/api/tasks', methods=['GET'])
def get_all_tasks():
"""获取所有任务(仅支持多模型)"""
tasks = task_manager.get_all_tasks()
# 打印tasks内容
import json
print("获取到的任务列表:")
# print(json.dumps(tasks, indent=2, ensure_ascii=False))
# 增强任务信息
enhanced_tasks = []
for task in tasks:
enhanced_task = {
'task_id': task['task_id'],
'status': task['status'],
'algoInstancesName': task['config']['algoInstancesName'],
'uavType': task['config']['uavType'],
'config': {
'rtmp_url': task['config']['rtmp_url'],
'taskname': task['config']['taskname'],
'push_url': task['config'].get('push_url', ''),
'enable_push': task['config'].get('enable_push', False)
},
'models': task.get('models', []), # 直接返回模型列表
'stats': task['stats'],
'performance': task.get('performance', {
'fps': 0,
'avg_process_time': 0,
'latency': 0,
'last_fps': 0
}), # 性能指标
'created_at': task['created_at']
}
enhanced_tasks.append(enhanced_task)
return jsonify({
"status": "success",
"data": {
"tasks": enhanced_tasks,
"total": len(tasks),
"active": task_manager.get_active_tasks_count(),
"models_count": sum(len(t.get('models', [])) for t in enhanced_tasks)
}
})
@app.route('/api/tasks/<task_id>/cleanup', methods=['POST'])
def cleanup_task(task_id):
"""清理任务资源"""
success = task_manager.cleanup_task(task_id)
if success:
return jsonify({
"status": "success",
"message": f"任务资源已清理: {task_id}"
})
else:
return jsonify({
"status": "error",
"message": f"清理任务失败: {task_id}"
}), 500
@app.route('/api/tasks/cleanup_all', methods=['POST'])
def cleanup_all_tasks():
"""清理所有任务"""
cleaned_count = task_manager.cleanup_all_tasks()
return jsonify({
"status": "success",
"message": f"已清理 {cleaned_count} 个任务",
"cleaned_count": cleaned_count
})
@app.route('/api/tasks/cleanup_stopped', methods=['POST'])
def cleanup_stopped_tasks():
"""清理所有已停止的任务"""
cleaned_count = task_manager.cleanup_stopped_tasks()
return jsonify({
"status": "success",
"message": f"已清理 {cleaned_count} 个已停止的任务",
"cleaned_count": cleaned_count
})
@app.route('/api/system/status', methods=['GET'])
def get_system_status():
"""获取系统状态"""
try:
import psutil
system_info = {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent,
"disk_percent": psutil.disk_usage('/').percent,
"active_tasks": task_manager.get_active_tasks_count(),
"total_tasks": len(task_manager.tasks),
"max_concurrent_tasks": task_manager.get_current_max_tasks()
}
# GPU信息如果可用
try:
import GPUtil
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append({
"id": gpu.id,
"name": gpu.name,
"load": gpu.load * 100,
"memory_used": gpu.memoryUsed,
"memory_total": gpu.memoryTotal,
"temperature": gpu.temperature
})
system_info["gpus"] = gpu_info
except ImportError:
# 如果没有安装GPUtil尝试使用torch获取GPU信息
if torch.cuda.is_available():
gpu_info = []
for i in range(torch.cuda.device_count()):
gpu_info.append({
"id": i,
"name": torch.cuda.get_device_name(i),
"memory_used": torch.cuda.memory_allocated(i) / 1024 ** 2,
"memory_total": torch.cuda.get_device_properties(i).total_memory / 1024 ** 2
})
system_info["gpus"] = gpu_info
return jsonify({
"status": "success",
"data": system_info
})
except Exception as e:
logger.error(f"获取系统状态失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取系统状态失败: {str(e)}"
}), 500
# WebSocket事件 - 按任务ID发送
@socketio.on('connect')
def handle_connect():
logger.info(f"Socket客户端已连接: {request.sid}")
@socketio.on('disconnect')
def handle_disconnect():
logger.info(f"Socket客户端断开: {request.sid}")
@socketio.on('subscribe_task')
def handle_subscribe_task(data):
"""订阅特定任务的WebSocket消息"""
task_id = data.get('task_id')
if task_id:
# 这里可以记录客户端订阅关系
logger.info(f"客户端 {request.sid} 订阅任务: {task_id}")
return {"status": "subscribed", "task_id": task_id}
return {"status": "error", "message": "需要提供task_id"}
@app.route('/api/system/resource_limits', methods=['GET', 'POST'])
def manage_resource_limits():
"""获取或设置资源限制"""
if request.method == 'GET':
# 获取当前资源限制
resource_monitor = gd.get_value('resource_monitor')
if resource_monitor:
return jsonify({
"status": "success",
"data": resource_monitor.resource_limits
})
else:
config = get_default_config()
return jsonify({
"status": "success",
"message": "资源监控器未初始化",
"data": config['resource_limits']
})
elif request.method == 'POST':
# 更新资源限制
try:
data = request.json
resource_monitor = gd.get_value('resource_monitor', None)
if resource_monitor and data:
# 更新限制
resource_monitor.resource_limits.update(data)
# 更新任务管理器中的限制
task_manager.resource_limits.update(data)
logger.info(f"更新资源限制: {data}")
return jsonify({
"status": "success",
"message": "资源限制更新成功"
})
else:
return jsonify({
"status": "error",
"message": "资源监控器未初始化或请求数据无效"
}), 400
except Exception as e:
logger.error(f"更新资源限制失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"更新失败: {str(e)}"
}), 500
@app.route('/api/tasks/<task_id>/models', methods=['GET'])
def get_task_models(task_id):
"""获取任务中的模型配置"""
try:
task = task_manager.get_task_status(task_id)
if not task:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 获取任务中的模型配置
models_info = []
if task_id in task_manager.tasks:
task_info = task_manager.tasks[task_id]
config = task_info.get('config', {})
models_config = config.get('models', [])
for i, model_config in enumerate(models_config):
models_info.append({
'id': i,
'name': os.path.basename(model_config.get('path', '')).split('.')[0],
'path': model_config.get('path'),
'conf_thres': model_config.get('conf_thres'),
'tags': model_config.get('tags', {}),
'color': model_config.get('color'),
'enabled': model_config.get('enabled', True)
})
return jsonify({
"status": "success",
"data": {
"task_id": task_id,
"models": models_info
}
})
except Exception as e:
logger.error(f"获取任务模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取失败: {str(e)}"
}), 500
# server.py 添加以下路由
@app.route('/api/tasks/<task_id>/stream/status', methods=['GET'])
def get_task_stream_status(task_id):
"""获取任务推流状态"""
try:
from task_stream_manager import task_stream_manager
# 获取任务状态
task_status = task_manager.get_task_status(task_id)
if not task_status:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 获取推流信息
stream_info = task_stream_manager.get_all_task_streams_info().get(task_id, {})
# 合并信息
result = {
"task_id": task_id,
"task_status": task_status['status'],
"stream_enabled": task_status['config'].get('enable_push', False),
"stream_info": stream_info,
"push_url": task_status['config'].get('push_url', '')
}
return jsonify({
"status": "success",
"data": result
})
except Exception as e:
logger.error(f"获取任务推流状态失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取失败: {str(e)}"
}), 500
@app.route('/api/system/streams/info', methods=['GET'])
def get_all_streams_info():
"""获取所有任务推流信息"""
try:
from task_stream_manager import task_stream_manager
streams_info = task_stream_manager.get_all_task_streams_info()
return jsonify({
"status": "success",
"data": {
"total_streams": len(streams_info),
"active_streams": sum(1 for info in streams_info.values() if info.get('running', False)),
"streams": streams_info
}
})
except Exception as e:
logger.error(f"获取所有推流信息失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取失败: {str(e)}"
}), 500
# 在适当位置初始化上传管理器
upload_manager = None
def init_upload_manager(config):
"""初始化上传管理器"""
global upload_manager
if not upload_manager:
upload_manager = get_upload_manager(config['upload'])
return upload_manager
# 添加上传相关路由
@app.route('/api/models/upload/start', methods=['POST'])
def start_model_upload():
"""开始模型文件上传"""
try:
data = request.json
if not data:
return jsonify({"status": "error", "message": "请求数据不能为空"}), 400
filename = data.get('filename')
total_size = data.get('total_size')
encryption_key = data.get('encryption_key') # 可选,但建议提供
if not all([filename, total_size]):
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
# 初始化上传管理器
config = get_default_config()
upload_mgr = init_upload_manager(config)
# 获取分片大小配置
chunk_size = config['upload'].get('chunk_size', 5 * 1024 * 1024)
# 创建上传会话
result = upload_mgr.create_upload_session(
filename=filename,
total_size=total_size,
chunk_size=chunk_size,
encryption_key=encryption_key
)
if result['success']:
return jsonify({
"status": "success",
"message": "上传会话创建成功",
"data": {
"session_id": result['session_id'],
"total_chunks": result['total_chunks'],
"chunk_size": chunk_size
}
})
else:
return jsonify({
"status": "error",
"message": f"创建上传会话失败: {result.get('error', '未知错误')}"
}), 500
except Exception as e:
logger.error(f"开始上传失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"开始上传失败: {str(e)}"
}), 500
@app.route('/api/models/upload/chunk', methods=['POST'])
def upload_model_chunk():
"""上传模型文件分片"""
try:
# 获取表单数据
session_id = request.form.get('session_id')
chunk_index = int(request.form.get('chunk_index', 0))
if not session_id:
return jsonify({"status": "error", "message": "缺少session_id"}), 400
# 获取文件数据
if 'chunk' not in request.files:
return jsonify({"status": "error", "message": "未找到文件分片"}), 400
chunk_file = request.files['chunk']
chunk_data = chunk_file.read()
# 获取上传管理器
config = get_default_config()
upload_mgr = init_upload_manager(config)
# 上传分片
result = upload_mgr.upload_chunk(session_id, chunk_index, chunk_data)
if result['success']:
return jsonify({
"status": "success",
"message": "分片上传成功",
"data": {
"progress": result['progress'],
"received_chunks": result['received_chunks'],
"total_chunks": result['total_chunks'],
"chunk_index": chunk_index
}
})
else:
return jsonify({
"status": "error",
"message": f"分片上传失败: {result.get('error', '未知错误')}"
}), 500
except Exception as e:
logger.error(f"上传分片失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"上传分片失败: {str(e)}"
}), 500
@app.route('/api/models/upload/status/<session_id>', methods=['GET'])
def get_upload_status(session_id):
"""获取上传状态"""
try:
config = get_default_config()
upload_mgr = init_upload_manager(config)
result = upload_mgr.get_upload_status(session_id)
if result['success']:
return jsonify({
"status": "success",
"data": result['data']
})
else:
return jsonify({
"status": "error",
"message": result.get('error', '获取状态失败')
}), 404
except Exception as e:
logger.error(f"获取上传状态失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取上传状态失败: {str(e)}"
}), 500
@app.route('/api/models/upload/cancel/<session_id>', methods=['POST'])
def cancel_upload(session_id):
"""取消上传"""
try:
# 实现取消逻辑
config = get_default_config()
upload_mgr = init_upload_manager(config)
# 这里需要在上传管理器中添加取消功能
# upload_mgr.cancel_upload(session_id)
return jsonify({
"status": "success",
"message": f"上传已取消: {session_id}"
})
except Exception as e:
logger.error(f"取消上传失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"取消上传失败: {str(e)}"
}), 500
@app.route('/api/models/list', methods=['GET'])
def list_encrypted_models():
"""列出所有已加密的模型文件"""
try:
config = get_default_config()
encrypted_dir = config['upload']['encrypted_models_dir']
if not os.path.exists(encrypted_dir):
return jsonify({
"status": "success",
"data": {
"models": [],
"total": 0
}
})
models = []
for filename in os.listdir(encrypted_dir):
if filename.endswith('.enc'):
filepath = os.path.join(encrypted_dir, filename)
stats = os.stat(filepath)
models.append({
'filename': filename,
'path': filepath,
'size': stats.st_size,
'modified': stats.st_mtime,
'encrypted': True
})
return jsonify({
"status": "success",
"data": {
"models": models,
"total": len(models),
"encrypted_dir": encrypted_dir
}
})
except Exception as e:
logger.error(f"列出模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"列出模型失败: {str(e)}"
}), 500
# 在 server.py 中添加上传页面路由
@app.route('/model_upload')
def model_upload_page():
"""模型上传页面"""
return render_template("model_upload.html")
# server.py - 添加以下路由
from mandatory_model_crypto import ModelEncryptionService, validate_models_before_task, verify_single_model_api
@app.route('/api/models/process/start', methods=['POST'])
def start_model_processing():
"""开始模型处理流程:生成密钥 -> 上传模型"""
try:
data = request.json
# 选项1客户端提供自己的密钥
client_key = data.get('encryption_key')
# 选项2服务器生成密钥
generate_new = data.get('generate_key', False)
response_data = {}
if generate_new and not client_key:
# 服务器生成新密钥
key_info = ModelEncryptionService.generate_secure_key()
response_data['key_info'] = {
'key': key_info['key'],
'key_hash': key_info['key_hash'],
'short_hash': key_info['short_hash'],
'generated_by': 'server'
}
elif client_key:
# 验证客户端提供的密钥
key_valid, key_msg = ModelEncryptionService.validate_key_strength(client_key)
if not key_valid:
return jsonify({
"status": "error",
"message": f"密钥强度不足: {key_msg}"
}), 400
response_data['key_info'] = {
'key': client_key,
'key_hash': hashlib.sha256(client_key.encode()).hexdigest(),
'short_hash': hashlib.sha256(client_key.encode()).hexdigest()[:16],
'generated_by': 'client'
}
else:
return jsonify({
"status": "error",
"message": "请提供加密密钥或选择生成新密钥"
}), 400
# 生成上传令牌
upload_token = secrets.token_urlsafe(32)
response_data['upload_token'] = upload_token
response_data['token_expires'] = time.time() + 3600 # 1小时有效期
# 存储上传会话(简化版,生产环境应使用数据库)
upload_sessions = gd.get_or_create_dict('upload_sessions')
upload_sessions[upload_token] = {
'key_info': response_data['key_info'],
'created_at': time.time(),
'status': 'pending'
}
return jsonify({
"status": "success",
"message": "模型处理流程已启动",
"data": response_data
})
except Exception as e:
logger.error(f"启动模型处理流程失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"启动流程失败: {str(e)}"
}), 500
@app.route('/api/models/process/verify_key', methods=['POST'])
def verify_encryption_key():
"""验证加密密钥"""
try:
data = request.json
encryption_key = data.get('encryption_key')
model_path = data.get('model_path') # 可选,如果有具体模型
if not encryption_key:
return jsonify({
"status": "error",
"message": "请提供加密密钥"
}), 400
# 验证密钥强度
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
if not key_valid:
return jsonify({
"status": "error",
"message": key_msg,
"data": {
"valid": False,
"strength": "weak"
}
}), 400
response_data = {
"valid": True,
"strength": "strong",
"key_hash": hashlib.sha256(encryption_key.encode()).hexdigest()[:16]
}
# 如果有具体模型,尝试解密验证
if model_path:
verify_result = verify_single_model_api(model_path, encryption_key)
response_data['model_verification'] = verify_result
return jsonify({
"status": "success",
"message": "密钥验证成功",
"data": response_data
})
except Exception as e:
logger.error(f"验证密钥失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"验证失败: {str(e)}"
}), 500
@app.route('/api/models/process/validate_task', methods=['POST'])
def validate_task_models():
"""创建任务前的模型验证"""
try:
data = request.json
if not data:
return jsonify({
"status": "error",
"message": "请求数据不能为空"
}), 400
# 提取任务配置
task_config = {
'models': data.get('models', [])
}
# 验证所有模型
validation_result = validate_models_before_task(task_config)
if validation_result['success']:
return jsonify({
"status": "success",
"message": f"模型验证通过 ({validation_result['valid_models']}/{validation_result['total_models']})",
"data": validation_result
})
else:
return jsonify({
"status": "error",
"message": f"模型验证失败: {validation_result.get('error', '未知错误')}",
"data": validation_result
}), 400
except Exception as e:
logger.error(f"验证任务模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"验证失败: {str(e)}"
}), 500
@app.route('/api/models/encrypted/list_available', methods=['GET'])
def list_available_encrypted_models():
"""列出可用的加密模型(用于任务创建选择)"""
try:
config = get_default_config()
encrypted_dir = config['upload']['encrypted_models_dir']
if not os.path.exists(encrypted_dir):
return jsonify({
"status": "success",
"data": {
"models": [],
"total": 0
}
})
models = []
for filename in os.listdir(encrypted_dir):
if filename.endswith('.enc'):
filepath = os.path.join(encrypted_dir, filename)
stats = os.stat(filepath)
# 尝试读取模型基本信息(不验证密钥)
try:
with open(filepath, 'rb') as f:
encrypted_data = pickle.load(f)
model_info = {
'filename': filename,
'path': f"encrypted_models/{filename}", # 相对路径
'size': stats.st_size,
'modified': stats.st_mtime,
'encrypted': True,
'model_hash': encrypted_data.get('model_hash', '')[:16] if isinstance(encrypted_data,
dict) else '',
'version': encrypted_data.get('version', 'unknown') if isinstance(encrypted_data, dict) else '',
'original_size': encrypted_data.get('original_size', 0) if isinstance(encrypted_data,
dict) else 0
}
models.append(model_info)
except Exception as e:
logger.warning(f"读取模型信息失败 {filename}: {str(e)}")
models.append({
'filename': filename,
'path': f"encrypted_models/{filename}",
'size': stats.st_size,
'modified': stats.st_mtime,
'encrypted': True,
'error': '无法读取模型信息'
})
# 按修改时间排序
models.sort(key=lambda x: x['modified'], reverse=True)
return jsonify({
"status": "success",
"data": {
"models": models,
"total": len(models),
"directory": encrypted_dir
}
})
except Exception as e:
logger.error(f"列出可用模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"列出模型失败: {str(e)}"
}), 500
@app.route('/api/models/process/test_decrypt', methods=['POST'])
def test_model_decryption():
"""测试模型解密不实际加载YOLO模型"""
try:
data = request.json
model_path = data.get('model_path')
encryption_key = data.get('encryption_key')
if not all([model_path, encryption_key]):
return jsonify({
"status": "error",
"message": "请提供模型路径和加密密钥"
}), 400
# 构建完整路径
config = get_default_config()
encrypted_dir = config['upload']['encrypted_models_dir']
if not os.path.isabs(model_path):
model_filename = os.path.basename(model_path)
full_model_path = os.path.join(encrypted_dir, model_filename)
else:
full_model_path = model_path
# 检查文件是否存在
if not os.path.exists(full_model_path):
return jsonify({
"status": "error",
"message": f"模型文件不存在: {full_model_path}"
}), 404
# 测试解密
from mandatory_model_crypto import MandatoryModelValidator
validator = MandatoryModelValidator()
start_time = time.time()
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
elapsed_time = time.time() - start_time
if decrypt_result['success']:
return jsonify({
"status": "success",
"message": "解密测试成功",
"data": {
'success': True,
'model_hash': decrypt_result.get('model_hash', '')[:16],
'model_size': decrypt_result.get('original_size', 0),
'decryption_time': elapsed_time,
'file_path': full_model_path
}
})
else:
return jsonify({
"status": "error",
"message": f"解密测试失败: {decrypt_result.get('error', '未知错误')}",
"data": {
'success': False,
'error': decrypt_result.get('error', '未知错误'),
'decryption_time': elapsed_time
}
}), 400
except Exception as e:
logger.error(f"测试解密失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"测试失败: {str(e)}"
}), 500
# server.py - 添加任务创建页面路由
@app.route('/task_create')
def task_create_page():
"""任务创建页面"""
return render_template("task_create.html")
# server.py - 添加资源检查接口
@app.route('/api/system/check_resources', methods=['GET'])
def check_system_resources():
"""检查系统资源是否足够创建新任务"""
try:
import psutil
import torch
# 获取系统资源
cpu_percent = psutil.cpu_percent(interval=0.1)
memory_info = psutil.virtual_memory()
memory_percent = memory_info.percent
# GPU信息
gpu_info = []
gpu_available = torch.cuda.is_available()
if gpu_available:
for i in range(torch.cuda.device_count()):
gpu_memory_used = torch.cuda.memory_allocated(i) / 1024 ** 2 # MB
gpu_memory_total = torch.cuda.get_device_properties(i).total_memory / 1024 ** 2 # MB
gpu_memory_percent = (gpu_memory_used / gpu_memory_total) * 100 if gpu_memory_total > 0 else 0
gpu_info.append({
'id': i,
'name': torch.cuda.get_device_name(i),
'memory_used': gpu_memory_used,
'memory_total': gpu_memory_total,
'memory_percent': gpu_memory_percent
})
# 获取当前任务信息
active_tasks = task_manager.get_active_tasks_count()
max_tasks = task_manager.get_current_max_tasks()
# 资源阈值
cpu_threshold = 80 # CPU使用率阈值
memory_threshold = 85 # 内存使用率阈值
gpu_memory_threshold = 90 # GPU内存使用率阈值
# 检查资源状态
resources_ok = True
warnings = []
if cpu_percent > cpu_threshold:
resources_ok = False
warnings.append(f"CPU使用率过高: {cpu_percent:.1f}% > {cpu_threshold}%")
if memory_percent > memory_threshold:
resources_ok = False
warnings.append(f"内存使用率过高: {memory_percent:.1f}% > {memory_threshold}%")
if gpu_available and gpu_info:
max_gpu_memory = max(gpu['memory_percent'] for gpu in gpu_info)
if max_gpu_memory > gpu_memory_threshold:
warnings.append(f"GPU内存使用率过高: {max_gpu_memory:.1f}% > {gpu_memory_threshold}%")
# GPU内存高不是致命错误只是警告
if active_tasks >= max_tasks:
resources_ok = False
warnings.append(f"任务数达到上限: {active_tasks}/{max_tasks}")
return jsonify({
"status": "success",
"data": {
"resources_available": resources_ok,
"active_tasks": active_tasks,
"max_tasks": max_tasks,
"slots_available": max(0, max_tasks - active_tasks),
"cpu_percent": cpu_percent,
"memory_percent": memory_percent,
"memory_used": memory_info.used / 1024 ** 2, # MB
"memory_total": memory_info.total / 1024 ** 2, # MB
"gpu_available": gpu_available,
"gpu_info": gpu_info,
"warnings": warnings,
"thresholds": {
"cpu": cpu_threshold,
"memory": memory_threshold,
"gpu_memory": gpu_memory_threshold
},
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"检查系统资源失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"检查资源失败: {str(e)}"
}), 500
# 初始化函数,可在主程序中调用
def init_app():
"""初始化应用程序"""
with app.app_context():
gd.set_value('task_manager', task_manager)
logger.info("任务管理器初始化完成")
return app
# ======================= 公共服务器功能 =======================
# 内部系统状态
internal_system_status = {
"is_connected": False,
"devices": {}
}
# 连接管理
connected_clients = {}
# WebSocket事件 - 公共服务器
@socketio.on('connect', namespace='/ws/internal')
def handle_internal_connect():
"""内部系统连接"""
client_id = "internal_system"
connected_clients[client_id] = request.sid
internal_system_status["is_connected"] = True
pub_logger.info(f"Internal system connected: {request.sid}")
@socketio.on('disconnect', namespace='/ws/internal')
def handle_internal_disconnect():
"""内部系统断开连接"""
client_id = "internal_system"
if client_id in connected_clients:
del connected_clients[client_id]
internal_system_status["is_connected"] = False
pub_logger.info("Internal system disconnected")
@socketio.on('message', namespace='/ws/internal')
def handle_internal_message(data):
"""处理内部系统消息"""
try:
message = json.loads(data)
if message.get("type") == "status_update":
# 更新设备状态
device_id = message.get("device_id")
if device_id:
internal_system_status["devices"][device_id] = message
# 广播状态更新给客户端
socketio.emit('message', json.dumps(message), namespace='/ws/client')
pub_logger.info(f"Status updated for device {device_id}: {message}")
elif message.get("type") == "command_response":
# 转发命令响应给客户端
socketio.emit('message', json.dumps(message), namespace='/ws/client')
pub_logger.info(f"Command response: {message}")
except json.JSONDecodeError:
pub_logger.error(f"Invalid JSON from internal system: {data}")
@socketio.on('connect', namespace='/ws/client')
def handle_client_connect():
"""客户端连接"""
client_id = request.args.get('client_id', request.sid)
connected_clients[client_id] = request.sid
pub_logger.info(f"Client connected: {client_id} ({request.sid})")
# 发送当前状态给新连接的客户端
for device_id, status in internal_system_status["devices"].items():
socketio.emit('message', json.dumps(status), namespace='/ws/client', room=request.sid)
@socketio.on('disconnect', namespace='/ws/client')
def handle_client_disconnect():
"""客户端断开连接"""
# 找到对应的客户端ID
client_id = None
for id, sid in connected_clients.items():
if sid == request.sid:
client_id = id
break
if client_id:
del connected_clients[client_id]
pub_logger.info(f"Client disconnected: {client_id}")
@socketio.on('message', namespace='/ws/client')
def handle_client_message(data):
"""处理客户端消息"""
try:
command = json.loads(data)
# 转发命令给内部系统
if internal_system_status["is_connected"] and "internal_system" in connected_clients:
socketio.emit('message', data, namespace='/ws/internal', room=connected_clients["internal_system"])
pub_logger.info(f"Forwarded command from client: {command}")
else:
# 内部系统未连接,返回错误
error_response = {
"type": "command_response",
"status": "error",
"message": "Internal system not connected",
"command": command
}
socketio.emit('message', json.dumps(error_response), namespace='/ws/client', room=request.sid)
except json.JSONDecodeError:
pub_logger.error(f"Invalid JSON from client: {data}")
# API路由 - 公共服务器
@app.route('/api/devices')
def get_devices():
"""获取设备列表"""
return jsonify({
"code": 200,
"msg": "查询成功",
"data": list(internal_system_status["devices"].values())
})
@app.route('/api/command', methods=['POST'])
def send_command():
"""发送命令"""
if not internal_system_status["is_connected"]:
return jsonify({
"code": 503,
"msg": "内部系统未连接",
"data": None
})
try:
command = request.json
# 发送命令给内部系统
if "internal_system" in connected_clients:
socketio.emit('message', json.dumps(command), namespace='/ws/internal', room=connected_clients["internal_system"])
pub_logger.info(f"Command sent: {command}")
return jsonify({
"code": 200,
"msg": "命令已发送",
"data": command
})
else:
return jsonify({
"code": 503,
"msg": "内部系统未连接",
"data": None
})
except Exception as e:
pub_logger.error(f"Error sending command: {str(e)}")
return jsonify({
"code": 500,
"msg": f"发送命令失败: {str(e)}",
"data": None
})