Yolov/server.py

808 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# server.py
import os
from datetime import datetime
import torch
from flask import Flask, jsonify, request, render_template
from flask_socketio import SocketIO
from flask_cors import CORS
from config import get_default_config
from mandatory_model_crypto import MandatoryModelEncryptor
from task_manager import task_manager # 导入任务管理器
from global_data import gd
from log import logger
import time
import traceback
# Flask初始化
app = Flask(__name__, static_url_path='/static')
CORS(app)
socketio = SocketIO(app,
cors_allowed_origins="*",
async_mode='threading',
allow_unsafe_werkzeug=True,
max_http_buffer_size=5 * 1024 * 1024)
_initialized = False
@app.before_request
def initialize_once():
global _initialized
if not _initialized:
with app.app_context():
gd.set_value('task_manager', task_manager)
logger.info("任务管理器初始化完成")
_initialized = True
# ======================= Flask路由 =======================
@app.route('/')
def task_management():
"""任务管理页面"""
return render_template("task_management.html")
@app.route('/video_player')
def video_player():
"""视频播放页面"""
return render_template("flv2.html")
@app.route('/api/tasks/create', methods=['POST'])
def create_task():
"""创建新任务 - 强制模型加密和密钥验证"""
try:
config = get_default_config()
config['socketIO'] = socketio
if not request.json:
return jsonify({"status": "error", "message": "请求体不能为空"}), 400
data = request.json
logger.info(f"收到创建任务请求: {data.get('taskname', '未命名')}")
# 验证必须的参数
if 'rtmp_url' not in data:
return jsonify({"status": "error", "message": "必须提供rtmp_url"}), 400
if 'models' not in data or not isinstance(data['models'], list):
return jsonify({"status": "error", "message": "必须提供models列表"}), 400
# 检查模型数量
if len(data['models']) == 0:
return jsonify({"status": "error", "message": "models列表不能为空"}), 400
# 更新配置
config['rtmp']['url'] = data['rtmp_url']
if 'push_url' in data and data['push_url']:
config['push']['url'] = data['push_url']
if 'taskname' in data:
config['task']['taskname'] = data['taskname']
else:
config['task']['taskname'] = f"task_{int(time.time())}"
if 'AlgoId' in data:
config['task']['aiid'] = data['AlgoId']
# 处理多模型配置 - 强制加密验证
config['models'] = []
encryption_checker = MandatoryModelEncryptor()
for i, model_data in enumerate(data['models']):
# 必须提供加密密钥
encryption_key = model_data.get('encryption_key')
if not encryption_key:
return jsonify({
"status": "error",
"message": f"模型 {i} ({model_data.get('path', 'unknown')}) 必须提供encryption_key"
}), 400
model_path = model_data.get('path', f'model_{i}.pt')
model_name = os.path.basename(model_path).split('.')[0]
# 检查模型文件是否加密(如果是本地文件)
local_model_path = os.path.join(os.path.basename(model_path))
# 如果本地文件存在,验证加密格式
if os.path.exists(local_model_path):
if not encryption_checker.is_properly_encrypted(local_model_path):
return jsonify({
"status": "error",
"message": f"模型 {i} ({model_name}) 未正确加密"
}), 400
# 构建模型配置
model_config = {
'path': model_path,
'encryption_key': encryption_key, # 必须提供
'encrypted': True, # 强制加密
'tags': model_data.get('tags', {}),
'conf_thres': float(model_data.get('conf_thres', 0.25)),
'iou_thres': float(model_data.get('iou_thres', 0.45)),
'imgsz': max(128, min(1920, int(model_data.get('imgsz', 640)))),
'color': model_data.get('color'),
'line_width': int(model_data.get('line_width', 1)),
'device': model_data.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu'),
'half': model_data.get('half', True),
'enabled': model_data.get('enabled', True),
'download_url': model_data.get('download_url') # 可选的下载地址
}
config['models'].append(model_config)
logger.info(f"添加加密模型 {i}: {model_name}")
# 在创建任务前清理已停止的任务,释放资源
logger.info("创建任务前清理已停止的任务...")
cleaned_count = task_manager.cleanup_stopped_tasks()
if cleaned_count > 0:
logger.info(f"已清理 {cleaned_count} 个已停止的任务")
# 检查系统资源
active_tasks = task_manager.get_active_tasks_count()
max_tasks = task_manager.get_current_max_tasks()
if active_tasks >= max_tasks:
# 尝试强制清理一些资源
logger.warning(f"当前活动任务数 {active_tasks} 已达到限制 {max_tasks},尝试强制清理")
# 清理所有非运行状态的任务
all_tasks = task_manager.get_all_tasks()
force_cleaned = 0
for task in all_tasks:
if task['status'] not in ['running', 'starting', 'creating']:
if task_manager.cleanup_task(task['task_id']):
force_cleaned += 1
if force_cleaned > 0:
logger.info(f"强制清理了 {force_cleaned} 个非运行状态的任务")
# 再次检查
active_tasks = task_manager.get_active_tasks_count()
if active_tasks >= max_tasks:
return jsonify({
"status": "error",
"message": f"已达到最大并发任务数限制 ({max_tasks}),请等待其他任务完成或停止部分任务"
}), 503
# 创建任务
logger.info(f"开始创建任务,包含 {len(config['models'])} 个加密模型...")
try:
task_id = task_manager.create_task(config, socketio)
logger.info(f"任务创建成功ID: {task_id}")
except Exception as e:
logger.error(f"任务创建失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"任务创建失败: {str(e)}"
}), 500
# 启动任务
logger.info(f"启动任务 {task_id}...")
success = task_manager.start_task(task_id)
if success:
logger.info(f"任务启动成功: {task_id}")
return jsonify({
"status": "success",
"message": "任务创建并启动成功",
"task_id": task_id,
"models_count": len(config['models']),
"encryption_required": True
})
else:
logger.error(f"任务启动失败: {task_id}")
# 如果启动失败,清理任务
task_manager.cleanup_task(task_id)
return jsonify({
"status": "error",
"message": "任务创建成功但启动失败",
"task_id": task_id
}), 500
except Exception as e:
logger.error(f"创建任务失败: {str(e)}")
logger.error(traceback.format_exc())
return jsonify({
"status": "error",
"message": f"创建任务失败: {str(e)}"
}), 500
@app.route('/api/system/resources', methods=['GET'])
def get_system_resources():
"""获取系统资源使用情况"""
try:
resources = gd.get_value('system_resources')
max_tasks = gd.get_value('max_concurrent_tasks', 5)
if not resources:
# 实时获取资源
import psutil
resources = {
'cpu_percent': psutil.cpu_percent(),
'memory_percent': psutil.virtual_memory().percent,
'memory_used': psutil.virtual_memory().used // (1024 * 1024),
'memory_total': psutil.virtual_memory().total // (1024 * 1024),
'timestamp': datetime.now().isoformat()
}
# 获取任务统计
active_tasks = task_manager.get_active_tasks_count()
total_tasks = len(task_manager.tasks)
return jsonify({
"status": "success",
"data": {
"resources": resources,
"tasks": {
"active": active_tasks,
"total": total_tasks,
"max_concurrent": max_tasks
}
}
})
except Exception as e:
logger.error(f"获取系统资源失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取系统资源失败: {str(e)}"
}), 500
@app.route('/api/models/encrypt', methods=['POST'])
def encrypt_model():
"""加密模型文件 - 强制加密"""
try:
config = get_default_config()
data = request.json
model_path = data.get('model_path')
output_path = data.get('output_path')
password = data.get('password')
download_url = data.get('download_url')
if not all([model_path, output_path, password, download_url]):
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
local_path = os.path.join(config['model_path'], model_path)
output_path = os.path.join(config['model_path'], output_path)
# 验证输入文件是否存在
if not os.path.exists(model_path):
from model_crypto import ModelManager
model_d = ModelManager(data)
down_status = model_d.download_model({"path": model_path, "download_url": download_url})
if not down_status:
return jsonify({"status": "error", "message": f"模型文件不存在: {model_path}"}), 400
# 使用强制加密器
from mandatory_model_crypto import MandatoryModelEncryptor
encryptor = MandatoryModelEncryptor()
result = encryptor.encrypt_model(local_path, output_path, password, require_encryption=True)
if result['success']:
return jsonify({
"status": "success",
"message": f"模型加密成功: {output_path}",
"data": {
"model_hash": result.get('model_hash'),
"key_hash": result.get('key_hash'),
"output_path": result.get('output_path')
}
})
else:
return jsonify({
"status": "error",
"message": f"模型加密失败: {result.get('error', '未知错误')}"
}), 500
except Exception as e:
logger.error(f"加密模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"加密模型失败: {str(e)}"
}), 500
@app.route('/api/models/verify_key', methods=['POST'])
def verify_model_key():
"""验证模型密钥"""
try:
data = request.json
model_path = data.get('model_path')
encryption_key = data.get('encryption_key')
if not all([model_path, encryption_key]):
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
# 检查模型文件是否存在
full_path = os.path.join('encrypted_models', os.path.basename(model_path))
if not os.path.exists(full_path):
return jsonify({"status": "error", "message": f"模型文件不存在: {full_path}"}), 400
# 使用强制加密器验证密钥
from mandatory_model_crypto import MandatoryModelEncryptor
encryptor = MandatoryModelEncryptor()
# 检查是否为正确加密的模型
if not encryptor.is_properly_encrypted(full_path):
return jsonify({
"status": "error",
"message": "模型文件未正确加密",
"valid": False
}), 400
# 验证密钥
verify_result = encryptor.decrypt_model(full_path, encryption_key, verify_key=True)
if verify_result['success']:
return jsonify({
"status": "success",
"message": "密钥验证成功",
"data": {
"valid": True,
"model_hash": verify_result.get('model_hash', '')[:16],
"model_size": verify_result.get('original_size', 0)
}
})
else:
return jsonify({
"status": "error",
"message": f"密钥验证失败: {verify_result.get('error', '未知错误')}",
"data": {
"valid": False,
"error": verify_result.get('error', '未知错误')
}
}), 400
except Exception as e:
logger.error(f"验证密钥失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"验证密钥失败: {str(e)}"
}), 500
@app.route('/api/models/generate_key', methods=['POST'])
def generate_secure_encryption_key():
"""生成安全的加密密钥"""
try:
from mandatory_model_crypto import MandatoryModelEncryptor
# 生成密钥
encryptor = MandatoryModelEncryptor()
key_info = encryptor.generate_secure_key()
return jsonify({
"status": "success",
"message": "加密密钥生成成功",
"data": {
"key": key_info['key'],
"key_hash": key_info['key_hash'],
"short_hash": key_info['short_hash'],
}
})
except Exception as e:
logger.error(f"生成加密密钥失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"生成加密密钥失败: {str(e)}"
}), 500
# server.py 中的 stop_detection 路由修改
@app.route('/api/tasks/<task_id>/stop', methods=['POST'])
def stop_task(task_id):
"""停止指定任务"""
try:
logger.info(f"接收到停止任务请求: {task_id}")
# 获取请求参数
data = request.json or {}
force = data.get('force', False)
timeout = data.get('timeout', 10) # 默认10秒超时
# 调用任务管理器停止任务
success = task_manager.stop_task(task_id, force=force)
if success:
logger.info(f"任务停止成功: {task_id}")
return jsonify({
"status": "success",
"message": f"任务停止成功: {task_id}",
"task_id": task_id,
"stopped": True
})
else:
logger.warning(f"停止任务失败: {task_id}")
return jsonify({
"status": "error",
"message": f"停止任务失败: {task_id}",
"task_id": task_id,
"stopped": False
}), 500
except Exception as e:
logger.error(f"停止任务异常: {str(e)}")
logger.error(traceback.format_exc())
return jsonify({
"status": "error",
"message": f"停止任务异常: {str(e)}",
"task_id": task_id
}), 500
@app.route('/api/tasks/<task_id>/status', methods=['GET'])
def get_task_status(task_id):
"""获取任务状态(仅支持多模型)"""
status = task_manager.get_task_status(task_id)
if status:
# 增强返回信息,包含模型详情
enhanced_status = {
'task_id': status['task_id'],
'status': status['status'],
'config': status['config'],
'models': status.get('models', []), # 直接返回模型列表
'stats': status['stats'],
'created_at': status['created_at']
}
return jsonify({
"status": "success",
"data": enhanced_status
})
else:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
@app.route('/api/tasks', methods=['GET'])
def get_all_tasks():
"""获取所有任务(仅支持多模型)"""
tasks = task_manager.get_all_tasks()
# 增强任务信息
enhanced_tasks = []
for task in tasks:
enhanced_task = {
'task_id': task['task_id'],
'status': task['status'],
'config': {
'rtmp_url': task['config']['rtmp_url'],
'taskname': task['config']['taskname'],
'push_url': task['config'].get('push_url', ''),
'enable_push': task['config'].get('enable_push', False)
},
'models': task.get('models', []), # 直接返回模型列表
'stats': task['stats'],
'created_at': task['created_at']
}
enhanced_tasks.append(enhanced_task)
return jsonify({
"status": "success",
"data": {
"tasks": enhanced_tasks,
"total": len(tasks),
"active": task_manager.get_active_tasks_count(),
"models_count": sum(len(t.get('models', [])) for t in enhanced_tasks)
}
})
@app.route('/api/tasks/<task_id>/cleanup', methods=['POST'])
def cleanup_task(task_id):
"""清理任务资源"""
success = task_manager.cleanup_task(task_id)
if success:
return jsonify({
"status": "success",
"message": f"任务资源已清理: {task_id}"
})
else:
return jsonify({
"status": "error",
"message": f"清理任务失败: {task_id}"
}), 500
@app.route('/api/tasks/cleanup_all', methods=['POST'])
def cleanup_all_tasks():
"""清理所有任务"""
cleaned_count = task_manager.cleanup_all_tasks()
return jsonify({
"status": "success",
"message": f"已清理 {cleaned_count} 个任务",
"cleaned_count": cleaned_count
})
@app.route('/api/tasks/cleanup_stopped', methods=['POST'])
def cleanup_stopped_tasks():
"""清理所有已停止的任务"""
cleaned_count = task_manager.cleanup_stopped_tasks()
return jsonify({
"status": "success",
"message": f"已清理 {cleaned_count} 个已停止的任务",
"cleaned_count": cleaned_count
})
@app.route('/api/system/status', methods=['GET'])
def get_system_status():
"""获取系统状态"""
try:
import psutil
system_info = {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent,
"disk_percent": psutil.disk_usage('/').percent,
"active_tasks": task_manager.get_active_tasks_count(),
"total_tasks": len(task_manager.tasks),
"max_concurrent_tasks": task_manager.get_current_max_tasks()
}
# GPU信息如果可用
try:
import GPUtil
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append({
"id": gpu.id,
"name": gpu.name,
"load": gpu.load * 100,
"memory_used": gpu.memoryUsed,
"memory_total": gpu.memoryTotal,
"temperature": gpu.temperature
})
system_info["gpus"] = gpu_info
except ImportError:
# 如果没有安装GPUtil尝试使用torch获取GPU信息
if torch.cuda.is_available():
gpu_info = []
for i in range(torch.cuda.device_count()):
gpu_info.append({
"id": i,
"name": torch.cuda.get_device_name(i),
"memory_used": torch.cuda.memory_allocated(i) / 1024 ** 2,
"memory_total": torch.cuda.get_device_properties(i).total_memory / 1024 ** 2
})
system_info["gpus"] = gpu_info
return jsonify({
"status": "success",
"data": system_info
})
except Exception as e:
logger.error(f"获取系统状态失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取系统状态失败: {str(e)}"
}), 500
# WebSocket事件 - 按任务ID发送
@socketio.on('connect')
def handle_connect():
logger.info(f"Socket客户端已连接: {request.sid}")
@socketio.on('disconnect')
def handle_disconnect():
logger.info(f"Socket客户端断开: {request.sid}")
@socketio.on('subscribe_task')
def handle_subscribe_task(data):
"""订阅特定任务的WebSocket消息"""
task_id = data.get('task_id')
if task_id:
# 这里可以记录客户端订阅关系
logger.info(f"客户端 {request.sid} 订阅任务: {task_id}")
return {"status": "subscribed", "task_id": task_id}
return {"status": "error", "message": "需要提供task_id"}
@app.route('/api/system/resource_limits', methods=['GET', 'POST'])
def manage_resource_limits():
"""获取或设置资源限制"""
if request.method == 'GET':
# 获取当前资源限制
resource_monitor = gd.get_value('resource_monitor')
if resource_monitor:
return jsonify({
"status": "success",
"data": resource_monitor.resource_limits
})
else:
return jsonify({
"status": "error",
"message": "资源监控器未初始化"
}), 500
elif request.method == 'POST':
# 更新资源限制
try:
data = request.json
resource_monitor = gd.get_value('resource_monitor')
if resource_monitor and data:
# 更新限制
resource_monitor.resource_limits.update(data)
# 更新任务管理器中的限制
task_manager.resource_limits.update(data)
logger.info(f"更新资源限制: {data}")
return jsonify({
"status": "success",
"message": "资源限制更新成功"
})
else:
return jsonify({
"status": "error",
"message": "资源监控器未初始化或请求数据无效"
}), 400
except Exception as e:
logger.error(f"更新资源限制失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"更新失败: {str(e)}"
}), 500
@app.route('/api/tasks/<task_id>/models', methods=['GET'])
def get_task_models(task_id):
"""获取任务中的模型配置"""
try:
task = task_manager.get_task_status(task_id)
if not task:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 获取任务中的模型配置
models_info = []
if task_id in task_manager.tasks:
task_info = task_manager.tasks[task_id]
config = task_info.get('config', {})
models_config = config.get('models', [])
for i, model_config in enumerate(models_config):
models_info.append({
'id': i,
'name': os.path.basename(model_config.get('path', '')).split('.')[0],
'path': model_config.get('path'),
'conf_thres': model_config.get('conf_thres'),
'tags': model_config.get('tags', {}),
'color': model_config.get('color'),
'enabled': model_config.get('enabled', True)
})
return jsonify({
"status": "success",
"data": {
"task_id": task_id,
"models": models_info
}
})
except Exception as e:
logger.error(f"获取任务模型失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取失败: {str(e)}"
}), 500
# server.py 添加以下路由
@app.route('/api/tasks/<task_id>/stream/status', methods=['GET'])
def get_task_stream_status(task_id):
"""获取任务推流状态"""
try:
from task_stream_manager import task_stream_manager
# 获取任务状态
task_status = task_manager.get_task_status(task_id)
if not task_status:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 获取推流信息
stream_info = task_stream_manager.get_all_task_streams_info().get(task_id, {})
# 合并信息
result = {
"task_id": task_id,
"task_status": task_status['status'],
"stream_enabled": task_status['config'].get('enable_push', False),
"stream_info": stream_info,
"push_url": task_status['config'].get('push_url', '')
}
return jsonify({
"status": "success",
"data": result
})
except Exception as e:
logger.error(f"获取任务推流状态失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取失败: {str(e)}"
}), 500
@app.route('/api/tasks/<task_id>/stream/restart', methods=['POST'])
def restart_task_stream(task_id):
"""重启任务推流"""
try:
from task_stream_manager import task_stream_manager
# 检查任务是否存在
if task_id not in task_manager.tasks:
return jsonify({
"status": "error",
"message": f"任务不存在: {task_id}"
}), 404
# 重启推流
success = task_stream_manager._restart_task_streamer(task_id)
if success:
return jsonify({
"status": "success",
"message": f"任务推流重启成功: {task_id}"
})
else:
return jsonify({
"status": "error",
"message": f"重启失败: {task_id}"
}), 500
except Exception as e:
logger.error(f"重启任务推流失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"重启失败: {str(e)}"
}), 500
@app.route('/api/system/streams/info', methods=['GET'])
def get_all_streams_info():
"""获取所有任务推流信息"""
try:
from task_stream_manager import task_stream_manager
streams_info = task_stream_manager.get_all_task_streams_info()
return jsonify({
"status": "success",
"data": {
"total_streams": len(streams_info),
"active_streams": sum(1 for info in streams_info.values() if info.get('running', False)),
"streams": streams_info
}
})
except Exception as e:
logger.error(f"获取所有推流信息失败: {str(e)}")
return jsonify({
"status": "error",
"message": f"获取失败: {str(e)}"
}), 500
# 初始化函数,可在主程序中调用
def init_app():
"""初始化应用程序"""
with app.app_context():
gd.set_value('task_manager', task_manager)
logger.info("任务管理器初始化完成")
return app