808 lines
28 KiB
Python
808 lines
28 KiB
Python
# server.py
|
||
import os
|
||
from datetime import datetime
|
||
|
||
import torch
|
||
from flask import Flask, jsonify, request, render_template
|
||
from flask_socketio import SocketIO
|
||
from flask_cors import CORS
|
||
from config import get_default_config
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
from task_manager import task_manager # 导入任务管理器
|
||
from global_data import gd
|
||
from log import logger
|
||
import time
|
||
import traceback
|
||
|
||
# Flask初始化
|
||
app = Flask(__name__, static_url_path='/static')
|
||
CORS(app)
|
||
socketio = SocketIO(app,
|
||
cors_allowed_origins="*",
|
||
async_mode='threading',
|
||
allow_unsafe_werkzeug=True,
|
||
max_http_buffer_size=5 * 1024 * 1024)
|
||
|
||
_initialized = False
|
||
|
||
|
||
@app.before_request
|
||
def initialize_once():
|
||
global _initialized
|
||
if not _initialized:
|
||
with app.app_context():
|
||
gd.set_value('task_manager', task_manager)
|
||
logger.info("任务管理器初始化完成")
|
||
_initialized = True
|
||
|
||
|
||
# ======================= Flask路由 =======================
|
||
@app.route('/')
|
||
def task_management():
|
||
"""任务管理页面"""
|
||
return render_template("task_management.html")
|
||
|
||
|
||
@app.route('/video_player')
|
||
def video_player():
|
||
"""视频播放页面"""
|
||
return render_template("flv2.html")
|
||
|
||
|
||
@app.route('/api/tasks/create', methods=['POST'])
|
||
def create_task():
|
||
"""创建新任务 - 强制模型加密和密钥验证"""
|
||
try:
|
||
config = get_default_config()
|
||
config['socketIO'] = socketio
|
||
|
||
if not request.json:
|
||
return jsonify({"status": "error", "message": "请求体不能为空"}), 400
|
||
|
||
data = request.json
|
||
logger.info(f"收到创建任务请求: {data.get('taskname', '未命名')}")
|
||
|
||
# 验证必须的参数
|
||
if 'rtmp_url' not in data:
|
||
return jsonify({"status": "error", "message": "必须提供rtmp_url"}), 400
|
||
|
||
if 'models' not in data or not isinstance(data['models'], list):
|
||
return jsonify({"status": "error", "message": "必须提供models列表"}), 400
|
||
|
||
# 检查模型数量
|
||
if len(data['models']) == 0:
|
||
return jsonify({"status": "error", "message": "models列表不能为空"}), 400
|
||
|
||
# 更新配置
|
||
config['rtmp']['url'] = data['rtmp_url']
|
||
|
||
if 'push_url' in data and data['push_url']:
|
||
config['push']['url'] = data['push_url']
|
||
|
||
if 'taskname' in data:
|
||
config['task']['taskname'] = data['taskname']
|
||
else:
|
||
config['task']['taskname'] = f"task_{int(time.time())}"
|
||
|
||
if 'AlgoId' in data:
|
||
config['task']['aiid'] = data['AlgoId']
|
||
|
||
# 处理多模型配置 - 强制加密验证
|
||
config['models'] = []
|
||
encryption_checker = MandatoryModelEncryptor()
|
||
|
||
for i, model_data in enumerate(data['models']):
|
||
# 必须提供加密密钥
|
||
encryption_key = model_data.get('encryption_key')
|
||
if not encryption_key:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型 {i} ({model_data.get('path', 'unknown')}) 必须提供encryption_key"
|
||
}), 400
|
||
|
||
model_path = model_data.get('path', f'model_{i}.pt')
|
||
model_name = os.path.basename(model_path).split('.')[0]
|
||
|
||
# 检查模型文件是否加密(如果是本地文件)
|
||
local_model_path = os.path.join(os.path.basename(model_path))
|
||
# 如果本地文件存在,验证加密格式
|
||
if os.path.exists(local_model_path):
|
||
if not encryption_checker.is_properly_encrypted(local_model_path):
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型 {i} ({model_name}) 未正确加密"
|
||
}), 400
|
||
|
||
# 构建模型配置
|
||
model_config = {
|
||
'path': model_path,
|
||
'encryption_key': encryption_key, # 必须提供
|
||
'encrypted': True, # 强制加密
|
||
'tags': model_data.get('tags', {}),
|
||
'conf_thres': float(model_data.get('conf_thres', 0.25)),
|
||
'iou_thres': float(model_data.get('iou_thres', 0.45)),
|
||
'imgsz': max(128, min(1920, int(model_data.get('imgsz', 640)))),
|
||
'color': model_data.get('color'),
|
||
'line_width': int(model_data.get('line_width', 1)),
|
||
'device': model_data.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu'),
|
||
'half': model_data.get('half', True),
|
||
'enabled': model_data.get('enabled', True),
|
||
'download_url': model_data.get('download_url') # 可选的下载地址
|
||
}
|
||
|
||
config['models'].append(model_config)
|
||
logger.info(f"添加加密模型 {i}: {model_name}")
|
||
|
||
# 在创建任务前清理已停止的任务,释放资源
|
||
logger.info("创建任务前清理已停止的任务...")
|
||
cleaned_count = task_manager.cleanup_stopped_tasks()
|
||
if cleaned_count > 0:
|
||
logger.info(f"已清理 {cleaned_count} 个已停止的任务")
|
||
|
||
# 检查系统资源
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
max_tasks = task_manager.get_current_max_tasks()
|
||
|
||
if active_tasks >= max_tasks:
|
||
# 尝试强制清理一些资源
|
||
logger.warning(f"当前活动任务数 {active_tasks} 已达到限制 {max_tasks},尝试强制清理")
|
||
|
||
# 清理所有非运行状态的任务
|
||
all_tasks = task_manager.get_all_tasks()
|
||
force_cleaned = 0
|
||
for task in all_tasks:
|
||
if task['status'] not in ['running', 'starting', 'creating']:
|
||
if task_manager.cleanup_task(task['task_id']):
|
||
force_cleaned += 1
|
||
|
||
if force_cleaned > 0:
|
||
logger.info(f"强制清理了 {force_cleaned} 个非运行状态的任务")
|
||
|
||
# 再次检查
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
if active_tasks >= max_tasks:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"已达到最大并发任务数限制 ({max_tasks}),请等待其他任务完成或停止部分任务"
|
||
}), 503
|
||
|
||
# 创建任务
|
||
logger.info(f"开始创建任务,包含 {len(config['models'])} 个加密模型...")
|
||
|
||
try:
|
||
task_id = task_manager.create_task(config, socketio)
|
||
logger.info(f"任务创建成功,ID: {task_id}")
|
||
except Exception as e:
|
||
logger.error(f"任务创建失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务创建失败: {str(e)}"
|
||
}), 500
|
||
|
||
# 启动任务
|
||
logger.info(f"启动任务 {task_id}...")
|
||
success = task_manager.start_task(task_id)
|
||
|
||
if success:
|
||
logger.info(f"任务启动成功: {task_id}")
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "任务创建并启动成功",
|
||
"task_id": task_id,
|
||
"models_count": len(config['models']),
|
||
"encryption_required": True
|
||
})
|
||
else:
|
||
logger.error(f"任务启动失败: {task_id}")
|
||
# 如果启动失败,清理任务
|
||
task_manager.cleanup_task(task_id)
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "任务创建成功但启动失败",
|
||
"task_id": task_id
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建任务失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"创建任务失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/system/resources', methods=['GET'])
|
||
def get_system_resources():
|
||
"""获取系统资源使用情况"""
|
||
try:
|
||
resources = gd.get_value('system_resources')
|
||
max_tasks = gd.get_value('max_concurrent_tasks', 5)
|
||
|
||
if not resources:
|
||
# 实时获取资源
|
||
import psutil
|
||
resources = {
|
||
'cpu_percent': psutil.cpu_percent(),
|
||
'memory_percent': psutil.virtual_memory().percent,
|
||
'memory_used': psutil.virtual_memory().used // (1024 * 1024),
|
||
'memory_total': psutil.virtual_memory().total // (1024 * 1024),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
# 获取任务统计
|
||
active_tasks = task_manager.get_active_tasks_count()
|
||
total_tasks = len(task_manager.tasks)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"resources": resources,
|
||
"tasks": {
|
||
"active": active_tasks,
|
||
"total": total_tasks,
|
||
"max_concurrent": max_tasks
|
||
}
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取系统资源失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取系统资源失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/encrypt', methods=['POST'])
|
||
def encrypt_model():
|
||
"""加密模型文件 - 强制加密"""
|
||
try:
|
||
config = get_default_config()
|
||
data = request.json
|
||
model_path = data.get('model_path')
|
||
output_path = data.get('output_path')
|
||
password = data.get('password')
|
||
download_url = data.get('download_url')
|
||
if not all([model_path, output_path, password, download_url]):
|
||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||
local_path = os.path.join(config['model_path'], model_path)
|
||
output_path = os.path.join(config['model_path'], output_path)
|
||
# 验证输入文件是否存在
|
||
if not os.path.exists(model_path):
|
||
from model_crypto import ModelManager
|
||
model_d = ModelManager(data)
|
||
down_status = model_d.download_model({"path": model_path, "download_url": download_url})
|
||
if not down_status:
|
||
return jsonify({"status": "error", "message": f"模型文件不存在: {model_path}"}), 400
|
||
|
||
# 使用强制加密器
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
encryptor = MandatoryModelEncryptor()
|
||
|
||
result = encryptor.encrypt_model(local_path, output_path, password, require_encryption=True)
|
||
|
||
if result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"模型加密成功: {output_path}",
|
||
"data": {
|
||
"model_hash": result.get('model_hash'),
|
||
"key_hash": result.get('key_hash'),
|
||
"output_path": result.get('output_path')
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"模型加密失败: {result.get('error', '未知错误')}"
|
||
}), 500
|
||
|
||
except Exception as e:
|
||
logger.error(f"加密模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"加密模型失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/verify_key', methods=['POST'])
|
||
def verify_model_key():
|
||
"""验证模型密钥"""
|
||
try:
|
||
data = request.json
|
||
model_path = data.get('model_path')
|
||
encryption_key = data.get('encryption_key')
|
||
|
||
if not all([model_path, encryption_key]):
|
||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||
|
||
# 检查模型文件是否存在
|
||
full_path = os.path.join('encrypted_models', os.path.basename(model_path))
|
||
if not os.path.exists(full_path):
|
||
return jsonify({"status": "error", "message": f"模型文件不存在: {full_path}"}), 400
|
||
|
||
# 使用强制加密器验证密钥
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
encryptor = MandatoryModelEncryptor()
|
||
|
||
# 检查是否为正确加密的模型
|
||
if not encryptor.is_properly_encrypted(full_path):
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "模型文件未正确加密",
|
||
"valid": False
|
||
}), 400
|
||
|
||
# 验证密钥
|
||
verify_result = encryptor.decrypt_model(full_path, encryption_key, verify_key=True)
|
||
|
||
if verify_result['success']:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "密钥验证成功",
|
||
"data": {
|
||
"valid": True,
|
||
"model_hash": verify_result.get('model_hash', '')[:16],
|
||
"model_size": verify_result.get('original_size', 0)
|
||
}
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"密钥验证失败: {verify_result.get('error', '未知错误')}",
|
||
"data": {
|
||
"valid": False,
|
||
"error": verify_result.get('error', '未知错误')
|
||
}
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"验证密钥失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"验证密钥失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/models/generate_key', methods=['POST'])
|
||
def generate_secure_encryption_key():
|
||
"""生成安全的加密密钥"""
|
||
try:
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
|
||
# 生成密钥
|
||
encryptor = MandatoryModelEncryptor()
|
||
key_info = encryptor.generate_secure_key()
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "加密密钥生成成功",
|
||
"data": {
|
||
"key": key_info['key'],
|
||
"key_hash": key_info['key_hash'],
|
||
"short_hash": key_info['short_hash'],
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成加密密钥失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"生成加密密钥失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# server.py 中的 stop_detection 路由修改
|
||
|
||
@app.route('/api/tasks/<task_id>/stop', methods=['POST'])
|
||
def stop_task(task_id):
|
||
"""停止指定任务"""
|
||
try:
|
||
logger.info(f"接收到停止任务请求: {task_id}")
|
||
|
||
# 获取请求参数
|
||
data = request.json or {}
|
||
force = data.get('force', False)
|
||
timeout = data.get('timeout', 10) # 默认10秒超时
|
||
|
||
# 调用任务管理器停止任务
|
||
success = task_manager.stop_task(task_id, force=force)
|
||
|
||
if success:
|
||
logger.info(f"任务停止成功: {task_id}")
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"任务停止成功: {task_id}",
|
||
"task_id": task_id,
|
||
"stopped": True
|
||
})
|
||
else:
|
||
logger.warning(f"停止任务失败: {task_id}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"停止任务失败: {task_id}",
|
||
"task_id": task_id,
|
||
"stopped": False
|
||
}), 500
|
||
except Exception as e:
|
||
logger.error(f"停止任务异常: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"停止任务异常: {str(e)}",
|
||
"task_id": task_id
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/status', methods=['GET'])
|
||
def get_task_status(task_id):
|
||
"""获取任务状态(仅支持多模型)"""
|
||
status = task_manager.get_task_status(task_id)
|
||
if status:
|
||
# 增强返回信息,包含模型详情
|
||
enhanced_status = {
|
||
'task_id': status['task_id'],
|
||
'status': status['status'],
|
||
'config': status['config'],
|
||
'models': status.get('models', []), # 直接返回模型列表
|
||
'stats': status['stats'],
|
||
'created_at': status['created_at']
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": enhanced_status
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
|
||
@app.route('/api/tasks', methods=['GET'])
|
||
def get_all_tasks():
|
||
"""获取所有任务(仅支持多模型)"""
|
||
tasks = task_manager.get_all_tasks()
|
||
|
||
# 增强任务信息
|
||
enhanced_tasks = []
|
||
for task in tasks:
|
||
enhanced_task = {
|
||
'task_id': task['task_id'],
|
||
'status': task['status'],
|
||
'config': {
|
||
'rtmp_url': task['config']['rtmp_url'],
|
||
'taskname': task['config']['taskname'],
|
||
'push_url': task['config'].get('push_url', ''),
|
||
'enable_push': task['config'].get('enable_push', False)
|
||
},
|
||
'models': task.get('models', []), # 直接返回模型列表
|
||
'stats': task['stats'],
|
||
'created_at': task['created_at']
|
||
}
|
||
enhanced_tasks.append(enhanced_task)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"tasks": enhanced_tasks,
|
||
"total": len(tasks),
|
||
"active": task_manager.get_active_tasks_count(),
|
||
"models_count": sum(len(t.get('models', [])) for t in enhanced_tasks)
|
||
}
|
||
})
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/cleanup', methods=['POST'])
|
||
def cleanup_task(task_id):
|
||
"""清理任务资源"""
|
||
success = task_manager.cleanup_task(task_id)
|
||
if success:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"任务资源已清理: {task_id}"
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"清理任务失败: {task_id}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/cleanup_all', methods=['POST'])
|
||
def cleanup_all_tasks():
|
||
"""清理所有任务"""
|
||
cleaned_count = task_manager.cleanup_all_tasks()
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"已清理 {cleaned_count} 个任务",
|
||
"cleaned_count": cleaned_count
|
||
})
|
||
|
||
|
||
@app.route('/api/tasks/cleanup_stopped', methods=['POST'])
|
||
def cleanup_stopped_tasks():
|
||
"""清理所有已停止的任务"""
|
||
cleaned_count = task_manager.cleanup_stopped_tasks()
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"已清理 {cleaned_count} 个已停止的任务",
|
||
"cleaned_count": cleaned_count
|
||
})
|
||
|
||
|
||
@app.route('/api/system/status', methods=['GET'])
|
||
def get_system_status():
|
||
"""获取系统状态"""
|
||
try:
|
||
import psutil
|
||
|
||
system_info = {
|
||
"cpu_percent": psutil.cpu_percent(),
|
||
"memory_percent": psutil.virtual_memory().percent,
|
||
"disk_percent": psutil.disk_usage('/').percent,
|
||
"active_tasks": task_manager.get_active_tasks_count(),
|
||
"total_tasks": len(task_manager.tasks),
|
||
"max_concurrent_tasks": task_manager.get_current_max_tasks()
|
||
}
|
||
|
||
# GPU信息(如果可用)
|
||
try:
|
||
import GPUtil
|
||
gpus = GPUtil.getGPUs()
|
||
gpu_info = []
|
||
for gpu in gpus:
|
||
gpu_info.append({
|
||
"id": gpu.id,
|
||
"name": gpu.name,
|
||
"load": gpu.load * 100,
|
||
"memory_used": gpu.memoryUsed,
|
||
"memory_total": gpu.memoryTotal,
|
||
"temperature": gpu.temperature
|
||
})
|
||
system_info["gpus"] = gpu_info
|
||
except ImportError:
|
||
# 如果没有安装GPUtil,尝试使用torch获取GPU信息
|
||
if torch.cuda.is_available():
|
||
gpu_info = []
|
||
for i in range(torch.cuda.device_count()):
|
||
gpu_info.append({
|
||
"id": i,
|
||
"name": torch.cuda.get_device_name(i),
|
||
"memory_used": torch.cuda.memory_allocated(i) / 1024 ** 2,
|
||
"memory_total": torch.cuda.get_device_properties(i).total_memory / 1024 ** 2
|
||
})
|
||
system_info["gpus"] = gpu_info
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": system_info
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取系统状态失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取系统状态失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# WebSocket事件 - 按任务ID发送
|
||
@socketio.on('connect')
|
||
def handle_connect():
|
||
logger.info(f"Socket客户端已连接: {request.sid}")
|
||
|
||
|
||
@socketio.on('disconnect')
|
||
def handle_disconnect():
|
||
logger.info(f"Socket客户端断开: {request.sid}")
|
||
|
||
|
||
@socketio.on('subscribe_task')
|
||
def handle_subscribe_task(data):
|
||
"""订阅特定任务的WebSocket消息"""
|
||
task_id = data.get('task_id')
|
||
if task_id:
|
||
# 这里可以记录客户端订阅关系
|
||
logger.info(f"客户端 {request.sid} 订阅任务: {task_id}")
|
||
return {"status": "subscribed", "task_id": task_id}
|
||
|
||
return {"status": "error", "message": "需要提供task_id"}
|
||
|
||
|
||
@app.route('/api/system/resource_limits', methods=['GET', 'POST'])
|
||
def manage_resource_limits():
|
||
"""获取或设置资源限制"""
|
||
if request.method == 'GET':
|
||
# 获取当前资源限制
|
||
resource_monitor = gd.get_value('resource_monitor')
|
||
if resource_monitor:
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": resource_monitor.resource_limits
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "资源监控器未初始化"
|
||
}), 500
|
||
|
||
elif request.method == 'POST':
|
||
# 更新资源限制
|
||
try:
|
||
data = request.json
|
||
resource_monitor = gd.get_value('resource_monitor')
|
||
|
||
if resource_monitor and data:
|
||
# 更新限制
|
||
resource_monitor.resource_limits.update(data)
|
||
|
||
# 更新任务管理器中的限制
|
||
task_manager.resource_limits.update(data)
|
||
|
||
logger.info(f"更新资源限制: {data}")
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "资源限制更新成功"
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "资源监控器未初始化或请求数据无效"
|
||
}), 400
|
||
except Exception as e:
|
||
logger.error(f"更新资源限制失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"更新失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/models', methods=['GET'])
|
||
def get_task_models(task_id):
|
||
"""获取任务中的模型配置"""
|
||
try:
|
||
task = task_manager.get_task_status(task_id)
|
||
if not task:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
# 获取任务中的模型配置
|
||
models_info = []
|
||
if task_id in task_manager.tasks:
|
||
task_info = task_manager.tasks[task_id]
|
||
config = task_info.get('config', {})
|
||
models_config = config.get('models', [])
|
||
|
||
for i, model_config in enumerate(models_config):
|
||
models_info.append({
|
||
'id': i,
|
||
'name': os.path.basename(model_config.get('path', '')).split('.')[0],
|
||
'path': model_config.get('path'),
|
||
'conf_thres': model_config.get('conf_thres'),
|
||
'tags': model_config.get('tags', {}),
|
||
'color': model_config.get('color'),
|
||
'enabled': model_config.get('enabled', True)
|
||
})
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"task_id": task_id,
|
||
"models": models_info
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取任务模型失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# server.py 添加以下路由
|
||
|
||
@app.route('/api/tasks/<task_id>/stream/status', methods=['GET'])
|
||
def get_task_stream_status(task_id):
|
||
"""获取任务推流状态"""
|
||
try:
|
||
from task_stream_manager import task_stream_manager
|
||
|
||
# 获取任务状态
|
||
task_status = task_manager.get_task_status(task_id)
|
||
if not task_status:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
# 获取推流信息
|
||
stream_info = task_stream_manager.get_all_task_streams_info().get(task_id, {})
|
||
|
||
# 合并信息
|
||
result = {
|
||
"task_id": task_id,
|
||
"task_status": task_status['status'],
|
||
"stream_enabled": task_status['config'].get('enable_push', False),
|
||
"stream_info": stream_info,
|
||
"push_url": task_status['config'].get('push_url', '')
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": result
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取任务推流状态失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/tasks/<task_id>/stream/restart', methods=['POST'])
|
||
def restart_task_stream(task_id):
|
||
"""重启任务推流"""
|
||
try:
|
||
from task_stream_manager import task_stream_manager
|
||
|
||
# 检查任务是否存在
|
||
if task_id not in task_manager.tasks:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"任务不存在: {task_id}"
|
||
}), 404
|
||
|
||
# 重启推流
|
||
success = task_stream_manager._restart_task_streamer(task_id)
|
||
|
||
if success:
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": f"任务推流重启成功: {task_id}"
|
||
})
|
||
else:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"重启失败: {task_id}"
|
||
}), 500
|
||
except Exception as e:
|
||
logger.error(f"重启任务推流失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"重启失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/system/streams/info', methods=['GET'])
|
||
def get_all_streams_info():
|
||
"""获取所有任务推流信息"""
|
||
try:
|
||
from task_stream_manager import task_stream_manager
|
||
|
||
streams_info = task_stream_manager.get_all_task_streams_info()
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"total_streams": len(streams_info),
|
||
"active_streams": sum(1 for info in streams_info.values() if info.get('running', False)),
|
||
"streams": streams_info
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"获取所有推流信息失败: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"获取失败: {str(e)}"
|
||
}), 500
|
||
|
||
|
||
# 初始化函数,可在主程序中调用
|
||
def init_app():
|
||
"""初始化应用程序"""
|
||
with app.app_context():
|
||
gd.set_value('task_manager', task_manager)
|
||
logger.info("任务管理器初始化完成")
|
||
return app |