Yolov/detection_render.py

539 lines
20 KiB
Python
Raw Normal View History

2025-12-11 13:41:07 +08:00
import cv2
2025-12-16 10:08:12 +08:00
import torch
import numpy as np
import concurrent.futures
from typing import List, Dict, Tuple, Any, Optional, Union
from dataclasses import dataclass
2025-12-11 13:41:07 +08:00
from collections import defaultdict
2025-12-16 10:08:12 +08:00
from PIL import Image, ImageDraw, ImageFont
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 尝试使用PIL绘制中文
font = None
FONT_PATHS = [
"simhei.ttf", # 黑体
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # 文泉驿微米黑
"C:/Windows/Fonts/simhei.ttf", # Windows黑体
"/System/Library/Fonts/PingFang.ttc" # macOS苹方
]
for path in FONT_PATHS:
try:
font = ImageFont.truetype(path, 20)
break
except:
continue
@dataclass
class ModelConfig:
"""模型配置数据类"""
model: Any
config: Any
tags: Optional[Dict[str, Dict]] = None
name: str = "未命名"
id: Any = None
device: str = "cuda:0"
conf_thres: float = 0.25
iou_thres: float = 0.45
half: bool = False
key_valid: Any = None
model_hash: Any = None
imgsz: Any = 640
def __post_init__(self):
"""初始化后处理"""
if self.tags is None:
self.tags = {}
@dataclass
class Detection:
"""检测结果数据类"""
model_idx: int
model_name: str
boxes: np.ndarray
confidences: np.ndarray
class_ids: np.ndarray
class_names: List[str]
tags: Dict[str, Dict]
raw_result: Any = None # 新增:保存原始结果对象
class DetectionVisualizer:
"""检测结果可视化器"""
# 预定义模型颜色
MODEL_COLORS = [
(0, 255, 0), # 绿色
(255, 0, 0), # 蓝色
(0, 0, 255), # 红色
(255, 255, 0), # 青色
(255, 0, 255), # 紫色
(0, 255, 255), # 黄色
(128, 0, 128), # 深紫色
(0, 128, 128), # 橄榄色
]
def __init__(self, use_pil: bool = True):
"""
初始化可视化器
Args:
use_pil: 是否使用PIL绘制支持中文
"""
self.use_pil = use_pil
self.font = font
def should_draw_detection(self, class_id: int, confidence: float,
tags: Dict[str, Dict]) -> Tuple[bool, Optional[Tuple[int, int, int]]]:
"""
判断是否应该绘制检测框
Args:
class_id: 类别ID
confidence: 置信度
tags: 标签配置字典
Returns:
(是否绘制, 颜色)
"""
class_id_str = str(class_id)
# 如果标签配置为空,默认不绘制任何标签(因为需求是只绘制配置的标签)
if not tags:
return False, None
# 如果标签不在配置中,不绘制
tag_config = tags.get(class_id_str)
if not tag_config:
return False, None
# 检查select标记
if not tag_config.get('select', True):
return False, None
# 检查置信度阈值
reliability = tag_config.get('reliability', 0)
if confidence < reliability:
return False, None
# 获取自定义颜色
color = tag_config.get('color')
if color and isinstance(color, (list, tuple)) and len(color) >= 3:
return True, tuple(color[:3])
return True, None
def draw_with_pil(self, frame: np.ndarray, detections: List[Detection],
confidence_threshold: float) -> np.ndarray:
"""使用PIL绘制检测结果支持中文"""
# 转换到PIL格式
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
draw = ImageDraw.Draw(pil_image)
for det in detections:
model_idx = det.model_idx
boxes = det.boxes
confidences = det.confidences
class_names = det.class_names
class_ids = det.class_ids
tags = det.tags
for i, (box, conf, cls_id, cls_name) in enumerate(
zip(boxes, confidences, class_ids, class_names)):
if conf < confidence_threshold:
2025-12-11 13:41:07 +08:00
continue
2025-12-16 10:08:12 +08:00
should_draw, custom_color = self.should_draw_detection(
cls_id, conf, tags)
if not should_draw:
continue
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 使用自定义颜色或模型颜色
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
x1, y1, x2, y2 = map(int, box[:4])
# 标签文本
label = f"{cls_name} {conf:.2f}"
# 绘制矩形框
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
# 绘制标签背景
if self.font:
try:
text_bbox = draw.textbbox((x1, y1 - 25), label, font=self.font)
draw.rectangle(text_bbox, fill=color)
draw.text((x1, y1 - 25), label, fill=(255, 255, 255), font=self.font)
except:
# 字体失败回退到OpenCV
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
else:
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# 转换回OpenCV格式
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
def draw_with_opencv(self, frame: np.ndarray, detections: List[Detection],
confidence_threshold: float) -> np.ndarray:
"""使用OpenCV绘制检测结果"""
frame_drawn = frame.copy()
for det in detections:
model_idx = det.model_idx
model_name = det.model_name
boxes = det.boxes
confidences = det.confidences
class_names = det.class_names
class_ids = det.class_ids
tags = det.tags
for i, (box, conf, cls_id, cls_name) in enumerate(
zip(boxes, confidences, class_ids, class_names)):
if conf < confidence_threshold:
continue
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
should_draw, custom_color = self.should_draw_detection(
cls_id, conf, tags)
if not should_draw:
continue
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 使用自定义颜色或模型颜色
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
x1, y1, x2, y2 = map(int, box[:4])
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 绘制矩形框
cv2.rectangle(frame_drawn, (x1, y1), (x2, y2), color, 2)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 绘制标签
label = f"{model_name}: {cls_name} {conf:.2f}"
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 计算文本大小
(text_width, text_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2
)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 绘制标签背景
cv2.rectangle(frame_drawn,
(x1, y1 - text_height - 10),
(x1 + text_width, y1),
color, -1)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 绘制文本
cv2.putText(frame_drawn, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
return frame_drawn
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
def draw(self, frame: np.ndarray, detections: List[Detection],
confidence_threshold: float) -> np.ndarray:
"""绘制检测结果"""
if self.use_pil:
try:
return self.draw_with_pil(frame, detections, confidence_threshold)
except Exception as e:
print(f"PIL绘制失败使用OpenCV: {e}")
return self.draw_with_opencv(frame, detections, confidence_threshold)
else:
return self.draw_with_opencv(frame, detections, confidence_threshold)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
class YOLOModelWrapper:
"""YOLO模型包装器"""
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
@staticmethod
def infer_yolov8(model: Any, frame: np.ndarray, conf_thres: float,
iou_thres: float, imgsz: int, half: bool, device: str):
"""YOLOv8推理"""
with torch.no_grad():
results = model(frame, conf=conf_thres, verbose=False,
iou=iou_thres, imgsz=imgsz, half=half, device=device)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
if not results or len(results) == 0:
return np.array([]), np.array([]), np.array([]), [], None
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
result = results[0]
if not hasattr(result, 'boxes'):
return np.array([]), np.array([]), np.array([]), [], None
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
boxes = result.boxes.xyxy.cpu().numpy()
confidences = result.boxes.conf.cpu().numpy()
class_ids = result.boxes.cls.cpu().numpy().astype(int)
class_names = [result.names.get(cid, str(cid)) for cid in class_ids]
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
return boxes, confidences, class_ids, class_names, result
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
@staticmethod
def infer_yolov5(model: Any, frame: np.ndarray):
"""YOLOv5推理"""
results = model(frame)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
if not hasattr(results, 'xyxy'):
return np.array([]), np.array([]), np.array([]), [], None
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
detections = results.xyxy[0].cpu().numpy()
if len(detections) == 0:
return np.array([]), np.array([]), np.array([]), [], None
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
boxes = detections[:, :4]
confidences = detections[:, 4]
class_ids = detections[:, 5].astype(int)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
if hasattr(results, 'names'):
class_names = [results.names.get(cid, str(cid)) for cid in class_ids]
else:
class_names = [str(cid) for cid in class_ids]
return boxes, confidences, class_ids, class_names, results
def multi_model_inference(
models_config: List[Union[Dict, ModelConfig]],
frame: np.ndarray,
confidence_threshold: float = 0.25,
parallel: bool = True,
use_pil: bool = True,
use_plot_for_single: bool = True # 新增参数是否对单个模型使用plot绘制
) -> Tuple[np.ndarray, List[Detection]]:
"""
多模型并行推理
Args:
models_config: 模型配置列表
frame: 视频帧 (BGR格式)
confidence_threshold: 全局置信度阈值
parallel: 是否并行推理
use_pil: 是否使用PIL绘制
use_plot_for_single: 当只有一个模型时是否使用result.plot()绘制
Returns:
(绘制完成的帧, 检测结果列表)
"""
# 转换为ModelConfig对象
configs = []
for cfg in models_config:
if isinstance(cfg, dict):
configs.append(ModelConfig(**cfg))
else:
configs.append(cfg)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
def single_model_inference(model_cfg: ModelConfig, model_idx: int) -> Detection:
"""单个模型的推理函数"""
try:
model = model_cfg.model
# 根据模型类型进行推理
if hasattr(model, 'predict'): # YOLOv8
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov8(
model, frame, model_cfg.conf_thres, model_cfg.iou_thres,
model_cfg.imgsz, model_cfg.half, model_cfg.device
)
elif hasattr(model, '__call__'): # YOLOv5
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov5(model, frame)
else:
raise ValueError(f"不支持的模型类型: {type(model)}")
return Detection(
model_idx=model_idx,
model_name=model_cfg.name,
boxes=boxes,
confidences=confidences,
class_ids=class_ids,
class_names=class_names,
tags=model_cfg.tags,
raw_result=raw_result
)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
except Exception as e:
print(f"模型 {model_idx} ({model_cfg.name}) 推理失败: {e}")
return Detection(
model_idx=model_idx,
model_name=model_cfg.name,
boxes=np.array([]),
confidences=np.array([]),
class_ids=np.array([]),
class_names=[],
tags=model_cfg.tags,
raw_result=None
)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 并行推理
if parallel and len(configs) > 1:
detections = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(configs)) as executor:
# 提交任务
future_to_idx = {
executor.submit(single_model_inference, cfg, idx): idx
for idx, cfg in enumerate(configs)
}
# 收集结果
for future in concurrent.futures.as_completed(future_to_idx):
try:
detection = future.result(timeout=2.0)
detections.append(detection)
except concurrent.futures.TimeoutError:
idx = future_to_idx[future]
print(f"模型 {idx} 推理超时")
# 创建空的检测结果
detections.append(Detection(
model_idx=idx,
model_name=configs[idx].name,
boxes=np.array([]),
confidences=np.array([]),
class_ids=np.array([]),
class_names=[],
tags=configs[idx].tags,
raw_result=None
))
except Exception as e:
print(f"模型推理异常: {e}")
else:
# 顺序推理
detections = [single_model_inference(cfg, idx)
for idx, cfg in enumerate(configs)]
# 按照模型索引排序
detections.sort(key=lambda x: x.model_idx)
# 新增如果只有一个模型且配置使用plot绘制则使用result.plot()
if len(configs) == 1 and use_plot_for_single:
single_detection = detections[0]
if single_detection.raw_result is not None:
try:
# 检查是否有plot方法
if hasattr(single_detection.raw_result, 'plot'):
# 使用plot方法绘制结果
frame_drawn = single_detection.raw_result.plot()
# 确保返回的是numpy数组
if not isinstance(frame_drawn, np.ndarray):
frame_drawn = np.array(frame_drawn)
# plot方法通常返回RGB图像转换为BGR
if len(frame_drawn.shape) == 3 and frame_drawn.shape[2] == 3:
frame_drawn = cv2.cvtColor(frame_drawn, cv2.COLOR_RGB2BGR)
print(f"使用 {single_detection.model_name} 的 plot() 方法绘制结果")
return frame_drawn, detections
else:
print(f"模型 {single_detection.model_name} 的结果对象没有 plot() 方法,使用自定义绘制")
except Exception as e:
print(f"使用 plot() 方法绘制失败: {e},回退到自定义绘制")
# 绘制结果
visualizer = DetectionVisualizer(use_pil=use_pil)
frame_drawn = visualizer.draw(frame.copy(), detections, confidence_threshold)
return frame_drawn, detections
# 兼容旧接口
def multi_model_inference_legacy(_models: List[Dict], frame: np.ndarray,
confidence_threshold: float = 0.25,
parallel: bool = True,
use_plot_for_single: bool = True) -> Tuple[np.ndarray, List[Dict]]:
"""旧接口兼容函数"""
frame_drawn, detections = multi_model_inference(
_models, frame, confidence_threshold, parallel, use_plot_for_single=use_plot_for_single
2025-12-11 13:41:07 +08:00
)
2025-12-16 10:08:12 +08:00
# 转换为旧格式
old_detections = []
for det in detections:
old_detections.append({
'model_idx': det.model_idx,
'model_name': det.model_name,
'boxes': det.boxes,
'confidences': det.confidences,
'class_ids': det.class_ids,
'class_names': det.class_names,
'tags': det.tags
})
return frame_drawn, old_detections
if __name__ == '__main__':
from ultralytics import YOLO
print('加载模型中')
model_paths = [
r"F:\PyModelScope\Yolov\models\yolov8m.pt",
r"F:\PyModelScope\Yolov\models\car.pt"
]
print("预热模型...")
model_list = []
for i, path in enumerate(model_paths):
model = YOLO(path)
model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
model.eval()
# 创建模型配置
model_config = ModelConfig(
model=model,
name=f'model_{i}',
device='cuda:0',
conf_thres=0.45,
iou_thres=0.45,
half=False,
imgsz=1920,
tags={
"0": {"name": "汽车", "reliability": 0.4, "select": True, "color": [0, 255, 0]},
"1": {"name": "行人", "reliability": 0.3, "select": True, "color": [255, 0, 0]},
"2": {"name": "自行车", "reliability": 0.5, "select": False, "color": [0, 0, 255]}
},
config=None
)
model_list.append(model_config)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
print("模型预热完成")
image_path = r"F:\PyModelScope\Yolov\images\444.png"
frame = cv2.imread(image_path)
2025-12-11 13:41:07 +08:00
2025-12-16 10:08:12 +08:00
# 测试单个模型的情况 - 使用plot绘制
print("\n=== 测试单个模型 (使用plot绘制) ===")
single_model_list = [model_list[0]] # 只使用第一个模型
frame_drawn_single, detections_single = multi_model_inference(
single_model_list, frame, use_plot_for_single=True
)
print(f"单个模型检测结果: {len(detections_single[0].boxes)} 个目标")
cv2.imwrite("uploads/result_single_plot.jpg", frame_drawn_single)
print("结果已保存到 uploads/result_single_plot.jpg")
# 测试单个模型的情况 - 强制使用自定义绘制
print("\n=== 测试单个模型 (强制使用自定义绘制) ===")
frame_drawn_single_custom, detections_single_custom = multi_model_inference(
single_model_list, frame, use_plot_for_single=False
)
print(f"单个模型自定义绘制结果: {len(detections_single_custom[0].boxes)} 个目标")
cv2.imwrite("uploads/result_single_custom.jpg", frame_drawn_single_custom)
print("结果已保存到 uploads/result_single_custom.jpg")
# 测试多个模型的情况
print("\n=== 测试多个模型 (使用自定义绘制) ===")
frame_drawn_multi, detections_multi = multi_model_inference(
model_list, frame, use_plot_for_single=True # 即使设为True多个模型也会使用自定义绘制
)
print(f"多个模型检测结果:")
for det in detections_multi:
print(f" 模型 {det.model_name}: 检测到 {len(det.boxes)} 个目标")
for box, conf, cls_id, cls_name in zip(det.boxes, det.confidences,
det.class_ids, det.class_names):
print(f"box:{box},conf:{conf},cls_id:{cls_id},cls_name:{cls_name}")
if conf >= 0.25: # 全局阈值
should_draw, color = DetectionVisualizer().should_draw_detection(
cls_id, conf, det.tags)
status = "绘制" if should_draw else "不绘制"
print(f" {cls_name} (置信度: {conf:.2f}): {status}")
cv2.imwrite("uploads/result_multi.jpg", frame_drawn_multi)
print("结果已保存到 uploads/result_multi.jpg")