Yolov/detectionThread.py

1618 lines
66 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import datetime
import gc
import hashlib
import json
import logging
import os
import queue
import tempfile
import threading
import time
import traceback
import cv2
import paho.mqtt.client as mqtt
import requests
import torch
from ultralytics import YOLO
from _minio import MinioUploader
from log import logger
from global_data import gd
from detection_render import multi_model_inference
from mandatory_model_crypto import MandatoryModelEncryptor
# detectionThread.py - 修改 ModelManager 类
class ModelManager:
"""模型管理器,支持多模型和加密模型"""
def __init__(self, config):
self.config = config
self.models_dir = "models"
self.encrypted_models_dir = config.get('upload', {}).get('encrypted_models_dir', 'encrypted_models')
# 确保目录存在
os.makedirs(self.models_dir, exist_ok=True)
os.makedirs(self.encrypted_models_dir, exist_ok=True)
# 模型加载缓存(避免重复解密)
self.model_cache = {}
self.cache_lock = threading.Lock()
def load_model(self, model_config, require_verification=False):
"""加载单个模型 - 从本地加载加密模型"""
try:
model_path = model_config['path']
encrypted = model_config.get('encrypted', False)
encryption_key = model_config.get('encryption_key')
# 构建本地路径
if encrypted:
# 加密模型从加密模型目录加载
if model_path.startswith('encrypted_models/'):
# 相对路径
local_path = os.path.join(self.encrypted_models_dir, os.path.basename(model_path))
elif os.path.isabs(model_path):
# 绝对路径
local_path = model_path
else:
# 尝试在加密目录中查找
model_filename = os.path.basename(model_path)
if not model_filename.endswith('.enc'):
model_filename += '.enc'
local_path = os.path.join(self.encrypted_models_dir, model_filename)
else:
# 普通模型从普通模型目录加载
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
# 检查模型文件是否存在
if not os.path.exists(local_path):
logger.error(f"模型文件不存在: {local_path}")
return None, {'success': False, 'error': f'模型文件不存在: {local_path}'}
# 检查缓存
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}" if encryption_key else local_path
with self.cache_lock:
if cache_key in self.model_cache:
logger.info(f"使用缓存的模型: {model_path}")
cached_info = self.model_cache[cache_key]
return cached_info['model'], cached_info.get('verification_result', {'success': True})
# 验证加密模型密钥(如果需要)
verification_result = None
model = None
if encrypted and encryption_key:
# 创建临时解密模型
try:
from mandatory_model_crypto import MandatoryModelValidator
validator = MandatoryModelValidator()
# 解密模型到内存
decrypt_result = validator.decrypt_and_verify(local_path, encryption_key)
if not decrypt_result['success']:
logger.error(f"解密模型失败: {model_path} - {decrypt_result.get('error', '未知错误')}")
return None, decrypt_result
verification_result = {
'success': True,
'model_hash': decrypt_result.get('model_hash', ''),
'original_size': decrypt_result.get('original_size', 0)
}
# 解密数据
decrypted_data = decrypt_result['decrypted_data']
# 保存到临时文件并加载
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
tmp.write(decrypted_data)
temp_path = tmp.name
# 加载YOLO模型
model = YOLO(temp_path)
# 清理临时文件
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"清理临时文件失败: {str(e)}")
logger.info(f"加密模型解密加载成功: {model_path}")
except ImportError:
logger.error("mandatory_model_crypto模块未找到无法处理加密模型")
return None, {'success': False, 'error': '加密模块未找到'}
except Exception as e:
logger.error(f"加密模型处理失败: {str(e)}")
return None, {'success': False, 'error': str(e)}
elif encrypted and not encryption_key:
# 加密模型但没有密钥
logger.error(f"加密模型但未提供密钥: {model_path}")
return None, {'success': False, 'error': '加密模型需要密钥'}
else:
# 普通模型加载
try:
model = YOLO(local_path)
logger.info(f"普通模型加载成功: {local_path}")
verification_result = {'success': True}
except Exception as e:
logger.error(f"加载普通模型失败: {str(e)}")
return None, {'success': False, 'error': str(e)}
if model is None:
return None, verification_result or {'success': False, 'error': '模型加载失败'}
# 应用设备配置
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 缓存模型
with self.cache_lock:
self.model_cache[cache_key] = {
'model': model,
'verification_result': verification_result,
'device': device,
'cached_at': time.time()
}
logger.info(f"模型加载成功: {model_path} -> {device}")
return model, verification_result or {'success': True}
except Exception as e:
logger.error(f"加载模型失败 {model_config.get('path')}: {str(e)}")
logger.error(traceback.format_exc())
return None, {'success': False, 'error': str(e)}
def clear_cache(self):
"""清空模型缓存"""
with self.cache_lock:
self.model_cache.clear()
logger.info("模型缓存已清空")
def get_cache_info(self):
"""获取缓存信息"""
with self.cache_lock:
return {
'cache_size': len(self.model_cache),
'cached_models': list(self.model_cache.keys()),
'total_size': sum(info.get('original_size', 0) for info in self.model_cache.values())
}
class DetectionThread(threading.Thread):
"""多模型检测线程 - 优化版本"""
def __init__(self, config):
super().__init__()
self.config = config
self.task_id = None
self.initialized = False
self.running = True
self._cleaning_up = False
self._force_stop = False
self._should_stop = threading.Event() # 添加停止事件
# 多模型支持
self.models = [] # 存储多个模型及相关配置
self.model_manager = ModelManager(config)
self.key_verification_results = {} # 密钥验证结果
# RTMP配置
self.cap = None
self.rtmp_url = config['rtmp']['url']
self.max_reconnect_attempts = config['rtmp']['max_reconnect_attempts']
self.reconnect_delay = config['rtmp']['reconnect_delay']
self.buffer_size = config['rtmp']['buffer_size']
self.timeout_ms = config['rtmp']['timeout_ms']
# 流稳定性监控
self.stream_stable = True
self.consecutive_read_failures = 0
self.max_consecutive_failures = 20 # 连续读取失败最大次数
self.retry_delay = 0.5 # 读取失败后的重试延迟(秒)
self.stream_check_interval = 10 # 流稳定性检查间隔(秒)
self.last_stream_check = time.time()
self.total_read_failures = 0
self.stream_recovery_attempts = 0
# 任务信息
self.taskname = config['task']['taskname']
self.taskid = config['task']['taskid']
self.aiid = config['task']['aiid']
# 推流管理
self.enable_push = config['push']['enable_push']
self.push_config = config['push']
self.task_streamer = None # 任务独立的推流器
self.streamer_initialized = False
self.last_push_time = 0
self.push_error_count = 0
self.max_push_errors = 5
# 性能监控
self.frame_count = 0
self.reconnect_attempts = 0
self.last_frame_time = time.time()
self.fps = 0
self.stop_event = threading.Event()
self.last_log_time = time.time()
self.daemon = True
self.target_latency = 0.05
self.processing_times = []
self.avg_process_time = 0.033
self.last_status_update = 0
self.last_fps = 0
self.original_width = 0
self.original_height = 0
# 上传配置
# self.minio_uploader = MinioUploader(config['minio'])
self.upload_queue = queue.Queue(maxsize=50)
self.upload_thread = None
self.upload_active = False
self.upload_interval = 2
self.last_upload_time = 0
self.res_api = config['task']['res_api']
# WebSocket发送队列和线程管理
self.websocket_queue = queue.Queue(maxsize=100)
self.websocket_thread = None
self.websocket_active = False
# MQTT配置
self.mqtt_config = config.get('mqtt', {})
self.mqtt_enabled = self.mqtt_config.get('enable', False)
self.mqtt_topic = self.mqtt_config.get('topic', 'drone/data')
self.mqtt_client = None
self.mqtt_connected = False
self.latest_drone_data = None
self.mqtt_data_lock = threading.Lock()
# Windows推流管理
if self._is_windows():
from windows_utils import detect_and_configure_windows
from task_stream_manager_windows import windows_task_stream_manager
# 检测和配置Windows
self.windows_config = detect_and_configure_windows()
self.stream_manager = windows_task_stream_manager
logger.info(f"Windows系统检测完成: {self.windows_config.get('status', 'unknown')}")
else:
# Linux/Mac使用原管理器
from task_stream_manager import task_stream_manager
self.stream_manager = task_stream_manager
# 推流状态跟踪
self.stream_stats = {
'total_frames_pushed': 0,
'failed_pushes': 0,
'last_push_time': 0,
'push_success_rate': 1.0,
'ffmpeg_restarts': 0
}
# 记录原始任务状态,以便恢复时知道应该恢复到什么状态
self.original_status = 'running'
# 当前任务状态跟踪
self._current_status = 'initializing'
# 密钥验证记录
self.key_verification_results = {}
def check_push_health(self):
"""检查推流健康状态,如果长时间失败则更新任务状态"""
# 检查推流失败次数和时间
current_time = time.time()
# 如果推流失败次数过多,或长时间没有成功推流,则认为推流健康状态不佳
if (self.push_error_count > self.max_push_errors or
(current_time - self.last_push_time > 60 and self.last_push_time > 0)): # 60秒内无成功推流
return False
return True
# 绘制结果
logger.info(f"检测线程初始化完成: {self.taskname}")
def _is_windows(self):
"""检查是否是Windows系统"""
import os
return os.name == 'nt' or os.name == 'win32'
def load_models(self):
"""加载多个模型 - 优化版本,从本地加载"""
try:
models_config = self.config.get('models', [])
if not models_config or not isinstance(models_config, list):
logger.error("未找到有效的models配置列表")
return False
logger.info(f"开始从本地加载 {len(models_config)} 个模型")
loaded_models = []
key_verification_results = {}
for i, model_config in enumerate(models_config):
# 检查模型是否启用
if not model_config.get('enabled', True):
logger.info(f"跳过未启用的模型 {i}")
continue
model_path = model_config.get('path', 'unknown')
model_name = os.path.basename(model_path).split('.')[0]
# 加载模型(从本地)
logger.info(f"加载模型 {i}: {model_name}")
model, verification_result = self.model_manager.load_model(
model_config,
require_verification=True # 总是验证密钥
)
# 记录验证结果
key_verification_results[i] = verification_result or {'success': False, 'error': '未知错误'}
if model is None:
logger.error(f"加载模型 {i} 失败: {model_name}")
continue
# 准备标签
tags = model_config.get('tags', {})
# 存储模型信息
model_info = {
'model': model,
'config': model_config,
'tags': tags,
'name': model_name,
'id': i,
'device': model_config.get('device', 'cpu'),
'imgsz': model_config.get('imgsz', 640),
'conf_thres': model_config.get('conf_thres', 0.25),
'iou_thres': model_config.get('iou_thres', 0.45),
'half': model_config.get('half', False),
'key_valid': verification_result.get('success', False) if verification_result else False,
'model_hash': verification_result.get('model_hash', 'unknown') if verification_result else 'unknown'
}
loaded_models.append(model_info)
logger.info(
f"模型加载成功: {model_name}, 设备: {model_info['device']}, 密钥验证: {model_info['key_valid']}")
# 检查加载结果
if len(loaded_models) == 0:
logger.error("所有模型加载失败")
# 生成详细的错误报告
error_report = []
for idx, result in key_verification_results.items():
if not result.get('success', False):
error_report.append(f"模型 {idx}: {result.get('error', '未知错误')}")
if error_report:
logger.error("模型加载失败详情:")
for error in error_report:
logger.error(f" {error}")
return False
self.models = loaded_models
self.key_verification_results = key_verification_results
logger.info(f"成功加载 {len(self.models)}/{len(models_config)} 个模型")
# 输出验证统计
valid_keys = sum(1 for result in key_verification_results.values()
if result.get('success', False))
logger.info(f"密钥验证统计: {valid_keys}个有效, {len(key_verification_results) - valid_keys}个无效")
return True
except Exception as e:
logger.error(f"加载模型异常: {str(e)}")
logger.error(traceback.format_exc())
return False
def get_key_verification_summary(self):
"""获取密钥验证摘要"""
summary = {
'total_models': len(self.key_verification_results),
'valid_keys': 0,
'invalid_keys': 0,
'loaded_models': len(self.models),
'details': {}
}
for idx, result in self.key_verification_results.items():
if result.get('success', False):
summary['valid_keys'] += 1
else:
summary['invalid_keys'] += 1
summary['details'][f'model_{idx}'] = {
'success': result.get('success', False),
'error': result.get('error', ''),
'model_hash': result.get('model_hash', '')
}
return summary
def initialize_rtmp(self):
"""初始化RTMP连接"""
try:
logger.info(f"连接RTMP: {self.rtmp_url}")
self.cap = cv2.VideoCapture()
# 设置优化参数
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, self.buffer_size)
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
self.cap.set(cv2.CAP_PROP_FPS, 30)
# 启用硬件加速解码
if self.config['rtmp'].get('gpu_decode', False):
try:
self.cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
logger.info("启用硬件加速解码")
except:
logger.warning("硬件解码不可用,使用软件解码")
# 尝试连接
if not self.cap.open(self.rtmp_url, cv2.CAP_FFMPEG):
logger.error(f"连接RTMP失败: {self.rtmp_url}")
raise IOError(f"无法连接RTMP流 ({self.rtmp_url})")
# 获取视频属性
self.original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
self.original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
self.fps = fps
# 重置流稳定性计数器
self.consecutive_read_failures = 0
self.total_read_failures = 0
self.stream_stable = True
logger.info(f"视频属性: {self.original_width}x{self.original_height} @ {fps}fps")
return True
except Exception as e:
logger.error(f"初始化RTMP失败: {str(e)}")
raise
def initialize_task_streamer(self):
"""初始化任务推流器Windows优化版"""
if not self.enable_push:
logger.info(f"任务 {self.task_id} 推流功能未启用")
return True
try:
logger.info(f"初始化任务推流器: {self.task_id}")
# Windows特定检查和配置
if self._is_windows():
# 检查RTMP服务器可达性
from windows_utils import WindowsSystemUtils
push_url = self.push_config.get('url', '')
if push_url:
accessibility = WindowsSystemUtils.check_rtmp_server_accessibility(push_url)
if not accessibility.get('accessible', False):
logger.error(f"RTMP服务器不可达: {accessibility.get('error', 'Unknown error')}")
logger.warning("推流可能失败,请检查网络和服务器状态")
# 使用推流管理器创建推流器
streamer = self.stream_manager.create_streamer_for_task(
self.task_id,
self.config, # 传递完整配置
self.fps,
self.original_width,
self.original_height
)
if streamer:
logger.info(f"任务 {self.task_id} 推流器初始化成功")
self.streamer_initialized = True
self.task_streamer = streamer
# 记录初始化时间
self.stream_stats['streamer_start_time'] = time.time()
return True
else:
logger.error(f"任务 {self.task_id} 推流器初始化失败")
# Windows上尝试备用配置
if self._is_windows():
logger.info("尝试Windows备用推流配置...")
return self._initialize_windows_fallback_streamer()
return False
except Exception as e:
logger.error(f"初始化任务推流器异常 {self.task_id}: {str(e)}", exc_info=True)
# Windows上的额外错误处理
if self._is_windows():
return self._initialize_windows_fallback_streamer()
return False
def _initialize_windows_fallback_streamer(self):
"""Windows备用推流器初始化"""
try:
logger.info(f"使用Windows备用推流配置: {self.task_id}")
# 简化配置,避免硬件加速
fallback_config = self.config.copy()
fallback_config['push'] = {
'enable_push': True,
'url': self.push_config['url'],
'video_codec': 'libx264',
'preset': 'ultrafast',
'tune': 'zerolatency',
'format': 'flv',
'pixel_format': 'bgr24',
'gpu_acceleration': False,
'bitrate': '1000k',
'bufsize': '2000k',
'framerate': self.fps,
'extra_args': [
'-max_delay', '0',
'-flags', '+global_header',
'-rtbufsize', '50M'
]
}
# 使用简化配置重新初始化
streamer = self.stream_manager.create_streamer_for_task(
self.task_id,
fallback_config,
self.fps,
self.original_width,
self.original_height
)
if streamer:
logger.info(f"Windows备用推流器初始化成功: {self.task_id}")
self.streamer_initialized = True
self.task_streamer = streamer
return True
else:
logger.error(f"Windows备用推流器也失败: {self.task_id}")
return False
except Exception as e:
logger.error(f"Windows备用推流器初始化异常: {str(e)}")
return False
def push_frame_to_task_streamer(self, frame):
"""推送帧到任务推流器Windows优化版"""
if not self.enable_push or not self.streamer_initialized:
return False
try:
# 使用推流管理器推送帧
success = self.stream_manager.push_frame(self.task_id, frame)
# 更新统计信息
self.stream_stats['total_frames_pushed'] += 1
if success:
self.last_push_time = time.time()
self.stream_stats['last_push_time'] = time.time()
self.push_error_count = 0
# 更新成功率
total = self.stream_stats['total_frames_pushed']
failed = self.stream_stats['failed_pushes']
self.stream_stats['push_success_rate'] = (total - failed) / max(total, 1)
# 每100帧记录一次成功
if total % 100 == 0:
logger.info(
f"任务 {self.task_id} 已成功推流 {total} 帧,成功率: {self.stream_stats['push_success_rate']:.2%}")
else:
self.push_error_count += 1
self.stream_stats['failed_pushes'] += 1
logger.warning(f"任务 {self.task_id} 推流失败 ({self.push_error_count}/{self.max_push_errors})")
# Windows上的额外诊断
if self._is_windows() and self.push_error_count >= 3:
self._diagnose_windows_streaming_issue()
# 连续失败处理
if self.push_error_count >= self.max_push_errors:
error_message = f"任务 {self.task_id} 推流连续失败,尝试恢复"
logger.error(error_message)
# 发送错误消息到WebSocket
self.send_error_to_websocket('push_error', error_message)
# 将任务状态更新为降级状态
self.update_task_status('degraded')
self.recover_task_streamer()
return success
except Exception as e:
logger.error(f"推流异常 {self.task_id}: {str(e)}")
self.push_error_count += 1
self.stream_stats['failed_pushes'] += 1
return False
def _diagnose_windows_streaming_issue(self):
"""诊断Windows推流问题"""
try:
logger.info("诊断Windows推流问题...")
# 获取推流信息
stream_info = self.stream_manager.get_task_stream_info(self.task_id)
if stream_info:
logger.info(f"推流状态: {stream_info.get('status', 'unknown')}")
logger.info(f"FFmpeg进程: {'运行中' if stream_info.get('process_alive') else '已停止'}")
logger.info(f"最近输出: {stream_info.get('last_ffmpeg_output', '')}")
# 显示最近错误
output_lines = stream_info.get('output_lines', [])
if output_lines:
logger.info("最近FFmpeg输出:")
for line in output_lines[-5:]:
logger.info(f" {line}")
# 检查系统资源
from windows_utils import WindowsSystemUtils
resources = WindowsSystemUtils.get_system_resources()
if resources['cpu_percent'] > 90:
logger.warning(f"CPU使用率过高: {resources['cpu_percent']}%")
if resources['memory_percent'] > 90:
logger.warning(f"内存使用率过高: {resources['memory_percent']}%")
except Exception as e:
logger.error(f"诊断推流问题异常: {str(e)}")
def recover_task_streamer(self):
"""恢复任务推流器Windows优化版"""
try:
logger.info(f"任务 {self.task_id} 尝试恢复推流器")
# Windows上增加重启计数
self.stream_stats['ffmpeg_restarts'] += 1
# 清理并重新初始化
self.cleanup_task_streamer()
time.sleep(2) # Windows上需要更长的等待时间
success = self.initialize_task_streamer()
if success:
self.push_error_count = 0
# 恢复正常状态
if hasattr(self, 'original_status') and self.original_status == 'running':
self.update_task_status('running')
logger.info(f"任务 {self.task_id} 推流器恢复成功 (第{self.stream_stats['ffmpeg_restarts']}次重启)")
else:
logger.error(f"任务 {self.task_id} 推流器恢复失败")
# 如果恢复失败,且已达到最大重启次数,将任务状态设置为错误
if self.stream_stats['ffmpeg_restarts'] >= 3: # 假设最大重启3次
error_message = f"任务 {self.task_id} 推流恢复失败,已达到最大重启次数"
logger.error(error_message)
# 发送错误消息到WebSocket
self.send_error_to_websocket('push_error', error_message)
# 将任务状态更新为错误
self.update_task_status('error')
# 同时停止整个检测线程
self._force_stop = True
self._should_stop.set()
self.stop_event.set()
self.running = False
return success
except Exception as e:
logger.error(f"恢复推流器异常 {self.task_id}: {str(e)}")
return False
def cleanup_task_streamer(self):
"""清理任务推流器"""
if self.task_streamer:
try:
self.task_streamer.stop()
self.stream_manager.stop_task_streamer(self.task_id)
self.task_streamer = None
logger.info(f"任务 {self.task_id} 推流器已清理")
except Exception as e:
logger.error(f"停止推流器失败: {str(e)}")
self.streamer_initialized = False
def warmup_models(self):
"""预热所有模型"""
logger.info("预热所有模型...")
for model_info in self.models:
model = model_info['model']
model_config = model_info['config']
try:
# 准备预热输入
imgsz = model_config.get('imgsz', 640)
dummy_input = torch.zeros(1, 3, imgsz, imgsz)
device = model_config.get('device', 'cpu')
dummy_input = dummy_input.to(device)
if model_config.get('half', False) and 'cuda' in device:
dummy_input = dummy_input.half()
# 预热推理
with torch.no_grad():
for _ in range(2):
model.predict(dummy_input)
logger.info(f"模型 {model_info['name']} 预热完成")
except Exception as e:
logger.warning(f"模型 {model_info['name']} 预热失败: {str(e)}")
logger.info("所有模型预热完成")
def _multi_model_inference(self, frame):
"""多模型推理(每个模型独立标签和置信度)"""
frame_drawn, detections = multi_model_inference(self.models, frame)
return frame_drawn, detections
def should_skip_frame(self, start_time):
"""判断是否应该跳过当前帧"""
processing_time = time.perf_counter() - start_time
# 基于处理时间判断
if processing_time > self.target_latency:
return True
# 基于FPS判断
min_fps = 15
if self.fps < min_fps:
return True
return False
def handle_upload(self, annotated_frame, all_detections, current_time):
"""处理上传逻辑"""
# 合并所有检测结果
all_detections_list = []
for model_detections in all_detections.values():
all_detections_list.extend(model_detections)
if len(all_detections_list) > 0 and (current_time - self.last_upload_time >= self.upload_interval):
try:
timestamp = int(current_time * 1000)
filename = f"DJI_{timestamp}.jpg"
self.upload_queue.put({
"image": annotated_frame.copy(),
"filename": filename,
"detection_data": all_detections_list,
"timestamp": current_time
}, block=False)
self.last_upload_time = current_time
except queue.Full:
logger.warning("上传队列已满,跳过上传")
except Exception as e:
logger.error(f"添加上传任务失败: {e}")
def _upload_worker(self):
"""独立的图片上传工作线程"""
logger.info("上传工作线程启动")
output_dir = "output_frames"
os.makedirs(output_dir, exist_ok=True)
while self.upload_active or not self.upload_queue.empty():
try:
task = self.upload_queue.get(timeout=1.0)
if task is None:
break
start_time = time.time()
image = task["image"]
filename = task["filename"]
detection_data = task["detection_data"]
filepath = os.path.join(output_dir, filename)
# 优化图片保存质量
cv2.imwrite(filepath, image, [cv2.IMWRITE_JPEG_QUALITY, 85])
try:
foldername = self.config['task']['taskname']
object_path = f"{foldername}/{filename}"
# self.minio_uploader.upload_file(filepath, object_path)
payload = {
"taskid": self.taskid,
"path": object_path,
"tag": detection_data,
"aiid": self.aiid,
}
# 添加MQTT数据
if self.mqtt_enabled and self.mqtt_connected:
with self.mqtt_data_lock:
if self.latest_drone_data:
payload["drone_info"] = self.latest_drone_data
# 发送MQTT消息
if self.mqtt_client and self.mqtt_connected:
try:
# 构造MQTT主题
mqtt_topic = f'ai/task/{self.taskid}/aiachievement'
# 发布消息
result = self.mqtt_client.publish(mqtt_topic, json.dumps(payload, ensure_ascii=False))
if result.rc == mqtt.MQTT_ERR_SUCCESS:
logger.debug(f'已上传帧至 MinIO: {object_path} | MQTT消息已发送到主题: {mqtt_topic} | 耗时: {time.time() - start_time:.2f}s')
else:
logger.warning(f'MQTT消息发送失败: {mqtt.error_string(result.rc)}')
except Exception as e:
logger.error(f'MQTT消息发送异常: {str(e)}')
else:
logger.warning('MQTT客户端未连接跳过消息发送')
except requests.exceptions.Timeout:
logger.warning(f"API调用超时: {self.res_api}")
except Exception as e:
logger.error(f"上传/API调用失败: {e}")
finally:
try:
os.remove(filepath)
except:
pass
# 标记任务完成
self.upload_queue.task_done()
except queue.Empty:
continue
except Exception as e:
logger.error(f"上传任务处理异常: {e}")
logger.info("上传工作线程已停止")
def on_mqtt_connect(self, client, userdata, flags, rc):
"""MQTT连接回调"""
if rc == 0:
client.subscribe(self.mqtt_topic)
self.mqtt_connected = True
# logger.debug("MQTT连接状态正常")
else:
logger.error(f"MQTT连接失败错误码: {rc}")
self.mqtt_connected = False
def on_mqtt_message(self, client, userdata, msg):
"""MQTT消息回调"""
try:
drone_data = json.loads(msg.payload.decode())
with self.mqtt_data_lock:
self.latest_drone_data = drone_data
logger.debug(f"收到MQTT消息: {drone_data}")
except Exception as e:
logger.error(f"解析MQTT消息失败: {str(e)}")
def start_mqtt_client(self):
"""启动MQTT客户端"""
if not self.mqtt_enabled:
logger.info("MQTT功能未启用")
return False
try:
logger.info("启动MQTT客户端...")
self.mqtt_client = mqtt.Client(client_id=self.mqtt_config.get('client_id', 'yolo_detection'))
self.mqtt_client.on_connect = self.on_mqtt_connect
self.mqtt_client.on_message = self.on_mqtt_message
if 'username' in self.mqtt_config and 'password' in self.mqtt_config:
self.mqtt_client.username_pw_set(
self.mqtt_config['username'],
self.mqtt_config['password']
)
self.mqtt_client.connect(
self.mqtt_config['broker'],
self.mqtt_config.get('port', 1883),
self.mqtt_config.get('keepalive', 60)
)
self.mqtt_client.loop_start()
logger.info("MQTT客户端已启动")
return True
except Exception as e:
logger.error(f"启动MQTT客户端失败: {str(e)}")
return False
def _websocket_worker(self):
"""独立的WebSocket发送工作线程"""
logger.info("WebSocket发送工作线程启动")
while self.websocket_active or not self.websocket_queue.empty():
try:
task = self.websocket_queue.get(timeout=1.0)
if task is None:
break
try:
# 执行WebSocket发送
now = datetime.datetime.now()
time_str = now.strftime("%H:%M:%S")
model_detections = task
# 合并所有模型的检测结果
all_detections_send = []
for det in model_detections:
det_detection = {
'count': len(det['boxes']),
'detections': {
'class_id': det['class_ids'],
'class_name': det['class_names'],
'box': det['boxes'],
'conf': det['confidences'],
},
}
all_detections_send.append(det_detection)
# 添加流稳定性信息
stream_info = {
'stable': self.stream_stable,
'consecutive_failures': self.consecutive_read_failures,
'total_failures': self.total_read_failures,
'recovery_attempts': self.stream_recovery_attempts
}
self.config['socketIO'].emit('detection_results', {
'task_id': getattr(self, 'task_id', 'unknown'),
'detections': all_detections_send,
'timestamp': time.time_ns() // 1000000,
'fps': round(self.last_fps, 1),
'frame_count': self.frame_count,
'taskname': self.taskname,
'time_str': time_str,
'models_count': len(self.models),
'stream_info': stream_info
})
except Exception as e:
logger.error(f"WebSocket发送错误: {str(e)}")
finally:
# 标记任务完成
self.websocket_queue.task_done()
except queue.Empty:
continue
except Exception as e:
logger.error(f"WebSocket工作线程异常: {e}")
logger.info("WebSocket发送工作线程已停止")
def stop_mqtt_client(self):
"""停止MQTT客户端"""
if self.mqtt_client:
try:
self.mqtt_client.loop_stop()
self.mqtt_client.disconnect()
logger.info("MQTT客户端已停止")
except Exception as e:
logger.error(f"停止MQTT客户端失败: {str(e)}")
finally:
self.mqtt_client = None
self.mqtt_connected = False
def handle_reconnect(self):
"""处理RTMP重连"""
if self.stop_event.is_set() or not self.running:
logger.info("收到停止信号,跳过重连")
return
self.reconnect_attempts += 1
if self.reconnect_attempts >= self.max_reconnect_attempts:
logger.error("达到最大重连次数")
self.running = False
return
# 使用指数退避策略计算延迟时间
delay = min(30, self.reconnect_attempts * self.reconnect_delay * 2) # 最大延迟30秒
logger.warning(f"流中断,{delay}秒后重连 (第{self.reconnect_attempts}/{self.max_reconnect_attempts}次重连)")
# 释放视频资源
if self.cap:
try:
self.cap.release()
self.cap = None
except:
pass
# 在延迟期间检查停止信号
start_time = time.time()
while time.time() - start_time < delay:
if self.stop_event.is_set() or not self.running:
logger.info("收到停止信号,取消重连")
return
time.sleep(0.5) # 增加休眠时间减少CPU使用
# 重置流稳定性计数器
self.consecutive_read_failures = 0
self.stream_stable = True
self.stream_recovery_attempts += 1
# 重新连接
try:
if self.stop_event.is_set() or not self.running:
logger.info("收到停止信号,跳过重新连接")
return
logger.info("尝试重新连接RTMP...")
if not self.initialize_rtmp():
raise IOError("RTMP重连失败")
logger.info("RTMP重连成功")
self.reconnect_attempts = 0 # 重置重连次数
self.stream_recovery_attempts = 0 # 重置恢复尝试次数
except Exception as e:
logger.error(f"重连异常: {str(e)}")
def send_to_websocket(self, all_detections):
"""发送检测结果到WebSocket队列"""
try:
# 将检测结果放入队列,由单独线程处理发送
self.websocket_queue.put(all_detections, block=False)
except queue.Full:
logger.warning("WebSocket队列已满跳过发送")
except Exception as e:
logger.error(f"WebSocket队列操作错误: {str(e)}")
def send_error_to_websocket(self, error_type, error_message):
"""发送错误消息到WebSocket"""
try:
now = datetime.datetime.now()
time_str = now.strftime("%H:%M:%S")
error_data = {
'task_id': getattr(self, 'task_id', 'unknown'),
'error_type': error_type,
'error_message': error_message,
'timestamp': time.time_ns() // 1000000,
'time_str': time_str,
'status': self._current_status if hasattr(self, '_current_status') else 'unknown'
}
if hasattr(self.config, 'socketIO') or 'socketIO' in self.config:
socketIO = self.config.get('socketIO') or getattr(self.config, 'socketIO')
if socketIO:
socketIO.emit('task_error', error_data)
logger.info(f"已发送错误消息到WebSocket: {error_type} - {error_message}")
except Exception as e:
logger.error(f"发送错误消息到WebSocket失败: {str(e)}")
def send_log_to_websocket(self, log_level, log_message):
"""发送日志消息到WebSocket"""
try:
now = datetime.datetime.now()
time_str = now.strftime("%H:%M:%S")
log_data = {
'task_id': getattr(self, 'task_id', 'unknown'),
'log_level': log_level,
'log_message': log_message,
'timestamp': time.time_ns() // 1000000,
'time_str': time_str,
'status': self._current_status if hasattr(self, '_current_status') else 'unknown'
}
if hasattr(self.config, 'socketIO') or 'socketIO' in self.config:
socketIO = self.config.get('socketIO') or getattr(self.config, 'socketIO')
if socketIO:
socketIO.emit('task_log', log_data)
except Exception as e:
logger.error(f"发送日志消息到WebSocket失败: {str(e)}")
def check_stream_stability(self):
"""检查流稳定性"""
current_time = time.time()
if current_time - self.last_stream_check >= self.stream_check_interval:
if self.consecutive_read_failures > 0:
logger.warning(f"流稳定性警告: 连续读取失败 {self.consecutive_read_failures}")
self.stream_stable = False
else:
self.stream_stable = True
self.last_stream_check = current_time
return True
return False
def handle_frame_read_failure(self):
"""处理帧读取失败"""
self.consecutive_read_failures += 1
self.total_read_failures += 1
# 检查流稳定性
self.check_stream_stability()
# 根据连续失败次数采取不同策略
if self.consecutive_read_failures == 1:
# 第一次失败,短暂等待后重试
logger.debug("帧读取失败等待0.5秒后重试")
time.sleep(self.retry_delay)
return False # 不立即重连
elif self.consecutive_read_failures <= 3:
# 2-3次失败中等等待后重试
logger.warning(f"连续 {self.consecutive_read_failures} 次帧读取失败等待1秒后重试")
time.sleep(1)
return False # 不立即重连
elif self.consecutive_read_failures <= self.max_consecutive_failures:
# 4-5次失败较长时间等待后重试
logger.error(f"连续 {self.consecutive_read_failures} 次帧读取失败等待2秒后重试")
time.sleep(2)
return False # 不立即重连
else:
# 超过最大失败次数,需要重连
error_message = f"连续 {self.consecutive_read_failures} 次帧读取失败,触发重连"
logger.error(error_message)
# 发送错误消息到WebSocket
self.send_error_to_websocket('stream_error', error_message)
# 更新任务状态为degraded
self.update_task_status('degraded')
return True # 需要重连
def run(self):
"""检测线程主循环"""
try:
logger.info(f"启动检测线程任务ID: {getattr(self, 'task_id', 'unknown')}")
# Windows系统检查
if self._is_windows() and hasattr(self, 'windows_config'):
if self.windows_config and 'error' in self.windows_config:
logger.warning(f"Windows配置问题: {self.windows_config['error']}")
# 1. 加载多个模型
log_message = "开始加载模型..."
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
if not self.load_models():
error_message = "加密模型加载失败,线程退出"
logger.error(error_message)
self.send_error_to_websocket('model_error', error_message)
self.update_task_status('failed')
return
key_summary = self.get_key_verification_summary()
log_message = f"加密密钥验证结果: {key_summary['loaded_models']}/{key_summary['total_models']} 个模型加载成功"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 2. 初始化RTMP连接
log_message = "开始连接视频流..."
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
if not self.initialize_rtmp():
error_message = "RTMP连接失败线程退出"
logger.error(error_message)
self.send_error_to_websocket('stream_error', error_message)
self.update_task_status('failed')
return
log_message = f"视频流连接成功: {self.original_width}x{self.original_height} @ {self.fps}fps"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 3. 初始化任务推流器
if self.enable_push:
log_message = f"开始初始化任务推流器: {self.task_id}"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
if not self.initialize_task_streamer():
error_message = f"任务 {self.task_id} 推流器初始化失败,继续运行但不推流"
logger.warning(error_message)
self.send_error_to_websocket('push_error', error_message)
self.enable_push = False
else:
log_message = f"任务 {self.task_id} 推流器初始化成功"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 4. 启动推流管理器健康监控
if self.enable_push and hasattr(self.stream_manager, 'start_health_monitor'):
try:
self.stream_manager.start_health_monitor()
log_message = "推流健康监控已启动"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
except Exception as e:
error_message = f"启动健康监控失败: {e}"
logger.warning(error_message)
self.send_error_to_websocket('push_error', error_message)
# 5. 启动MQTT如果启用
# if self.mqtt_enabled:
# log_message = "开始启动MQTT客户端..."
# logger.info(log_message)
# self.send_log_to_websocket('info', log_message)
#
# if self.start_mqtt_client():
# log_message = "MQTT客户端已启动"
# logger.info(log_message)
# self.send_log_to_websocket('info', log_message)
# else:
# error_message = "启动MQTT客户端失败"
# logger.error(error_message)
# self.send_error_to_websocket('mqtt_error', error_message)
# 6. 启动上传线程
log_message = "开始启动图片上传线程..."
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
self.upload_active = True
self.upload_thread = threading.Thread(target=self._upload_worker, daemon=True)
self.upload_thread.start()
log_message = "图片上传线程已启动"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 7. 启动WebSocket发送线程
log_message = "开始启动WebSocket发送线程..."
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
self.websocket_active = True
self.websocket_thread = threading.Thread(target=self._websocket_worker, daemon=True)
self.websocket_thread.start()
log_message = "WebSocket发送线程已启动"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 8. 预热所有模型
log_message = "开始预热所有模型..."
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
self.warmup_models()
log_message = "所有模型预热完成"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 9. 更新任务状态
self.initialized = True
self.update_task_status('running')
self.original_status = 'running' # 记录原始状态为running
log_message = "资源初始化完成,任务开始运行"
logger.info(log_message)
self.send_log_to_websocket('info', log_message)
# 10. 主循环
while self.running and not self.stop_event.is_set():
# 检查停止标志
if self._force_stop or self._should_stop.is_set():
logger.info("收到停止信号,退出循环")
break
if self.stop_event.is_set():
logger.info("收到停止事件信号,退出循环")
break
start_time = time.perf_counter()
try:
# 读取帧
ret, frame = self.cap.read()
if not ret:
if self.stop_event.is_set() or not self.running or self._force_stop:
logger.info("收到停止信号,不再尝试重连")
break
# 处理帧读取失败(带有重试和延迟策略)
need_reconnect = self.handle_frame_read_failure()
if need_reconnect:
# 需要重连
self.handle_reconnect()
continue
# 成功读取帧,重置失败计数器
if self.consecutive_read_failures > 0:
logger.info(f"恢复帧读取成功,之前连续失败 {self.consecutive_read_failures}")
self.consecutive_read_failures = 0
self.stream_stable = True
# 确保帧分辨率一致
if frame.shape[1] != self.original_width or frame.shape[0] != self.original_height:
frame = cv2.resize(frame, (self.original_width, self.original_height))
# 计算FPS
current_time = time.time()
time_diff = current_time - self.last_frame_time
if time_diff > 0:
self.fps = 0.9 * self.fps + 0.1 / time_diff
if current_time - self.last_status_update > 0.5:
self.last_fps = self.fps
self.last_status_update = current_time
self.last_frame_time = current_time
# 动态跳帧Windows上更保守
if self.should_skip_frame(start_time):
self.frame_count += 1
continue
# 多模型推理
start = time.time()
annotated_frame, model_detections = self._multi_model_inference(frame)
# annotated_frame = frame
# model_detections = []
# logger.info(f'startTime:{start},endTime:{time.time()},时间差:{time.time() - start}')
# 推流处理Windows优化
if self.enable_push:
if not self.push_frame_to_task_streamer(annotated_frame):
# 推流失败,但继续处理,避免影响检测
pass
if current_time - self.last_log_time >= 1:
# # WebSocket发送
self.send_to_websocket(model_detections)
# # 上传处理
# self.handle_upload(annotated_frame, model_detections, current_time)
# 检查推流健康状态
if self.enable_push:
if not self.check_push_health():
# 如果推流不健康,且当前不是错误状态,则更新为降级状态
current_status = getattr(self, '_current_status', 'running')
if current_status != 'error' and current_status != 'degraded':
# self.update_task_status('degraded')
pass
self.last_log_time = current_time
self.frame_count += 1
self.reconnect_attempts = 0
# 性能监控
elapsed = time.perf_counter() - start_time
self.processing_times.append(elapsed)
self.avg_process_time = self.avg_process_time * 0.9 + elapsed * 0.1
# 每50帧输出一次性能报告包含推流统计
if self.frame_count % 50 == 0:
avg_time = sum(self.processing_times) / len(
self.processing_times) if self.processing_times else 0
logger.info(
f"帧处理耗时: {avg_time * 1000:.2f}ms | 平均FPS: {(1.0 / avg_time if avg_time > 0 else 0):.2f}")
# 流稳定性报告
if self.total_read_failures > 0:
logger.info(
f"流稳定性: 总失败次数 {self.total_read_failures}, 当前连续失败 {self.consecutive_read_failures}")
# 推流统计
if self.enable_push:
logger.info(f"推流统计: {self.stream_stats['total_frames_pushed']}帧, "
f"成功率: {self.stream_stats['push_success_rate']:.2%}, "
f"重启: {self.stream_stats.get('ffmpeg_restarts', 0)}")
self.processing_times = []
except Exception as e:
logger.error(f"处理帧时发生异常: {str(e)}")
if self.stop_event.is_set() or not self.running or self._force_stop:
logger.info("收到停止信号,退出异常处理")
break
continue
logger.info("检测线程主循环结束")
except Exception as e:
error_message = f"检测线程异常: {str(e)}"
logger.error(error_message)
logger.error(traceback.format_exc())
# 发送错误消息到WebSocket
self.send_error_to_websocket('thread_error', error_message)
# 更新任务状态为错误
self.update_task_status('error')
finally:
self.cleanup()
logger.info("检测线程已安全停止")
self.update_task_status('stopped')
return False
def stop(self):
"""停止检测线程"""
logger.info(f"收到停止请求任务ID: {getattr(self, 'task_id', 'unknown')}")
# 首先设置强制停止标志
self._force_stop = True
self._should_stop.set() # 设置停止事件
# 发送停止事件
self.stop_event.set()
# 设置运行标志为False
self.running = False
# 强制释放视频流以解除阻塞
if self.cap:
try:
# 强制中断视频读取
self.cap.release()
logger.info("已强制释放视频流")
except Exception as e:
logger.error(f"释放视频流失败: {str(e)}")
# 清理任务推流器
self.cleanup_task_streamer()
# 停止MQTT客户端
if self.mqtt_enabled:
self.stop_mqtt_client()
# 停止上传线程
self.upload_active = False
if hasattr(self, 'upload_thread') and self.upload_thread and self.upload_thread.is_alive():
# 发送停止信号到上传队列
try:
self.upload_queue.put(None, timeout=0.5)
except:
pass
logger.info(f"停止信号已发送任务ID: {getattr(self, 'task_id', 'unknown')}")
def cleanup(self):
"""清理所有资源"""
logger.info(f"开始清理资源任务ID: {getattr(self, 'task_id', 'unknown')}")
if hasattr(self, '_cleaning_up') and self._cleaning_up:
logger.warning("资源已经在清理中,跳过重复清理")
return
self._cleaning_up = True
try:
# 停止MQTT客户端
if self.mqtt_enabled:
self.stop_mqtt_client()
# 停止上传线程
self.upload_active = False
if hasattr(self, 'upload_thread') and self.upload_thread and self.upload_thread.is_alive():
logger.info("停止上传线程...")
try:
if hasattr(self, 'upload_queue'):
try:
self.upload_queue.put(None, timeout=0.5)
except queue.Full:
try:
while not self.upload_queue.empty():
self.upload_queue.get_nowait()
except:
pass
self.upload_queue.put(None, block=False)
except:
pass
self.upload_thread.join(3.0)
if self.upload_thread.is_alive():
logger.warning("上传线程未在3秒内停止")
except Exception as e:
logger.error(f"停止上传线程异常: {str(e)}")
# 停止WebSocket发送线程
self.websocket_active = False
if hasattr(self, 'websocket_thread') and self.websocket_thread and self.websocket_thread.is_alive():
logger.info("停止WebSocket发送线程...")
try:
if hasattr(self, 'websocket_queue'):
try:
self.websocket_queue.put(None, timeout=0.5)
except queue.Full:
try:
while not self.websocket_queue.empty():
self.websocket_queue.get_nowait()
except:
pass
self.websocket_queue.put(None, block=False)
except:
pass
self.websocket_thread.join(3.0)
if self.websocket_thread.is_alive():
logger.warning("WebSocket发送线程未在3秒内停止")
except Exception as e:
logger.error(f"停止WebSocket发送线程异常: {str(e)}")
finally:
self.websocket_thread = None
# 清理任务推流器
self.cleanup_task_streamer()
# 释放视频流
if hasattr(self, 'cap') and self.cap:
logger.info("释放视频流...")
try:
self.cap.release()
except Exception as e:
logger.error(f"释放视频流异常: {str(e)}")
finally:
self.cap = None
# 释放所有模型
logger.info(f"释放所有模型,共 {len(self.models)}")
for i, model_info in enumerate(self.models):
try:
model = model_info['model']
model_name = model_info['name']
# 清理模型缓存
if hasattr(model, 'predictor'):
try:
del model.predictor
except:
pass
if hasattr(model, 'model'):
try:
del model.model
except:
pass
# 释放模型引用
del model
logger.info(f"模型 {model_name} 已释放")
except Exception as e:
logger.error(f"释放模型 {i} 异常: {str(e)}")
self.models = []
# 清理GPU缓存
logger.info("清理GPU缓存...")
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("GPU缓存已清理")
except Exception as e:
logger.error(f"清理GPU缓存异常: {str(e)}")
# 清理其他资源
logger.info("清理其他资源...")
try:
if hasattr(self, 'upload_queue'):
try:
while not self.upload_queue.empty():
try:
self.upload_queue.get_nowait()
except:
break
except:
pass
self.processing_times = []
logger.info("其他资源已清理")
except Exception as e:
logger.error(f"清理其他资源异常: {str(e)}")
logger.info("资源清理完成")
except Exception as e:
logger.error(f"资源清理过程中发生异常: {str(e)}")
logger.error(traceback.format_exc())
finally:
self._cleaning_up = False # 重置清理标志
def update_task_status(self, status):
"""更新任务状态"""
if hasattr(self, 'task_id') and self.task_id:
try:
from task_manager import task_manager
if hasattr(task_manager, 'update_task_status'):
task_manager.update_task_status(self.task_id, status)
else:
# 备用方法:直接更新全局数据
tasks_dict = gd.get_or_create_dict('tasks')
if self.task_id in tasks_dict:
tasks_dict[self.task_id]['status'] = status
logger.debug(f"更新任务 {self.task_id} 状态为: {status}")
# 更新本地状态跟踪
self._current_status = status
except Exception as e:
logger.warning(f"更新任务状态失败: {str(e)}")