1636 lines
55 KiB
Python
1636 lines
55 KiB
Python
# server.py
|
||
import hashlib
|
||
import os
|
||
import pickle
|
||
import secrets
|
||
from datetime import datetime
|
||
import json
|
||
import logging
|
||
|
||
import torch
|
||
from flask import Flask, jsonify, request, render_template
|
||
from flask_socketio import SocketIO, emit, join_room, leave_room
|
||
from config import get_default_config
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
from model_upload_manager import get_upload_manager
|
||
from task_manager import task_manager # 导入任务管理器
|
||
from global_data import gd
|
||
from log import logger
|
||
import time
|
||
import traceback
|
||
from mandatory_model_crypto import ModelEncryptionService, validate_models_before_task, verify_single_model_api
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
handlers=[
|
||
logging.StreamHandler(),
|
||
logging.FileHandler("public_server.log", encoding="utf-8")
|
||
]
|
||
)
|
||
pub_logger = logging.getLogger("PublicServer")
|
||
|
||
# Flask初始化
|
||
app = Flask(__name__, static_url_path='/static')
|
||
CORS(app)
|
||
|
||
# 处理Windows平台兼容性问题
|
||
import platform
|
||
|
||
if platform.system() == 'Windows':
|
||
socketio = SocketIO(app,
|
||
cors_allowed_origins="*",
|
||
async_mode='threading',
|
||
allow_unsafe_werkzeug=True, #
|
||
max_http_buffer_size=5 * 1024 * 1024)
|
||
else:
|
||
socketio = SocketIO(app,
|
||
cors_allowed_origins="*",
|
||
async_mode='threading',
|
||
max_http_buffer_size=5 * 1024 * 1024)
|
||
|
||
_initialized = False
|
||
|
||
|
||
@app.before_request
|
||
def initialize_once():
|
||
global _initialized
|
||
if not _initialized:
|
||
with app.app_context():
|
||
gd.set_value('task_manager', task_manager)
|
||
logger.info("任务管理器初始化完成")
|
||
_initialized = True
|
||
|
||
|
||
# ======================= Flask路由 =======================
|
||
@app.route('/')
|
||
def task_management():
|
||
"""任务管理页面"""
|
||
return render_template("task_management.html")
|
||
|
||
|
||
@app.route('/video_player')
|
||
def video_player():
|
||
"""视频播放页面"""
|
||
return render_template("flv2.html")
|
||
|
||
|
||
# server.py - 修改 create_task 函数
|
||
|
||
@app.route('/api/tasks/create', methods=['POST'])
|
||
def create_task():
|
||
"""创建新任务 - 强制模型加密和密钥验证"""
|
||
try:
|
||
config = get_default_config()
|
||
config['socketIO'] = socketio
|
||
|
||
if not request.json:
|
||
return jsonify({"status": "error", "message": "请求体不能为空"}), 400
|
||
|
||
data = request.json
|
||
import json
|
||
logger.info(f"请求参数: {json.dumps(data, indent=2, ensure_ascii=False)}")
|
||
logger.info(f"收到创建任务请求: {data.get('taskname', '未命名')}")
|
||
|
||
# 验证必须的参数
|
||
if 'rtmp_url' not in data:
|
||
return jsonify({"status": "error", "message": "必须提供rtmp_url"}), 400
|
||
|
||
if 'models' not in data or not isinstance(data['models'], list):
|
||
return jsonify({"status": "error", "message": "必须提供models列表"}), 400
|
||
|
||
# 检查模型数量
|
||
if len(data['models']) == 0:
|
||
return jsonify({"status": "error", "message": "models列表不能为空"}), 400
|
||
|
||
# ================= 关键修改:创建任务前的模型验证 =================
|
||
logger.info("开始创建任务前的模型验证...")
|
||
|
||
# 1. 检查所有模型是否都有加密密钥
|
||
for i, model_data in enumerate(data['models']):
|
||
if 'encryption_key' not in model_data:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型 {i} 必须提供encryption_key"
|
||
}), 400
|
||
|
||
# 2. 验证所有模型的密钥
|
||
task_config = {
|
||
'models': data['models']
|
||
}
|
||
|
||
validation_result = validate_models_before_task(task_config)
|
||
|
||
if not validation_result['success']:
|
||
logger.error(f"模型验证失败: {validation_result.get('error', '未知错误')}")
|
||
|
||
# 提供详细的验证结果
|
||
error_details = []
|
||
for result in validation_result.get('validation_results', []):
|
||
if not result.get('key_valid', False):
|
||
error_details.append(f"模型 {result['model_index']}: {result.get('error', '验证失败')}")
|
||
|
||
error_message = validation_result.get('error', '模型验证失败')
|
||
if error_details:
|
||
error_message += f" | 详情: {', '.join(error_details)}"
|
||
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": error_message,
|
||
"data": validation_result
|
||
}), 400
|
||
|
||
logger.info(f"模型验证通过: {validation_result['valid_models']}/{validation_result['total_models']} 个模型有效")
|
||
# ================= 验证结束 =================
|
||
|
||
# 更新配置
|
||
config['rtmp']['url'] = data['rtmp_url']
|
||
|
||
if 'uavType' in data and data['uavType']:
|
||
config['task']['uavType'] = data['uavType']
|
||
|
||
if 'algoInstancesName' in data and data['algoInstancesName']:
|
||
config['task']['algoInstancesName'] = data['algoInstancesName']
|
||
|
||
if 'push_url' in data and data['push_url']:
|
||
config['push']['url'] = data['push_url']
|
||
|
||
if 'taskname' in data:
|
||
config['task']['taskname'] = data['taskname']
|
||
else:
|
||
config['task']['taskname'] = f"task_{int(time.time())}"
|
||
|
||
if 'AlgoId' in data:
|
||
config['task']['aiid'] = data['AlgoId']
|
||
|
||
# 处理多模型配置 - 使用已验证的模型
|
||
config['models'] = []
|
||
|
||
for i, model_data in enumerate(data['models']):
|
||
# 此时密钥已验证通过
|
||
encryption_key = model_data['encryption_key']
|
||
# 使用加密模型文件名(相对路径)
|
||
model_path = model_data.get('path', f'model_{i}.enc')
|
||
if not model_path.startswith('encrypted_models/'):
|
||
model_filename = os.path.basename(model_path)
|
||
if not model_filename.endswith('.enc'):
|
||
model_filename += '.enc'
|
||
model_path = f"{model_filename}"
|
||
|
||
# 构建模型配置
|
||
model_config = {
|
||
'path': model_path,
|
||
'encryption_key': encryption_key,
|
||
'encrypted': True,
|
||
'tags': model_data.get('tags', {}),
|
||
'conf_thres': float(model_data.get('conf_thres', 0.25)),
|
||
'iou_thres': float(model_data.get('iou_thres', 0.45)),
|
||
'imgsz': max(128, min(1920, int(model_data.get('imgsz', 640)))),
|
||
'color': model_data.get('color'),
|
||
'line_width': int(model_data.get('line_width', 1)),
|
||
'device': model_data.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu'),
|
||
'half': model_data.get('half', True),
|
||
'enabled': model_data.get('enabled', True),
|
||
# 注意:不再需要 download_url,模型从本地加载
|
||
}
|
||
|
||
config['models'].append(model_config)
|
||
logger.info(f"添加已验证的加密模型 {i}: {model_path}")
|
||
|
||
# 在创建任务前清理已停止的任务,释放资源
|
||
logger.info("创建任务前清理已停止的任务...")
|
||
cleaned_count = task_manager.cleanup_stopped_tasks()
|
||
if cleaned_count > 0:
|
||
logger.info(f"已清理 {cleaned_count} 个已停止的任务")
|
||
|
||
# 检查系统资源
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
max_tasks = task_manager.get_current_max_tasks()
|
||
|
||
if active_tasks >= max_tasks:
|
||
# 尝试强制清理一些资源
|
||
logger.warning(f"当前活动任务数 {active_tasks} 已达到限制 {max_tasks},尝试强制清理")
|
||
|
||
# 清理所有非运行状态的任务
|
||
all_tasks = task_manager.get_all_tasks()
|
||
force_cleaned = 0
|
||
for task in all_tasks:
|
||
if task['status'] not in ['running', 'starting', 'creating']:
|
||
if task_manager.cleanup_task(task['task_id']):
|
||
force_cleaned += 1
|
||
|
||
if force_cleaned > 0:
|
||
logger.info(f"强制清理了 {force_cleaned} 个非运行状态的任务")
|
||
|
||
# 再次检查
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
if active_tasks >= max_tasks:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"已达到最大并发任务数限制 ({max_tasks}),请等待其他任务完成或停止部分任务"
|
||
}), 503
|
||
|
||
# 创建任务
|
||
logger.info(f"开始创建任务,包含 {len(config['models'])} 个已验证的加密模型...")
|
||
|
||
try:
|
||
task_id = task_manager.create_task(config, socketio)
|
||
logger.info(f"任务创建成功,ID: {task_id}")
|
||
except Exception as e:
|
||
logger.error(f"任务创建失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务创建失败: {str(e)}"
|
||
}), 500
|
||
|
||
# 启动任务
|
||
logger.info(f"启动任务 {task_id}...")
|
||
success = task_manager.start_task(task_id)
|
||
|
||
if success:
|
||
logger.info(f"任务启动成功: {task_id}")
|
||
|
||
# 获取任务详细信息
|
||
task_status = task_manager.get_task_status(task_id)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "任务创建并启动成功",
|
||
"data": {
|
||
"task_id": task_id,
|
||
"models_count": len(config['models']),
|
||
"encryption_required": True,
|
||
"key_validated": True,
|
||
"validation_result": validation_result,
|
||
"task_info": task_status
|
||
}
|
||
})
|
||
else:
|
||
logger.error(f"任务启动失败: {task_id}")
|
||
# 如果启动失败,清理任务
|
||
task_manager.cleanup_task(task_id)
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "任务创建成功但启动失败",
|
||
"task_id": task_id
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建任务失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"创建任务失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/system/resources', methods=['GET'])
|
||
def get_system_resources():
|
||
"""获取系统资源使用情况"""
|
||
try:
|
||
resources = gd.get_value('system_resources')
|
||
max_tasks = gd.get_value('max_concurrent_tasks', 5)
|
||
|
||
if not resources:
|
||
# 实时获取资源
|
||
import psutil
|
||
resources = {
|
||
'cpu_percent': psutil.cpu_percent(),
|
||
'memory_percent': psutil.virtual_memory().percent,
|
||
'memory_used': psutil.virtual_memory().used // (1024 * 1024),
|
||
'memory_total': psutil.virtual_memory().total // (1024 * 1024),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
# 获取任务统计
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
total_tasks = len(task_manager.tasks)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"resources": resources,
|
||
"tasks": {
|
||
"active": active_tasks,
|
||
"total": total_tasks,
|
||
"max_concurrent": max_tasks
|
||
}
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取系统资源失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取系统资源失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/encrypt', methods=['POST'])
|
||
def encrypt_model():
|
||
"""加密模型文件 - 强制加密"""
|
||
try:
|
||
config = get_default_config()
|
||
data = request.json
|
||
model_path = data.get('model_path')
|
||
output_path = data.get('output_path')
|
||
password = data.get('password')
|
||
download_url = data.get('download_url')
|
||
if not all([model_path, output_path, password, download_url]):
|
||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||
local_path = os.path.join(config['model_path'], model_path)
|
||
output_path = os.path.join(config['model_path'], output_path)
|
||
# 使用强制加密器
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
encryptor = MandatoryModelEncryptor()
|
||
|
||
result = encryptor.encrypt_model(local_path, output_path, password, require_encryption=True)
|
||
|
||
if result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"模型加密成功: {output_path}",
|
||
"data": {
|
||
"model_hash": result.get('model_hash'),
|
||
"key_hash": result.get('key_hash'),
|
||
"output_path": result.get('output_path')
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型加密失败: {result.get('error', '未知错误')}"
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
logger.error(f"加密模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"加密模型失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/verify_key', methods=['POST'])
|
||
def verify_model_key():
|
||
"""验证模型密钥"""
|
||
try:
|
||
data = request.json
|
||
model_path = data.get('model_path')
|
||
encryption_key = data.get('encryption_key')
|
||
|
||
if not all([model_path, encryption_key]):
|
||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||
|
||
# 检查模型文件是否存在
|
||
full_path = os.path.join('encrypted_models', os.path.basename(model_path))
|
||
if not os.path.exists(full_path):
|
||
return jsonify({"status": "error", "message": f"模型文件不存在: {full_path}"}), 400
|
||
|
||
# 使用强制加密器验证密钥
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
encryptor = MandatoryModelEncryptor()
|
||
|
||
# 检查是否为正确加密的模型
|
||
if not encryptor.is_properly_encrypted(full_path):
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "模型文件未正确加密",
|
||
"valid": False
|
||
}), 400
|
||
|
||
# 验证密钥
|
||
verify_result = encryptor.decrypt_model(full_path, encryption_key, verify_key=True)
|
||
|
||
if verify_result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "密钥验证成功",
|
||
"data": {
|
||
"valid": True,
|
||
"model_hash": verify_result.get('model_hash', '')[:16],
|
||
"model_size": verify_result.get('original_size', 0)
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"密钥验证失败: {verify_result.get('error', '未知错误')}",
|
||
"data": {
|
||
"valid": False,
|
||
"error": verify_result.get('error', '未知错误')
|
||
}
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"验证密钥失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"验证密钥失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/generate_key', methods=['POST'])
|
||
def generate_secure_encryption_key():
|
||
"""生成安全的加密密钥"""
|
||
try:
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
|
||
# 生成密钥
|
||
encryptor = MandatoryModelEncryptor()
|
||
key_info = encryptor.generate_secure_key()
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "加密密钥生成成功",
|
||
"data": {
|
||
"key": key_info['key'],
|
||
"key_hash": key_info['key_hash'],
|
||
"short_hash": key_info['short_hash'],
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成加密密钥失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"生成加密密钥失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# server.py 中的 stop_detection 路由修改
|
||
|
||
@app.route('/api/tasks/<task_id>/stop', methods=['POST'])
|
||
def stop_task(task_id):
|
||
"""停止指定任务"""
|
||
try:
|
||
logger.info(f"接收到停止任务请求: {task_id}")
|
||
|
||
# 获取请求参数
|
||
data = request.json or {}
|
||
force = data.get('force', False)
|
||
timeout = data.get('timeout', 10) # 默认10秒超时
|
||
|
||
# 调用任务管理器停止任务
|
||
success = task_manager.stop_task(task_id, force=force)
|
||
|
||
if success:
|
||
logger.info(f"任务停止成功: {task_id}")
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"任务停止成功: {task_id}",
|
||
"task_id": task_id,
|
||
"stopped": True
|
||
})
|
||
else:
|
||
logger.warning(f"停止任务失败: {task_id}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"停止任务失败: {task_id}",
|
||
"task_id": task_id,
|
||
"stopped": False
|
||
}), 500
|
||
except Exception as e:
|
||
logger.error(f"停止任务异常: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"停止任务异常: {str(e)}",
|
||
"task_id": task_id
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/status', methods=['GET'])
|
||
def get_task_status(task_id):
|
||
"""获取任务状态(仅支持多模型)"""
|
||
status = task_manager.get_task_status(task_id)
|
||
if status:
|
||
# 增强返回信息,包含模型详情
|
||
enhanced_status = {
|
||
'task_id': status['task_id'],
|
||
'status': status['status'],
|
||
'config': status['config'],
|
||
'models': status.get('models', []), # 直接返回模型列表
|
||
'stats': status['stats'],
|
||
'created_at': status['created_at']
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": enhanced_status
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
|
||
@app.route('/api/tasks', methods=['GET'])
|
||
def get_all_tasks():
|
||
"""获取所有任务(仅支持多模型)"""
|
||
tasks = task_manager.get_all_tasks()
|
||
# 打印tasks内容
|
||
import json
|
||
print("获取到的任务列表:")
|
||
# print(json.dumps(tasks, indent=2, ensure_ascii=False))
|
||
|
||
# 增强任务信息
|
||
enhanced_tasks = []
|
||
for task in tasks:
|
||
enhanced_task = {
|
||
'task_id': task['task_id'],
|
||
'status': task['status'],
|
||
'algoInstancesName': task['config']['algoInstancesName'],
|
||
'uavType': task['config']['uavType'],
|
||
'config': {
|
||
'rtmp_url': task['config']['rtmp_url'],
|
||
'taskname': task['config']['taskname'],
|
||
'push_url': task['config'].get('push_url', ''),
|
||
'enable_push': task['config'].get('enable_push', False)
|
||
},
|
||
'models': task.get('models', []), # 直接返回模型列表
|
||
'stats': task['stats'],
|
||
'performance': task.get('performance', {
|
||
'fps': 0,
|
||
'avg_process_time': 0,
|
||
'latency': 0,
|
||
'last_fps': 0
|
||
}), # 性能指标
|
||
'created_at': task['created_at']
|
||
}
|
||
enhanced_tasks.append(enhanced_task)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"tasks": enhanced_tasks,
|
||
"total": len(tasks),
|
||
"active": task_manager.get_active_tasks_count(),
|
||
"models_count": sum(len(t.get('models', [])) for t in enhanced_tasks)
|
||
}
|
||
})
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/cleanup', methods=['POST'])
|
||
def cleanup_task(task_id):
|
||
"""清理任务资源"""
|
||
success = task_manager.cleanup_task(task_id)
|
||
if success:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"任务资源已清理: {task_id}"
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"清理任务失败: {task_id}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/cleanup_all', methods=['POST'])
|
||
def cleanup_all_tasks():
|
||
"""清理所有任务"""
|
||
cleaned_count = task_manager.cleanup_all_tasks()
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"已清理 {cleaned_count} 个任务",
|
||
"cleaned_count": cleaned_count
|
||
})
|
||
|
||
|
||
@app.route('/api/tasks/cleanup_stopped', methods=['POST'])
|
||
def cleanup_stopped_tasks():
|
||
"""清理所有已停止的任务"""
|
||
cleaned_count = task_manager.cleanup_stopped_tasks()
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"已清理 {cleaned_count} 个已停止的任务",
|
||
"cleaned_count": cleaned_count
|
||
})
|
||
|
||
|
||
@app.route('/api/system/status', methods=['GET'])
|
||
def get_system_status():
|
||
"""获取系统状态"""
|
||
try:
|
||
import psutil
|
||
|
||
system_info = {
|
||
"cpu_percent": psutil.cpu_percent(),
|
||
"memory_percent": psutil.virtual_memory().percent,
|
||
"disk_percent": psutil.disk_usage('/').percent,
|
||
"active_tasks": task_manager.get_active_tasks_count(),
|
||
"total_tasks": len(task_manager.tasks),
|
||
"max_concurrent_tasks": task_manager.get_current_max_tasks()
|
||
}
|
||
|
||
# GPU信息(如果可用)
|
||
try:
|
||
import GPUtil
|
||
gpus = GPUtil.getGPUs()
|
||
gpu_info = []
|
||
for gpu in gpus:
|
||
gpu_info.append({
|
||
"id": gpu.id,
|
||
"name": gpu.name,
|
||
"load": gpu.load * 100,
|
||
"memory_used": gpu.memoryUsed,
|
||
"memory_total": gpu.memoryTotal,
|
||
"temperature": gpu.temperature
|
||
})
|
||
system_info["gpus"] = gpu_info
|
||
except ImportError:
|
||
# 如果没有安装GPUtil,尝试使用torch获取GPU信息
|
||
if torch.cuda.is_available():
|
||
gpu_info = []
|
||
for i in range(torch.cuda.device_count()):
|
||
gpu_info.append({
|
||
"id": i,
|
||
"name": torch.cuda.get_device_name(i),
|
||
"memory_used": torch.cuda.memory_allocated(i) / 1024 ** 2,
|
||
"memory_total": torch.cuda.get_device_properties(i).total_memory / 1024 ** 2
|
||
})
|
||
system_info["gpus"] = gpu_info
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": system_info
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取系统状态失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取系统状态失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# WebSocket事件 - 按任务ID发送
|
||
@socketio.on('connect')
|
||
def handle_connect():
|
||
logger.info(f"Socket客户端已连接: {request.sid}")
|
||
|
||
|
||
@socketio.on('disconnect')
|
||
def handle_disconnect():
|
||
logger.info(f"Socket客户端断开: {request.sid}")
|
||
|
||
|
||
@socketio.on('subscribe_task')
|
||
def handle_subscribe_task(data):
|
||
"""订阅特定任务的WebSocket消息"""
|
||
task_id = data.get('task_id')
|
||
if task_id:
|
||
# 这里可以记录客户端订阅关系
|
||
logger.info(f"客户端 {request.sid} 订阅任务: {task_id}")
|
||
return {"status": "subscribed", "task_id": task_id}
|
||
|
||
return {"status": "error", "message": "需要提供task_id"}
|
||
|
||
|
||
@app.route('/api/system/resource_limits', methods=['GET', 'POST'])
|
||
def manage_resource_limits():
|
||
"""获取或设置资源限制"""
|
||
if request.method == 'GET':
|
||
# 获取当前资源限制
|
||
resource_monitor = gd.get_value('resource_monitor')
|
||
if resource_monitor:
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": resource_monitor.resource_limits
|
||
})
|
||
else:
|
||
config = get_default_config()
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "资源监控器未初始化",
|
||
"data": config['resource_limits']
|
||
})
|
||
|
||
elif request.method == 'POST':
|
||
# 更新资源限制
|
||
try:
|
||
data = request.json
|
||
resource_monitor = gd.get_value('resource_monitor', None)
|
||
|
||
if resource_monitor and data:
|
||
# 更新限制
|
||
resource_monitor.resource_limits.update(data)
|
||
|
||
# 更新任务管理器中的限制
|
||
task_manager.resource_limits.update(data)
|
||
|
||
logger.info(f"更新资源限制: {data}")
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "资源限制更新成功"
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "资源监控器未初始化或请求数据无效"
|
||
}), 400
|
||
except Exception as e:
|
||
logger.error(f"更新资源限制失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"更新失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/models', methods=['GET'])
|
||
def get_task_models(task_id):
|
||
"""获取任务中的模型配置"""
|
||
try:
|
||
task = task_manager.get_task_status(task_id)
|
||
if not task:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
# 获取任务中的模型配置
|
||
models_info = []
|
||
if task_id in task_manager.tasks:
|
||
task_info = task_manager.tasks[task_id]
|
||
config = task_info.get('config', {})
|
||
models_config = config.get('models', [])
|
||
|
||
for i, model_config in enumerate(models_config):
|
||
models_info.append({
|
||
'id': i,
|
||
'name': os.path.basename(model_config.get('path', '')).split('.')[0],
|
||
'path': model_config.get('path'),
|
||
'conf_thres': model_config.get('conf_thres'),
|
||
'tags': model_config.get('tags', {}),
|
||
'color': model_config.get('color'),
|
||
'enabled': model_config.get('enabled', True)
|
||
})
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"task_id": task_id,
|
||
"models": models_info
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取任务模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# server.py 添加以下路由
|
||
|
||
@app.route('/api/tasks/<task_id>/stream/status', methods=['GET'])
|
||
def get_task_stream_status(task_id):
|
||
"""获取任务推流状态"""
|
||
try:
|
||
from task_stream_manager import task_stream_manager
|
||
|
||
# 获取任务状态
|
||
task_status = task_manager.get_task_status(task_id)
|
||
if not task_status:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
# 获取推流信息
|
||
stream_info = task_stream_manager.get_all_task_streams_info().get(task_id, {})
|
||
|
||
# 合并信息
|
||
result = {
|
||
"task_id": task_id,
|
||
"task_status": task_status['status'],
|
||
"stream_enabled": task_status['config'].get('enable_push', False),
|
||
"stream_info": stream_info,
|
||
"push_url": task_status['config'].get('push_url', '')
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": result
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取任务推流状态失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/system/streams/info', methods=['GET'])
|
||
def get_all_streams_info():
|
||
"""获取所有任务推流信息"""
|
||
try:
|
||
from task_stream_manager import task_stream_manager
|
||
|
||
streams_info = task_stream_manager.get_all_task_streams_info()
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"total_streams": len(streams_info),
|
||
"active_streams": sum(1 for info in streams_info.values() if info.get('running', False)),
|
||
"streams": streams_info
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取所有推流信息失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# 在适当位置初始化上传管理器
|
||
upload_manager = None
|
||
|
||
|
||
def init_upload_manager(config):
|
||
"""初始化上传管理器"""
|
||
global upload_manager
|
||
if not upload_manager:
|
||
upload_manager = get_upload_manager(config['upload'])
|
||
return upload_manager
|
||
|
||
|
||
# 添加上传相关路由
|
||
@app.route('/api/models/upload/start', methods=['POST'])
|
||
def start_model_upload():
|
||
"""开始模型文件上传"""
|
||
try:
|
||
data = request.json
|
||
if not data:
|
||
return jsonify({"status": "error", "message": "请求数据不能为空"}), 400
|
||
|
||
filename = data.get('filename')
|
||
total_size = data.get('total_size')
|
||
encryption_key = data.get('encryption_key') # 可选,但建议提供
|
||
|
||
if not all([filename, total_size]):
|
||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||
|
||
# 初始化上传管理器
|
||
config = get_default_config()
|
||
upload_mgr = init_upload_manager(config)
|
||
|
||
# 获取分片大小配置
|
||
chunk_size = config['upload'].get('chunk_size', 5 * 1024 * 1024)
|
||
|
||
# 创建上传会话
|
||
result = upload_mgr.create_upload_session(
|
||
filename=filename,
|
||
total_size=total_size,
|
||
chunk_size=chunk_size,
|
||
encryption_key=encryption_key
|
||
)
|
||
|
||
if result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "上传会话创建成功",
|
||
"data": {
|
||
"session_id": result['session_id'],
|
||
"total_chunks": result['total_chunks'],
|
||
"chunk_size": chunk_size
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"创建上传会话失败: {result.get('error', '未知错误')}"
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
logger.error(f"开始上传失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"开始上传失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/upload/chunk', methods=['POST'])
|
||
def upload_model_chunk():
|
||
"""上传模型文件分片"""
|
||
try:
|
||
# 获取表单数据
|
||
session_id = request.form.get('session_id')
|
||
chunk_index = int(request.form.get('chunk_index', 0))
|
||
|
||
if not session_id:
|
||
return jsonify({"status": "error", "message": "缺少session_id"}), 400
|
||
|
||
# 获取文件数据
|
||
if 'chunk' not in request.files:
|
||
return jsonify({"status": "error", "message": "未找到文件分片"}), 400
|
||
|
||
chunk_file = request.files['chunk']
|
||
chunk_data = chunk_file.read()
|
||
|
||
# 获取上传管理器
|
||
config = get_default_config()
|
||
upload_mgr = init_upload_manager(config)
|
||
|
||
# 上传分片
|
||
result = upload_mgr.upload_chunk(session_id, chunk_index, chunk_data)
|
||
|
||
if result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "分片上传成功",
|
||
"data": {
|
||
"progress": result['progress'],
|
||
"received_chunks": result['received_chunks'],
|
||
"total_chunks": result['total_chunks'],
|
||
"chunk_index": chunk_index
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"分片上传失败: {result.get('error', '未知错误')}"
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
logger.error(f"上传分片失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"上传分片失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/upload/status/<session_id>', methods=['GET'])
|
||
def get_upload_status(session_id):
|
||
"""获取上传状态"""
|
||
try:
|
||
config = get_default_config()
|
||
upload_mgr = init_upload_manager(config)
|
||
|
||
result = upload_mgr.get_upload_status(session_id)
|
||
|
||
if result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": result['data']
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": result.get('error', '获取状态失败')
|
||
}), 404
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取上传状态失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取上传状态失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/upload/cancel/<session_id>', methods=['POST'])
|
||
def cancel_upload(session_id):
|
||
"""取消上传"""
|
||
try:
|
||
# 实现取消逻辑
|
||
config = get_default_config()
|
||
upload_mgr = init_upload_manager(config)
|
||
|
||
# 这里需要在上传管理器中添加取消功能
|
||
# upload_mgr.cancel_upload(session_id)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"上传已取消: {session_id}"
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"取消上传失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"取消上传失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/list', methods=['GET'])
|
||
def list_encrypted_models():
|
||
"""列出所有已加密的模型文件"""
|
||
try:
|
||
config = get_default_config()
|
||
encrypted_dir = config['upload']['encrypted_models_dir']
|
||
|
||
if not os.path.exists(encrypted_dir):
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"models": [],
|
||
"total": 0
|
||
}
|
||
})
|
||
|
||
models = []
|
||
for filename in os.listdir(encrypted_dir):
|
||
if filename.endswith('.enc'):
|
||
filepath = os.path.join(encrypted_dir, filename)
|
||
stats = os.stat(filepath)
|
||
models.append({
|
||
'filename': filename,
|
||
'path': filepath,
|
||
'size': stats.st_size,
|
||
'modified': stats.st_mtime,
|
||
'encrypted': True
|
||
})
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"models": models,
|
||
"total": len(models),
|
||
"encrypted_dir": encrypted_dir
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"列出模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"列出模型失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# 在 server.py 中添加上传页面路由
|
||
@app.route('/model_upload')
|
||
def model_upload_page():
|
||
"""模型上传页面"""
|
||
return render_template("model_upload.html")
|
||
|
||
|
||
# server.py - 添加以下路由
|
||
|
||
from mandatory_model_crypto import ModelEncryptionService, validate_models_before_task, verify_single_model_api
|
||
|
||
|
||
@app.route('/api/models/process/start', methods=['POST'])
|
||
def start_model_processing():
|
||
"""开始模型处理流程:生成密钥 -> 上传模型"""
|
||
try:
|
||
data = request.json
|
||
|
||
# 选项1:客户端提供自己的密钥
|
||
client_key = data.get('encryption_key')
|
||
|
||
# 选项2:服务器生成密钥
|
||
generate_new = data.get('generate_key', False)
|
||
|
||
response_data = {}
|
||
|
||
if generate_new and not client_key:
|
||
# 服务器生成新密钥
|
||
key_info = ModelEncryptionService.generate_secure_key()
|
||
response_data['key_info'] = {
|
||
'key': key_info['key'],
|
||
'key_hash': key_info['key_hash'],
|
||
'short_hash': key_info['short_hash'],
|
||
'generated_by': 'server'
|
||
}
|
||
|
||
elif client_key:
|
||
# 验证客户端提供的密钥
|
||
key_valid, key_msg = ModelEncryptionService.validate_key_strength(client_key)
|
||
if not key_valid:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"密钥强度不足: {key_msg}"
|
||
}), 400
|
||
|
||
response_data['key_info'] = {
|
||
'key': client_key,
|
||
'key_hash': hashlib.sha256(client_key.encode()).hexdigest(),
|
||
'short_hash': hashlib.sha256(client_key.encode()).hexdigest()[:16],
|
||
'generated_by': 'client'
|
||
}
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "请提供加密密钥或选择生成新密钥"
|
||
}), 400
|
||
|
||
# 生成上传令牌
|
||
upload_token = secrets.token_urlsafe(32)
|
||
response_data['upload_token'] = upload_token
|
||
response_data['token_expires'] = time.time() + 3600 # 1小时有效期
|
||
|
||
# 存储上传会话(简化版,生产环境应使用数据库)
|
||
upload_sessions = gd.get_or_create_dict('upload_sessions')
|
||
upload_sessions[upload_token] = {
|
||
'key_info': response_data['key_info'],
|
||
'created_at': time.time(),
|
||
'status': 'pending'
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "模型处理流程已启动",
|
||
"data": response_data
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"启动模型处理流程失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"启动流程失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/process/verify_key', methods=['POST'])
|
||
def verify_encryption_key():
|
||
"""验证加密密钥"""
|
||
try:
|
||
data = request.json
|
||
encryption_key = data.get('encryption_key')
|
||
model_path = data.get('model_path') # 可选,如果有具体模型
|
||
|
||
if not encryption_key:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "请提供加密密钥"
|
||
}), 400
|
||
|
||
# 验证密钥强度
|
||
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
|
||
|
||
if not key_valid:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": key_msg,
|
||
"data": {
|
||
"valid": False,
|
||
"strength": "weak"
|
||
}
|
||
}), 400
|
||
|
||
response_data = {
|
||
"valid": True,
|
||
"strength": "strong",
|
||
"key_hash": hashlib.sha256(encryption_key.encode()).hexdigest()[:16]
|
||
}
|
||
|
||
# 如果有具体模型,尝试解密验证
|
||
if model_path:
|
||
verify_result = verify_single_model_api(model_path, encryption_key)
|
||
response_data['model_verification'] = verify_result
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "密钥验证成功",
|
||
"data": response_data
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"验证密钥失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"验证失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/process/validate_task', methods=['POST'])
|
||
def validate_task_models():
|
||
"""创建任务前的模型验证"""
|
||
try:
|
||
data = request.json
|
||
|
||
if not data:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "请求数据不能为空"
|
||
}), 400
|
||
|
||
# 提取任务配置
|
||
task_config = {
|
||
'models': data.get('models', [])
|
||
}
|
||
|
||
# 验证所有模型
|
||
validation_result = validate_models_before_task(task_config)
|
||
|
||
if validation_result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"模型验证通过 ({validation_result['valid_models']}/{validation_result['total_models']})",
|
||
"data": validation_result
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型验证失败: {validation_result.get('error', '未知错误')}",
|
||
"data": validation_result
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"验证任务模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"验证失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/encrypted/list_available', methods=['GET'])
|
||
def list_available_encrypted_models():
|
||
"""列出可用的加密模型(用于任务创建选择)"""
|
||
try:
|
||
config = get_default_config()
|
||
encrypted_dir = config['upload']['encrypted_models_dir']
|
||
|
||
if not os.path.exists(encrypted_dir):
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"models": [],
|
||
"total": 0
|
||
}
|
||
})
|
||
|
||
models = []
|
||
for filename in os.listdir(encrypted_dir):
|
||
if filename.endswith('.enc'):
|
||
filepath = os.path.join(encrypted_dir, filename)
|
||
stats = os.stat(filepath)
|
||
|
||
# 尝试读取模型基本信息(不验证密钥)
|
||
try:
|
||
with open(filepath, 'rb') as f:
|
||
encrypted_data = pickle.load(f)
|
||
|
||
model_info = {
|
||
'filename': filename,
|
||
'path': f"encrypted_models/{filename}", # 相对路径
|
||
'size': stats.st_size,
|
||
'modified': stats.st_mtime,
|
||
'encrypted': True,
|
||
'model_hash': encrypted_data.get('model_hash', '')[:16] if isinstance(encrypted_data,
|
||
dict) else '',
|
||
'version': encrypted_data.get('version', 'unknown') if isinstance(encrypted_data, dict) else '',
|
||
'original_size': encrypted_data.get('original_size', 0) if isinstance(encrypted_data,
|
||
dict) else 0
|
||
}
|
||
|
||
models.append(model_info)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"读取模型信息失败 {filename}: {str(e)}")
|
||
models.append({
|
||
'filename': filename,
|
||
'path': f"encrypted_models/{filename}",
|
||
'size': stats.st_size,
|
||
'modified': stats.st_mtime,
|
||
'encrypted': True,
|
||
'error': '无法读取模型信息'
|
||
})
|
||
|
||
# 按修改时间排序
|
||
models.sort(key=lambda x: x['modified'], reverse=True)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"models": models,
|
||
"total": len(models),
|
||
"directory": encrypted_dir
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"列出可用模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"列出模型失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/process/test_decrypt', methods=['POST'])
|
||
def test_model_decryption():
|
||
"""测试模型解密(不实际加载YOLO模型)"""
|
||
try:
|
||
data = request.json
|
||
model_path = data.get('model_path')
|
||
encryption_key = data.get('encryption_key')
|
||
|
||
if not all([model_path, encryption_key]):
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "请提供模型路径和加密密钥"
|
||
}), 400
|
||
|
||
# 构建完整路径
|
||
config = get_default_config()
|
||
encrypted_dir = config['upload']['encrypted_models_dir']
|
||
|
||
if not os.path.isabs(model_path):
|
||
model_filename = os.path.basename(model_path)
|
||
full_model_path = os.path.join(encrypted_dir, model_filename)
|
||
else:
|
||
full_model_path = model_path
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(full_model_path):
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型文件不存在: {full_model_path}"
|
||
}), 404
|
||
|
||
# 测试解密
|
||
from mandatory_model_crypto import MandatoryModelValidator
|
||
validator = MandatoryModelValidator()
|
||
|
||
start_time = time.time()
|
||
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
if decrypt_result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "解密测试成功",
|
||
"data": {
|
||
'success': True,
|
||
'model_hash': decrypt_result.get('model_hash', '')[:16],
|
||
'model_size': decrypt_result.get('original_size', 0),
|
||
'decryption_time': elapsed_time,
|
||
'file_path': full_model_path
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"解密测试失败: {decrypt_result.get('error', '未知错误')}",
|
||
"data": {
|
||
'success': False,
|
||
'error': decrypt_result.get('error', '未知错误'),
|
||
'decryption_time': elapsed_time
|
||
}
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"测试解密失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"测试失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# server.py - 添加任务创建页面路由
|
||
@app.route('/task_create')
|
||
def task_create_page():
|
||
"""任务创建页面"""
|
||
return render_template("task_create.html")
|
||
|
||
|
||
# server.py - 添加资源检查接口
|
||
|
||
@app.route('/api/system/check_resources', methods=['GET'])
|
||
def check_system_resources():
|
||
"""检查系统资源是否足够创建新任务"""
|
||
try:
|
||
import psutil
|
||
import torch
|
||
|
||
# 获取系统资源
|
||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||
memory_info = psutil.virtual_memory()
|
||
memory_percent = memory_info.percent
|
||
|
||
# GPU信息
|
||
gpu_info = []
|
||
gpu_available = torch.cuda.is_available()
|
||
if gpu_available:
|
||
for i in range(torch.cuda.device_count()):
|
||
gpu_memory_used = torch.cuda.memory_allocated(i) / 1024 ** 2 # MB
|
||
gpu_memory_total = torch.cuda.get_device_properties(i).total_memory / 1024 ** 2 # MB
|
||
gpu_memory_percent = (gpu_memory_used / gpu_memory_total) * 100 if gpu_memory_total > 0 else 0
|
||
|
||
gpu_info.append({
|
||
'id': i,
|
||
'name': torch.cuda.get_device_name(i),
|
||
'memory_used': gpu_memory_used,
|
||
'memory_total': gpu_memory_total,
|
||
'memory_percent': gpu_memory_percent
|
||
})
|
||
|
||
# 获取当前任务信息
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
max_tasks = task_manager.get_current_max_tasks()
|
||
|
||
# 资源阈值
|
||
cpu_threshold = 80 # CPU使用率阈值
|
||
memory_threshold = 85 # 内存使用率阈值
|
||
gpu_memory_threshold = 90 # GPU内存使用率阈值
|
||
|
||
# 检查资源状态
|
||
resources_ok = True
|
||
warnings = []
|
||
|
||
if cpu_percent > cpu_threshold:
|
||
resources_ok = False
|
||
warnings.append(f"CPU使用率过高: {cpu_percent:.1f}% > {cpu_threshold}%")
|
||
|
||
if memory_percent > memory_threshold:
|
||
resources_ok = False
|
||
warnings.append(f"内存使用率过高: {memory_percent:.1f}% > {memory_threshold}%")
|
||
|
||
if gpu_available and gpu_info:
|
||
max_gpu_memory = max(gpu['memory_percent'] for gpu in gpu_info)
|
||
if max_gpu_memory > gpu_memory_threshold:
|
||
warnings.append(f"GPU内存使用率过高: {max_gpu_memory:.1f}% > {gpu_memory_threshold}%")
|
||
# GPU内存高不是致命错误,只是警告
|
||
|
||
if active_tasks >= max_tasks:
|
||
resources_ok = False
|
||
warnings.append(f"任务数达到上限: {active_tasks}/{max_tasks}")
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"resources_available": resources_ok,
|
||
"active_tasks": active_tasks,
|
||
"max_tasks": max_tasks,
|
||
"slots_available": max(0, max_tasks - active_tasks),
|
||
"cpu_percent": cpu_percent,
|
||
"memory_percent": memory_percent,
|
||
"memory_used": memory_info.used / 1024 ** 2, # MB
|
||
"memory_total": memory_info.total / 1024 ** 2, # MB
|
||
"gpu_available": gpu_available,
|
||
"gpu_info": gpu_info,
|
||
"warnings": warnings,
|
||
"thresholds": {
|
||
"cpu": cpu_threshold,
|
||
"memory": memory_threshold,
|
||
"gpu_memory": gpu_memory_threshold
|
||
},
|
||
"timestamp": time.time()
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查系统资源失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"检查资源失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# 初始化函数,可在主程序中调用
|
||
def init_app():
|
||
"""初始化应用程序"""
|
||
with app.app_context():
|
||
gd.set_value('task_manager', task_manager)
|
||
logger.info("任务管理器初始化完成")
|
||
return app
|
||
|
||
# ======================= 公共服务器功能 =======================
|
||
|
||
# 内部系统状态
|
||
internal_system_status = {
|
||
"is_connected": False,
|
||
"devices": {}
|
||
}
|
||
|
||
# 连接管理
|
||
connected_clients = {}
|
||
|
||
|
||
# WebSocket事件 - 公共服务器
|
||
@socketio.on('connect', namespace='/ws/internal')
|
||
def handle_internal_connect():
|
||
"""内部系统连接"""
|
||
client_id = "internal_system"
|
||
connected_clients[client_id] = request.sid
|
||
internal_system_status["is_connected"] = True
|
||
pub_logger.info(f"Internal system connected: {request.sid}")
|
||
|
||
|
||
@socketio.on('disconnect', namespace='/ws/internal')
|
||
def handle_internal_disconnect():
|
||
"""内部系统断开连接"""
|
||
client_id = "internal_system"
|
||
if client_id in connected_clients:
|
||
del connected_clients[client_id]
|
||
internal_system_status["is_connected"] = False
|
||
pub_logger.info("Internal system disconnected")
|
||
|
||
|
||
@socketio.on('message', namespace='/ws/internal')
|
||
def handle_internal_message(data):
|
||
"""处理内部系统消息"""
|
||
try:
|
||
message = json.loads(data)
|
||
if message.get("type") == "status_update":
|
||
# 更新设备状态
|
||
device_id = message.get("device_id")
|
||
if device_id:
|
||
internal_system_status["devices"][device_id] = message
|
||
# 广播状态更新给客户端
|
||
socketio.emit('message', json.dumps(message), namespace='/ws/client')
|
||
pub_logger.info(f"Status updated for device {device_id}: {message}")
|
||
elif message.get("type") == "command_response":
|
||
# 转发命令响应给客户端
|
||
socketio.emit('message', json.dumps(message), namespace='/ws/client')
|
||
pub_logger.info(f"Command response: {message}")
|
||
except json.JSONDecodeError:
|
||
pub_logger.error(f"Invalid JSON from internal system: {data}")
|
||
|
||
|
||
@socketio.on('connect', namespace='/ws/client')
|
||
def handle_client_connect():
|
||
"""客户端连接"""
|
||
client_id = request.args.get('client_id', request.sid)
|
||
connected_clients[client_id] = request.sid
|
||
pub_logger.info(f"Client connected: {client_id} ({request.sid})")
|
||
|
||
# 发送当前状态给新连接的客户端
|
||
for device_id, status in internal_system_status["devices"].items():
|
||
socketio.emit('message', json.dumps(status), namespace='/ws/client', room=request.sid)
|
||
|
||
|
||
@socketio.on('disconnect', namespace='/ws/client')
|
||
def handle_client_disconnect():
|
||
"""客户端断开连接"""
|
||
# 找到对应的客户端ID
|
||
client_id = None
|
||
for id, sid in connected_clients.items():
|
||
if sid == request.sid:
|
||
client_id = id
|
||
break
|
||
if client_id:
|
||
del connected_clients[client_id]
|
||
pub_logger.info(f"Client disconnected: {client_id}")
|
||
|
||
|
||
@socketio.on('message', namespace='/ws/client')
|
||
def handle_client_message(data):
|
||
"""处理客户端消息"""
|
||
try:
|
||
command = json.loads(data)
|
||
# 转发命令给内部系统
|
||
if internal_system_status["is_connected"] and "internal_system" in connected_clients:
|
||
socketio.emit('message', data, namespace='/ws/internal', room=connected_clients["internal_system"])
|
||
pub_logger.info(f"Forwarded command from client: {command}")
|
||
else:
|
||
# 内部系统未连接,返回错误
|
||
error_response = {
|
||
"type": "command_response",
|
||
"status": "error",
|
||
"message": "Internal system not connected",
|
||
"command": command
|
||
}
|
||
socketio.emit('message', json.dumps(error_response), namespace='/ws/client', room=request.sid)
|
||
except json.JSONDecodeError:
|
||
pub_logger.error(f"Invalid JSON from client: {data}")
|
||
|
||
|
||
# API路由 - 公共服务器
|
||
@app.route('/api/devices')
|
||
def get_devices():
|
||
"""获取设备列表"""
|
||
return jsonify({
|
||
"code": 200,
|
||
"msg": "查询成功",
|
||
"data": list(internal_system_status["devices"].values())
|
||
})
|
||
|
||
|
||
@app.route('/api/command', methods=['POST'])
|
||
def send_command():
|
||
"""发送命令"""
|
||
if not internal_system_status["is_connected"]:
|
||
return jsonify({
|
||
"code": 503,
|
||
"msg": "内部系统未连接",
|
||
"data": None
|
||
})
|
||
|
||
try:
|
||
command = request.json
|
||
# 发送命令给内部系统
|
||
if "internal_system" in connected_clients:
|
||
socketio.emit('message', json.dumps(command), namespace='/ws/internal', room=connected_clients["internal_system"])
|
||
pub_logger.info(f"Command sent: {command}")
|
||
return jsonify({
|
||
"code": 200,
|
||
"msg": "命令已发送",
|
||
"data": command
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"code": 503,
|
||
"msg": "内部系统未连接",
|
||
"data": None
|
||
})
|
||
except Exception as e:
|
||
pub_logger.error(f"Error sending command: {str(e)}")
|
||
return jsonify({
|
||
"code": 500,
|
||
"msg": f"发送命令失败: {str(e)}",
|
||
"data": None
|
||
}) |