2025-12-11 13:41:07 +08:00
|
|
|
|
import cv2
|
2025-12-16 10:08:12 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import concurrent.futures
|
|
|
|
|
|
from typing import List, Dict, Tuple, Any, Optional, Union
|
|
|
|
|
|
from dataclasses import dataclass
|
2025-12-11 13:41:07 +08:00
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
from PIL import Image, ImageDraw, ImageFont
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 尝试使用PIL绘制中文
|
|
|
|
|
|
font = None
|
|
|
|
|
|
FONT_PATHS = [
|
|
|
|
|
|
"simhei.ttf", # 黑体
|
|
|
|
|
|
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # 文泉驿微米黑
|
|
|
|
|
|
"C:/Windows/Fonts/simhei.ttf", # Windows黑体
|
|
|
|
|
|
"/System/Library/Fonts/PingFang.ttc" # macOS苹方
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for path in FONT_PATHS:
|
|
|
|
|
|
try:
|
|
|
|
|
|
font = ImageFont.truetype(path, 20)
|
|
|
|
|
|
break
|
|
|
|
|
|
except:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class ModelConfig:
|
|
|
|
|
|
"""模型配置数据类"""
|
|
|
|
|
|
model: Any
|
|
|
|
|
|
config: Any
|
|
|
|
|
|
tags: Optional[Dict[str, Dict]] = None
|
|
|
|
|
|
name: str = "未命名"
|
|
|
|
|
|
id: Any = None
|
|
|
|
|
|
device: str = "cuda:0"
|
|
|
|
|
|
conf_thres: float = 0.25
|
|
|
|
|
|
iou_thres: float = 0.45
|
|
|
|
|
|
half: bool = False
|
|
|
|
|
|
key_valid: Any = None
|
|
|
|
|
|
model_hash: Any = None
|
|
|
|
|
|
imgsz: Any = 640
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
|
"""初始化后处理"""
|
|
|
|
|
|
if self.tags is None:
|
|
|
|
|
|
self.tags = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class Detection:
|
|
|
|
|
|
"""检测结果数据类"""
|
|
|
|
|
|
model_idx: int
|
|
|
|
|
|
model_name: str
|
|
|
|
|
|
boxes: np.ndarray
|
|
|
|
|
|
confidences: np.ndarray
|
|
|
|
|
|
class_ids: np.ndarray
|
|
|
|
|
|
class_names: List[str]
|
|
|
|
|
|
tags: Dict[str, Dict]
|
|
|
|
|
|
raw_result: Any = None # 新增:保存原始结果对象
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectionVisualizer:
|
|
|
|
|
|
"""检测结果可视化器"""
|
|
|
|
|
|
|
|
|
|
|
|
# 预定义模型颜色
|
|
|
|
|
|
MODEL_COLORS = [
|
|
|
|
|
|
(0, 255, 0), # 绿色
|
|
|
|
|
|
(255, 0, 0), # 蓝色
|
|
|
|
|
|
(0, 0, 255), # 红色
|
|
|
|
|
|
(255, 255, 0), # 青色
|
|
|
|
|
|
(255, 0, 255), # 紫色
|
|
|
|
|
|
(0, 255, 255), # 黄色
|
|
|
|
|
|
(128, 0, 128), # 深紫色
|
|
|
|
|
|
(0, 128, 128), # 橄榄色
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, use_pil: bool = True):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化可视化器
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
use_pil: 是否使用PIL绘制(支持中文)
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.use_pil = use_pil
|
|
|
|
|
|
self.font = font
|
|
|
|
|
|
|
|
|
|
|
|
def should_draw_detection(self, class_id: int, confidence: float,
|
|
|
|
|
|
tags: Dict[str, Dict]) -> Tuple[bool, Optional[Tuple[int, int, int]]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
判断是否应该绘制检测框
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
class_id: 类别ID
|
|
|
|
|
|
confidence: 置信度
|
|
|
|
|
|
tags: 标签配置字典
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(是否绘制, 颜色)
|
|
|
|
|
|
"""
|
|
|
|
|
|
class_id_str = str(class_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果标签配置为空,默认不绘制任何标签(因为需求是只绘制配置的标签)
|
|
|
|
|
|
if not tags:
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
|
|
|
|
|
|
# 如果标签不在配置中,不绘制
|
|
|
|
|
|
tag_config = tags.get(class_id_str)
|
|
|
|
|
|
if not tag_config:
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
|
|
|
|
|
|
# 检查select标记
|
|
|
|
|
|
if not tag_config.get('select', True):
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
|
|
|
|
|
|
# 检查置信度阈值
|
|
|
|
|
|
reliability = tag_config.get('reliability', 0)
|
|
|
|
|
|
if confidence < reliability:
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
|
|
|
|
|
|
# 获取自定义颜色
|
|
|
|
|
|
color = tag_config.get('color')
|
|
|
|
|
|
if color and isinstance(color, (list, tuple)) and len(color) >= 3:
|
|
|
|
|
|
return True, tuple(color[:3])
|
|
|
|
|
|
|
|
|
|
|
|
return True, None
|
|
|
|
|
|
|
|
|
|
|
|
def draw_with_pil(self, frame: np.ndarray, detections: List[Detection],
|
|
|
|
|
|
confidence_threshold: float) -> np.ndarray:
|
|
|
|
|
|
"""使用PIL绘制检测结果(支持中文)"""
|
|
|
|
|
|
# 转换到PIL格式
|
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
pil_image = Image.fromarray(frame_rgb)
|
|
|
|
|
|
draw = ImageDraw.Draw(pil_image)
|
|
|
|
|
|
|
|
|
|
|
|
for det in detections:
|
|
|
|
|
|
model_idx = det.model_idx
|
|
|
|
|
|
boxes = det.boxes
|
|
|
|
|
|
confidences = det.confidences
|
|
|
|
|
|
class_names = det.class_names
|
|
|
|
|
|
class_ids = det.class_ids
|
|
|
|
|
|
tags = det.tags
|
|
|
|
|
|
|
|
|
|
|
|
for i, (box, conf, cls_id, cls_name) in enumerate(
|
|
|
|
|
|
zip(boxes, confidences, class_ids, class_names)):
|
|
|
|
|
|
|
|
|
|
|
|
if conf < confidence_threshold:
|
2025-12-11 13:41:07 +08:00
|
|
|
|
continue
|
|
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
should_draw, custom_color = self.should_draw_detection(
|
|
|
|
|
|
cls_id, conf, tags)
|
|
|
|
|
|
if not should_draw:
|
|
|
|
|
|
continue
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 使用自定义颜色或模型颜色
|
|
|
|
|
|
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
|
|
|
|
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = map(int, box[:4])
|
|
|
|
|
|
|
|
|
|
|
|
# 标签文本
|
|
|
|
|
|
label = f"{cls_name} {conf:.2f}"
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制矩形框
|
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制标签背景
|
|
|
|
|
|
if self.font:
|
|
|
|
|
|
try:
|
|
|
|
|
|
text_bbox = draw.textbbox((x1, y1 - 25), label, font=self.font)
|
|
|
|
|
|
draw.rectangle(text_bbox, fill=color)
|
|
|
|
|
|
draw.text((x1, y1 - 25), label, fill=(255, 255, 255), font=self.font)
|
|
|
|
|
|
except:
|
|
|
|
|
|
# 字体失败,回退到OpenCV
|
|
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
|
|
|
|
cv2.putText(frame, label, (x1, y1 - 10),
|
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
|
|
|
|
|
else:
|
|
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
|
|
|
|
cv2.putText(frame, label, (x1, y1 - 10),
|
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
|
|
|
|
|
|
|
|
|
|
|
# 转换回OpenCV格式
|
|
|
|
|
|
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
|
|
|
|
|
def draw_with_opencv(self, frame: np.ndarray, detections: List[Detection],
|
|
|
|
|
|
confidence_threshold: float) -> np.ndarray:
|
|
|
|
|
|
"""使用OpenCV绘制检测结果"""
|
|
|
|
|
|
frame_drawn = frame.copy()
|
|
|
|
|
|
|
|
|
|
|
|
for det in detections:
|
|
|
|
|
|
model_idx = det.model_idx
|
|
|
|
|
|
model_name = det.model_name
|
|
|
|
|
|
boxes = det.boxes
|
|
|
|
|
|
confidences = det.confidences
|
|
|
|
|
|
class_names = det.class_names
|
|
|
|
|
|
class_ids = det.class_ids
|
|
|
|
|
|
tags = det.tags
|
|
|
|
|
|
|
|
|
|
|
|
for i, (box, conf, cls_id, cls_name) in enumerate(
|
|
|
|
|
|
zip(boxes, confidences, class_ids, class_names)):
|
|
|
|
|
|
|
|
|
|
|
|
if conf < confidence_threshold:
|
|
|
|
|
|
continue
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
should_draw, custom_color = self.should_draw_detection(
|
|
|
|
|
|
cls_id, conf, tags)
|
|
|
|
|
|
if not should_draw:
|
|
|
|
|
|
continue
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 使用自定义颜色或模型颜色
|
|
|
|
|
|
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
x1, y1, x2, y2 = map(int, box[:4])
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 绘制矩形框
|
|
|
|
|
|
cv2.rectangle(frame_drawn, (x1, y1), (x2, y2), color, 2)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 绘制标签
|
|
|
|
|
|
label = f"{model_name}: {cls_name} {conf:.2f}"
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 计算文本大小
|
|
|
|
|
|
(text_width, text_height), baseline = cv2.getTextSize(
|
|
|
|
|
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2
|
|
|
|
|
|
)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 绘制标签背景
|
|
|
|
|
|
cv2.rectangle(frame_drawn,
|
|
|
|
|
|
(x1, y1 - text_height - 10),
|
|
|
|
|
|
(x1 + text_width, y1),
|
|
|
|
|
|
color, -1)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 绘制文本
|
|
|
|
|
|
cv2.putText(frame_drawn, label, (x1, y1 - 5),
|
|
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
return frame_drawn
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
def draw(self, frame: np.ndarray, detections: List[Detection],
|
|
|
|
|
|
confidence_threshold: float) -> np.ndarray:
|
|
|
|
|
|
"""绘制检测结果"""
|
|
|
|
|
|
if self.use_pil:
|
|
|
|
|
|
try:
|
|
|
|
|
|
return self.draw_with_pil(frame, detections, confidence_threshold)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"PIL绘制失败,使用OpenCV: {e}")
|
|
|
|
|
|
return self.draw_with_opencv(frame, detections, confidence_threshold)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self.draw_with_opencv(frame, detections, confidence_threshold)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
class YOLOModelWrapper:
|
|
|
|
|
|
"""YOLO模型包装器"""
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def infer_yolov8(model: Any, frame: np.ndarray, conf_thres: float,
|
|
|
|
|
|
iou_thres: float, imgsz: int, half: bool, device: str):
|
|
|
|
|
|
"""YOLOv8推理"""
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
results = model(frame, conf=conf_thres, verbose=False,
|
|
|
|
|
|
iou=iou_thres, imgsz=imgsz, half=half, device=device)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
if not results or len(results) == 0:
|
|
|
|
|
|
return np.array([]), np.array([]), np.array([]), [], None
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
result = results[0]
|
|
|
|
|
|
if not hasattr(result, 'boxes'):
|
|
|
|
|
|
return np.array([]), np.array([]), np.array([]), [], None
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
boxes = result.boxes.xyxy.cpu().numpy()
|
|
|
|
|
|
confidences = result.boxes.conf.cpu().numpy()
|
|
|
|
|
|
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
|
|
|
|
|
class_names = [result.names.get(cid, str(cid)) for cid in class_ids]
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
return boxes, confidences, class_ids, class_names, result
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def infer_yolov5(model: Any, frame: np.ndarray):
|
|
|
|
|
|
"""YOLOv5推理"""
|
|
|
|
|
|
results = model(frame)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
if not hasattr(results, 'xyxy'):
|
|
|
|
|
|
return np.array([]), np.array([]), np.array([]), [], None
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
detections = results.xyxy[0].cpu().numpy()
|
|
|
|
|
|
if len(detections) == 0:
|
|
|
|
|
|
return np.array([]), np.array([]), np.array([]), [], None
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
boxes = detections[:, :4]
|
|
|
|
|
|
confidences = detections[:, 4]
|
|
|
|
|
|
class_ids = detections[:, 5].astype(int)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
if hasattr(results, 'names'):
|
|
|
|
|
|
class_names = [results.names.get(cid, str(cid)) for cid in class_ids]
|
|
|
|
|
|
else:
|
|
|
|
|
|
class_names = [str(cid) for cid in class_ids]
|
|
|
|
|
|
|
|
|
|
|
|
return boxes, confidences, class_ids, class_names, results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_model_inference(
|
|
|
|
|
|
models_config: List[Union[Dict, ModelConfig]],
|
|
|
|
|
|
frame: np.ndarray,
|
|
|
|
|
|
confidence_threshold: float = 0.25,
|
|
|
|
|
|
parallel: bool = True,
|
|
|
|
|
|
use_pil: bool = True,
|
|
|
|
|
|
use_plot_for_single: bool = True # 新增参数:是否对单个模型使用plot绘制
|
|
|
|
|
|
) -> Tuple[np.ndarray, List[Detection]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
多模型并行推理
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
models_config: 模型配置列表
|
|
|
|
|
|
frame: 视频帧 (BGR格式)
|
|
|
|
|
|
confidence_threshold: 全局置信度阈值
|
|
|
|
|
|
parallel: 是否并行推理
|
|
|
|
|
|
use_pil: 是否使用PIL绘制
|
|
|
|
|
|
use_plot_for_single: 当只有一个模型时,是否使用result.plot()绘制
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(绘制完成的帧, 检测结果列表)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 转换为ModelConfig对象
|
|
|
|
|
|
configs = []
|
|
|
|
|
|
for cfg in models_config:
|
|
|
|
|
|
if isinstance(cfg, dict):
|
|
|
|
|
|
configs.append(ModelConfig(**cfg))
|
|
|
|
|
|
else:
|
|
|
|
|
|
configs.append(cfg)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
def single_model_inference(model_cfg: ModelConfig, model_idx: int) -> Detection:
|
|
|
|
|
|
"""单个模型的推理函数"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
model = model_cfg.model
|
|
|
|
|
|
|
|
|
|
|
|
# 根据模型类型进行推理
|
|
|
|
|
|
if hasattr(model, 'predict'): # YOLOv8
|
|
|
|
|
|
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov8(
|
|
|
|
|
|
model, frame, model_cfg.conf_thres, model_cfg.iou_thres,
|
|
|
|
|
|
model_cfg.imgsz, model_cfg.half, model_cfg.device
|
|
|
|
|
|
)
|
|
|
|
|
|
elif hasattr(model, '__call__'): # YOLOv5
|
|
|
|
|
|
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov5(model, frame)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的模型类型: {type(model)}")
|
|
|
|
|
|
|
|
|
|
|
|
return Detection(
|
|
|
|
|
|
model_idx=model_idx,
|
|
|
|
|
|
model_name=model_cfg.name,
|
|
|
|
|
|
boxes=boxes,
|
|
|
|
|
|
confidences=confidences,
|
|
|
|
|
|
class_ids=class_ids,
|
|
|
|
|
|
class_names=class_names,
|
|
|
|
|
|
tags=model_cfg.tags,
|
|
|
|
|
|
raw_result=raw_result
|
|
|
|
|
|
)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"模型 {model_idx} ({model_cfg.name}) 推理失败: {e}")
|
|
|
|
|
|
return Detection(
|
|
|
|
|
|
model_idx=model_idx,
|
|
|
|
|
|
model_name=model_cfg.name,
|
|
|
|
|
|
boxes=np.array([]),
|
|
|
|
|
|
confidences=np.array([]),
|
|
|
|
|
|
class_ids=np.array([]),
|
|
|
|
|
|
class_names=[],
|
|
|
|
|
|
tags=model_cfg.tags,
|
|
|
|
|
|
raw_result=None
|
|
|
|
|
|
)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 并行推理
|
|
|
|
|
|
if parallel and len(configs) > 1:
|
|
|
|
|
|
detections = []
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(configs)) as executor:
|
|
|
|
|
|
# 提交任务
|
|
|
|
|
|
future_to_idx = {
|
|
|
|
|
|
executor.submit(single_model_inference, cfg, idx): idx
|
|
|
|
|
|
for idx, cfg in enumerate(configs)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 收集结果
|
|
|
|
|
|
for future in concurrent.futures.as_completed(future_to_idx):
|
|
|
|
|
|
try:
|
|
|
|
|
|
detection = future.result(timeout=2.0)
|
|
|
|
|
|
detections.append(detection)
|
|
|
|
|
|
except concurrent.futures.TimeoutError:
|
|
|
|
|
|
idx = future_to_idx[future]
|
|
|
|
|
|
print(f"模型 {idx} 推理超时")
|
|
|
|
|
|
# 创建空的检测结果
|
|
|
|
|
|
detections.append(Detection(
|
|
|
|
|
|
model_idx=idx,
|
|
|
|
|
|
model_name=configs[idx].name,
|
|
|
|
|
|
boxes=np.array([]),
|
|
|
|
|
|
confidences=np.array([]),
|
|
|
|
|
|
class_ids=np.array([]),
|
|
|
|
|
|
class_names=[],
|
|
|
|
|
|
tags=configs[idx].tags,
|
|
|
|
|
|
raw_result=None
|
|
|
|
|
|
))
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"模型推理异常: {e}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 顺序推理
|
|
|
|
|
|
detections = [single_model_inference(cfg, idx)
|
|
|
|
|
|
for idx, cfg in enumerate(configs)]
|
|
|
|
|
|
|
|
|
|
|
|
# 按照模型索引排序
|
|
|
|
|
|
detections.sort(key=lambda x: x.model_idx)
|
|
|
|
|
|
|
|
|
|
|
|
# 新增:如果只有一个模型且配置使用plot绘制,则使用result.plot()
|
|
|
|
|
|
if len(configs) == 1 and use_plot_for_single:
|
|
|
|
|
|
single_detection = detections[0]
|
|
|
|
|
|
if single_detection.raw_result is not None:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 检查是否有plot方法
|
|
|
|
|
|
if hasattr(single_detection.raw_result, 'plot'):
|
|
|
|
|
|
# 使用plot方法绘制结果
|
|
|
|
|
|
frame_drawn = single_detection.raw_result.plot()
|
|
|
|
|
|
# 确保返回的是numpy数组
|
|
|
|
|
|
if not isinstance(frame_drawn, np.ndarray):
|
|
|
|
|
|
frame_drawn = np.array(frame_drawn)
|
|
|
|
|
|
# plot方法通常返回RGB图像,转换为BGR
|
|
|
|
|
|
if len(frame_drawn.shape) == 3 and frame_drawn.shape[2] == 3:
|
|
|
|
|
|
frame_drawn = cv2.cvtColor(frame_drawn, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
print(f"使用 {single_detection.model_name} 的 plot() 方法绘制结果")
|
|
|
|
|
|
return frame_drawn, detections
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"模型 {single_detection.model_name} 的结果对象没有 plot() 方法,使用自定义绘制")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"使用 plot() 方法绘制失败: {e},回退到自定义绘制")
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制结果
|
|
|
|
|
|
visualizer = DetectionVisualizer(use_pil=use_pil)
|
|
|
|
|
|
frame_drawn = visualizer.draw(frame.copy(), detections, confidence_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
return frame_drawn, detections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 兼容旧接口
|
|
|
|
|
|
def multi_model_inference_legacy(_models: List[Dict], frame: np.ndarray,
|
|
|
|
|
|
confidence_threshold: float = 0.25,
|
|
|
|
|
|
parallel: bool = True,
|
|
|
|
|
|
use_plot_for_single: bool = True) -> Tuple[np.ndarray, List[Dict]]:
|
|
|
|
|
|
"""旧接口兼容函数"""
|
|
|
|
|
|
frame_drawn, detections = multi_model_inference(
|
|
|
|
|
|
_models, frame, confidence_threshold, parallel, use_plot_for_single=use_plot_for_single
|
2025-12-11 13:41:07 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 转换为旧格式
|
|
|
|
|
|
old_detections = []
|
|
|
|
|
|
for det in detections:
|
|
|
|
|
|
old_detections.append({
|
|
|
|
|
|
'model_idx': det.model_idx,
|
|
|
|
|
|
'model_name': det.model_name,
|
|
|
|
|
|
'boxes': det.boxes,
|
|
|
|
|
|
'confidences': det.confidences,
|
|
|
|
|
|
'class_ids': det.class_ids,
|
|
|
|
|
|
'class_names': det.class_names,
|
|
|
|
|
|
'tags': det.tags
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return frame_drawn, old_detections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
|
|
|
|
|
|
print('加载模型中')
|
|
|
|
|
|
model_paths = [
|
|
|
|
|
|
r"F:\PyModelScope\Yolov\models\yolov8m.pt",
|
|
|
|
|
|
r"F:\PyModelScope\Yolov\models\car.pt"
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
print("预热模型...")
|
|
|
|
|
|
model_list = []
|
|
|
|
|
|
for i, path in enumerate(model_paths):
|
|
|
|
|
|
model = YOLO(path)
|
|
|
|
|
|
model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建模型配置
|
|
|
|
|
|
model_config = ModelConfig(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
name=f'model_{i}',
|
|
|
|
|
|
device='cuda:0',
|
|
|
|
|
|
conf_thres=0.45,
|
|
|
|
|
|
iou_thres=0.45,
|
|
|
|
|
|
half=False,
|
|
|
|
|
|
imgsz=1920,
|
|
|
|
|
|
tags={
|
|
|
|
|
|
"0": {"name": "汽车", "reliability": 0.4, "select": True, "color": [0, 255, 0]},
|
|
|
|
|
|
"1": {"name": "行人", "reliability": 0.3, "select": True, "color": [255, 0, 0]},
|
|
|
|
|
|
"2": {"name": "自行车", "reliability": 0.5, "select": False, "color": [0, 0, 255]}
|
|
|
|
|
|
},
|
|
|
|
|
|
config=None
|
|
|
|
|
|
)
|
|
|
|
|
|
model_list.append(model_config)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
print("模型预热完成")
|
|
|
|
|
|
image_path = r"F:\PyModelScope\Yolov\images\444.png"
|
|
|
|
|
|
frame = cv2.imread(image_path)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
# 测试单个模型的情况 - 使用plot绘制
|
|
|
|
|
|
print("\n=== 测试单个模型 (使用plot绘制) ===")
|
|
|
|
|
|
single_model_list = [model_list[0]] # 只使用第一个模型
|
|
|
|
|
|
frame_drawn_single, detections_single = multi_model_inference(
|
|
|
|
|
|
single_model_list, frame, use_plot_for_single=True
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"单个模型检测结果: {len(detections_single[0].boxes)} 个目标")
|
|
|
|
|
|
cv2.imwrite("uploads/result_single_plot.jpg", frame_drawn_single)
|
|
|
|
|
|
print("结果已保存到 uploads/result_single_plot.jpg")
|
|
|
|
|
|
|
|
|
|
|
|
# 测试单个模型的情况 - 强制使用自定义绘制
|
|
|
|
|
|
print("\n=== 测试单个模型 (强制使用自定义绘制) ===")
|
|
|
|
|
|
frame_drawn_single_custom, detections_single_custom = multi_model_inference(
|
|
|
|
|
|
single_model_list, frame, use_plot_for_single=False
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"单个模型自定义绘制结果: {len(detections_single_custom[0].boxes)} 个目标")
|
|
|
|
|
|
cv2.imwrite("uploads/result_single_custom.jpg", frame_drawn_single_custom)
|
|
|
|
|
|
print("结果已保存到 uploads/result_single_custom.jpg")
|
|
|
|
|
|
|
|
|
|
|
|
# 测试多个模型的情况
|
|
|
|
|
|
print("\n=== 测试多个模型 (使用自定义绘制) ===")
|
|
|
|
|
|
frame_drawn_multi, detections_multi = multi_model_inference(
|
|
|
|
|
|
model_list, frame, use_plot_for_single=True # 即使设为True,多个模型也会使用自定义绘制
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"多个模型检测结果:")
|
|
|
|
|
|
for det in detections_multi:
|
|
|
|
|
|
print(f" 模型 {det.model_name}: 检测到 {len(det.boxes)} 个目标")
|
|
|
|
|
|
for box, conf, cls_id, cls_name in zip(det.boxes, det.confidences,
|
|
|
|
|
|
det.class_ids, det.class_names):
|
|
|
|
|
|
print(f"box:{box},conf:{conf},cls_id:{cls_id},cls_name:{cls_name}")
|
|
|
|
|
|
if conf >= 0.25: # 全局阈值
|
|
|
|
|
|
should_draw, color = DetectionVisualizer().should_draw_detection(
|
|
|
|
|
|
cls_id, conf, det.tags)
|
|
|
|
|
|
status = "绘制" if should_draw else "不绘制"
|
|
|
|
|
|
print(f" {cls_name} (置信度: {conf:.2f}): {status}")
|
|
|
|
|
|
|
|
|
|
|
|
cv2.imwrite("uploads/result_multi.jpg", frame_drawn_multi)
|
|
|
|
|
|
print("结果已保存到 uploads/result_multi.jpg")
|