import datetime import gc import json import os import queue import threading import time import traceback from pathlib import Path import cv2 import paho.mqtt.client as mqtt import requests import torch from ultralytics import YOLO from _minio import MinioUploader from ffmpegStreamer import FFmpegStreamer from log import logger import global_data as gd from mapping_cn import class_mapping_cn as cmc class DetectionThread(threading.Thread): def __init__(self, config): super().__init__() self.config = config self.running = True self.model = None self.cap = None self.streamer = None self.frame_count = 0 self.frame_skip_counter = 0 self.reconnect_attempts = 0 self.last_frame_time = time.time() self.fps = 0 self.detections_count = 0 self.stop_event = threading.Event() self.rtmp_url = config['rtmp']['url'] self.max_reconnect_attempts = config['rtmp']['max_reconnect_attempts'] self.reconnect_delay = config['rtmp']['reconnect_delay'] self.buffer_size = config['rtmp']['buffer_size'] self.timeout_ms = config['rtmp']['timeout_ms'] self.taskname = config['task']['taskname'] self.taskid = config['task']['taskid'] self.tag = config['task']['tag'] self.aiid = config['task']['aiid'] self.last_log_time = time.time() self.daemon = True self.frame_skip = 0 self.target_latency = 0.05 self.max_processing_time = 0.033 self.last_processing_time = 0 self.prev_frame = None self.prev_results = None self.processing_times = [] self.avg_process_time = 0.033 self.last_status_update = 0 self.last_fps = 0 self.original_width = 0 self.original_height = 0 self.imgsz = config['predict']['imgsz'] self.minio_uploader = MinioUploader(config['minio']) # 上传队列和线程 self.upload_queue = queue.Queue(maxsize=50) self.upload_thread = None self.upload_active = False self.upload_interval = 2 self.last_upload_time = 0 self.res_api = config['task']['res_api'] # MQTT配置 self.mqtt_config = config.get('mqtt', {}) self.mqtt_enabled = self.mqtt_config.get('enable', False) self.mqtt_topic = self.mqtt_config.get('topic', 'drone/data') self.mqtt_client = None self.mqtt_connected = False self.latest_drone_data = None self.mqtt_data_lock = threading.Lock() def on_mqtt_connect(self, client, userdata, flags, rc): if rc == 0: client.subscribe(self.mqtt_topic) self.mqtt_connected = True # logger.info("MQTT连接成功") # logger.info(f"已订阅主题: {self.mqtt_topic}") else: logger.error(f"MQTT连接失败,错误码: {rc}") self.mqtt_connected = False def on_mqtt_message(self, client, userdata, msg): try: drone_data = json.loads(msg.payload.decode()) with self.mqtt_data_lock: self.latest_drone_data = drone_data logger.debug(f"收到MQTT消息: {drone_data}") except Exception as e: logger.error(f"解析MQTT消息失败: {str(e)}") def start_mqtt_client(self): if not self.mqtt_enabled: logger.info("MQTT功能未启用") return False try: logger.info("启动MQTT客户端...") self.mqtt_client = mqtt.Client(client_id=self.mqtt_config.get('client_id', 'yolo_detection')) self.mqtt_client.on_connect = self.on_mqtt_connect self.mqtt_client.on_message = self.on_mqtt_message if 'username' in self.mqtt_config and 'password' in self.mqtt_config: self.mqtt_client.username_pw_set( self.mqtt_config['username'], self.mqtt_config['password'] ) self.mqtt_client.connect( self.mqtt_config['broker'], self.mqtt_config.get('port', 1883), self.mqtt_config.get('keepalive', 60) ) self.mqtt_client.loop_start() logger.info("MQTT客户端已启动") return True except Exception as e: logger.error(f"启动MQTT客户端失败: {str(e)}") return False def stop_mqtt_client(self): if self.mqtt_client: try: self.mqtt_client.loop_stop() self.mqtt_client.disconnect() logger.info("MQTT客户端已停止") except Exception as e: logger.error(f"停止MQTT客户端失败: {str(e)}") finally: self.mqtt_client = None self.mqtt_connected = False def run(self): global detection_active try: logger.info("启动优化检测线程") self.initialize_resources() logger.info("资源初始化完成") if self.mqtt_enabled: self.start_mqtt_client() self.upload_active = True self.upload_thread = threading.Thread(target=self._upload_worker, daemon=True) self.upload_thread.start() logger.info("图片上传线程已启动") # 预热模型 logger.info("预热模型...") dummy_input = torch.zeros(1, 3, self.imgsz, self.imgsz) dummy_input = dummy_input.to(self.config['predict']['device']) if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']: dummy_input = dummy_input.half() for _ in range(5): self.model.predict(dummy_input) logger.info("模型预热完成") while self.running and not self.stop_event.is_set(): start_time = time.perf_counter() # 优化帧读取 ret, frame = self.cap.read() logger.debug(f'读取帧结果: {ret}') # 处理重连 if not ret: self.handle_reconnect() continue # 确保帧分辨率一致 if frame.shape[1] != self.original_width or frame.shape[0] != self.original_height: logger.warning( f"帧分辨率不匹配: 预期 {self.original_width}x{self.original_height}, 实际 {frame.shape[1]}x{frame.shape[0]}") frame = cv2.resize(frame, (self.original_width, self.original_height)) # 计算FPS current_time = time.time() time_diff = current_time - self.last_frame_time if time_diff > 0: self.fps = 0.9 * self.fps + 0.1 / time_diff if current_time - self.last_status_update > 0.5: self.last_fps = self.fps self.last_status_update = current_time self.last_frame_time = current_time # 最低延迟处理逻辑 processing_time = time.perf_counter() - start_time if processing_time > self.target_latency: self.frame_skip_counter += 1 self.frame_count += 1 continue # 性能优化:跳过帧处理 if self.fps < 15 and self.frame_skip < 2: self.frame_skip += 1 self.frame_count += 1 continue self.frame_skip = 0 # 使用帧差分算法减少处理量 if self.prev_frame is not None: # 确保前一帧与当前帧分辨率一致 if self.prev_frame.shape != frame.shape: logger.warning(f"前一帧分辨率不匹配: 预期 {frame.shape}, 实际 {self.prev_frame.shape}") self.prev_frame = cv2.resize(self.prev_frame, (frame.shape[1], frame.shape[0])) # 使用优化的推理 results = self.model( frame, stream=False, verbose=False, conf=self.config['predict']['conf_thres'], iou=self.config['predict']['iou_thres'], imgsz=self.imgsz, device=self.config['predict']['device'], half=self.config['predict'].get('half', False) ) self.prev_results = results else: results = self.model( frame, stream=False, verbose=False, conf=self.config['predict']['conf_thres'], iou=self.config['predict']['iou_thres'], imgsz=self.imgsz, device=self.config['predict']['device'], half=self.config['predict'].get('half', False) ) self.prev_frame = frame.copy() # 使用副本避免引用问题 # 处理结果 annotated_frame, detection_data = self.process_results(frame, results) self.detections_count = len(detection_data) # 仅在检测到目标且满足时间间隔时添加上传任务 if len(detection_data) > 0 and (current_time - self.last_upload_time >= self.upload_interval): try: timestamp = int(current_time * 1000) filename = f"DJI_{timestamp}.jpg" self.upload_queue.put({ "image": annotated_frame.copy(), "filename": filename, "detection_data": detection_data, "timestamp": current_time }) self.last_upload_time = current_time except Exception as e: logger.error(f"添加上传任务失败: {e}") # 推流处理 if self.config['push']['enable_push'] and self.streamer: try: self.streamer.add_frame(annotated_frame) except queue.Full: logger.warning("推流队列已满,丢弃帧") # WebSocket发送优化 if current_time - self.last_log_time >= 1: self.send_to_websocket(detection_data) self.last_log_time = current_time self.frame_count += 1 self.reconnect_attempts = 0 # 性能监控 elapsed = time.perf_counter() - start_time self.processing_times.append(elapsed) self.avg_process_time = self.avg_process_time * 0.9 + elapsed * 0.1 # 每50帧输出一次性能报告 if self.frame_count % 50 == 0: avg_time = sum(self.processing_times) / len(self.processing_times) logger.info(f"帧处理耗时: {avg_time * 1000:.2f}ms | 平均FPS: {1.0 / avg_time:.1f}") self.processing_times = [] # 动态调整跳帧参数 if avg_time > 0.15: self.config['predict']['frame_skip'] = min(4, self.config['predict']['frame_skip'] + 1) logger.info(f"增加跳帧至 {self.config['predict']['frame_skip']}") elif avg_time < 0.05 and self.config['predict']['frame_skip'] > 0: self.config['predict']['frame_skip'] = max(0, self.config['predict']['frame_skip'] - 1) logger.info(f"减少跳帧至 {self.config['predict']['frame_skip']}") logger.info("检测线程主循环结束") except Exception as e: logger.error(f"检测线程异常: {str(e)}") logger.error(traceback.format_exc()) finally: self.cleanup() logger.info("检测线程已安全停止") gd.set_value('detection_active', False) def _upload_worker(self): """独立的图片上传工作线程""" logger.info("上传工作线程启动") output_dir = "output_frames" os.makedirs(output_dir, exist_ok=True) while self.upload_active or not self.upload_queue.empty(): try: task = self.upload_queue.get(timeout=1.0) if task is None: break start_time = time.time() image = task["image"] filename = task["filename"] detection_data = task["detection_data"] filepath = os.path.join(output_dir, filename) cv2.imwrite(filepath, image) try: foldername = self.config['task']['taskname'] object_path = f"{foldername}/{filename}" self.minio_uploader.upload_file(filepath, object_path) payload = { "taskid": self.taskid, "path": object_path, "tag": detection_data, "aiid": self.aiid, } # 添加MQTT数据 if self.mqtt_enabled and self.mqtt_connected: with self.mqtt_data_lock: if self.latest_drone_data: payload["drone_info"] = self.latest_drone_data logger.info("添加无人机数据到API调用") headers = {"Content-Type": "application/json"} response = requests.post(self.res_api, json=payload, headers=headers) if response.status_code == 200: logger.info(f"已上传帧至 MinIO: {object_path} | 耗时: {time.time() - start_time:.2f}s") else: logger.warning(f"API调用失败: {response.status_code} - {response.text}") except Exception as e: logger.error(f"上传/API调用失败: {e}") finally: try: os.remove(filepath) except: pass except queue.Empty: continue except Exception as e: logger.error(f"上传任务处理异常: {e}") logger.info("上传工作线程已停止") def process_results(self, frame, results): result = results[0] annotated_frame = [] # for result in results: annotated_frame = result.plot( img=frame, line_width=self.config['predict']['line_width'], # font_size=self.config['predict']['font_size'], font= self.config['predict']['font'], conf=True, labels=True, probs=True ) if annotated_frame.shape != frame.shape: logger.warning(f"渲染帧分辨率不匹配: 输入 {frame.shape}, 输出 {annotated_frame.shape}") annotated_frame = cv2.resize(annotated_frame, (frame.shape[1], frame.shape[0])) # break detection_data = [] # for result in results: # print("result",result) boxes = result.boxes for box in boxes: class_id = int(box.cls) confidence = float(box.conf) class_name_obj = self.tag.get(str(class_id), None) if class_name_obj is None: class_name_obj = cmc.get(str(class_id), None) if float(box.conf) < class_name_obj['reliability']: continue detection_data.append({ 'class_id': class_id, 'class_name': class_name_obj['name'], 'confidence': confidence, 'box': box.xyxy[0].tolist(), 'reliability': class_name_obj['reliability'] }) return annotated_frame, detection_data def initialize_resources(self): """初始化所有资源""" # 加载模型 model_path = self.config['model']['path'] download_url = self.config['task']['api'] + model_path _model_path = r"models/" + model_path if not Path(os.path.join('models', model_path)).exists(): logger.info(r"模型不存在,开始下载模型") try: model_dir = model_path.split("\\")[0] if ".pt" in model_dir: model_dir = "" model_save_path = r"models/" + model_dir model_dir = os.path.dirname(model_save_path) if model_dir != '' and not os.path.exists(model_save_path): os.makedirs(model_save_path) response = requests.get(download_url, stream=True) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) block_size = 1024 * 1024 downloaded = 0 progress = 0 with open(os.path.join('models', model_path), 'wb') as f: for data in response.iter_content(block_size): downloaded += len(data) f.write(data) new_progress = int(downloaded * 100 / total_size) if total_size > 0 else 0 if new_progress != progress and (new_progress % 10 == 0 or downloaded == total_size): logger.info(f"下载进度: {new_progress}% ({downloaded}/{total_size} 字节)") progress = new_progress logger.info(f"模型下载成功: {model_save_path} ({downloaded} 字节)") except Exception as e: logger.error(f"下载模型失败: {str(e)}") if not Path(model_path).exists(): logger.error("没有可用的模型文件,使用默认模型") _model_path = r"models/yolov8n.pt" logger.info(f"加载模型...{_model_path}") self.model = YOLO(_model_path).to(self.config['predict']['device']) if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']: self.model = self.model.half() logger.info("启用半精度推理") logger.info("模型预热...") dummy_input = torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.config['predict']['device']) if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']: dummy_input = dummy_input.half() for _ in range(5): self.model.predict(dummy_input) logger.info(f"模型加载成功 | 设备: {self.config['predict']['device'].upper()}") # 初始化RTMP连接 self.original_width, self.original_height, fps = self.initialize_rtmp() self.fps = fps # 保存fps用于后续创建推流线程 # 初始化推流线程 self.start_streamer() def initialize_rtmp(self): """初始化RTMP连接 - 使用更可靠的方法""" logger.info(f"连接RTMP: {self.rtmp_url}") self.cap = cv2.VideoCapture() # 设置优化参数 self.cap.set(cv2.CAP_PROP_BUFFERSIZE, self.buffer_size) self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264')) self.cap.set(cv2.CAP_PROP_FPS, 30) # 启用硬件加速解码 if self.config['rtmp'].get('gpu_decode', False): try: self.cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY) logger.info("启用硬件加速解码") except: logger.warning("硬件解码不可用,使用软件解码") # 尝试连接 if not self.cap.open(self.rtmp_url, cv2.CAP_FFMPEG): logger.error(f"连接RTMP失败: {self.rtmp_url}") raise IOError(f"无法连接RTMP流 ({self.rtmp_url})") # 获取视频属性 frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = self.cap.get(cv2.CAP_PROP_FPS) or 30 logger.info(f"视频属性: {frame_width}x{frame_height} @ {fps}fps") return frame_width, frame_height, fps def handle_reconnect(self): """优化重连逻辑 - 更高效的恢复策略""" self.reconnect_attempts += 1 if self.reconnect_attempts >= self.max_reconnect_attempts: logger.error("达到最大重连次数") self.running = False return delay = min(10, self.reconnect_attempts * self.reconnect_delay) logger.warning(f"流中断,{delay}秒后重连 ({self.reconnect_attempts}/{self.max_reconnect_attempts})") # 清理推流数据并停止推流线程 self.stop_streamer() # 释放视频资源 if self.cap: self.cap.release() self.cap = None time.sleep(delay) # 重新连接 try: width, height, fps = self.initialize_rtmp() self.original_width = width self.original_height = height self.prev_frame = None self.prev_results = None # 重新创建推流线程 self.start_streamer() logger.info("重连成功") except Exception as e: logger.error(f"重连异常: {str(e)}") def send_to_websocket(self, detection_data): """高效发送WebSocket数据""" try: now = datetime.datetime.now() time_str = now.strftime("%H:%M:%S") simplified_data = [{ 'class_id': d['class_id'], 'class_name': d['class_name'], 'confidence': round(d['confidence'], 2), 'box': [round(float(c), 1) for c in d['box']], } for d in detection_data] self.config['socketIO'].emit('detection_results', { 'detections': simplified_data, 'timestamp': time.time_ns() // 1000000, 'fps': round(self.last_fps, 1), 'frame_count': self.frame_count, 'taskid': self.taskid, 'time_str': time_str }) except Exception as e: logger.error(f"WebSocket发送错误: {str(e)}") def start_streamer(self): """启动新的推流线程""" if self.config['push']['enable_push']: try: self.streamer = FFmpegStreamer(self.config, self.fps, self.original_width, self.original_height) self.streamer.start() logger.info("推流线程已启动") except Exception as e: logger.error(f"启动推流线程失败: {str(e)}") def stop_streamer(self): """停止推流线程并清理资源""" if self.streamer: try: self.streamer.stop() self.streamer = None logger.info("推流线程已停止") except Exception as e: logger.error(f"停止推流线程失败: {str(e)}") def stop(self): """停止检测线程""" self.stop_event.set() self.running = False def cleanup(self): """资源清理 - 保证资源释放""" logger.info("清理资源...") # 停止MQTT客户端 if self.mqtt_enabled: self.stop_mqtt_client() # 停止上传线程 self.upload_active = False if self.upload_thread and self.upload_thread.is_alive(): try: self.upload_queue.put(None, timeout=0.5) except: pass self.upload_thread.join(2.0) if self.upload_thread.is_alive(): logger.warning("上传线程未能正常停止") logger.info("上传资源清理完成") # 释放流 if self.cap: try: self.cap.release() except: pass self.cap = None logger.info("视频流已释放") # 停止推流 if self.streamer: self.streamer.stop() self.streamer = None logger.info("推流已停止") # 释放模型 if self.model: try: if hasattr(self.model, 'predictor'): del self.model.predictor if hasattr(self.model, 'model'): del self.model.model del self.model except: pass self.model = None logger.info("模型已释放") # 清空缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.info("资源清理完成")