Yolov/detectionThread.py

1432 lines
57 KiB
Python
Raw Normal View History

2025-11-26 13:55:04 +08:00
import datetime
import gc
2025-12-12 16:04:22 +08:00
import hashlib
2025-11-26 13:55:04 +08:00
import json
2025-12-13 16:13:12 +08:00
import logging
2025-11-26 13:55:04 +08:00
import os
import queue
2025-12-12 16:04:22 +08:00
import tempfile
2025-11-26 13:55:04 +08:00
import threading
import time
import traceback
import cv2
import paho.mqtt.client as mqtt
import requests
import torch
from ultralytics import YOLO
from _minio import MinioUploader
from log import logger
2025-12-11 13:41:07 +08:00
from global_data import gd
2025-12-16 10:08:12 +08:00
from detection_render import multi_model_inference
2025-12-12 16:04:22 +08:00
from mandatory_model_crypto import MandatoryModelEncryptor
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# detectionThread.py - 修改 ModelManager 类
2025-12-11 13:41:07 +08:00
class ModelManager:
"""模型管理器,支持多模型和加密模型"""
def __init__(self, config):
self.config = config
self.models_dir = "models"
2025-12-12 16:04:22 +08:00
self.encrypted_models_dir = config.get('upload', {}).get('encrypted_models_dir', 'encrypted_models')
# 确保目录存在
2025-12-11 13:41:07 +08:00
os.makedirs(self.models_dir, exist_ok=True)
2025-12-12 16:04:22 +08:00
os.makedirs(self.encrypted_models_dir, exist_ok=True)
# 模型加载缓存(避免重复解密)
self.model_cache = {}
self.cache_lock = threading.Lock()
2025-12-11 13:41:07 +08:00
def load_model(self, model_config, require_verification=False):
2025-12-12 16:04:22 +08:00
"""加载单个模型 - 从本地加载加密模型"""
2025-12-11 13:41:07 +08:00
try:
model_path = model_config['path']
encrypted = model_config.get('encrypted', False)
encryption_key = model_config.get('encryption_key')
# 构建本地路径
2025-12-12 16:04:22 +08:00
if encrypted:
# 加密模型从加密模型目录加载
if model_path.startswith('encrypted_models/'):
# 相对路径
local_path = os.path.join(self.encrypted_models_dir, os.path.basename(model_path))
elif os.path.isabs(model_path):
# 绝对路径
local_path = model_path
else:
# 尝试在加密目录中查找
model_filename = os.path.basename(model_path)
if not model_filename.endswith('.enc'):
model_filename += '.enc'
local_path = os.path.join(self.encrypted_models_dir, model_filename)
else:
# 普通模型从普通模型目录加载
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 检查模型文件是否存在
2025-12-11 13:41:07 +08:00
if not os.path.exists(local_path):
2025-12-12 16:04:22 +08:00
logger.error(f"模型文件不存在: {local_path}")
return None, {'success': False, 'error': f'模型文件不存在: {local_path}'}
# 检查缓存
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}" if encryption_key else local_path
with self.cache_lock:
if cache_key in self.model_cache:
logger.info(f"使用缓存的模型: {model_path}")
cached_info = self.model_cache[cache_key]
return cached_info['model'], cached_info.get('verification_result', {'success': True})
2025-12-11 13:41:07 +08:00
# 验证加密模型密钥(如果需要)
verification_result = None
2025-12-12 16:04:22 +08:00
model = None
2025-12-11 13:41:07 +08:00
if encrypted and encryption_key:
2025-12-12 16:04:22 +08:00
# 创建临时解密模型
2025-12-11 13:41:07 +08:00
try:
2025-12-12 16:04:22 +08:00
from mandatory_model_crypto import MandatoryModelValidator
validator = MandatoryModelValidator()
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 解密模型到内存
decrypt_result = validator.decrypt_and_verify(local_path, encryption_key)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
if not decrypt_result['success']:
logger.error(f"解密模型失败: {model_path} - {decrypt_result.get('error', '未知错误')}")
2025-12-11 13:41:07 +08:00
return None, decrypt_result
2025-12-12 16:04:22 +08:00
verification_result = {
'success': True,
'model_hash': decrypt_result.get('model_hash', ''),
'original_size': decrypt_result.get('original_size', 0)
}
# 解密数据
decrypted_data = decrypt_result['decrypted_data']
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 保存到临时文件并加载
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
tmp.write(decrypted_data)
temp_path = tmp.name
# 加载YOLO模型
model = YOLO(temp_path)
2025-12-11 13:41:07 +08:00
# 清理临时文件
try:
2025-12-12 16:04:22 +08:00
os.unlink(temp_path)
2025-12-11 13:41:07 +08:00
except Exception as e:
logger.warning(f"清理临时文件失败: {str(e)}")
2025-12-12 16:04:22 +08:00
logger.info(f"加密模型解密加载成功: {model_path}")
2025-12-11 13:41:07 +08:00
except ImportError:
logger.error("mandatory_model_crypto模块未找到无法处理加密模型")
return None, {'success': False, 'error': '加密模块未找到'}
except Exception as e:
logger.error(f"加密模型处理失败: {str(e)}")
return None, {'success': False, 'error': str(e)}
2025-12-12 16:04:22 +08:00
elif encrypted and not encryption_key:
# 加密模型但没有密钥
logger.error(f"加密模型但未提供密钥: {model_path}")
return None, {'success': False, 'error': '加密模型需要密钥'}
2025-12-11 13:41:07 +08:00
else:
# 普通模型加载
2025-12-12 16:04:22 +08:00
try:
model = YOLO(local_path)
logger.info(f"普通模型加载成功: {local_path}")
verification_result = {'success': True}
except Exception as e:
logger.error(f"加载普通模型失败: {str(e)}")
return None, {'success': False, 'error': str(e)}
if model is None:
return None, verification_result or {'success': False, 'error': '模型加载失败'}
2025-12-11 13:41:07 +08:00
# 应用设备配置
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 应用半精度配置
if model_config.get('half', False) and 'cuda' in device:
model = model.half()
logger.info(f"启用半精度推理: {model_path}")
2025-12-12 16:04:22 +08:00
# 缓存模型
with self.cache_lock:
self.model_cache[cache_key] = {
'model': model,
'verification_result': verification_result,
'device': device,
'cached_at': time.time()
}
2025-12-11 13:41:07 +08:00
logger.info(f"模型加载成功: {model_path} -> {device}")
2025-12-12 16:04:22 +08:00
return model, verification_result or {'success': True}
2025-12-11 13:41:07 +08:00
except Exception as e:
logger.error(f"加载模型失败 {model_config.get('path')}: {str(e)}")
logger.error(traceback.format_exc())
return None, {'success': False, 'error': str(e)}
2025-12-12 16:04:22 +08:00
def clear_cache(self):
"""清空模型缓存"""
with self.cache_lock:
self.model_cache.clear()
logger.info("模型缓存已清空")
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def get_cache_info(self):
"""获取缓存信息"""
with self.cache_lock:
return {
'cache_size': len(self.model_cache),
'cached_models': list(self.model_cache.keys()),
'total_size': sum(info.get('original_size', 0) for info in self.model_cache.values())
}
2025-12-11 13:41:07 +08:00
2025-11-26 13:55:04 +08:00
class DetectionThread(threading.Thread):
2025-12-11 13:41:07 +08:00
"""多模型检测线程 - 优化版本"""
2025-11-26 13:55:04 +08:00
def __init__(self, config):
super().__init__()
self.config = config
2025-12-11 13:41:07 +08:00
self.task_id = None
self.initialized = False
2025-11-26 13:55:04 +08:00
self.running = True
2025-12-11 13:41:07 +08:00
self._cleaning_up = False
self._force_stop = False
2025-12-11 15:08:28 +08:00
self._should_stop = threading.Event() # 添加停止事件
2025-12-11 13:41:07 +08:00
# 多模型支持
self.models = [] # 存储多个模型及相关配置
self.model_manager = ModelManager(config)
self.key_verification_results = {} # 密钥验证结果
# RTMP配置
2025-11-26 13:55:04 +08:00
self.cap = None
self.rtmp_url = config['rtmp']['url']
self.max_reconnect_attempts = config['rtmp']['max_reconnect_attempts']
self.reconnect_delay = config['rtmp']['reconnect_delay']
self.buffer_size = config['rtmp']['buffer_size']
self.timeout_ms = config['rtmp']['timeout_ms']
2025-12-11 13:41:07 +08:00
# 流稳定性监控
self.stream_stable = True
self.consecutive_read_failures = 0
self.max_consecutive_failures = 20 # 连续读取失败最大次数
self.retry_delay = 0.5 # 读取失败后的重试延迟(秒)
self.stream_check_interval = 10 # 流稳定性检查间隔(秒)
self.last_stream_check = time.time()
self.total_read_failures = 0
self.stream_recovery_attempts = 0
# 任务信息
2025-11-26 13:55:04 +08:00
self.taskname = config['task']['taskname']
self.taskid = config['task']['taskid']
self.aiid = config['task']['aiid']
2025-12-11 13:41:07 +08:00
# 推流管理
self.enable_push = config['push']['enable_push']
self.push_config = config['push']
self.task_streamer = None # 任务独立的推流器
self.streamer_initialized = False
self.last_push_time = 0
self.push_error_count = 0
self.max_push_errors = 5
# 性能监控
self.frame_count = 0
self.reconnect_attempts = 0
self.last_frame_time = time.time()
self.fps = 0
self.stop_event = threading.Event()
2025-11-26 13:55:04 +08:00
self.last_log_time = time.time()
self.daemon = True
self.target_latency = 0.05
self.processing_times = []
self.avg_process_time = 0.033
self.last_status_update = 0
self.last_fps = 0
self.original_width = 0
self.original_height = 0
2025-12-11 13:41:07 +08:00
# 上传配置
self.minio_uploader = MinioUploader(config['minio'])
2025-11-26 13:55:04 +08:00
self.upload_queue = queue.Queue(maxsize=50)
self.upload_thread = None
self.upload_active = False
self.upload_interval = 2
self.last_upload_time = 0
self.res_api = config['task']['res_api']
# MQTT配置
self.mqtt_config = config.get('mqtt', {})
self.mqtt_enabled = self.mqtt_config.get('enable', False)
self.mqtt_topic = self.mqtt_config.get('topic', 'drone/data')
self.mqtt_client = None
self.mqtt_connected = False
self.latest_drone_data = None
self.mqtt_data_lock = threading.Lock()
2025-12-11 13:41:07 +08:00
# Windows推流管理
if self._is_windows():
from windows_utils import detect_and_configure_windows
from task_stream_manager_windows import windows_task_stream_manager
# 检测和配置Windows
self.windows_config = detect_and_configure_windows()
self.stream_manager = windows_task_stream_manager
logger.info(f"Windows系统检测完成: {self.windows_config.get('status', 'unknown')}")
2025-11-26 13:55:04 +08:00
else:
2025-12-11 13:41:07 +08:00
# Linux/Mac使用原管理器
from task_stream_manager import task_stream_manager
self.stream_manager = task_stream_manager
# 推流状态跟踪
self.stream_stats = {
'total_frames_pushed': 0,
'failed_pushes': 0,
'last_push_time': 0,
'push_success_rate': 1.0,
'ffmpeg_restarts': 0
}
# 记录原始任务状态,以便恢复时知道应该恢复到什么状态
self.original_status = 'running'
# 当前任务状态跟踪
self._current_status = 'initializing'
2025-12-11 13:41:07 +08:00
# 密钥验证记录
self.key_verification_results = {}
def check_push_health(self):
"""检查推流健康状态,如果长时间失败则更新任务状态"""
# 检查推流失败次数和时间
current_time = time.time()
# 如果推流失败次数过多,或长时间没有成功推流,则认为推流健康状态不佳
if (self.push_error_count > self.max_push_errors or
(current_time - self.last_push_time > 60 and self.last_push_time > 0)): # 60秒内无成功推流
return False
return True
2025-12-11 13:41:07 +08:00
# 绘制结果
logger.info(f"检测线程初始化完成: {self.taskname}")
def _is_windows(self):
"""检查是否是Windows系统"""
import os
return os.name == 'nt' or os.name == 'win32'
def load_models(self):
2025-12-12 16:04:22 +08:00
"""加载多个模型 - 优化版本,从本地加载"""
2025-12-11 13:41:07 +08:00
try:
models_config = self.config.get('models', [])
if not models_config or not isinstance(models_config, list):
logger.error("未找到有效的models配置列表")
return False
2025-11-26 13:55:04 +08:00
2025-12-12 16:04:22 +08:00
logger.info(f"开始从本地加载 {len(models_config)} 个模型")
2025-12-11 13:41:07 +08:00
loaded_models = []
key_verification_results = {}
for i, model_config in enumerate(models_config):
# 检查模型是否启用
if not model_config.get('enabled', True):
logger.info(f"跳过未启用的模型 {i}")
continue
model_path = model_config.get('path', 'unknown')
model_name = os.path.basename(model_path).split('.')[0]
2025-12-12 16:04:22 +08:00
# 加载模型(从本地)
2025-12-11 13:41:07 +08:00
logger.info(f"加载模型 {i}: {model_name}")
model, verification_result = self.model_manager.load_model(
model_config,
2025-12-12 16:04:22 +08:00
require_verification=True # 总是验证密钥
2025-12-11 13:41:07 +08:00
)
# 记录验证结果
key_verification_results[i] = verification_result or {'success': False, 'error': '未知错误'}
if model is None:
logger.error(f"加载模型 {i} 失败: {model_name}")
continue
# 准备标签
tags = model_config.get('tags', {})
# 存储模型信息
model_info = {
'model': model,
'config': model_config,
'tags': tags,
'name': model_name,
'id': i,
'device': model_config.get('device', 'cpu'),
'imgsz': model_config.get('imgsz', 640),
'conf_thres': model_config.get('conf_thres', 0.25),
'iou_thres': model_config.get('iou_thres', 0.45),
'half': model_config.get('half', False),
'key_valid': verification_result.get('success', False) if verification_result else False,
'model_hash': verification_result.get('model_hash', 'unknown') if verification_result else 'unknown'
}
loaded_models.append(model_info)
logger.info(
2025-12-13 16:13:12 +08:00
f"模型加载成功: {model_name}, 设备: {model_info['device']}, 密钥验证: {model_info['key_valid']}")
2025-12-11 13:41:07 +08:00
# 检查加载结果
if len(loaded_models) == 0:
logger.error("所有模型加载失败")
# 生成详细的错误报告
error_report = []
for idx, result in key_verification_results.items():
if not result.get('success', False):
error_report.append(f"模型 {idx}: {result.get('error', '未知错误')}")
if error_report:
logger.error("模型加载失败详情:")
for error in error_report:
logger.error(f" {error}")
return False
self.models = loaded_models
self.key_verification_results = key_verification_results
logger.info(f"成功加载 {len(self.models)}/{len(models_config)} 个模型")
# 输出验证统计
valid_keys = sum(1 for result in key_verification_results.values()
if result.get('success', False))
logger.info(f"密钥验证统计: {valid_keys}个有效, {len(key_verification_results) - valid_keys}个无效")
return True
except Exception as e:
logger.error(f"加载模型异常: {str(e)}")
logger.error(traceback.format_exc())
return False
def get_key_verification_summary(self):
"""获取密钥验证摘要"""
summary = {
'total_models': len(self.key_verification_results),
'valid_keys': 0,
'invalid_keys': 0,
'loaded_models': len(self.models),
'details': {}
}
for idx, result in self.key_verification_results.items():
if result.get('success', False):
summary['valid_keys'] += 1
else:
summary['invalid_keys'] += 1
summary['details'][f'model_{idx}'] = {
'success': result.get('success', False),
'error': result.get('error', ''),
'model_hash': result.get('model_hash', '')
}
return summary
def initialize_rtmp(self):
"""初始化RTMP连接"""
2025-11-26 13:55:04 +08:00
try:
2025-12-11 13:41:07 +08:00
logger.info(f"连接RTMP: {self.rtmp_url}")
self.cap = cv2.VideoCapture()
# 设置优化参数
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, self.buffer_size)
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
self.cap.set(cv2.CAP_PROP_FPS, 30)
# 启用硬件加速解码
if self.config['rtmp'].get('gpu_decode', False):
try:
self.cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
logger.info("启用硬件加速解码")
except:
logger.warning("硬件解码不可用,使用软件解码")
# 尝试连接
if not self.cap.open(self.rtmp_url, cv2.CAP_FFMPEG):
logger.error(f"连接RTMP失败: {self.rtmp_url}")
raise IOError(f"无法连接RTMP流 ({self.rtmp_url})")
# 获取视频属性
self.original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
self.original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
self.fps = fps
# 重置流稳定性计数器
self.consecutive_read_failures = 0
self.total_read_failures = 0
self.stream_stable = True
logger.info(f"视频属性: {self.original_width}x{self.original_height} @ {fps}fps")
return True
2025-11-26 13:55:04 +08:00
except Exception as e:
2025-12-11 13:41:07 +08:00
logger.error(f"初始化RTMP失败: {str(e)}")
raise
def initialize_task_streamer(self):
"""初始化任务推流器Windows优化版"""
if not self.enable_push:
logger.info(f"任务 {self.task_id} 推流功能未启用")
return True
2025-11-26 13:55:04 +08:00
try:
2025-12-11 13:41:07 +08:00
logger.info(f"初始化任务推流器: {self.task_id}")
# Windows特定检查和配置
if self._is_windows():
# 检查RTMP服务器可达性
from windows_utils import WindowsSystemUtils
push_url = self.push_config.get('url', '')
if push_url:
accessibility = WindowsSystemUtils.check_rtmp_server_accessibility(push_url)
if not accessibility.get('accessible', False):
logger.error(f"RTMP服务器不可达: {accessibility.get('error', 'Unknown error')}")
logger.warning("推流可能失败,请检查网络和服务器状态")
# 使用推流管理器创建推流器
streamer = self.stream_manager.create_streamer_for_task(
self.task_id,
self.config, # 传递完整配置
self.fps,
self.original_width,
self.original_height
)
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
if streamer:
logger.info(f"任务 {self.task_id} 推流器初始化成功")
self.streamer_initialized = True
self.task_streamer = streamer
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# 记录初始化时间
self.stream_stats['streamer_start_time'] = time.time()
return True
else:
logger.error(f"任务 {self.task_id} 推流器初始化失败")
# Windows上尝试备用配置
if self._is_windows():
logger.info("尝试Windows备用推流配置...")
return self._initialize_windows_fallback_streamer()
return False
except Exception as e:
logger.error(f"初始化任务推流器异常 {self.task_id}: {str(e)}", exc_info=True)
# Windows上的额外错误处理
if self._is_windows():
return self._initialize_windows_fallback_streamer()
return False
def _initialize_windows_fallback_streamer(self):
"""Windows备用推流器初始化"""
try:
logger.info(f"使用Windows备用推流配置: {self.task_id}")
# 简化配置,避免硬件加速
fallback_config = self.config.copy()
fallback_config['push'] = {
'enable_push': True,
'url': self.push_config['url'],
'video_codec': 'libx264',
'preset': 'ultrafast',
'tune': 'zerolatency',
'format': 'flv',
'pixel_format': 'bgr24',
'gpu_acceleration': False,
'bitrate': '1000k',
'bufsize': '2000k',
'framerate': self.fps,
'extra_args': [
'-max_delay', '0',
'-flags', '+global_header',
'-rtbufsize', '50M'
]
}
# 使用简化配置重新初始化
streamer = self.stream_manager.create_streamer_for_task(
self.task_id,
fallback_config,
self.fps,
self.original_width,
self.original_height
2025-11-26 13:55:04 +08:00
)
2025-12-11 13:41:07 +08:00
if streamer:
logger.info(f"Windows备用推流器初始化成功: {self.task_id}")
self.streamer_initialized = True
self.task_streamer = streamer
return True
else:
logger.error(f"Windows备用推流器也失败: {self.task_id}")
return False
2025-11-26 13:55:04 +08:00
except Exception as e:
2025-12-11 13:41:07 +08:00
logger.error(f"Windows备用推流器初始化异常: {str(e)}")
2025-11-26 13:55:04 +08:00
return False
2025-12-11 13:41:07 +08:00
def push_frame_to_task_streamer(self, frame):
"""推送帧到任务推流器Windows优化版"""
if not self.enable_push or not self.streamer_initialized:
return False
2025-11-26 13:55:04 +08:00
try:
2025-12-11 13:41:07 +08:00
# 使用推流管理器推送帧
success = self.stream_manager.push_frame(self.task_id, frame)
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# 更新统计信息
self.stream_stats['total_frames_pushed'] += 1
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
if success:
self.last_push_time = time.time()
self.stream_stats['last_push_time'] = time.time()
self.push_error_count = 0
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# 更新成功率
total = self.stream_stats['total_frames_pushed']
failed = self.stream_stats['failed_pushes']
self.stream_stats['push_success_rate'] = (total - failed) / max(total, 1)
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# 每100帧记录一次成功
if total % 100 == 0:
logger.info(
f"任务 {self.task_id} 已成功推流 {total} 帧,成功率: {self.stream_stats['push_success_rate']:.2%}")
else:
self.push_error_count += 1
self.stream_stats['failed_pushes'] += 1
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
logger.warning(f"任务 {self.task_id} 推流失败 ({self.push_error_count}/{self.max_push_errors})")
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# Windows上的额外诊断
if self._is_windows() and self.push_error_count >= 3:
self._diagnose_windows_streaming_issue()
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# 连续失败处理
if self.push_error_count >= self.max_push_errors:
logger.error(f"任务 {self.task_id} 推流连续失败,尝试恢复")
# 将任务状态更新为降级状态
self.update_task_status('degraded')
2025-12-11 13:41:07 +08:00
self.recover_task_streamer()
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
return success
except Exception as e:
logger.error(f"推流异常 {self.task_id}: {str(e)}")
self.push_error_count += 1
self.stream_stats['failed_pushes'] += 1
return False
def _diagnose_windows_streaming_issue(self):
"""诊断Windows推流问题"""
try:
logger.info("诊断Windows推流问题...")
# 获取推流信息
stream_info = self.stream_manager.get_task_stream_info(self.task_id)
if stream_info:
logger.info(f"推流状态: {stream_info.get('status', 'unknown')}")
logger.info(f"FFmpeg进程: {'运行中' if stream_info.get('process_alive') else '已停止'}")
logger.info(f"最近输出: {stream_info.get('last_ffmpeg_output', '')}")
# 显示最近错误
output_lines = stream_info.get('output_lines', [])
if output_lines:
logger.info("最近FFmpeg输出:")
for line in output_lines[-5:]:
logger.info(f" {line}")
# 检查系统资源
from windows_utils import WindowsSystemUtils
resources = WindowsSystemUtils.get_system_resources()
if resources['cpu_percent'] > 90:
logger.warning(f"CPU使用率过高: {resources['cpu_percent']}%")
if resources['memory_percent'] > 90:
logger.warning(f"内存使用率过高: {resources['memory_percent']}%")
except Exception as e:
logger.error(f"诊断推流问题异常: {str(e)}")
def recover_task_streamer(self):
"""恢复任务推流器Windows优化版"""
try:
logger.info(f"任务 {self.task_id} 尝试恢复推流器")
# Windows上增加重启计数
self.stream_stats['ffmpeg_restarts'] += 1
# 清理并重新初始化
self.cleanup_task_streamer()
time.sleep(2) # Windows上需要更长的等待时间
success = self.initialize_task_streamer()
if success:
self.push_error_count = 0
# 恢复正常状态
if hasattr(self, 'original_status') and self.original_status == 'running':
self.update_task_status('running')
2025-12-11 13:41:07 +08:00
logger.info(f"任务 {self.task_id} 推流器恢复成功 (第{self.stream_stats['ffmpeg_restarts']}次重启)")
else:
logger.error(f"任务 {self.task_id} 推流器恢复失败")
# 如果恢复失败,且已达到最大重启次数,将任务状态设置为错误
if self.stream_stats['ffmpeg_restarts'] >= 3: # 假设最大重启3次
self.update_task_status('error')
# 同时停止整个检测线程
self._force_stop = True
self._should_stop.set()
self.stop_event.set()
self.running = False
2025-12-11 13:41:07 +08:00
return success
except Exception as e:
logger.error(f"恢复推流器异常 {self.task_id}: {str(e)}")
return False
def cleanup_task_streamer(self):
"""清理任务推流器"""
if self.task_streamer:
try:
self.task_streamer.stop()
2025-12-11 15:08:28 +08:00
self.stream_manager.stop_task_streamer(self.task_id)
2025-12-11 13:41:07 +08:00
self.task_streamer = None
logger.info(f"任务 {self.task_id} 推流器已清理")
except Exception as e:
logger.error(f"停止推流器失败: {str(e)}")
self.streamer_initialized = False
def warmup_models(self):
"""预热所有模型"""
logger.info("预热所有模型...")
for model_info in self.models:
model = model_info['model']
model_config = model_info['config']
try:
# 准备预热输入
imgsz = model_config.get('imgsz', 640)
dummy_input = torch.zeros(1, 3, imgsz, imgsz)
device = model_config.get('device', 'cpu')
dummy_input = dummy_input.to(device)
if model_config.get('half', False) and 'cuda' in device:
dummy_input = dummy_input.half()
# 预热推理
with torch.no_grad():
for _ in range(2):
model.predict(dummy_input)
logger.info(f"模型 {model_info['name']} 预热完成")
except Exception as e:
logger.warning(f"模型 {model_info['name']} 预热失败: {str(e)}")
logger.info("所有模型预热完成")
2025-12-16 10:08:12 +08:00
def _multi_model_inference(self, frame):
2025-12-11 13:41:07 +08:00
"""多模型推理(每个模型独立标签和置信度)"""
2025-12-16 10:08:12 +08:00
frame_drawn, detections = multi_model_inference(self.models, frame)
return frame_drawn, detections
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
def should_skip_frame(self, start_time):
"""判断是否应该跳过当前帧"""
processing_time = time.perf_counter() - start_time
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
# 基于处理时间判断
if processing_time > self.target_latency:
return True
# 基于FPS判断
min_fps = 15
if self.fps < min_fps:
return True
return False
def handle_upload(self, annotated_frame, all_detections, current_time):
"""处理上传逻辑"""
# 合并所有检测结果
all_detections_list = []
for model_detections in all_detections.values():
all_detections_list.extend(model_detections)
if len(all_detections_list) > 0 and (current_time - self.last_upload_time >= self.upload_interval):
try:
timestamp = int(current_time * 1000)
filename = f"DJI_{timestamp}.jpg"
self.upload_queue.put({
"image": annotated_frame.copy(),
"filename": filename,
"detection_data": all_detections_list,
"timestamp": current_time
}, block=False)
self.last_upload_time = current_time
except queue.Full:
logger.warning("上传队列已满,跳过上传")
except Exception as e:
logger.error(f"添加上传任务失败: {e}")
2025-11-26 13:55:04 +08:00
def _upload_worker(self):
"""独立的图片上传工作线程"""
logger.info("上传工作线程启动")
output_dir = "output_frames"
os.makedirs(output_dir, exist_ok=True)
while self.upload_active or not self.upload_queue.empty():
try:
task = self.upload_queue.get(timeout=1.0)
if task is None:
break
start_time = time.time()
image = task["image"]
filename = task["filename"]
detection_data = task["detection_data"]
filepath = os.path.join(output_dir, filename)
2025-12-11 13:41:07 +08:00
# 优化图片保存质量
cv2.imwrite(filepath, image, [cv2.IMWRITE_JPEG_QUALITY, 85])
2025-11-26 13:55:04 +08:00
try:
foldername = self.config['task']['taskname']
object_path = f"{foldername}/{filename}"
self.minio_uploader.upload_file(filepath, object_path)
payload = {
"taskid": self.taskid,
"path": object_path,
"tag": detection_data,
"aiid": self.aiid,
}
2025-12-11 13:41:07 +08:00
2025-11-26 13:55:04 +08:00
# 添加MQTT数据
if self.mqtt_enabled and self.mqtt_connected:
with self.mqtt_data_lock:
if self.latest_drone_data:
payload["drone_info"] = self.latest_drone_data
# 发送MQTT消息
if self.mqtt_client and self.mqtt_connected:
try:
# 构造MQTT主题
mqtt_topic = f'ai/task/{self.taskid}/achievement'
# 发布消息
result = self.mqtt_client.publish(mqtt_topic, json.dumps(payload, ensure_ascii=False))
if result.rc == mqtt.MQTT_ERR_SUCCESS:
logger.debug(f'已上传帧至 MinIO: {object_path} | MQTT消息已发送到主题: {mqtt_topic} | 耗时: {time.time() - start_time:.2f}s')
else:
logger.warning(f'MQTT消息发送失败: {mqtt.error_string(result.rc)}')
except Exception as e:
logger.error(f'MQTT消息发送异常: {str(e)}')
2025-11-26 13:55:04 +08:00
else:
logger.warning('MQTT客户端未连接跳过消息发送')
2025-12-11 13:41:07 +08:00
except requests.exceptions.Timeout:
logger.warning(f"API调用超时: {self.res_api}")
2025-11-26 13:55:04 +08:00
except Exception as e:
logger.error(f"上传/API调用失败: {e}")
finally:
try:
os.remove(filepath)
except:
pass
2025-12-11 13:41:07 +08:00
# 标记任务完成
self.upload_queue.task_done()
2025-11-26 13:55:04 +08:00
except queue.Empty:
continue
except Exception as e:
logger.error(f"上传任务处理异常: {e}")
logger.info("上传工作线程已停止")
2025-12-11 13:41:07 +08:00
def on_mqtt_connect(self, client, userdata, flags, rc):
"""MQTT连接回调"""
if rc == 0:
client.subscribe(self.mqtt_topic)
self.mqtt_connected = True
2025-12-13 16:13:12 +08:00
# logger.debug("MQTT连接状态正常")
2025-12-11 13:41:07 +08:00
else:
logger.error(f"MQTT连接失败错误码: {rc}")
self.mqtt_connected = False
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
def on_mqtt_message(self, client, userdata, msg):
"""MQTT消息回调"""
try:
drone_data = json.loads(msg.payload.decode())
with self.mqtt_data_lock:
self.latest_drone_data = drone_data
logger.debug(f"收到MQTT消息: {drone_data}")
except Exception as e:
logger.error(f"解析MQTT消息失败: {str(e)}")
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
def start_mqtt_client(self):
"""启动MQTT客户端"""
if not self.mqtt_enabled:
logger.info("MQTT功能未启用")
return False
try:
logger.info("启动MQTT客户端...")
self.mqtt_client = mqtt.Client(client_id=self.mqtt_config.get('client_id', 'yolo_detection'))
self.mqtt_client.on_connect = self.on_mqtt_connect
self.mqtt_client.on_message = self.on_mqtt_message
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
if 'username' in self.mqtt_config and 'password' in self.mqtt_config:
self.mqtt_client.username_pw_set(
self.mqtt_config['username'],
self.mqtt_config['password']
)
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
self.mqtt_client.connect(
self.mqtt_config['broker'],
self.mqtt_config.get('port', 1883),
self.mqtt_config.get('keepalive', 60)
)
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
self.mqtt_client.loop_start()
logger.info("MQTT客户端已启动")
return True
except Exception as e:
logger.error(f"启动MQTT客户端失败: {str(e)}")
return False
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
def stop_mqtt_client(self):
"""停止MQTT客户端"""
if self.mqtt_client:
2025-11-26 13:55:04 +08:00
try:
2025-12-11 13:41:07 +08:00
self.mqtt_client.loop_stop()
self.mqtt_client.disconnect()
logger.info("MQTT客户端已停止")
except Exception as e:
logger.error(f"停止MQTT客户端失败: {str(e)}")
finally:
self.mqtt_client = None
self.mqtt_connected = False
2025-11-26 13:55:04 +08:00
def handle_reconnect(self):
2025-12-11 13:41:07 +08:00
"""处理RTMP重连"""
if self.stop_event.is_set() or not self.running:
logger.info("收到停止信号,跳过重连")
return
2025-11-26 13:55:04 +08:00
self.reconnect_attempts += 1
if self.reconnect_attempts >= self.max_reconnect_attempts:
logger.error("达到最大重连次数")
self.running = False
return
2025-12-11 13:41:07 +08:00
# 使用指数退避策略计算延迟时间
delay = min(30, self.reconnect_attempts * self.reconnect_delay * 2) # 最大延迟30秒
logger.warning(f"流中断,{delay}秒后重连 (第{self.reconnect_attempts}/{self.max_reconnect_attempts}次重连)")
2025-11-26 13:55:04 +08:00
# 释放视频资源
if self.cap:
2025-12-11 13:41:07 +08:00
try:
self.cap.release()
self.cap = None
except:
pass
# 在延迟期间检查停止信号
start_time = time.time()
while time.time() - start_time < delay:
if self.stop_event.is_set() or not self.running:
logger.info("收到停止信号,取消重连")
return
time.sleep(0.5) # 增加休眠时间减少CPU使用
# 重置流稳定性计数器
self.consecutive_read_failures = 0
self.stream_stable = True
self.stream_recovery_attempts += 1
2025-11-26 13:55:04 +08:00
# 重新连接
try:
2025-12-11 13:41:07 +08:00
if self.stop_event.is_set() or not self.running:
logger.info("收到停止信号,跳过重新连接")
return
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
logger.info("尝试重新连接RTMP...")
if not self.initialize_rtmp():
raise IOError("RTMP重连失败")
logger.info("RTMP重连成功")
self.reconnect_attempts = 0 # 重置重连次数
self.stream_recovery_attempts = 0 # 重置恢复尝试次数
2025-11-26 13:55:04 +08:00
except Exception as e:
logger.error(f"重连异常: {str(e)}")
2025-12-11 13:41:07 +08:00
def send_to_websocket(self, all_detections):
"""发送检测结果到WebSocket"""
2025-11-26 13:55:04 +08:00
try:
now = datetime.datetime.now()
time_str = now.strftime("%H:%M:%S")
2025-12-11 13:41:07 +08:00
# 合并所有模型的检测结果
2025-12-16 10:08:12 +08:00
all_detections_send = []
for det in all_detections:
print(f" 模型 {det.model_name}: 检测到 {len(det.boxes)} 个目标")
det_detections_res = []
for box, conf, cls_id, cls_name in zip(det.boxes, det.confidences,
det.class_ids, det.class_names):
det_detections_res.append(
{
'class_id': cls_id,
'class_name': cls_name,
'box': box,
'conf': conf,
}
)
det_detection = {
'count': len(det.boxes),
'detections': det_detections_res,
}
all_detections_send.append(det_detection)
2025-12-11 13:41:07 +08:00
# 添加流稳定性信息
stream_info = {
'stable': self.stream_stable,
'consecutive_failures': self.consecutive_read_failures,
'total_failures': self.total_read_failures,
'recovery_attempts': self.stream_recovery_attempts
}
2025-11-26 13:55:04 +08:00
self.config['socketIO'].emit('detection_results', {
2025-12-11 13:41:07 +08:00
'task_id': getattr(self, 'task_id', 'unknown'),
2025-12-16 10:08:12 +08:00
'detections': all_detections_send,
2025-11-26 13:55:04 +08:00
'timestamp': time.time_ns() // 1000000,
'fps': round(self.last_fps, 1),
'frame_count': self.frame_count,
2025-12-11 13:41:07 +08:00
'taskname': self.taskname,
'time_str': time_str,
'models_count': len(self.models),
'stream_info': stream_info
2025-11-26 13:55:04 +08:00
})
except Exception as e:
logger.error(f"WebSocket发送错误: {str(e)}")
2025-12-11 13:41:07 +08:00
def check_stream_stability(self):
"""检查流稳定性"""
current_time = time.time()
if current_time - self.last_stream_check >= self.stream_check_interval:
if self.consecutive_read_failures > 0:
logger.warning(f"流稳定性警告: 连续读取失败 {self.consecutive_read_failures}")
self.stream_stable = False
else:
self.stream_stable = True
self.last_stream_check = current_time
return True
return False
def handle_frame_read_failure(self):
"""处理帧读取失败"""
self.consecutive_read_failures += 1
self.total_read_failures += 1
# 检查流稳定性
self.check_stream_stability()
# 根据连续失败次数采取不同策略
if self.consecutive_read_failures == 1:
# 第一次失败,短暂等待后重试
logger.debug("帧读取失败等待0.5秒后重试")
time.sleep(self.retry_delay)
return False # 不立即重连
elif self.consecutive_read_failures <= 3:
# 2-3次失败中等等待后重试
logger.warning(f"连续 {self.consecutive_read_failures} 次帧读取失败等待1秒后重试")
time.sleep(1)
return False # 不立即重连
elif self.consecutive_read_failures <= self.max_consecutive_failures:
# 4-5次失败较长时间等待后重试
logger.error(f"连续 {self.consecutive_read_failures} 次帧读取失败等待2秒后重试")
time.sleep(2)
return False # 不立即重连
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
else:
# 超过最大失败次数,需要重连
logger.error(f"连续 {self.consecutive_read_failures} 次帧读取失败,触发重连")
return True # 需要重连
def run(self):
"""检测线程主循环"""
try:
logger.info(f"启动检测线程任务ID: {getattr(self, 'task_id', 'unknown')}")
# Windows系统检查
if self._is_windows() and hasattr(self, 'windows_config'):
if self.windows_config and 'error' in self.windows_config:
logger.warning(f"Windows配置问题: {self.windows_config['error']}")
# 1. 加载多个模型
if not self.load_models():
logger.error("加密模型加载失败,线程退出")
self.update_task_status('failed')
return
key_summary = self.get_key_verification_summary()
logger.info(
f"加密密钥验证结果: {key_summary['loaded_models']}/{key_summary['total_models']} 个模型加载成功")
# 2. 初始化RTMP连接
if not self.initialize_rtmp():
logger.error("RTMP连接失败线程退出")
self.update_task_status('failed')
return
# 3. 初始化任务推流器
if self.enable_push:
if not self.initialize_task_streamer():
logger.warning(f"任务 {self.task_id} 推流器初始化失败,继续运行但不推流")
self.enable_push = False
# 4. 启动推流管理器健康监控
if self.enable_push and hasattr(self.stream_manager, 'start_health_monitor'):
try:
self.stream_manager.start_health_monitor()
logger.info("推流健康监控已启动")
except Exception as e:
logger.warning(f"启动健康监控失败: {e}")
2025-12-11 15:08:28 +08:00
2025-12-11 13:41:07 +08:00
# 5. 启动MQTT如果启用
if self.mqtt_enabled:
self.start_mqtt_client()
# 6. 启动上传线程
self.upload_active = True
self.upload_thread = threading.Thread(target=self._upload_worker, daemon=True)
self.upload_thread.start()
logger.info("图片上传线程已启动")
# 7. 预热所有模型
self.warmup_models()
# 8. 更新任务状态
self.initialized = True
self.update_task_status('running')
self.original_status = 'running' # 记录原始状态为running
2025-12-11 13:41:07 +08:00
logger.info("资源初始化完成")
# 9. 主循环
while self.running and not self.stop_event.is_set():
2025-12-11 15:08:28 +08:00
# 检查停止标志
if self._force_stop or self._should_stop.is_set():
logger.info("收到停止信号,退出循环")
break
if self.stop_event.is_set():
logger.info("收到停止事件信号,退出循环")
2025-12-11 13:41:07 +08:00
break
start_time = time.perf_counter()
try:
# 读取帧
ret, frame = self.cap.read()
if not ret:
2025-12-11 15:08:28 +08:00
if self.stop_event.is_set() or not self.running or self._force_stop:
2025-12-11 13:41:07 +08:00
logger.info("收到停止信号,不再尝试重连")
break
# 处理帧读取失败(带有重试和延迟策略)
need_reconnect = self.handle_frame_read_failure()
if need_reconnect:
# 需要重连
self.handle_reconnect()
continue
# 成功读取帧,重置失败计数器
if self.consecutive_read_failures > 0:
logger.info(f"恢复帧读取成功,之前连续失败 {self.consecutive_read_failures}")
self.consecutive_read_failures = 0
self.stream_stable = True
# 确保帧分辨率一致
if frame.shape[1] != self.original_width or frame.shape[0] != self.original_height:
frame = cv2.resize(frame, (self.original_width, self.original_height))
# 计算FPS
current_time = time.time()
time_diff = current_time - self.last_frame_time
if time_diff > 0:
self.fps = 0.9 * self.fps + 0.1 / time_diff
if current_time - self.last_status_update > 0.5:
self.last_fps = self.fps
self.last_status_update = current_time
self.last_frame_time = current_time
# 动态跳帧Windows上更保守
if self.should_skip_frame(start_time):
self.frame_count += 1
continue
# 多模型推理
2025-12-16 10:58:30 +08:00
start = time.time()
2025-12-16 10:08:12 +08:00
annotated_frame, model_detections = self._multi_model_inference(frame)
2025-12-16 10:58:30 +08:00
logger.info(f'startTime:{start},endTime:{time.time()},时间差:{time.time() - start}')
2025-12-11 13:41:07 +08:00
# 推流处理Windows优化
if self.enable_push:
if not self.push_frame_to_task_streamer(annotated_frame):
# 推流失败,但继续处理,避免影响检测
pass
if current_time - self.last_log_time >= 1:
2025-12-16 10:08:12 +08:00
# # WebSocket发送
self.send_to_websocket(model_detections)
# # 上传处理
# self.handle_upload(annotated_frame, model_detections, current_time)
# 检查推流健康状态
if self.enable_push:
if not self.check_push_health():
# 如果推流不健康,且当前不是错误状态,则更新为降级状态
current_status = getattr(self, '_current_status', 'running')
if current_status != 'error' and current_status != 'degraded':
self.update_task_status('degraded')
2025-12-11 13:41:07 +08:00
self.last_log_time = current_time
self.frame_count += 1
self.reconnect_attempts = 0
# 性能监控
elapsed = time.perf_counter() - start_time
self.processing_times.append(elapsed)
self.avg_process_time = self.avg_process_time * 0.9 + elapsed * 0.1
# 每50帧输出一次性能报告包含推流统计
if self.frame_count % 50 == 0:
avg_time = sum(self.processing_times) / len(
self.processing_times) if self.processing_times else 0
logger.info(
f"帧处理耗时: {avg_time * 1000:.2f}ms | 平均FPS: {(1.0 / avg_time if avg_time > 0 else 0):.2f}")
# 流稳定性报告
if self.total_read_failures > 0:
logger.info(
f"流稳定性: 总失败次数 {self.total_read_failures}, 当前连续失败 {self.consecutive_read_failures}")
# 推流统计
if self.enable_push:
logger.info(f"推流统计: {self.stream_stats['total_frames_pushed']}帧, "
f"成功率: {self.stream_stats['push_success_rate']:.2%}, "
f"重启: {self.stream_stats.get('ffmpeg_restarts', 0)}")
self.processing_times = []
except Exception as e:
logger.error(f"处理帧时发生异常: {str(e)}")
2025-12-11 15:08:28 +08:00
if self.stop_event.is_set() or not self.running or self._force_stop:
2025-12-11 13:41:07 +08:00
logger.info("收到停止信号,退出异常处理")
break
continue
logger.info("检测线程主循环结束")
except Exception as e:
logger.error(f"检测线程异常: {str(e)}")
logger.error(traceback.format_exc())
finally:
self.cleanup()
logger.info("检测线程已安全停止")
self.update_task_status('stopped')
2025-12-11 15:08:28 +08:00
return False
2025-11-26 13:55:04 +08:00
def stop(self):
"""停止检测线程"""
2025-12-11 13:41:07 +08:00
logger.info(f"收到停止请求任务ID: {getattr(self, 'task_id', 'unknown')}")
2025-12-11 15:08:28 +08:00
# 首先设置强制停止标志
2025-12-11 13:41:07 +08:00
self._force_stop = True
2025-12-11 15:08:28 +08:00
self._should_stop.set() # 设置停止事件
# 发送停止事件
self.stop_event.set()
# 设置运行标志为False
self.running = False
# 强制释放视频流以解除阻塞
if self.cap:
try:
# 强制中断视频读取
self.cap.release()
logger.info("已强制释放视频流")
except Exception as e:
logger.error(f"释放视频流失败: {str(e)}")
2025-12-11 13:41:07 +08:00
2025-12-11 15:08:28 +08:00
# 清理任务推流器
self.cleanup_task_streamer()
2025-11-26 13:55:04 +08:00
# 停止MQTT客户端
if self.mqtt_enabled:
self.stop_mqtt_client()
# 停止上传线程
self.upload_active = False
2025-12-11 15:08:28 +08:00
if hasattr(self, 'upload_thread') and self.upload_thread and self.upload_thread.is_alive():
# 发送停止信号到上传队列
2025-11-26 13:55:04 +08:00
try:
2025-12-11 15:08:28 +08:00
self.upload_queue.put(None, timeout=0.5)
2025-11-26 13:55:04 +08:00
except:
pass
2025-12-11 13:41:07 +08:00
logger.info(f"停止信号已发送任务ID: {getattr(self, 'task_id', 'unknown')}")
def cleanup(self):
"""清理所有资源"""
logger.info(f"开始清理资源任务ID: {getattr(self, 'task_id', 'unknown')}")
if hasattr(self, '_cleaning_up') and self._cleaning_up:
logger.warning("资源已经在清理中,跳过重复清理")
return
self._cleaning_up = True
try:
# 停止MQTT客户端
if self.mqtt_enabled:
self.stop_mqtt_client()
# 停止上传线程
self.upload_active = False
if hasattr(self, 'upload_thread') and self.upload_thread and self.upload_thread.is_alive():
logger.info("停止上传线程...")
try:
if hasattr(self, 'upload_queue'):
try:
self.upload_queue.put(None, timeout=0.5)
except queue.Full:
try:
while not self.upload_queue.empty():
self.upload_queue.get_nowait()
except:
pass
self.upload_queue.put(None, block=False)
except:
pass
self.upload_thread.join(3.0)
if self.upload_thread.is_alive():
logger.warning("上传线程未在3秒内停止")
except Exception as e:
logger.error(f"停止上传线程异常: {str(e)}")
finally:
self.upload_thread = None
# 清理任务推流器
self.cleanup_task_streamer()
# 释放视频流
if hasattr(self, 'cap') and self.cap:
logger.info("释放视频流...")
try:
self.cap.release()
except Exception as e:
logger.error(f"释放视频流异常: {str(e)}")
finally:
self.cap = None
# 释放所有模型
logger.info(f"释放所有模型,共 {len(self.models)}")
for i, model_info in enumerate(self.models):
try:
model = model_info['model']
model_name = model_info['name']
# 清理模型缓存
if hasattr(model, 'predictor'):
try:
del model.predictor
except:
pass
if hasattr(model, 'model'):
try:
del model.model
except:
pass
# 释放模型引用
del model
logger.info(f"模型 {model_name} 已释放")
except Exception as e:
logger.error(f"释放模型 {i} 异常: {str(e)}")
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
self.models = []
# 清理GPU缓存
logger.info("清理GPU缓存...")
2025-11-26 13:55:04 +08:00
try:
2025-12-11 13:41:07 +08:00
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("GPU缓存已清理")
except Exception as e:
logger.error(f"清理GPU缓存异常: {str(e)}")
# 清理其他资源
logger.info("清理其他资源...")
try:
if hasattr(self, 'upload_queue'):
try:
while not self.upload_queue.empty():
try:
self.upload_queue.get_nowait()
except:
break
except:
pass
self.processing_times = []
logger.info("其他资源已清理")
except Exception as e:
logger.error(f"清理其他资源异常: {str(e)}")
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
logger.info("资源清理完成")
2025-11-26 13:55:04 +08:00
2025-12-11 13:41:07 +08:00
except Exception as e:
logger.error(f"资源清理过程中发生异常: {str(e)}")
logger.error(traceback.format_exc())
2025-12-11 15:08:28 +08:00
finally:
self._cleaning_up = False # 重置清理标志
2025-12-11 13:41:07 +08:00
def update_task_status(self, status):
"""更新任务状态"""
if hasattr(self, 'task_id') and self.task_id:
try:
from task_manager import task_manager
if hasattr(task_manager, 'update_task_status'):
task_manager.update_task_status(self.task_id, status)
else:
# 备用方法:直接更新全局数据
tasks_dict = gd.get_or_create_dict('tasks')
if self.task_id in tasks_dict:
tasks_dict[self.task_id]['status'] = status
logger.debug(f"更新任务 {self.task_id} 状态为: {status}")
# 更新本地状态跟踪
self._current_status = status
2025-12-11 13:41:07 +08:00
except Exception as e:
2025-12-13 16:13:12 +08:00
logger.warning(f"更新任务状态失败: {str(e)}")