1618 lines
66 KiB
Python
1618 lines
66 KiB
Python
import datetime
|
||
import gc
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import os
|
||
import queue
|
||
import tempfile
|
||
import threading
|
||
import time
|
||
import traceback
|
||
|
||
import cv2
|
||
import paho.mqtt.client as mqtt
|
||
import requests
|
||
import torch
|
||
from ultralytics import YOLO
|
||
|
||
from _minio import MinioUploader
|
||
from log import logger
|
||
from global_data import gd
|
||
from detection_render import multi_model_inference
|
||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||
|
||
|
||
# detectionThread.py - 修改 ModelManager 类
|
||
|
||
class ModelManager:
|
||
"""模型管理器,支持多模型和加密模型"""
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
self.models_dir = "models"
|
||
self.encrypted_models_dir = config.get('upload', {}).get('encrypted_models_dir', 'encrypted_models')
|
||
|
||
# 确保目录存在
|
||
os.makedirs(self.models_dir, exist_ok=True)
|
||
os.makedirs(self.encrypted_models_dir, exist_ok=True)
|
||
|
||
# 模型加载缓存(避免重复解密)
|
||
self.model_cache = {}
|
||
self.cache_lock = threading.Lock()
|
||
|
||
def load_model(self, model_config, require_verification=False):
|
||
"""加载单个模型 - 从本地加载加密模型"""
|
||
try:
|
||
model_path = model_config['path']
|
||
encrypted = model_config.get('encrypted', False)
|
||
encryption_key = model_config.get('encryption_key')
|
||
|
||
# 构建本地路径
|
||
if encrypted:
|
||
# 加密模型从加密模型目录加载
|
||
if model_path.startswith('encrypted_models/'):
|
||
# 相对路径
|
||
local_path = os.path.join(self.encrypted_models_dir, os.path.basename(model_path))
|
||
elif os.path.isabs(model_path):
|
||
# 绝对路径
|
||
local_path = model_path
|
||
else:
|
||
# 尝试在加密目录中查找
|
||
model_filename = os.path.basename(model_path)
|
||
if not model_filename.endswith('.enc'):
|
||
model_filename += '.enc'
|
||
local_path = os.path.join(self.encrypted_models_dir, model_filename)
|
||
else:
|
||
# 普通模型从普通模型目录加载
|
||
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
||
|
||
# 检查模型文件是否存在
|
||
if not os.path.exists(local_path):
|
||
logger.error(f"模型文件不存在: {local_path}")
|
||
return None, {'success': False, 'error': f'模型文件不存在: {local_path}'}
|
||
|
||
# 检查缓存
|
||
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}" if encryption_key else local_path
|
||
|
||
with self.cache_lock:
|
||
if cache_key in self.model_cache:
|
||
logger.info(f"使用缓存的模型: {model_path}")
|
||
cached_info = self.model_cache[cache_key]
|
||
return cached_info['model'], cached_info.get('verification_result', {'success': True})
|
||
|
||
# 验证加密模型密钥(如果需要)
|
||
verification_result = None
|
||
model = None
|
||
|
||
if encrypted and encryption_key:
|
||
# 创建临时解密模型
|
||
try:
|
||
from mandatory_model_crypto import MandatoryModelValidator
|
||
validator = MandatoryModelValidator()
|
||
|
||
# 解密模型到内存
|
||
decrypt_result = validator.decrypt_and_verify(local_path, encryption_key)
|
||
|
||
if not decrypt_result['success']:
|
||
logger.error(f"解密模型失败: {model_path} - {decrypt_result.get('error', '未知错误')}")
|
||
return None, decrypt_result
|
||
|
||
verification_result = {
|
||
'success': True,
|
||
'model_hash': decrypt_result.get('model_hash', ''),
|
||
'original_size': decrypt_result.get('original_size', 0)
|
||
}
|
||
|
||
# 解密数据
|
||
decrypted_data = decrypt_result['decrypted_data']
|
||
|
||
# 保存到临时文件并加载
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
||
tmp.write(decrypted_data)
|
||
temp_path = tmp.name
|
||
|
||
# 加载YOLO模型
|
||
model = YOLO(temp_path)
|
||
|
||
# 清理临时文件
|
||
try:
|
||
os.unlink(temp_path)
|
||
except Exception as e:
|
||
logger.warning(f"清理临时文件失败: {str(e)}")
|
||
|
||
logger.info(f"加密模型解密加载成功: {model_path}")
|
||
|
||
except ImportError:
|
||
logger.error("mandatory_model_crypto模块未找到,无法处理加密模型")
|
||
return None, {'success': False, 'error': '加密模块未找到'}
|
||
except Exception as e:
|
||
logger.error(f"加密模型处理失败: {str(e)}")
|
||
return None, {'success': False, 'error': str(e)}
|
||
elif encrypted and not encryption_key:
|
||
# 加密模型但没有密钥
|
||
logger.error(f"加密模型但未提供密钥: {model_path}")
|
||
return None, {'success': False, 'error': '加密模型需要密钥'}
|
||
else:
|
||
# 普通模型加载
|
||
try:
|
||
model = YOLO(local_path)
|
||
logger.info(f"普通模型加载成功: {local_path}")
|
||
verification_result = {'success': True}
|
||
except Exception as e:
|
||
logger.error(f"加载普通模型失败: {str(e)}")
|
||
return None, {'success': False, 'error': str(e)}
|
||
|
||
if model is None:
|
||
return None, verification_result or {'success': False, 'error': '模型加载失败'}
|
||
|
||
# 应用设备配置
|
||
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
|
||
model = model.to(device)
|
||
|
||
# 缓存模型
|
||
with self.cache_lock:
|
||
self.model_cache[cache_key] = {
|
||
'model': model,
|
||
'verification_result': verification_result,
|
||
'device': device,
|
||
'cached_at': time.time()
|
||
}
|
||
|
||
logger.info(f"模型加载成功: {model_path} -> {device}")
|
||
return model, verification_result or {'success': True}
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载模型失败 {model_config.get('path')}: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return None, {'success': False, 'error': str(e)}
|
||
|
||
def clear_cache(self):
|
||
"""清空模型缓存"""
|
||
with self.cache_lock:
|
||
self.model_cache.clear()
|
||
logger.info("模型缓存已清空")
|
||
|
||
def get_cache_info(self):
|
||
"""获取缓存信息"""
|
||
with self.cache_lock:
|
||
return {
|
||
'cache_size': len(self.model_cache),
|
||
'cached_models': list(self.model_cache.keys()),
|
||
'total_size': sum(info.get('original_size', 0) for info in self.model_cache.values())
|
||
}
|
||
|
||
|
||
class DetectionThread(threading.Thread):
|
||
"""多模型检测线程 - 优化版本"""
|
||
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
self.task_id = None
|
||
self.initialized = False
|
||
self.running = True
|
||
self._cleaning_up = False
|
||
self._force_stop = False
|
||
self._should_stop = threading.Event() # 添加停止事件
|
||
|
||
# 多模型支持
|
||
self.models = [] # 存储多个模型及相关配置
|
||
self.model_manager = ModelManager(config)
|
||
self.key_verification_results = {} # 密钥验证结果
|
||
|
||
# RTMP配置
|
||
self.cap = None
|
||
self.rtmp_url = config['rtmp']['url']
|
||
self.max_reconnect_attempts = config['rtmp']['max_reconnect_attempts']
|
||
self.reconnect_delay = config['rtmp']['reconnect_delay']
|
||
self.buffer_size = config['rtmp']['buffer_size']
|
||
self.timeout_ms = config['rtmp']['timeout_ms']
|
||
|
||
# 流稳定性监控
|
||
self.stream_stable = True
|
||
self.consecutive_read_failures = 0
|
||
self.max_consecutive_failures = 20 # 连续读取失败最大次数
|
||
self.retry_delay = 0.5 # 读取失败后的重试延迟(秒)
|
||
self.stream_check_interval = 10 # 流稳定性检查间隔(秒)
|
||
self.last_stream_check = time.time()
|
||
self.total_read_failures = 0
|
||
self.stream_recovery_attempts = 0
|
||
|
||
# 任务信息
|
||
self.taskname = config['task']['taskname']
|
||
self.taskid = config['task']['taskid']
|
||
self.aiid = config['task']['aiid']
|
||
|
||
# 推流管理
|
||
self.enable_push = config['push']['enable_push']
|
||
self.push_config = config['push']
|
||
self.task_streamer = None # 任务独立的推流器
|
||
self.streamer_initialized = False
|
||
self.last_push_time = 0
|
||
self.push_error_count = 0
|
||
self.max_push_errors = 5
|
||
|
||
# 性能监控
|
||
self.frame_count = 0
|
||
self.reconnect_attempts = 0
|
||
self.last_frame_time = time.time()
|
||
self.fps = 0
|
||
self.stop_event = threading.Event()
|
||
self.last_log_time = time.time()
|
||
self.daemon = True
|
||
self.target_latency = 0.05
|
||
self.processing_times = []
|
||
self.avg_process_time = 0.033
|
||
self.last_status_update = 0
|
||
self.last_fps = 0
|
||
self.original_width = 0
|
||
self.original_height = 0
|
||
|
||
# 上传配置
|
||
# self.minio_uploader = MinioUploader(config['minio'])
|
||
self.upload_queue = queue.Queue(maxsize=50)
|
||
self.upload_thread = None
|
||
self.upload_active = False
|
||
self.upload_interval = 2
|
||
self.last_upload_time = 0
|
||
self.res_api = config['task']['res_api']
|
||
|
||
# WebSocket发送队列和线程管理
|
||
self.websocket_queue = queue.Queue(maxsize=100)
|
||
self.websocket_thread = None
|
||
self.websocket_active = False
|
||
|
||
# MQTT配置
|
||
self.mqtt_config = config.get('mqtt', {})
|
||
self.mqtt_enabled = self.mqtt_config.get('enable', False)
|
||
self.mqtt_topic = self.mqtt_config.get('topic', 'drone/data')
|
||
self.mqtt_client = None
|
||
self.mqtt_connected = False
|
||
self.latest_drone_data = None
|
||
self.mqtt_data_lock = threading.Lock()
|
||
|
||
# Windows推流管理
|
||
if self._is_windows():
|
||
from windows_utils import detect_and_configure_windows
|
||
from task_stream_manager_windows import windows_task_stream_manager
|
||
|
||
# 检测和配置Windows
|
||
self.windows_config = detect_and_configure_windows()
|
||
self.stream_manager = windows_task_stream_manager
|
||
logger.info(f"Windows系统检测完成: {self.windows_config.get('status', 'unknown')}")
|
||
else:
|
||
# Linux/Mac使用原管理器
|
||
from task_stream_manager import task_stream_manager
|
||
self.stream_manager = task_stream_manager
|
||
# 推流状态跟踪
|
||
self.stream_stats = {
|
||
'total_frames_pushed': 0,
|
||
'failed_pushes': 0,
|
||
'last_push_time': 0,
|
||
'push_success_rate': 1.0,
|
||
'ffmpeg_restarts': 0
|
||
}
|
||
# 记录原始任务状态,以便恢复时知道应该恢复到什么状态
|
||
self.original_status = 'running'
|
||
# 当前任务状态跟踪
|
||
self._current_status = 'initializing'
|
||
# 密钥验证记录
|
||
self.key_verification_results = {}
|
||
|
||
def check_push_health(self):
|
||
"""检查推流健康状态,如果长时间失败则更新任务状态"""
|
||
# 检查推流失败次数和时间
|
||
current_time = time.time()
|
||
|
||
# 如果推流失败次数过多,或长时间没有成功推流,则认为推流健康状态不佳
|
||
if (self.push_error_count > self.max_push_errors or
|
||
(current_time - self.last_push_time > 60 and self.last_push_time > 0)): # 60秒内无成功推流
|
||
return False
|
||
return True
|
||
|
||
# 绘制结果
|
||
logger.info(f"检测线程初始化完成: {self.taskname}")
|
||
|
||
def _is_windows(self):
|
||
"""检查是否是Windows系统"""
|
||
import os
|
||
return os.name == 'nt' or os.name == 'win32'
|
||
|
||
def load_models(self):
|
||
"""加载多个模型 - 优化版本,从本地加载"""
|
||
try:
|
||
models_config = self.config.get('models', [])
|
||
if not models_config or not isinstance(models_config, list):
|
||
logger.error("未找到有效的models配置列表")
|
||
return False
|
||
|
||
logger.info(f"开始从本地加载 {len(models_config)} 个模型")
|
||
|
||
loaded_models = []
|
||
key_verification_results = {}
|
||
|
||
for i, model_config in enumerate(models_config):
|
||
# 检查模型是否启用
|
||
if not model_config.get('enabled', True):
|
||
logger.info(f"跳过未启用的模型 {i}")
|
||
continue
|
||
|
||
model_path = model_config.get('path', 'unknown')
|
||
model_name = os.path.basename(model_path).split('.')[0]
|
||
|
||
# 加载模型(从本地)
|
||
logger.info(f"加载模型 {i}: {model_name}")
|
||
model, verification_result = self.model_manager.load_model(
|
||
model_config,
|
||
require_verification=True # 总是验证密钥
|
||
)
|
||
|
||
# 记录验证结果
|
||
key_verification_results[i] = verification_result or {'success': False, 'error': '未知错误'}
|
||
|
||
if model is None:
|
||
logger.error(f"加载模型 {i} 失败: {model_name}")
|
||
continue
|
||
|
||
# 准备标签
|
||
tags = model_config.get('tags', {})
|
||
# 存储模型信息
|
||
model_info = {
|
||
'model': model,
|
||
'config': model_config,
|
||
'tags': tags,
|
||
'name': model_name,
|
||
'id': i,
|
||
'device': model_config.get('device', 'cpu'),
|
||
'imgsz': model_config.get('imgsz', 640),
|
||
'conf_thres': model_config.get('conf_thres', 0.25),
|
||
'iou_thres': model_config.get('iou_thres', 0.45),
|
||
'half': model_config.get('half', False),
|
||
'key_valid': verification_result.get('success', False) if verification_result else False,
|
||
'model_hash': verification_result.get('model_hash', 'unknown') if verification_result else 'unknown'
|
||
}
|
||
|
||
loaded_models.append(model_info)
|
||
logger.info(
|
||
f"模型加载成功: {model_name}, 设备: {model_info['device']}, 密钥验证: {model_info['key_valid']}")
|
||
|
||
# 检查加载结果
|
||
if len(loaded_models) == 0:
|
||
logger.error("所有模型加载失败")
|
||
|
||
# 生成详细的错误报告
|
||
error_report = []
|
||
for idx, result in key_verification_results.items():
|
||
if not result.get('success', False):
|
||
error_report.append(f"模型 {idx}: {result.get('error', '未知错误')}")
|
||
|
||
if error_report:
|
||
logger.error("模型加载失败详情:")
|
||
for error in error_report:
|
||
logger.error(f" {error}")
|
||
|
||
return False
|
||
|
||
self.models = loaded_models
|
||
self.key_verification_results = key_verification_results
|
||
|
||
logger.info(f"成功加载 {len(self.models)}/{len(models_config)} 个模型")
|
||
|
||
# 输出验证统计
|
||
valid_keys = sum(1 for result in key_verification_results.values()
|
||
if result.get('success', False))
|
||
logger.info(f"密钥验证统计: {valid_keys}个有效, {len(key_verification_results) - valid_keys}个无效")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载模型异常: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def get_key_verification_summary(self):
|
||
"""获取密钥验证摘要"""
|
||
summary = {
|
||
'total_models': len(self.key_verification_results),
|
||
'valid_keys': 0,
|
||
'invalid_keys': 0,
|
||
'loaded_models': len(self.models),
|
||
'details': {}
|
||
}
|
||
|
||
for idx, result in self.key_verification_results.items():
|
||
if result.get('success', False):
|
||
summary['valid_keys'] += 1
|
||
else:
|
||
summary['invalid_keys'] += 1
|
||
|
||
summary['details'][f'model_{idx}'] = {
|
||
'success': result.get('success', False),
|
||
'error': result.get('error', ''),
|
||
'model_hash': result.get('model_hash', '')
|
||
}
|
||
|
||
return summary
|
||
|
||
def initialize_rtmp(self):
|
||
"""初始化RTMP连接"""
|
||
try:
|
||
logger.info(f"连接RTMP: {self.rtmp_url}")
|
||
self.cap = cv2.VideoCapture()
|
||
|
||
# 设置优化参数
|
||
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, self.buffer_size)
|
||
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
|
||
self.cap.set(cv2.CAP_PROP_FPS, 30)
|
||
|
||
# 启用硬件加速解码
|
||
if self.config['rtmp'].get('gpu_decode', False):
|
||
try:
|
||
self.cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
|
||
logger.info("启用硬件加速解码")
|
||
except:
|
||
logger.warning("硬件解码不可用,使用软件解码")
|
||
|
||
# 尝试连接
|
||
if not self.cap.open(self.rtmp_url, cv2.CAP_FFMPEG):
|
||
logger.error(f"连接RTMP失败: {self.rtmp_url}")
|
||
raise IOError(f"无法连接RTMP流 ({self.rtmp_url})")
|
||
|
||
# 获取视频属性
|
||
self.original_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
self.original_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
|
||
self.fps = fps
|
||
|
||
# 重置流稳定性计数器
|
||
self.consecutive_read_failures = 0
|
||
self.total_read_failures = 0
|
||
self.stream_stable = True
|
||
|
||
logger.info(f"视频属性: {self.original_width}x{self.original_height} @ {fps}fps")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"初始化RTMP失败: {str(e)}")
|
||
raise
|
||
|
||
def initialize_task_streamer(self):
|
||
"""初始化任务推流器(Windows优化版)"""
|
||
if not self.enable_push:
|
||
logger.info(f"任务 {self.task_id} 推流功能未启用")
|
||
return True
|
||
|
||
try:
|
||
logger.info(f"初始化任务推流器: {self.task_id}")
|
||
|
||
# Windows特定检查和配置
|
||
if self._is_windows():
|
||
# 检查RTMP服务器可达性
|
||
from windows_utils import WindowsSystemUtils
|
||
push_url = self.push_config.get('url', '')
|
||
if push_url:
|
||
accessibility = WindowsSystemUtils.check_rtmp_server_accessibility(push_url)
|
||
if not accessibility.get('accessible', False):
|
||
logger.error(f"RTMP服务器不可达: {accessibility.get('error', 'Unknown error')}")
|
||
logger.warning("推流可能失败,请检查网络和服务器状态")
|
||
|
||
# 使用推流管理器创建推流器
|
||
streamer = self.stream_manager.create_streamer_for_task(
|
||
self.task_id,
|
||
self.config, # 传递完整配置
|
||
self.fps,
|
||
self.original_width,
|
||
self.original_height
|
||
)
|
||
|
||
if streamer:
|
||
logger.info(f"任务 {self.task_id} 推流器初始化成功")
|
||
self.streamer_initialized = True
|
||
self.task_streamer = streamer
|
||
|
||
# 记录初始化时间
|
||
self.stream_stats['streamer_start_time'] = time.time()
|
||
|
||
return True
|
||
else:
|
||
logger.error(f"任务 {self.task_id} 推流器初始化失败")
|
||
|
||
# Windows上尝试备用配置
|
||
if self._is_windows():
|
||
logger.info("尝试Windows备用推流配置...")
|
||
return self._initialize_windows_fallback_streamer()
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"初始化任务推流器异常 {self.task_id}: {str(e)}", exc_info=True)
|
||
|
||
# Windows上的额外错误处理
|
||
if self._is_windows():
|
||
return self._initialize_windows_fallback_streamer()
|
||
|
||
return False
|
||
|
||
def _initialize_windows_fallback_streamer(self):
|
||
"""Windows备用推流器初始化"""
|
||
try:
|
||
logger.info(f"使用Windows备用推流配置: {self.task_id}")
|
||
|
||
# 简化配置,避免硬件加速
|
||
fallback_config = self.config.copy()
|
||
fallback_config['push'] = {
|
||
'enable_push': True,
|
||
'url': self.push_config['url'],
|
||
'video_codec': 'libx264',
|
||
'preset': 'ultrafast',
|
||
'tune': 'zerolatency',
|
||
'format': 'flv',
|
||
'pixel_format': 'bgr24',
|
||
'gpu_acceleration': False,
|
||
'bitrate': '1000k',
|
||
'bufsize': '2000k',
|
||
'framerate': self.fps,
|
||
'extra_args': [
|
||
'-max_delay', '0',
|
||
'-flags', '+global_header',
|
||
'-rtbufsize', '50M'
|
||
]
|
||
}
|
||
|
||
# 使用简化配置重新初始化
|
||
streamer = self.stream_manager.create_streamer_for_task(
|
||
self.task_id,
|
||
fallback_config,
|
||
self.fps,
|
||
self.original_width,
|
||
self.original_height
|
||
)
|
||
|
||
if streamer:
|
||
logger.info(f"Windows备用推流器初始化成功: {self.task_id}")
|
||
self.streamer_initialized = True
|
||
self.task_streamer = streamer
|
||
return True
|
||
else:
|
||
logger.error(f"Windows备用推流器也失败: {self.task_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Windows备用推流器初始化异常: {str(e)}")
|
||
return False
|
||
|
||
def push_frame_to_task_streamer(self, frame):
|
||
"""推送帧到任务推流器(Windows优化版)"""
|
||
if not self.enable_push or not self.streamer_initialized:
|
||
return False
|
||
|
||
try:
|
||
# 使用推流管理器推送帧
|
||
success = self.stream_manager.push_frame(self.task_id, frame)
|
||
|
||
# 更新统计信息
|
||
self.stream_stats['total_frames_pushed'] += 1
|
||
|
||
if success:
|
||
self.last_push_time = time.time()
|
||
self.stream_stats['last_push_time'] = time.time()
|
||
self.push_error_count = 0
|
||
|
||
# 更新成功率
|
||
total = self.stream_stats['total_frames_pushed']
|
||
failed = self.stream_stats['failed_pushes']
|
||
self.stream_stats['push_success_rate'] = (total - failed) / max(total, 1)
|
||
|
||
# 每100帧记录一次成功
|
||
if total % 100 == 0:
|
||
logger.info(
|
||
f"任务 {self.task_id} 已成功推流 {total} 帧,成功率: {self.stream_stats['push_success_rate']:.2%}")
|
||
else:
|
||
self.push_error_count += 1
|
||
self.stream_stats['failed_pushes'] += 1
|
||
|
||
logger.warning(f"任务 {self.task_id} 推流失败 ({self.push_error_count}/{self.max_push_errors})")
|
||
|
||
# Windows上的额外诊断
|
||
if self._is_windows() and self.push_error_count >= 3:
|
||
self._diagnose_windows_streaming_issue()
|
||
|
||
# 连续失败处理
|
||
if self.push_error_count >= self.max_push_errors:
|
||
error_message = f"任务 {self.task_id} 推流连续失败,尝试恢复"
|
||
logger.error(error_message)
|
||
# 发送错误消息到WebSocket
|
||
self.send_error_to_websocket('push_error', error_message)
|
||
# 将任务状态更新为降级状态
|
||
self.update_task_status('degraded')
|
||
self.recover_task_streamer()
|
||
|
||
return success
|
||
|
||
except Exception as e:
|
||
logger.error(f"推流异常 {self.task_id}: {str(e)}")
|
||
self.push_error_count += 1
|
||
self.stream_stats['failed_pushes'] += 1
|
||
return False
|
||
|
||
def _diagnose_windows_streaming_issue(self):
|
||
"""诊断Windows推流问题"""
|
||
try:
|
||
logger.info("诊断Windows推流问题...")
|
||
|
||
# 获取推流信息
|
||
stream_info = self.stream_manager.get_task_stream_info(self.task_id)
|
||
if stream_info:
|
||
logger.info(f"推流状态: {stream_info.get('status', 'unknown')}")
|
||
logger.info(f"FFmpeg进程: {'运行中' if stream_info.get('process_alive') else '已停止'}")
|
||
logger.info(f"最近输出: {stream_info.get('last_ffmpeg_output', '无')}")
|
||
|
||
# 显示最近错误
|
||
output_lines = stream_info.get('output_lines', [])
|
||
if output_lines:
|
||
logger.info("最近FFmpeg输出:")
|
||
for line in output_lines[-5:]:
|
||
logger.info(f" {line}")
|
||
|
||
# 检查系统资源
|
||
from windows_utils import WindowsSystemUtils
|
||
resources = WindowsSystemUtils.get_system_resources()
|
||
|
||
if resources['cpu_percent'] > 90:
|
||
logger.warning(f"CPU使用率过高: {resources['cpu_percent']}%")
|
||
if resources['memory_percent'] > 90:
|
||
logger.warning(f"内存使用率过高: {resources['memory_percent']}%")
|
||
|
||
except Exception as e:
|
||
logger.error(f"诊断推流问题异常: {str(e)}")
|
||
|
||
def recover_task_streamer(self):
|
||
"""恢复任务推流器(Windows优化版)"""
|
||
try:
|
||
logger.info(f"任务 {self.task_id} 尝试恢复推流器")
|
||
|
||
# Windows上增加重启计数
|
||
self.stream_stats['ffmpeg_restarts'] += 1
|
||
|
||
# 清理并重新初始化
|
||
self.cleanup_task_streamer()
|
||
time.sleep(2) # Windows上需要更长的等待时间
|
||
|
||
success = self.initialize_task_streamer()
|
||
if success:
|
||
self.push_error_count = 0
|
||
# 恢复正常状态
|
||
if hasattr(self, 'original_status') and self.original_status == 'running':
|
||
self.update_task_status('running')
|
||
logger.info(f"任务 {self.task_id} 推流器恢复成功 (第{self.stream_stats['ffmpeg_restarts']}次重启)")
|
||
else:
|
||
logger.error(f"任务 {self.task_id} 推流器恢复失败")
|
||
# 如果恢复失败,且已达到最大重启次数,将任务状态设置为错误
|
||
if self.stream_stats['ffmpeg_restarts'] >= 3: # 假设最大重启3次
|
||
error_message = f"任务 {self.task_id} 推流恢复失败,已达到最大重启次数"
|
||
logger.error(error_message)
|
||
# 发送错误消息到WebSocket
|
||
self.send_error_to_websocket('push_error', error_message)
|
||
# 将任务状态更新为错误
|
||
self.update_task_status('error')
|
||
# 同时停止整个检测线程
|
||
self._force_stop = True
|
||
self._should_stop.set()
|
||
self.stop_event.set()
|
||
self.running = False
|
||
|
||
return success
|
||
|
||
except Exception as e:
|
||
logger.error(f"恢复推流器异常 {self.task_id}: {str(e)}")
|
||
return False
|
||
|
||
def cleanup_task_streamer(self):
|
||
"""清理任务推流器"""
|
||
if self.task_streamer:
|
||
try:
|
||
self.task_streamer.stop()
|
||
self.stream_manager.stop_task_streamer(self.task_id)
|
||
self.task_streamer = None
|
||
logger.info(f"任务 {self.task_id} 推流器已清理")
|
||
except Exception as e:
|
||
logger.error(f"停止推流器失败: {str(e)}")
|
||
self.streamer_initialized = False
|
||
|
||
def warmup_models(self):
|
||
"""预热所有模型"""
|
||
logger.info("预热所有模型...")
|
||
for model_info in self.models:
|
||
model = model_info['model']
|
||
model_config = model_info['config']
|
||
|
||
try:
|
||
# 准备预热输入
|
||
imgsz = model_config.get('imgsz', 640)
|
||
dummy_input = torch.zeros(1, 3, imgsz, imgsz)
|
||
device = model_config.get('device', 'cpu')
|
||
dummy_input = dummy_input.to(device)
|
||
|
||
if model_config.get('half', False) and 'cuda' in device:
|
||
dummy_input = dummy_input.half()
|
||
|
||
# 预热推理
|
||
with torch.no_grad():
|
||
for _ in range(2):
|
||
model.predict(dummy_input)
|
||
|
||
logger.info(f"模型 {model_info['name']} 预热完成")
|
||
except Exception as e:
|
||
logger.warning(f"模型 {model_info['name']} 预热失败: {str(e)}")
|
||
|
||
logger.info("所有模型预热完成")
|
||
|
||
def _multi_model_inference(self, frame):
|
||
"""多模型推理(每个模型独立标签和置信度)"""
|
||
frame_drawn, detections = multi_model_inference(self.models, frame)
|
||
return frame_drawn, detections
|
||
|
||
def should_skip_frame(self, start_time):
|
||
"""判断是否应该跳过当前帧"""
|
||
processing_time = time.perf_counter() - start_time
|
||
|
||
# 基于处理时间判断
|
||
if processing_time > self.target_latency:
|
||
return True
|
||
|
||
# 基于FPS判断
|
||
min_fps = 15
|
||
if self.fps < min_fps:
|
||
return True
|
||
|
||
return False
|
||
|
||
def handle_upload(self, annotated_frame, all_detections, current_time):
|
||
"""处理上传逻辑"""
|
||
# 合并所有检测结果
|
||
all_detections_list = []
|
||
for model_detections in all_detections.values():
|
||
all_detections_list.extend(model_detections)
|
||
|
||
if len(all_detections_list) > 0 and (current_time - self.last_upload_time >= self.upload_interval):
|
||
try:
|
||
timestamp = int(current_time * 1000)
|
||
filename = f"DJI_{timestamp}.jpg"
|
||
self.upload_queue.put({
|
||
"image": annotated_frame.copy(),
|
||
"filename": filename,
|
||
"detection_data": all_detections_list,
|
||
"timestamp": current_time
|
||
}, block=False)
|
||
self.last_upload_time = current_time
|
||
except queue.Full:
|
||
logger.warning("上传队列已满,跳过上传")
|
||
except Exception as e:
|
||
logger.error(f"添加上传任务失败: {e}")
|
||
|
||
def _upload_worker(self):
|
||
"""独立的图片上传工作线程"""
|
||
logger.info("上传工作线程启动")
|
||
output_dir = "output_frames"
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
while self.upload_active or not self.upload_queue.empty():
|
||
try:
|
||
task = self.upload_queue.get(timeout=1.0)
|
||
if task is None:
|
||
break
|
||
|
||
start_time = time.time()
|
||
image = task["image"]
|
||
filename = task["filename"]
|
||
detection_data = task["detection_data"]
|
||
filepath = os.path.join(output_dir, filename)
|
||
|
||
# 优化图片保存质量
|
||
cv2.imwrite(filepath, image, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||
|
||
try:
|
||
foldername = self.config['task']['taskname']
|
||
object_path = f"{foldername}/{filename}"
|
||
# self.minio_uploader.upload_file(filepath, object_path)
|
||
|
||
payload = {
|
||
"taskid": self.taskid,
|
||
"path": object_path,
|
||
"tag": detection_data,
|
||
"aiid": self.aiid,
|
||
}
|
||
|
||
# 添加MQTT数据
|
||
if self.mqtt_enabled and self.mqtt_connected:
|
||
with self.mqtt_data_lock:
|
||
if self.latest_drone_data:
|
||
payload["drone_info"] = self.latest_drone_data
|
||
|
||
# 发送MQTT消息
|
||
if self.mqtt_client and self.mqtt_connected:
|
||
try:
|
||
# 构造MQTT主题
|
||
mqtt_topic = f'ai/task/{self.taskid}/aiachievement'
|
||
|
||
# 发布消息
|
||
result = self.mqtt_client.publish(mqtt_topic, json.dumps(payload, ensure_ascii=False))
|
||
|
||
if result.rc == mqtt.MQTT_ERR_SUCCESS:
|
||
logger.debug(f'已上传帧至 MinIO: {object_path} | MQTT消息已发送到主题: {mqtt_topic} | 耗时: {time.time() - start_time:.2f}s')
|
||
else:
|
||
logger.warning(f'MQTT消息发送失败: {mqtt.error_string(result.rc)}')
|
||
except Exception as e:
|
||
logger.error(f'MQTT消息发送异常: {str(e)}')
|
||
else:
|
||
logger.warning('MQTT客户端未连接,跳过消息发送')
|
||
except requests.exceptions.Timeout:
|
||
logger.warning(f"API调用超时: {self.res_api}")
|
||
except Exception as e:
|
||
logger.error(f"上传/API调用失败: {e}")
|
||
finally:
|
||
try:
|
||
os.remove(filepath)
|
||
except:
|
||
pass
|
||
|
||
# 标记任务完成
|
||
self.upload_queue.task_done()
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"上传任务处理异常: {e}")
|
||
|
||
logger.info("上传工作线程已停止")
|
||
|
||
def on_mqtt_connect(self, client, userdata, flags, rc):
|
||
"""MQTT连接回调"""
|
||
if rc == 0:
|
||
client.subscribe(self.mqtt_topic)
|
||
self.mqtt_connected = True
|
||
# logger.debug("MQTT连接状态正常")
|
||
else:
|
||
logger.error(f"MQTT连接失败,错误码: {rc}")
|
||
self.mqtt_connected = False
|
||
|
||
def on_mqtt_message(self, client, userdata, msg):
|
||
"""MQTT消息回调"""
|
||
try:
|
||
drone_data = json.loads(msg.payload.decode())
|
||
with self.mqtt_data_lock:
|
||
self.latest_drone_data = drone_data
|
||
logger.debug(f"收到MQTT消息: {drone_data}")
|
||
except Exception as e:
|
||
logger.error(f"解析MQTT消息失败: {str(e)}")
|
||
|
||
def start_mqtt_client(self):
|
||
"""启动MQTT客户端"""
|
||
if not self.mqtt_enabled:
|
||
logger.info("MQTT功能未启用")
|
||
return False
|
||
try:
|
||
logger.info("启动MQTT客户端...")
|
||
self.mqtt_client = mqtt.Client(client_id=self.mqtt_config.get('client_id', 'yolo_detection'))
|
||
self.mqtt_client.on_connect = self.on_mqtt_connect
|
||
self.mqtt_client.on_message = self.on_mqtt_message
|
||
|
||
if 'username' in self.mqtt_config and 'password' in self.mqtt_config:
|
||
self.mqtt_client.username_pw_set(
|
||
self.mqtt_config['username'],
|
||
self.mqtt_config['password']
|
||
)
|
||
|
||
self.mqtt_client.connect(
|
||
self.mqtt_config['broker'],
|
||
self.mqtt_config.get('port', 1883),
|
||
self.mqtt_config.get('keepalive', 60)
|
||
)
|
||
|
||
self.mqtt_client.loop_start()
|
||
logger.info("MQTT客户端已启动")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"启动MQTT客户端失败: {str(e)}")
|
||
return False
|
||
|
||
def _websocket_worker(self):
|
||
"""独立的WebSocket发送工作线程"""
|
||
logger.info("WebSocket发送工作线程启动")
|
||
|
||
while self.websocket_active or not self.websocket_queue.empty():
|
||
try:
|
||
task = self.websocket_queue.get(timeout=1.0)
|
||
if task is None:
|
||
break
|
||
|
||
try:
|
||
# 执行WebSocket发送
|
||
now = datetime.datetime.now()
|
||
time_str = now.strftime("%H:%M:%S")
|
||
model_detections = task
|
||
|
||
# 合并所有模型的检测结果
|
||
all_detections_send = []
|
||
for det in model_detections:
|
||
det_detection = {
|
||
'count': len(det['boxes']),
|
||
'detections': {
|
||
'class_id': det['class_ids'],
|
||
'class_name': det['class_names'],
|
||
'box': det['boxes'],
|
||
'conf': det['confidences'],
|
||
},
|
||
}
|
||
all_detections_send.append(det_detection)
|
||
|
||
# 添加流稳定性信息
|
||
stream_info = {
|
||
'stable': self.stream_stable,
|
||
'consecutive_failures': self.consecutive_read_failures,
|
||
'total_failures': self.total_read_failures,
|
||
'recovery_attempts': self.stream_recovery_attempts
|
||
}
|
||
|
||
self.config['socketIO'].emit('detection_results', {
|
||
'task_id': getattr(self, 'task_id', 'unknown'),
|
||
'detections': all_detections_send,
|
||
'timestamp': time.time_ns() // 1000000,
|
||
'fps': round(self.last_fps, 1),
|
||
'frame_count': self.frame_count,
|
||
'taskname': self.taskname,
|
||
'time_str': time_str,
|
||
'models_count': len(self.models),
|
||
'stream_info': stream_info
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"WebSocket发送错误: {str(e)}")
|
||
finally:
|
||
# 标记任务完成
|
||
self.websocket_queue.task_done()
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"WebSocket工作线程异常: {e}")
|
||
|
||
logger.info("WebSocket发送工作线程已停止")
|
||
|
||
def stop_mqtt_client(self):
|
||
"""停止MQTT客户端"""
|
||
if self.mqtt_client:
|
||
try:
|
||
self.mqtt_client.loop_stop()
|
||
self.mqtt_client.disconnect()
|
||
logger.info("MQTT客户端已停止")
|
||
except Exception as e:
|
||
logger.error(f"停止MQTT客户端失败: {str(e)}")
|
||
finally:
|
||
self.mqtt_client = None
|
||
self.mqtt_connected = False
|
||
|
||
def handle_reconnect(self):
|
||
"""处理RTMP重连"""
|
||
if self.stop_event.is_set() or not self.running:
|
||
logger.info("收到停止信号,跳过重连")
|
||
return
|
||
|
||
self.reconnect_attempts += 1
|
||
if self.reconnect_attempts >= self.max_reconnect_attempts:
|
||
logger.error("达到最大重连次数")
|
||
self.running = False
|
||
return
|
||
|
||
# 使用指数退避策略计算延迟时间
|
||
delay = min(30, self.reconnect_attempts * self.reconnect_delay * 2) # 最大延迟30秒
|
||
logger.warning(f"流中断,{delay}秒后重连 (第{self.reconnect_attempts}/{self.max_reconnect_attempts}次重连)")
|
||
|
||
# 释放视频资源
|
||
if self.cap:
|
||
try:
|
||
self.cap.release()
|
||
self.cap = None
|
||
except:
|
||
pass
|
||
|
||
# 在延迟期间检查停止信号
|
||
start_time = time.time()
|
||
while time.time() - start_time < delay:
|
||
if self.stop_event.is_set() or not self.running:
|
||
logger.info("收到停止信号,取消重连")
|
||
return
|
||
time.sleep(0.5) # 增加休眠时间,减少CPU使用
|
||
|
||
# 重置流稳定性计数器
|
||
self.consecutive_read_failures = 0
|
||
self.stream_stable = True
|
||
self.stream_recovery_attempts += 1
|
||
|
||
# 重新连接
|
||
try:
|
||
if self.stop_event.is_set() or not self.running:
|
||
logger.info("收到停止信号,跳过重新连接")
|
||
return
|
||
|
||
logger.info("尝试重新连接RTMP...")
|
||
if not self.initialize_rtmp():
|
||
raise IOError("RTMP重连失败")
|
||
|
||
logger.info("RTMP重连成功")
|
||
self.reconnect_attempts = 0 # 重置重连次数
|
||
self.stream_recovery_attempts = 0 # 重置恢复尝试次数
|
||
|
||
except Exception as e:
|
||
logger.error(f"重连异常: {str(e)}")
|
||
|
||
def send_to_websocket(self, all_detections):
|
||
"""发送检测结果到WebSocket队列"""
|
||
try:
|
||
# 将检测结果放入队列,由单独线程处理发送
|
||
self.websocket_queue.put(all_detections, block=False)
|
||
except queue.Full:
|
||
logger.warning("WebSocket队列已满,跳过发送")
|
||
except Exception as e:
|
||
logger.error(f"WebSocket队列操作错误: {str(e)}")
|
||
|
||
def send_error_to_websocket(self, error_type, error_message):
|
||
"""发送错误消息到WebSocket"""
|
||
try:
|
||
now = datetime.datetime.now()
|
||
time_str = now.strftime("%H:%M:%S")
|
||
|
||
error_data = {
|
||
'task_id': getattr(self, 'task_id', 'unknown'),
|
||
'error_type': error_type,
|
||
'error_message': error_message,
|
||
'timestamp': time.time_ns() // 1000000,
|
||
'time_str': time_str,
|
||
'status': self._current_status if hasattr(self, '_current_status') else 'unknown'
|
||
}
|
||
|
||
if hasattr(self.config, 'socketIO') or 'socketIO' in self.config:
|
||
socketIO = self.config.get('socketIO') or getattr(self.config, 'socketIO')
|
||
if socketIO:
|
||
socketIO.emit('task_error', error_data)
|
||
logger.info(f"已发送错误消息到WebSocket: {error_type} - {error_message}")
|
||
except Exception as e:
|
||
logger.error(f"发送错误消息到WebSocket失败: {str(e)}")
|
||
|
||
def send_log_to_websocket(self, log_level, log_message):
|
||
"""发送日志消息到WebSocket"""
|
||
try:
|
||
now = datetime.datetime.now()
|
||
time_str = now.strftime("%H:%M:%S")
|
||
|
||
log_data = {
|
||
'task_id': getattr(self, 'task_id', 'unknown'),
|
||
'log_level': log_level,
|
||
'log_message': log_message,
|
||
'timestamp': time.time_ns() // 1000000,
|
||
'time_str': time_str,
|
||
'status': self._current_status if hasattr(self, '_current_status') else 'unknown'
|
||
}
|
||
|
||
if hasattr(self.config, 'socketIO') or 'socketIO' in self.config:
|
||
socketIO = self.config.get('socketIO') or getattr(self.config, 'socketIO')
|
||
if socketIO:
|
||
socketIO.emit('task_log', log_data)
|
||
except Exception as e:
|
||
logger.error(f"发送日志消息到WebSocket失败: {str(e)}")
|
||
|
||
def check_stream_stability(self):
|
||
"""检查流稳定性"""
|
||
current_time = time.time()
|
||
if current_time - self.last_stream_check >= self.stream_check_interval:
|
||
if self.consecutive_read_failures > 0:
|
||
logger.warning(f"流稳定性警告: 连续读取失败 {self.consecutive_read_failures} 次")
|
||
self.stream_stable = False
|
||
else:
|
||
self.stream_stable = True
|
||
|
||
self.last_stream_check = current_time
|
||
return True
|
||
return False
|
||
|
||
def handle_frame_read_failure(self):
|
||
"""处理帧读取失败"""
|
||
self.consecutive_read_failures += 1
|
||
self.total_read_failures += 1
|
||
|
||
# 检查流稳定性
|
||
self.check_stream_stability()
|
||
|
||
# 根据连续失败次数采取不同策略
|
||
if self.consecutive_read_failures == 1:
|
||
# 第一次失败,短暂等待后重试
|
||
logger.debug("帧读取失败,等待0.5秒后重试")
|
||
time.sleep(self.retry_delay)
|
||
return False # 不立即重连
|
||
|
||
elif self.consecutive_read_failures <= 3:
|
||
# 2-3次失败,中等等待后重试
|
||
logger.warning(f"连续 {self.consecutive_read_failures} 次帧读取失败,等待1秒后重试")
|
||
time.sleep(1)
|
||
return False # 不立即重连
|
||
|
||
elif self.consecutive_read_failures <= self.max_consecutive_failures:
|
||
# 4-5次失败,较长时间等待后重试
|
||
logger.error(f"连续 {self.consecutive_read_failures} 次帧读取失败,等待2秒后重试")
|
||
time.sleep(2)
|
||
return False # 不立即重连
|
||
|
||
else:
|
||
# 超过最大失败次数,需要重连
|
||
error_message = f"连续 {self.consecutive_read_failures} 次帧读取失败,触发重连"
|
||
logger.error(error_message)
|
||
# 发送错误消息到WebSocket
|
||
self.send_error_to_websocket('stream_error', error_message)
|
||
# 更新任务状态为degraded
|
||
self.update_task_status('degraded')
|
||
return True # 需要重连
|
||
|
||
def run(self):
|
||
"""检测线程主循环"""
|
||
try:
|
||
logger.info(f"启动检测线程,任务ID: {getattr(self, 'task_id', 'unknown')}")
|
||
# Windows系统检查
|
||
if self._is_windows() and hasattr(self, 'windows_config'):
|
||
if self.windows_config and 'error' in self.windows_config:
|
||
logger.warning(f"Windows配置问题: {self.windows_config['error']}")
|
||
# 1. 加载多个模型
|
||
log_message = "开始加载模型..."
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
if not self.load_models():
|
||
error_message = "加密模型加载失败,线程退出"
|
||
logger.error(error_message)
|
||
self.send_error_to_websocket('model_error', error_message)
|
||
self.update_task_status('failed')
|
||
return
|
||
|
||
key_summary = self.get_key_verification_summary()
|
||
log_message = f"加密密钥验证结果: {key_summary['loaded_models']}/{key_summary['total_models']} 个模型加载成功"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
# 2. 初始化RTMP连接
|
||
log_message = "开始连接视频流..."
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
if not self.initialize_rtmp():
|
||
error_message = "RTMP连接失败,线程退出"
|
||
logger.error(error_message)
|
||
self.send_error_to_websocket('stream_error', error_message)
|
||
self.update_task_status('failed')
|
||
return
|
||
|
||
log_message = f"视频流连接成功: {self.original_width}x{self.original_height} @ {self.fps}fps"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
# 3. 初始化任务推流器
|
||
if self.enable_push:
|
||
log_message = f"开始初始化任务推流器: {self.task_id}"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
if not self.initialize_task_streamer():
|
||
error_message = f"任务 {self.task_id} 推流器初始化失败,继续运行但不推流"
|
||
logger.warning(error_message)
|
||
self.send_error_to_websocket('push_error', error_message)
|
||
self.enable_push = False
|
||
else:
|
||
log_message = f"任务 {self.task_id} 推流器初始化成功"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
# 4. 启动推流管理器健康监控
|
||
if self.enable_push and hasattr(self.stream_manager, 'start_health_monitor'):
|
||
try:
|
||
self.stream_manager.start_health_monitor()
|
||
log_message = "推流健康监控已启动"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
except Exception as e:
|
||
error_message = f"启动健康监控失败: {e}"
|
||
logger.warning(error_message)
|
||
self.send_error_to_websocket('push_error', error_message)
|
||
|
||
# 5. 启动MQTT(如果启用)
|
||
# if self.mqtt_enabled:
|
||
# log_message = "开始启动MQTT客户端..."
|
||
# logger.info(log_message)
|
||
# self.send_log_to_websocket('info', log_message)
|
||
#
|
||
# if self.start_mqtt_client():
|
||
# log_message = "MQTT客户端已启动"
|
||
# logger.info(log_message)
|
||
# self.send_log_to_websocket('info', log_message)
|
||
# else:
|
||
# error_message = "启动MQTT客户端失败"
|
||
# logger.error(error_message)
|
||
# self.send_error_to_websocket('mqtt_error', error_message)
|
||
|
||
# 6. 启动上传线程
|
||
log_message = "开始启动图片上传线程..."
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
self.upload_active = True
|
||
self.upload_thread = threading.Thread(target=self._upload_worker, daemon=True)
|
||
self.upload_thread.start()
|
||
|
||
log_message = "图片上传线程已启动"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
# 7. 启动WebSocket发送线程
|
||
log_message = "开始启动WebSocket发送线程..."
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
self.websocket_active = True
|
||
self.websocket_thread = threading.Thread(target=self._websocket_worker, daemon=True)
|
||
self.websocket_thread.start()
|
||
|
||
log_message = "WebSocket发送线程已启动"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
# 8. 预热所有模型
|
||
log_message = "开始预热所有模型..."
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
self.warmup_models()
|
||
|
||
log_message = "所有模型预热完成"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
# 9. 更新任务状态
|
||
self.initialized = True
|
||
self.update_task_status('running')
|
||
self.original_status = 'running' # 记录原始状态为running
|
||
|
||
log_message = "资源初始化完成,任务开始运行"
|
||
logger.info(log_message)
|
||
self.send_log_to_websocket('info', log_message)
|
||
|
||
# 10. 主循环
|
||
while self.running and not self.stop_event.is_set():
|
||
# 检查停止标志
|
||
if self._force_stop or self._should_stop.is_set():
|
||
logger.info("收到停止信号,退出循环")
|
||
break
|
||
|
||
if self.stop_event.is_set():
|
||
logger.info("收到停止事件信号,退出循环")
|
||
break
|
||
|
||
start_time = time.perf_counter()
|
||
|
||
try:
|
||
# 读取帧
|
||
ret, frame = self.cap.read()
|
||
if not ret:
|
||
if self.stop_event.is_set() or not self.running or self._force_stop:
|
||
logger.info("收到停止信号,不再尝试重连")
|
||
break
|
||
|
||
# 处理帧读取失败(带有重试和延迟策略)
|
||
need_reconnect = self.handle_frame_read_failure()
|
||
|
||
if need_reconnect:
|
||
# 需要重连
|
||
self.handle_reconnect()
|
||
continue
|
||
|
||
# 成功读取帧,重置失败计数器
|
||
if self.consecutive_read_failures > 0:
|
||
logger.info(f"恢复帧读取成功,之前连续失败 {self.consecutive_read_failures} 次")
|
||
self.consecutive_read_failures = 0
|
||
self.stream_stable = True
|
||
|
||
# 确保帧分辨率一致
|
||
if frame.shape[1] != self.original_width or frame.shape[0] != self.original_height:
|
||
frame = cv2.resize(frame, (self.original_width, self.original_height))
|
||
|
||
# 计算FPS
|
||
current_time = time.time()
|
||
time_diff = current_time - self.last_frame_time
|
||
if time_diff > 0:
|
||
self.fps = 0.9 * self.fps + 0.1 / time_diff
|
||
if current_time - self.last_status_update > 0.5:
|
||
self.last_fps = self.fps
|
||
self.last_status_update = current_time
|
||
self.last_frame_time = current_time
|
||
|
||
# 动态跳帧(Windows上更保守)
|
||
if self.should_skip_frame(start_time):
|
||
self.frame_count += 1
|
||
continue
|
||
# 多模型推理
|
||
start = time.time()
|
||
annotated_frame, model_detections = self._multi_model_inference(frame)
|
||
# annotated_frame = frame
|
||
# model_detections = []
|
||
# logger.info(f'startTime:{start},endTime:{time.time()},时间差:{time.time() - start}')
|
||
# 推流处理(Windows优化)
|
||
if self.enable_push:
|
||
if not self.push_frame_to_task_streamer(annotated_frame):
|
||
# 推流失败,但继续处理,避免影响检测
|
||
pass
|
||
|
||
if current_time - self.last_log_time >= 1:
|
||
# # WebSocket发送
|
||
self.send_to_websocket(model_detections)
|
||
# # 上传处理
|
||
# self.handle_upload(annotated_frame, model_detections, current_time)
|
||
|
||
# 检查推流健康状态
|
||
if self.enable_push:
|
||
if not self.check_push_health():
|
||
# 如果推流不健康,且当前不是错误状态,则更新为降级状态
|
||
current_status = getattr(self, '_current_status', 'running')
|
||
if current_status != 'error' and current_status != 'degraded':
|
||
# self.update_task_status('degraded')
|
||
pass
|
||
|
||
self.last_log_time = current_time
|
||
|
||
self.frame_count += 1
|
||
self.reconnect_attempts = 0
|
||
|
||
# 性能监控
|
||
elapsed = time.perf_counter() - start_time
|
||
self.processing_times.append(elapsed)
|
||
self.avg_process_time = self.avg_process_time * 0.9 + elapsed * 0.1
|
||
|
||
# 每50帧输出一次性能报告(包含推流统计)
|
||
if self.frame_count % 50 == 0:
|
||
avg_time = sum(self.processing_times) / len(
|
||
self.processing_times) if self.processing_times else 0
|
||
logger.info(
|
||
f"帧处理耗时: {avg_time * 1000:.2f}ms | 平均FPS: {(1.0 / avg_time if avg_time > 0 else 0):.2f}")
|
||
|
||
# 流稳定性报告
|
||
if self.total_read_failures > 0:
|
||
logger.info(
|
||
f"流稳定性: 总失败次数 {self.total_read_failures}, 当前连续失败 {self.consecutive_read_failures}")
|
||
|
||
# 推流统计
|
||
if self.enable_push:
|
||
logger.info(f"推流统计: {self.stream_stats['total_frames_pushed']}帧, "
|
||
f"成功率: {self.stream_stats['push_success_rate']:.2%}, "
|
||
f"重启: {self.stream_stats.get('ffmpeg_restarts', 0)}次")
|
||
self.processing_times = []
|
||
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理帧时发生异常: {str(e)}")
|
||
if self.stop_event.is_set() or not self.running or self._force_stop:
|
||
logger.info("收到停止信号,退出异常处理")
|
||
break
|
||
continue
|
||
|
||
logger.info("检测线程主循环结束")
|
||
|
||
except Exception as e:
|
||
error_message = f"检测线程异常: {str(e)}"
|
||
logger.error(error_message)
|
||
logger.error(traceback.format_exc())
|
||
# 发送错误消息到WebSocket
|
||
self.send_error_to_websocket('thread_error', error_message)
|
||
# 更新任务状态为错误
|
||
self.update_task_status('error')
|
||
finally:
|
||
self.cleanup()
|
||
logger.info("检测线程已安全停止")
|
||
self.update_task_status('stopped')
|
||
return False
|
||
|
||
def stop(self):
|
||
"""停止检测线程"""
|
||
logger.info(f"收到停止请求,任务ID: {getattr(self, 'task_id', 'unknown')}")
|
||
|
||
# 首先设置强制停止标志
|
||
self._force_stop = True
|
||
self._should_stop.set() # 设置停止事件
|
||
|
||
# 发送停止事件
|
||
self.stop_event.set()
|
||
|
||
# 设置运行标志为False
|
||
self.running = False
|
||
|
||
# 强制释放视频流以解除阻塞
|
||
if self.cap:
|
||
try:
|
||
# 强制中断视频读取
|
||
self.cap.release()
|
||
logger.info("已强制释放视频流")
|
||
except Exception as e:
|
||
logger.error(f"释放视频流失败: {str(e)}")
|
||
|
||
# 清理任务推流器
|
||
self.cleanup_task_streamer()
|
||
|
||
# 停止MQTT客户端
|
||
if self.mqtt_enabled:
|
||
self.stop_mqtt_client()
|
||
|
||
# 停止上传线程
|
||
self.upload_active = False
|
||
if hasattr(self, 'upload_thread') and self.upload_thread and self.upload_thread.is_alive():
|
||
# 发送停止信号到上传队列
|
||
try:
|
||
self.upload_queue.put(None, timeout=0.5)
|
||
except:
|
||
pass
|
||
|
||
logger.info(f"停止信号已发送,任务ID: {getattr(self, 'task_id', 'unknown')}")
|
||
|
||
def cleanup(self):
|
||
"""清理所有资源"""
|
||
logger.info(f"开始清理资源,任务ID: {getattr(self, 'task_id', 'unknown')}")
|
||
|
||
if hasattr(self, '_cleaning_up') and self._cleaning_up:
|
||
logger.warning("资源已经在清理中,跳过重复清理")
|
||
return
|
||
|
||
self._cleaning_up = True
|
||
|
||
try:
|
||
# 停止MQTT客户端
|
||
if self.mqtt_enabled:
|
||
self.stop_mqtt_client()
|
||
|
||
# 停止上传线程
|
||
self.upload_active = False
|
||
if hasattr(self, 'upload_thread') and self.upload_thread and self.upload_thread.is_alive():
|
||
logger.info("停止上传线程...")
|
||
try:
|
||
if hasattr(self, 'upload_queue'):
|
||
try:
|
||
self.upload_queue.put(None, timeout=0.5)
|
||
except queue.Full:
|
||
try:
|
||
while not self.upload_queue.empty():
|
||
self.upload_queue.get_nowait()
|
||
except:
|
||
pass
|
||
self.upload_queue.put(None, block=False)
|
||
except:
|
||
pass
|
||
|
||
self.upload_thread.join(3.0)
|
||
if self.upload_thread.is_alive():
|
||
logger.warning("上传线程未在3秒内停止")
|
||
except Exception as e:
|
||
logger.error(f"停止上传线程异常: {str(e)}")
|
||
|
||
# 停止WebSocket发送线程
|
||
self.websocket_active = False
|
||
if hasattr(self, 'websocket_thread') and self.websocket_thread and self.websocket_thread.is_alive():
|
||
logger.info("停止WebSocket发送线程...")
|
||
try:
|
||
if hasattr(self, 'websocket_queue'):
|
||
try:
|
||
self.websocket_queue.put(None, timeout=0.5)
|
||
except queue.Full:
|
||
try:
|
||
while not self.websocket_queue.empty():
|
||
self.websocket_queue.get_nowait()
|
||
except:
|
||
pass
|
||
self.websocket_queue.put(None, block=False)
|
||
except:
|
||
pass
|
||
|
||
self.websocket_thread.join(3.0)
|
||
if self.websocket_thread.is_alive():
|
||
logger.warning("WebSocket发送线程未在3秒内停止")
|
||
except Exception as e:
|
||
logger.error(f"停止WebSocket发送线程异常: {str(e)}")
|
||
finally:
|
||
self.websocket_thread = None
|
||
|
||
# 清理任务推流器
|
||
self.cleanup_task_streamer()
|
||
|
||
# 释放视频流
|
||
if hasattr(self, 'cap') and self.cap:
|
||
logger.info("释放视频流...")
|
||
try:
|
||
self.cap.release()
|
||
except Exception as e:
|
||
logger.error(f"释放视频流异常: {str(e)}")
|
||
finally:
|
||
self.cap = None
|
||
|
||
# 释放所有模型
|
||
logger.info(f"释放所有模型,共 {len(self.models)} 个")
|
||
for i, model_info in enumerate(self.models):
|
||
try:
|
||
model = model_info['model']
|
||
model_name = model_info['name']
|
||
|
||
# 清理模型缓存
|
||
if hasattr(model, 'predictor'):
|
||
try:
|
||
del model.predictor
|
||
except:
|
||
pass
|
||
if hasattr(model, 'model'):
|
||
try:
|
||
del model.model
|
||
except:
|
||
pass
|
||
|
||
# 释放模型引用
|
||
del model
|
||
logger.info(f"模型 {model_name} 已释放")
|
||
except Exception as e:
|
||
logger.error(f"释放模型 {i} 异常: {str(e)}")
|
||
|
||
self.models = []
|
||
|
||
# 清理GPU缓存
|
||
logger.info("清理GPU缓存...")
|
||
try:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
logger.info("GPU缓存已清理")
|
||
except Exception as e:
|
||
logger.error(f"清理GPU缓存异常: {str(e)}")
|
||
|
||
# 清理其他资源
|
||
logger.info("清理其他资源...")
|
||
try:
|
||
if hasattr(self, 'upload_queue'):
|
||
try:
|
||
while not self.upload_queue.empty():
|
||
try:
|
||
self.upload_queue.get_nowait()
|
||
except:
|
||
break
|
||
except:
|
||
pass
|
||
|
||
self.processing_times = []
|
||
|
||
logger.info("其他资源已清理")
|
||
except Exception as e:
|
||
logger.error(f"清理其他资源异常: {str(e)}")
|
||
|
||
logger.info("资源清理完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"资源清理过程中发生异常: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
finally:
|
||
self._cleaning_up = False # 重置清理标志
|
||
|
||
def update_task_status(self, status):
|
||
"""更新任务状态"""
|
||
if hasattr(self, 'task_id') and self.task_id:
|
||
try:
|
||
from task_manager import task_manager
|
||
if hasattr(task_manager, 'update_task_status'):
|
||
task_manager.update_task_status(self.task_id, status)
|
||
else:
|
||
# 备用方法:直接更新全局数据
|
||
tasks_dict = gd.get_or_create_dict('tasks')
|
||
if self.task_id in tasks_dict:
|
||
tasks_dict[self.task_id]['status'] = status
|
||
logger.debug(f"更新任务 {self.task_id} 状态为: {status}")
|
||
|
||
# 更新本地状态跟踪
|
||
self._current_status = status
|
||
|
||
except Exception as e:
|
||
logger.warning(f"更新任务状态失败: {str(e)}")
|