544 lines
20 KiB
Python
544 lines
20 KiB
Python
import cv2
|
||
import torch
|
||
import numpy as np
|
||
import concurrent.futures
|
||
from typing import List, Dict, Tuple, Any, Optional, Union
|
||
from dataclasses import dataclass
|
||
from collections import defaultdict
|
||
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
# 尝试使用PIL绘制中文
|
||
font = None
|
||
FONT_PATHS = [
|
||
"simhei.ttf", # 黑体
|
||
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # 文泉驿微米黑
|
||
"C:/Windows/Fonts/simhei.ttf", # Windows黑体
|
||
"/System/Library/Fonts/PingFang.ttc" # macOS苹方
|
||
]
|
||
|
||
for path in FONT_PATHS:
|
||
try:
|
||
font = ImageFont.truetype(path, 20)
|
||
break
|
||
except:
|
||
continue
|
||
|
||
|
||
@dataclass
|
||
class ModelConfig:
|
||
"""模型配置数据类"""
|
||
model: Any
|
||
config: Any # 应该是模型的配置参数
|
||
tags: Optional[Dict[str, Dict]] = None
|
||
name: str = "未命名"
|
||
id: Any = None
|
||
device: str = "cuda:0" # 模型运行的设备 cuda:0表示第一个GPU
|
||
conf_thres: float = 0.25 # 置信度阈值
|
||
iou_thres: float = 0.45 # iou 衡量两个边界框的重叠程度,过低会误检,过高会漏检
|
||
half: bool = False # 是否使用半精度推理
|
||
key_valid: Any = None # 密钥验证结果
|
||
model_hash: Any = None # 模型哈希值
|
||
imgsz: Any = 640 # 输入图像尺寸
|
||
|
||
def __post_init__(self):
|
||
"""初始化后处理"""
|
||
if self.tags is None:
|
||
self.tags = {}
|
||
|
||
|
||
@dataclass
|
||
class Detection:
|
||
"""检测结果数据类"""
|
||
model_idx: int # 模型索引
|
||
model_name: str # 模型名称
|
||
boxes: np.ndarray # 边界框坐标
|
||
confidences: np.ndarray # 置信度
|
||
class_ids: np.ndarray # 类别索引
|
||
class_names: List[str] # 类别名称
|
||
"""
|
||
# 存储每个类别的标签配置信息
|
||
# 包含颜色、置信度阈值、是否绘制等信息
|
||
# 用于控制检测结果的可视化
|
||
"""
|
||
tags: Dict[str, Dict] # 标签配置 如何绘制检测结果
|
||
raw_result: Any = None # 新增:保存原始结果对象
|
||
|
||
|
||
class DetectionVisualizer:
|
||
"""检测结果可视化器"""
|
||
|
||
# 预定义模型颜色
|
||
MODEL_COLORS = [
|
||
(0, 255, 0), # 绿色
|
||
(255, 0, 0), # 蓝色
|
||
(0, 0, 255), # 红色
|
||
(255, 255, 0), # 青色
|
||
(255, 0, 255), # 紫色
|
||
(0, 255, 255), # 黄色
|
||
(128, 0, 128), # 深紫色
|
||
(0, 128, 128), # 橄榄色
|
||
]
|
||
|
||
def __init__(self, use_pil: bool = True):
|
||
"""
|
||
初始化可视化器
|
||
|
||
Args:
|
||
use_pil: 是否使用PIL绘制(支持中文)
|
||
"""
|
||
self.use_pil = use_pil
|
||
self.font = font
|
||
|
||
def should_draw_detection(self, class_id: int, confidence: float,
|
||
tags: Dict[str, Dict]) -> Tuple[bool, Optional[Tuple[int, int, int]]]:
|
||
"""
|
||
判断是否应该绘制检测框
|
||
|
||
Args:
|
||
class_id: 类别ID
|
||
confidence: 置信度
|
||
tags: 标签配置字典
|
||
|
||
Returns:
|
||
(是否绘制, 颜色)
|
||
"""
|
||
class_id_str = str(class_id)
|
||
|
||
# 如果标签配置为空,默认不绘制任何标签(因为需求是只绘制配置的标签)
|
||
if not tags:
|
||
return False, None
|
||
|
||
# 如果标签不在配置中,不绘制
|
||
tag_config = tags.get(class_id_str)
|
||
if not tag_config:
|
||
return False, None
|
||
|
||
# 检查select标记
|
||
if not tag_config.get('select', True):
|
||
return False, None
|
||
|
||
# 检查置信度阈值
|
||
reliability = tag_config.get('reliability', 0)
|
||
if confidence < reliability:
|
||
return False, None
|
||
|
||
# 获取自定义颜色
|
||
color = tag_config.get('color')
|
||
if color and isinstance(color, (list, tuple)) and len(color) >= 3:
|
||
return True, tuple(color[:3])
|
||
|
||
return True, None
|
||
|
||
def draw_with_pil(self, frame: np.ndarray, detections: List[Detection],
|
||
confidence_threshold: float) -> np.ndarray:
|
||
"""使用PIL绘制检测结果(支持中文)"""
|
||
# 转换到PIL格式
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
pil_image = Image.fromarray(frame_rgb)
|
||
draw = ImageDraw.Draw(pil_image)
|
||
|
||
for det in detections:
|
||
model_idx = det.model_idx
|
||
boxes = det.boxes
|
||
confidences = det.confidences
|
||
class_names = det.class_names
|
||
class_ids = det.class_ids
|
||
tags = det.tags
|
||
|
||
for i, (box, conf, cls_id, cls_name) in enumerate(
|
||
zip(boxes, confidences, class_ids, class_names)):
|
||
|
||
if conf < confidence_threshold:
|
||
continue
|
||
|
||
should_draw, custom_color = self.should_draw_detection(
|
||
cls_id, conf, tags)
|
||
if not should_draw:
|
||
continue
|
||
|
||
# 使用自定义颜色或模型颜色
|
||
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
|
||
|
||
x1, y1, x2, y2 = map(int, box[:4])
|
||
|
||
# 标签文本
|
||
label = f"{cls_name} {conf:.2f}"
|
||
|
||
# 绘制矩形框
|
||
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
||
|
||
# 绘制标签背景
|
||
if self.font:
|
||
try:
|
||
text_bbox = draw.textbbox((x1, y1 - 25), label, font=self.font)
|
||
draw.rectangle(text_bbox, fill=color)
|
||
draw.text((x1, y1 - 25), label, fill=(255, 255, 255), font=self.font)
|
||
except:
|
||
# 字体失败,回退到OpenCV
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||
cv2.putText(frame, label, (x1, y1 - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||
else:
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||
cv2.putText(frame, label, (x1, y1 - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||
|
||
# 转换回OpenCV格式
|
||
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||
|
||
def draw_with_opencv(self, frame: np.ndarray, detections: List[Detection],
|
||
confidence_threshold: float) -> np.ndarray:
|
||
"""使用OpenCV绘制检测结果"""
|
||
frame_drawn = frame.copy()
|
||
|
||
for det in detections:
|
||
model_idx = det.model_idx
|
||
model_name = det.model_name
|
||
boxes = det.boxes
|
||
confidences = det.confidences
|
||
class_names = det.class_names
|
||
class_ids = det.class_ids
|
||
tags = det.tags
|
||
|
||
for i, (box, conf, cls_id, cls_name) in enumerate(
|
||
zip(boxes, confidences, class_ids, class_names)):
|
||
|
||
if conf < confidence_threshold:
|
||
continue
|
||
|
||
should_draw, custom_color = self.should_draw_detection(
|
||
cls_id, conf, tags)
|
||
if not should_draw:
|
||
continue
|
||
|
||
# 使用自定义颜色或模型颜色
|
||
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
|
||
|
||
x1, y1, x2, y2 = map(int, box[:4])
|
||
|
||
# 绘制矩形框
|
||
cv2.rectangle(frame_drawn, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 绘制标签
|
||
label = f"{model_name}: {cls_name} {conf:.2f}"
|
||
|
||
# 计算文本大小
|
||
(text_width, text_height), baseline = cv2.getTextSize(
|
||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2
|
||
)
|
||
|
||
# 绘制标签背景
|
||
cv2.rectangle(frame_drawn,
|
||
(x1, y1 - text_height - 10),
|
||
(x1 + text_width, y1),
|
||
color, -1)
|
||
|
||
# 绘制文本
|
||
cv2.putText(frame_drawn, label, (x1, y1 - 5),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||
|
||
return frame_drawn
|
||
|
||
def draw(self, frame: np.ndarray, detections: List[Detection],
|
||
confidence_threshold: float) -> np.ndarray:
|
||
"""绘制检测结果"""
|
||
if self.use_pil:
|
||
try:
|
||
return self.draw_with_pil(frame, detections, confidence_threshold)
|
||
except Exception as e:
|
||
print(f"PIL绘制失败,使用OpenCV: {e}")
|
||
return self.draw_with_opencv(frame, detections, confidence_threshold)
|
||
else:
|
||
return self.draw_with_opencv(frame, detections, confidence_threshold)
|
||
|
||
|
||
class YOLOModelWrapper:
|
||
"""YOLO模型包装器"""
|
||
|
||
@staticmethod
|
||
def infer_yolov8(model: Any, frame: np.ndarray, conf_thres: float,
|
||
iou_thres: float, imgsz: int, half: bool, device: str):
|
||
"""YOLOv8推理"""
|
||
with torch.no_grad():
|
||
results = model(frame, conf=conf_thres, verbose=False,
|
||
iou=iou_thres, imgsz=imgsz, half=half, device=device)
|
||
|
||
if not results or len(results) == 0:
|
||
return np.array([]), np.array([]), np.array([]), [], None
|
||
|
||
result = results[0]
|
||
if not hasattr(result, 'boxes'):
|
||
return np.array([]), np.array([]), np.array([]), [], None
|
||
|
||
boxes = result.boxes.xyxy.cpu().numpy()
|
||
confidences = result.boxes.conf.cpu().numpy()
|
||
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
||
class_names = [result.names.get(cid, str(cid)) for cid in class_ids]
|
||
|
||
return boxes, confidences, class_ids, class_names, result
|
||
|
||
@staticmethod
|
||
def infer_yolov5(model: Any, frame: np.ndarray):
|
||
"""YOLOv5推理"""
|
||
results = model(frame)
|
||
|
||
if not hasattr(results, 'xyxy'):
|
||
return np.array([]), np.array([]), np.array([]), [], None
|
||
|
||
detections = results.xyxy[0].cpu().numpy()
|
||
if len(detections) == 0:
|
||
return np.array([]), np.array([]), np.array([]), [], None
|
||
|
||
boxes = detections[:, :4]
|
||
confidences = detections[:, 4]
|
||
class_ids = detections[:, 5].astype(int)
|
||
|
||
if hasattr(results, 'names'):
|
||
class_names = [results.names.get(cid, str(cid)) for cid in class_ids]
|
||
else:
|
||
class_names = [str(cid) for cid in class_ids]
|
||
|
||
return boxes, confidences, class_ids, class_names, results
|
||
|
||
|
||
def multi_model_inference(
|
||
models_config: List[Union[Dict, ModelConfig]],
|
||
frame: np.ndarray,
|
||
confidence_threshold: float = 0.25,
|
||
parallel: bool = True,
|
||
use_pil: bool = True,
|
||
use_plot_for_single: bool = True # 新增参数:是否对单个模型使用plot绘制
|
||
) -> Tuple[np.ndarray, List[Detection]]:
|
||
"""
|
||
多模型并行推理
|
||
|
||
Args:
|
||
models_config: 模型配置列表
|
||
frame: 视频帧 (BGR格式)
|
||
confidence_threshold: 全局置信度阈值
|
||
parallel: 是否并行推理
|
||
use_pil: 是否使用PIL绘制
|
||
use_plot_for_single: 当只有一个模型时,是否使用result.plot()绘制
|
||
|
||
Returns:
|
||
(绘制完成的帧, 检测结果列表)
|
||
"""
|
||
# 转换为ModelConfig对象
|
||
configs = []
|
||
for cfg in models_config:
|
||
if isinstance(cfg, dict):
|
||
configs.append(ModelConfig(**cfg))
|
||
else:
|
||
configs.append(cfg)
|
||
|
||
def single_model_inference(model_cfg: ModelConfig, model_idx: int) -> Detection:
|
||
"""单个模型的推理函数"""
|
||
try:
|
||
model = model_cfg.model
|
||
|
||
# 根据模型类型进行推理
|
||
if hasattr(model, 'predict'): # YOLOv8
|
||
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov8(
|
||
model, frame, model_cfg.conf_thres, model_cfg.iou_thres,
|
||
model_cfg.imgsz, model_cfg.half, model_cfg.device
|
||
)
|
||
elif hasattr(model, '__call__'): # YOLOv5
|
||
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov5(model, frame)
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {type(model)}")
|
||
|
||
return Detection(
|
||
model_idx=model_idx,
|
||
model_name=model_cfg.name,
|
||
boxes=boxes,
|
||
confidences=confidences,
|
||
class_ids=class_ids,
|
||
class_names=class_names,
|
||
tags=model_cfg.tags,
|
||
raw_result=raw_result
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"模型 {model_idx} ({model_cfg.name}) 推理失败: {e}")
|
||
return Detection(
|
||
model_idx=model_idx,
|
||
model_name=model_cfg.name,
|
||
boxes=np.array([]),
|
||
confidences=np.array([]),
|
||
class_ids=np.array([]),
|
||
class_names=[],
|
||
tags=model_cfg.tags,
|
||
raw_result=None
|
||
)
|
||
|
||
# 并行推理
|
||
if parallel and len(configs) > 1:
|
||
detections = []
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(configs)) as executor:
|
||
# 提交任务
|
||
future_to_idx = {
|
||
executor.submit(single_model_inference, cfg, idx): idx
|
||
for idx, cfg in enumerate(configs)
|
||
}
|
||
|
||
# 收集结果
|
||
for future in concurrent.futures.as_completed(future_to_idx):
|
||
try:
|
||
detection = future.result(timeout=2.0)
|
||
detections.append(detection)
|
||
except concurrent.futures.TimeoutError:
|
||
idx = future_to_idx[future]
|
||
print(f"模型 {idx} 推理超时")
|
||
# 创建空的检测结果
|
||
detections.append(Detection(
|
||
model_idx=idx,
|
||
model_name=configs[idx].name,
|
||
boxes=np.array([]),
|
||
confidences=np.array([]),
|
||
class_ids=np.array([]),
|
||
class_names=[],
|
||
tags=configs[idx].tags,
|
||
raw_result=None
|
||
))
|
||
except Exception as e:
|
||
print(f"模型推理异常: {e}")
|
||
else:
|
||
# 顺序推理
|
||
detections = [single_model_inference(cfg, idx)
|
||
for idx, cfg in enumerate(configs)]
|
||
|
||
# 按照模型索引排序
|
||
detections.sort(key=lambda x: x.model_idx)
|
||
|
||
# 新增:如果只有一个模型且配置使用plot绘制,则使用result.plot()
|
||
if len(configs) == 1 and use_plot_for_single:
|
||
single_detection = detections[0]
|
||
if single_detection.raw_result is not None:
|
||
try:
|
||
# 检查是否有plot方法
|
||
if hasattr(single_detection.raw_result, 'plot'):
|
||
# 使用plot方法绘制结果
|
||
frame_drawn = single_detection.raw_result.plot()
|
||
# 确保返回的是numpy数组
|
||
if not isinstance(frame_drawn, np.ndarray):
|
||
frame_drawn = np.array(frame_drawn)
|
||
# plot方法通常返回RGB图像,转换为BGR
|
||
if len(frame_drawn.shape) == 3 and frame_drawn.shape[2] == 3:
|
||
frame_drawn = cv2.cvtColor(frame_drawn, cv2.COLOR_RGB2BGR)
|
||
# print(f"使用 {single_detection.model_name} 的 plot() 方法绘制结果")
|
||
return frame_drawn, detections
|
||
else:
|
||
print(f"模型 {single_detection.model_name} 的结果对象没有 plot() 方法,使用自定义绘制")
|
||
except Exception as e:
|
||
print(f"使用 plot() 方法绘制失败: {e},回退到自定义绘制")
|
||
|
||
# 绘制结果
|
||
visualizer = DetectionVisualizer(use_pil=use_pil)
|
||
frame_drawn = visualizer.draw(frame.copy(), detections, confidence_threshold)
|
||
|
||
return frame_drawn, detections
|
||
|
||
|
||
# 兼容旧接口
|
||
def multi_model_inference_legacy(_models: List[Dict], frame: np.ndarray,
|
||
confidence_threshold: float = 0.25,
|
||
parallel: bool = True,
|
||
use_plot_for_single: bool = True) -> Tuple[np.ndarray, List[Dict]]:
|
||
"""旧接口兼容函数"""
|
||
frame_drawn, detections = multi_model_inference(
|
||
_models, frame, confidence_threshold, parallel, use_plot_for_single=use_plot_for_single
|
||
)
|
||
|
||
# 转换为旧格式
|
||
old_detections = []
|
||
for det in detections:
|
||
old_detections.append({
|
||
'model_idx': det.model_idx,
|
||
'model_name': det.model_name,
|
||
'boxes': det.boxes,
|
||
'confidences': det.confidences,
|
||
'class_ids': det.class_ids,
|
||
'class_names': det.class_names,
|
||
'tags': det.tags
|
||
})
|
||
|
||
return frame_drawn, old_detections
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from ultralytics import YOLO
|
||
|
||
print('加载模型中')
|
||
model_paths = [
|
||
r"F:\PyModelScope\Yolov\models\yolov8m.pt",
|
||
r"F:\PyModelScope\Yolov\models\car.pt"
|
||
]
|
||
|
||
print("预热模型...")
|
||
model_list = []
|
||
for i, path in enumerate(model_paths):
|
||
model = YOLO(path)
|
||
model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||
model.eval()
|
||
|
||
# 创建模型配置
|
||
model_config = ModelConfig(
|
||
model=model,
|
||
name=f'model_{i}',
|
||
device='cuda:0',
|
||
conf_thres=0.45,
|
||
iou_thres=0.45,
|
||
half=False,
|
||
imgsz=1920,
|
||
tags={
|
||
"0": {"name": "汽车", "reliability": 0.4, "select": True, "color": [0, 255, 0]},
|
||
"1": {"name": "行人", "reliability": 0.3, "select": True, "color": [255, 0, 0]},
|
||
"2": {"name": "自行车", "reliability": 0.5, "select": False, "color": [0, 0, 255]}
|
||
},
|
||
config=None
|
||
)
|
||
model_list.append(model_config)
|
||
|
||
print("模型预热完成")
|
||
image_path = r"F:\PyModelScope\Yolov\images\444.png"
|
||
frame = cv2.imread(image_path)
|
||
|
||
# 测试单个模型的情况 - 使用plot绘制
|
||
print("\n=== 测试单个模型 (使用plot绘制) ===")
|
||
single_model_list = [model_list[0]] # 只使用第一个模型
|
||
frame_drawn_single, detections_single = multi_model_inference(
|
||
single_model_list, frame, use_plot_for_single=True
|
||
)
|
||
print(f"单个模型检测结果: {len(detections_single[0].boxes)} 个目标")
|
||
cv2.imwrite("uploads/result_single_plot.jpg", frame_drawn_single)
|
||
print("结果已保存到 uploads/result_single_plot.jpg")
|
||
|
||
# 测试单个模型的情况 - 强制使用自定义绘制
|
||
print("\n=== 测试单个模型 (强制使用自定义绘制) ===")
|
||
frame_drawn_single_custom, detections_single_custom = multi_model_inference(
|
||
single_model_list, frame, use_plot_for_single=False
|
||
)
|
||
print(f"单个模型自定义绘制结果: {len(detections_single_custom[0].boxes)} 个目标")
|
||
cv2.imwrite("uploads/result_single_custom.jpg", frame_drawn_single_custom)
|
||
print("结果已保存到 uploads/result_single_custom.jpg")
|
||
|
||
# 测试多个模型的情况
|
||
print("\n=== 测试多个模型 (使用自定义绘制) ===")
|
||
frame_drawn_multi, detections_multi = multi_model_inference(
|
||
model_list, frame, use_plot_for_single=True # 即使设为True,多个模型也会使用自定义绘制
|
||
)
|
||
print(f"多个模型检测结果:")
|
||
for det in detections_multi:
|
||
print(f" 模型 {det.model_name}: 检测到 {len(det.boxes)} 个目标")
|
||
for box, conf, cls_id, cls_name in zip(det.boxes, det.confidences,
|
||
det.class_ids, det.class_names):
|
||
print(f"box:{box},conf:{conf},cls_id:{cls_id},cls_name:{cls_name}")
|
||
if conf >= 0.25: # 全局阈值
|
||
should_draw, color = DetectionVisualizer().should_draw_detection(
|
||
cls_id, conf, det.tags)
|
||
status = "绘制" if should_draw else "不绘制"
|
||
print(f" {cls_name} (置信度: {conf:.2f}): {status}")
|
||
|
||
cv2.imwrite("uploads/result_multi.jpg", frame_drawn_multi)
|
||
print("结果已保存到 uploads/result_multi.jpg") |