Yolov/detectionThread.py

636 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import datetime
import gc
import json
import os
import queue
import threading
import time
import traceback
from pathlib import Path
import cv2
import paho.mqtt.client as mqtt
import requests
import torch
from ultralytics import YOLO
from _minio import MinioUploader
from ffmpegStreamer import FFmpegStreamer
from log import logger
import global_data as gd
from mapping_cn import class_mapping_cn as cmc
class DetectionThread(threading.Thread):
def __init__(self, config):
super().__init__()
self.config = config
self.running = True
self.model = None
self.cap = None
self.streamer = None
self.frame_count = 0
self.frame_skip_counter = 0
self.reconnect_attempts = 0
self.last_frame_time = time.time()
self.fps = 0
self.detections_count = 0
self.stop_event = threading.Event()
self.rtmp_url = config['rtmp']['url']
self.max_reconnect_attempts = config['rtmp']['max_reconnect_attempts']
self.reconnect_delay = config['rtmp']['reconnect_delay']
self.buffer_size = config['rtmp']['buffer_size']
self.timeout_ms = config['rtmp']['timeout_ms']
self.taskname = config['task']['taskname']
self.taskid = config['task']['taskid']
self.tag = config['task']['tag']
self.aiid = config['task']['aiid']
self.last_log_time = time.time()
self.daemon = True
self.frame_skip = 0
self.target_latency = 0.05
self.max_processing_time = 0.033
self.last_processing_time = 0
self.prev_frame = None
self.prev_results = None
self.processing_times = []
self.avg_process_time = 0.033
self.last_status_update = 0
self.last_fps = 0
self.original_width = 0
self.original_height = 0
self.imgsz = config['predict']['imgsz']
self.minio_uploader = MinioUploader(config['minio'])
# 上传队列和线程
self.upload_queue = queue.Queue(maxsize=50)
self.upload_thread = None
self.upload_active = False
self.upload_interval = 2
self.last_upload_time = 0
self.res_api = config['task']['res_api']
# MQTT配置
self.mqtt_config = config.get('mqtt', {})
self.mqtt_enabled = self.mqtt_config.get('enable', False)
self.mqtt_topic = self.mqtt_config.get('topic', 'drone/data')
self.mqtt_client = None
self.mqtt_connected = False
self.latest_drone_data = None
self.mqtt_data_lock = threading.Lock()
def on_mqtt_connect(self, client, userdata, flags, rc):
if rc == 0:
client.subscribe(self.mqtt_topic)
self.mqtt_connected = True
# logger.info("MQTT连接成功")
# logger.info(f"已订阅主题: {self.mqtt_topic}")
else:
logger.error(f"MQTT连接失败错误码: {rc}")
self.mqtt_connected = False
def on_mqtt_message(self, client, userdata, msg):
try:
drone_data = json.loads(msg.payload.decode())
with self.mqtt_data_lock:
self.latest_drone_data = drone_data
logger.debug(f"收到MQTT消息: {drone_data}")
except Exception as e:
logger.error(f"解析MQTT消息失败: {str(e)}")
def start_mqtt_client(self):
if not self.mqtt_enabled:
logger.info("MQTT功能未启用")
return False
try:
logger.info("启动MQTT客户端...")
self.mqtt_client = mqtt.Client(client_id=self.mqtt_config.get('client_id', 'yolo_detection'))
self.mqtt_client.on_connect = self.on_mqtt_connect
self.mqtt_client.on_message = self.on_mqtt_message
if 'username' in self.mqtt_config and 'password' in self.mqtt_config:
self.mqtt_client.username_pw_set(
self.mqtt_config['username'],
self.mqtt_config['password']
)
self.mqtt_client.connect(
self.mqtt_config['broker'],
self.mqtt_config.get('port', 1883),
self.mqtt_config.get('keepalive', 60)
)
self.mqtt_client.loop_start()
logger.info("MQTT客户端已启动")
return True
except Exception as e:
logger.error(f"启动MQTT客户端失败: {str(e)}")
return False
def stop_mqtt_client(self):
if self.mqtt_client:
try:
self.mqtt_client.loop_stop()
self.mqtt_client.disconnect()
logger.info("MQTT客户端已停止")
except Exception as e:
logger.error(f"停止MQTT客户端失败: {str(e)}")
finally:
self.mqtt_client = None
self.mqtt_connected = False
def run(self):
global detection_active
try:
logger.info("启动优化检测线程")
self.initialize_resources()
logger.info("资源初始化完成")
if self.mqtt_enabled:
self.start_mqtt_client()
self.upload_active = True
self.upload_thread = threading.Thread(target=self._upload_worker, daemon=True)
self.upload_thread.start()
logger.info("图片上传线程已启动")
# 预热模型
logger.info("预热模型...")
dummy_input = torch.zeros(1, 3, self.imgsz, self.imgsz)
dummy_input = dummy_input.to(self.config['predict']['device'])
if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']:
dummy_input = dummy_input.half()
for _ in range(5):
self.model.predict(dummy_input)
logger.info("模型预热完成")
while self.running and not self.stop_event.is_set():
start_time = time.perf_counter()
# 优化帧读取
ret, frame = self.cap.read()
logger.debug(f'读取帧结果: {ret}')
# 处理重连
if not ret:
self.handle_reconnect()
continue
# 确保帧分辨率一致
if frame.shape[1] != self.original_width or frame.shape[0] != self.original_height:
logger.warning(
f"帧分辨率不匹配: 预期 {self.original_width}x{self.original_height}, 实际 {frame.shape[1]}x{frame.shape[0]}")
frame = cv2.resize(frame, (self.original_width, self.original_height))
# 计算FPS
current_time = time.time()
time_diff = current_time - self.last_frame_time
if time_diff > 0:
self.fps = 0.9 * self.fps + 0.1 / time_diff
if current_time - self.last_status_update > 0.5:
self.last_fps = self.fps
self.last_status_update = current_time
self.last_frame_time = current_time
# 最低延迟处理逻辑
processing_time = time.perf_counter() - start_time
if processing_time > self.target_latency:
self.frame_skip_counter += 1
self.frame_count += 1
continue
# 性能优化:跳过帧处理
if self.fps < 15 and self.frame_skip < 2:
self.frame_skip += 1
self.frame_count += 1
continue
self.frame_skip = 0
# 使用帧差分算法减少处理量
if self.prev_frame is not None:
# 确保前一帧与当前帧分辨率一致
if self.prev_frame.shape != frame.shape:
logger.warning(f"前一帧分辨率不匹配: 预期 {frame.shape}, 实际 {self.prev_frame.shape}")
self.prev_frame = cv2.resize(self.prev_frame, (frame.shape[1], frame.shape[0]))
# 使用优化的推理
results = self.model(
frame,
stream=False,
verbose=False,
conf=self.config['predict']['conf_thres'],
iou=self.config['predict']['iou_thres'],
imgsz=self.imgsz,
device=self.config['predict']['device'],
half=self.config['predict'].get('half', False)
)
self.prev_results = results
else:
results = self.model(
frame,
stream=False,
verbose=False,
conf=self.config['predict']['conf_thres'],
iou=self.config['predict']['iou_thres'],
imgsz=self.imgsz,
device=self.config['predict']['device'],
half=self.config['predict'].get('half', False)
)
self.prev_frame = frame.copy() # 使用副本避免引用问题
# 处理结果
annotated_frame, detection_data = self.process_results(frame, results)
self.detections_count = len(detection_data)
# 仅在检测到目标且满足时间间隔时添加上传任务
if len(detection_data) > 0 and (current_time - self.last_upload_time >= self.upload_interval):
try:
timestamp = int(current_time * 1000)
filename = f"DJI_{timestamp}.jpg"
self.upload_queue.put({
"image": annotated_frame.copy(),
"filename": filename,
"detection_data": detection_data,
"timestamp": current_time
})
self.last_upload_time = current_time
except Exception as e:
logger.error(f"添加上传任务失败: {e}")
# 推流处理
if self.config['push']['enable_push'] and self.streamer:
try:
self.streamer.add_frame(annotated_frame)
except queue.Full:
logger.warning("推流队列已满,丢弃帧")
# WebSocket发送优化
if current_time - self.last_log_time >= 1:
self.send_to_websocket(detection_data)
self.last_log_time = current_time
self.frame_count += 1
self.reconnect_attempts = 0
# 性能监控
elapsed = time.perf_counter() - start_time
self.processing_times.append(elapsed)
self.avg_process_time = self.avg_process_time * 0.9 + elapsed * 0.1
# 每50帧输出一次性能报告
if self.frame_count % 50 == 0:
avg_time = sum(self.processing_times) / len(self.processing_times)
logger.info(f"帧处理耗时: {avg_time * 1000:.2f}ms | 平均FPS: {1.0 / avg_time:.1f}")
self.processing_times = []
# 动态调整跳帧参数
if avg_time > 0.15:
self.config['predict']['frame_skip'] = min(4, self.config['predict']['frame_skip'] + 1)
logger.info(f"增加跳帧至 {self.config['predict']['frame_skip']}")
elif avg_time < 0.05 and self.config['predict']['frame_skip'] > 0:
self.config['predict']['frame_skip'] = max(0, self.config['predict']['frame_skip'] - 1)
logger.info(f"减少跳帧至 {self.config['predict']['frame_skip']}")
logger.info("检测线程主循环结束")
except Exception as e:
logger.error(f"检测线程异常: {str(e)}")
logger.error(traceback.format_exc())
finally:
self.cleanup()
logger.info("检测线程已安全停止")
gd.set_value('detection_active', False)
def _upload_worker(self):
"""独立的图片上传工作线程"""
logger.info("上传工作线程启动")
output_dir = "output_frames"
os.makedirs(output_dir, exist_ok=True)
while self.upload_active or not self.upload_queue.empty():
try:
task = self.upload_queue.get(timeout=1.0)
if task is None:
break
start_time = time.time()
image = task["image"]
filename = task["filename"]
detection_data = task["detection_data"]
filepath = os.path.join(output_dir, filename)
cv2.imwrite(filepath, image)
try:
foldername = self.config['task']['taskname']
object_path = f"{foldername}/{filename}"
self.minio_uploader.upload_file(filepath, object_path)
payload = {
"taskid": self.taskid,
"path": object_path,
"tag": detection_data,
"aiid": self.aiid,
}
# 添加MQTT数据
if self.mqtt_enabled and self.mqtt_connected:
with self.mqtt_data_lock:
if self.latest_drone_data:
payload["drone_info"] = self.latest_drone_data
logger.info("添加无人机数据到API调用")
headers = {"Content-Type": "application/json"}
response = requests.post(self.res_api, json=payload, headers=headers)
if response.status_code == 200:
logger.info(f"已上传帧至 MinIO: {object_path} | 耗时: {time.time() - start_time:.2f}s")
else:
logger.warning(f"API调用失败: {response.status_code} - {response.text}")
except Exception as e:
logger.error(f"上传/API调用失败: {e}")
finally:
try:
os.remove(filepath)
except:
pass
except queue.Empty:
continue
except Exception as e:
logger.error(f"上传任务处理异常: {e}")
logger.info("上传工作线程已停止")
def process_results(self, frame, results):
result = results[0]
annotated_frame = []
# for result in results:
annotated_frame = result.plot(
img=frame,
line_width=self.config['predict']['line_width'],
# font_size=self.config['predict']['font_size'],
font= self.config['predict']['font'],
conf=True,
labels=True,
probs=True
)
if annotated_frame.shape != frame.shape:
logger.warning(f"渲染帧分辨率不匹配: 输入 {frame.shape}, 输出 {annotated_frame.shape}")
annotated_frame = cv2.resize(annotated_frame, (frame.shape[1], frame.shape[0]))
# break
detection_data = []
# for result in results:
# print("result",result)
boxes = result.boxes
for box in boxes:
class_id = int(box.cls)
confidence = float(box.conf)
class_name_obj = self.tag.get(str(class_id), None)
if class_name_obj is None:
class_name_obj = cmc.get(str(class_id), None)
if float(box.conf) < class_name_obj['reliability']:
continue
detection_data.append({
'class_id': class_id,
'class_name': class_name_obj['name'],
'confidence': confidence,
'box': box.xyxy[0].tolist(),
'reliability': class_name_obj['reliability']
})
return annotated_frame, detection_data
def initialize_resources(self):
"""初始化所有资源"""
# 加载模型
model_path = self.config['model']['path']
download_url = self.config['task']['api'] + model_path
_model_path = r"models/" + model_path
if not Path(os.path.join('models', model_path)).exists():
logger.info(r"模型不存在,开始下载模型")
try:
model_dir = model_path.split("\\")[0]
if ".pt" in model_dir:
model_dir = ""
model_save_path = r"models/" + model_dir
model_dir = os.path.dirname(model_save_path)
if model_dir != '' and not os.path.exists(model_save_path):
os.makedirs(model_save_path)
response = requests.get(download_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 * 1024
downloaded = 0
progress = 0
with open(os.path.join('models', model_path), 'wb') as f:
for data in response.iter_content(block_size):
downloaded += len(data)
f.write(data)
new_progress = int(downloaded * 100 / total_size) if total_size > 0 else 0
if new_progress != progress and (new_progress % 10 == 0 or downloaded == total_size):
logger.info(f"下载进度: {new_progress}% ({downloaded}/{total_size} 字节)")
progress = new_progress
logger.info(f"模型下载成功: {model_save_path} ({downloaded} 字节)")
except Exception as e:
logger.error(f"下载模型失败: {str(e)}")
if not Path(model_path).exists():
logger.error("没有可用的模型文件,使用默认模型")
_model_path = r"models/yolov8n.pt"
logger.info(f"加载模型...{_model_path}")
self.model = YOLO(_model_path).to(self.config['predict']['device'])
if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']:
self.model = self.model.half()
logger.info("启用半精度推理")
logger.info("模型预热...")
dummy_input = torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.config['predict']['device'])
if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']:
dummy_input = dummy_input.half()
for _ in range(5):
self.model.predict(dummy_input)
logger.info(f"模型加载成功 | 设备: {self.config['predict']['device'].upper()}")
# 初始化RTMP连接
self.original_width, self.original_height, fps = self.initialize_rtmp()
self.fps = fps # 保存fps用于后续创建推流线程
# 初始化推流线程
self.start_streamer()
def initialize_rtmp(self):
"""初始化RTMP连接 - 使用更可靠的方法"""
logger.info(f"连接RTMP: {self.rtmp_url}")
self.cap = cv2.VideoCapture()
# 设置优化参数
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, self.buffer_size)
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
self.cap.set(cv2.CAP_PROP_FPS, 30)
# 启用硬件加速解码
if self.config['rtmp'].get('gpu_decode', False):
try:
self.cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
logger.info("启用硬件加速解码")
except:
logger.warning("硬件解码不可用,使用软件解码")
# 尝试连接
if not self.cap.open(self.rtmp_url, cv2.CAP_FFMPEG):
logger.error(f"连接RTMP失败: {self.rtmp_url}")
raise IOError(f"无法连接RTMP流 ({self.rtmp_url})")
# 获取视频属性
frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
logger.info(f"视频属性: {frame_width}x{frame_height} @ {fps}fps")
return frame_width, frame_height, fps
def handle_reconnect(self):
"""优化重连逻辑 - 更高效的恢复策略"""
self.reconnect_attempts += 1
if self.reconnect_attempts >= self.max_reconnect_attempts:
logger.error("达到最大重连次数")
self.running = False
return
delay = min(10, self.reconnect_attempts * self.reconnect_delay)
logger.warning(f"流中断,{delay}秒后重连 ({self.reconnect_attempts}/{self.max_reconnect_attempts})")
# 清理推流数据并停止推流线程
self.stop_streamer()
# 释放视频资源
if self.cap:
self.cap.release()
self.cap = None
time.sleep(delay)
# 重新连接
try:
width, height, fps = self.initialize_rtmp()
self.original_width = width
self.original_height = height
self.prev_frame = None
self.prev_results = None
# 重新创建推流线程
self.start_streamer()
logger.info("重连成功")
except Exception as e:
logger.error(f"重连异常: {str(e)}")
def send_to_websocket(self, detection_data):
"""高效发送WebSocket数据"""
try:
now = datetime.datetime.now()
time_str = now.strftime("%H:%M:%S")
simplified_data = [{
'class_id': d['class_id'],
'class_name': d['class_name'],
'confidence': round(d['confidence'], 2),
'box': [round(float(c), 1) for c in d['box']],
} for d in detection_data]
self.config['socketIO'].emit('detection_results', {
'detections': simplified_data,
'timestamp': time.time_ns() // 1000000,
'fps': round(self.last_fps, 1),
'frame_count': self.frame_count,
'taskid': self.taskid,
'time_str': time_str
})
except Exception as e:
logger.error(f"WebSocket发送错误: {str(e)}")
def start_streamer(self):
"""启动新的推流线程"""
if self.config['push']['enable_push']:
try:
self.streamer = FFmpegStreamer(self.config, self.fps, self.original_width, self.original_height)
self.streamer.start()
logger.info("推流线程已启动")
except Exception as e:
logger.error(f"启动推流线程失败: {str(e)}")
def stop_streamer(self):
"""停止推流线程并清理资源"""
if self.streamer:
try:
self.streamer.stop()
self.streamer = None
logger.info("推流线程已停止")
except Exception as e:
logger.error(f"停止推流线程失败: {str(e)}")
def stop(self):
"""停止检测线程"""
self.stop_event.set()
self.running = False
def cleanup(self):
"""资源清理 - 保证资源释放"""
logger.info("清理资源...")
# 停止MQTT客户端
if self.mqtt_enabled:
self.stop_mqtt_client()
# 停止上传线程
self.upload_active = False
if self.upload_thread and self.upload_thread.is_alive():
try:
self.upload_queue.put(None, timeout=0.5)
except:
pass
self.upload_thread.join(2.0)
if self.upload_thread.is_alive():
logger.warning("上传线程未能正常停止")
logger.info("上传资源清理完成")
# 释放流
if self.cap:
try:
self.cap.release()
except:
pass
self.cap = None
logger.info("视频流已释放")
# 停止推流
if self.streamer:
self.streamer.stop()
self.streamer = None
logger.info("推流已停止")
# 释放模型
if self.model:
try:
if hasattr(self.model, 'predictor'):
del self.model.predictor
if hasattr(self.model, 'model'):
del self.model.model
del self.model
except:
pass
self.model = None
logger.info("模型已释放")
# 清空缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("资源清理完成")