2025-12-16 10:08:12 +08:00
|
|
|
|
import numpy as np
|
2026-02-05 10:39:03 +08:00
|
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
def multi_model_inference(
|
|
|
|
|
|
models: List[dict],
|
|
|
|
|
|
frame: np.ndarray,
|
|
|
|
|
|
confidence_threshold: float = 0.25,
|
|
|
|
|
|
parallel: bool = True,
|
|
|
|
|
|
use_plot_for_single: bool = True
|
|
|
|
|
|
) -> Tuple[np.ndarray, List[dict]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
多模型并行推理
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
models: 已加载的模型列表,每个元素是包含'model'键的字典
|
|
|
|
|
|
frame: 视频帧 (BGR格式)
|
|
|
|
|
|
confidence_threshold: 全局置信度阈值
|
|
|
|
|
|
parallel: 是否并行推理
|
|
|
|
|
|
use_plot_for_single: 当只有一个模型时,是否使用result.plot()绘制
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
(绘制完成的帧, 检测结果列表)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if len(models) == 0:
|
|
|
|
|
|
return frame, []
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
original_frame = frame.copy()
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
def inference_single_model(model_info: dict, img: np.ndarray) -> dict:
|
|
|
|
|
|
"""单模型推理"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
model = model_info['model']
|
|
|
|
|
|
# 使用模型配置中的参数,如果不存在则使用默认值
|
|
|
|
|
|
conf_thres = model_info.get('conf_thres', confidence_threshold)
|
|
|
|
|
|
iou_thres = model_info.get('iou_thres', 0.45)
|
|
|
|
|
|
imgsz = model_info.get('imgsz', 640)
|
|
|
|
|
|
device = model_info.get('device', 'cpu')
|
|
|
|
|
|
|
|
|
|
|
|
# 执行推理
|
|
|
|
|
|
results = model.predict(
|
|
|
|
|
|
source=img,
|
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
|
conf=conf_thres,
|
|
|
|
|
|
iou=iou_thres,
|
|
|
|
|
|
verbose=False,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
max_det=300
|
|
|
|
|
|
)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
result = results[0] if len(results) > 0 else None
|
|
|
|
|
|
model_name = model_info.get('name', f"model_{model_info.get('id', 'unknown')}")
|
|
|
|
|
|
model_id = model_info.get('id', 0)
|
|
|
|
|
|
|
|
|
|
|
|
if result and result.boxes is not None and len(result.boxes):
|
|
|
|
|
|
# 提取检测信息
|
|
|
|
|
|
boxes_xyxy = result.boxes.xyxy.cpu().numpy()
|
|
|
|
|
|
confidences = result.boxes.conf.cpu().numpy()
|
|
|
|
|
|
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
|
|
|
|
|
class_names = [result.names[int(cls_id)] for cls_id in class_ids]
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'model_idx': model_id,
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'boxes': boxes_xyxy,
|
|
|
|
|
|
'confidences': confidences,
|
|
|
|
|
|
'class_ids': class_ids,
|
|
|
|
|
|
'class_names': class_names,
|
|
|
|
|
|
'tags': model_info.get('tags', []),
|
|
|
|
|
|
'raw_result': result,
|
|
|
|
|
|
'success': True
|
|
|
|
|
|
}
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'model_idx': model_id,
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'boxes': np.array([]),
|
|
|
|
|
|
'confidences': np.array([]),
|
|
|
|
|
|
'class_ids': np.array([]),
|
|
|
|
|
|
'class_names': [],
|
|
|
|
|
|
'tags': model_info.get('tags', []),
|
|
|
|
|
|
'raw_result': result,
|
|
|
|
|
|
'success': False
|
|
|
|
|
|
}
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"模型 {model_info.get('name', 'unknown')} 推理失败: {str(e)}")
|
|
|
|
|
|
model_id = model_info.get('id', 0)
|
|
|
|
|
|
model_name = model_info.get('name', f"model_{model_id}")
|
|
|
|
|
|
return {
|
|
|
|
|
|
'model_idx': model_id,
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'boxes': np.array([]),
|
|
|
|
|
|
'confidences': np.array([]),
|
|
|
|
|
|
'class_ids': np.array([]),
|
|
|
|
|
|
'class_names': [],
|
|
|
|
|
|
'tags': model_info.get('tags', []),
|
|
|
|
|
|
'raw_result': None,
|
|
|
|
|
|
'success': False
|
|
|
|
|
|
}
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 执行推理
|
|
|
|
|
|
detections = []
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
if parallel and len(models) > 1:
|
|
|
|
|
|
# 并行推理
|
|
|
|
|
|
max_workers = min(len(models), 4)
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
|
|
# 提交推理任务
|
|
|
|
|
|
future_to_model = {}
|
|
|
|
|
|
for model_info in models:
|
|
|
|
|
|
future = executor.submit(inference_single_model, model_info, original_frame)
|
|
|
|
|
|
future_to_model[future] = model_info
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 收集结果
|
|
|
|
|
|
for future in as_completed(future_to_model.keys()):
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = future.result(timeout=30.0)
|
|
|
|
|
|
detections.append(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
model_info = future_to_model[future]
|
|
|
|
|
|
model_name = model_info.get('name', 'unknown')
|
|
|
|
|
|
print(f"并行推理任务失败 - 模型 {model_name}: {str(e)}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 串行推理
|
|
|
|
|
|
for model_info in models:
|
|
|
|
|
|
result = inference_single_model(model_info, original_frame)
|
|
|
|
|
|
detections.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
# 结果绘制
|
|
|
|
|
|
plotted_frame = original_frame.copy()
|
|
|
|
|
|
|
|
|
|
|
|
if len(detections) == 1 and use_plot_for_single:
|
|
|
|
|
|
# 单个模型直接使用YOLO的plot()方法
|
|
|
|
|
|
detection = detections[0]
|
|
|
|
|
|
if detection['success'] and detection['raw_result'] is not None:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用YOLO自带的plot方法
|
|
|
|
|
|
plotted_frame_rgb = detection['raw_result'].plot(
|
|
|
|
|
|
img=original_frame,
|
|
|
|
|
|
conf=True,
|
|
|
|
|
|
labels=True,
|
|
|
|
|
|
boxes=True,
|
|
|
|
|
|
line_width=2
|
2025-12-16 10:08:12 +08:00
|
|
|
|
)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 确保返回的是BGR格式
|
|
|
|
|
|
if plotted_frame_rgb.shape[-1] == 3:
|
|
|
|
|
|
plotted_frame = cv2.cvtColor(plotted_frame_rgb, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
else:
|
|
|
|
|
|
plotted_frame = plotted_frame_rgb
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2025-12-16 10:08:12 +08:00
|
|
|
|
except Exception as e:
|
2026-02-05 10:39:03 +08:00
|
|
|
|
print(f"使用YOLO plot()绘制失败: {str(e)}")
|
|
|
|
|
|
# 单个模型时不使用自定义绘制,直接返回原图
|
|
|
|
|
|
plotted_frame = original_frame.copy()
|
2025-12-16 10:08:12 +08:00
|
|
|
|
else:
|
2026-02-05 10:39:03 +08:00
|
|
|
|
plotted_frame = original_frame.copy()
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
elif len(detections) > 0:
|
|
|
|
|
|
# 多个模型使用自定义绘制
|
|
|
|
|
|
plotted_frame = plot_custom_results(original_frame, detections)
|
|
|
|
|
|
else:
|
|
|
|
|
|
plotted_frame = original_frame.copy()
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
return plotted_frame, detections
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
def plot_custom_results(frame: np.ndarray, detections: List[dict]) -> np.ndarray:
|
|
|
|
|
|
"""自定义绘制多个模型的结果"""
|
|
|
|
|
|
result_frame = frame.copy()
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 预定义的颜色列表
|
|
|
|
|
|
colors = [
|
|
|
|
|
|
(0, 255, 0), # 绿色
|
|
|
|
|
|
(255, 0, 0), # 蓝色
|
|
|
|
|
|
(0, 0, 255), # 红色
|
|
|
|
|
|
(255, 255, 0), # 青色
|
|
|
|
|
|
(255, 0, 255), # 洋红
|
|
|
|
|
|
(0, 255, 255), # 黄色
|
|
|
|
|
|
(128, 0, 128), # 紫色
|
|
|
|
|
|
(0, 128, 128), # 橄榄色
|
|
|
|
|
|
]
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
for idx, detection in enumerate(detections):
|
|
|
|
|
|
if not detection['success'] or len(detection['boxes']) == 0:
|
|
|
|
|
|
continue
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
color = colors[idx % len(colors)]
|
|
|
|
|
|
model_name = detection['model_name']
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 绘制每个检测框
|
|
|
|
|
|
for box, conf, cls_id, cls_name in zip(
|
|
|
|
|
|
detection['boxes'],
|
|
|
|
|
|
detection['confidences'],
|
|
|
|
|
|
detection['class_ids'],
|
|
|
|
|
|
detection['class_names']
|
|
|
|
|
|
):
|
|
|
|
|
|
x1, y1, x2, y2 = map(int, box)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 绘制边界框
|
|
|
|
|
|
cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 2)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 准备标签
|
|
|
|
|
|
label = f"{model_name}:{cls_name} {conf:.2f}"
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 计算文本尺寸
|
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
|
|
|
font_scale = 0.5
|
|
|
|
|
|
thickness = 2
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
(text_width, text_height), baseline = cv2.getTextSize(
|
|
|
|
|
|
label, font, font_scale, thickness
|
|
|
|
|
|
)
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 绘制标签背景
|
|
|
|
|
|
label_y1 = max(0, y1 - text_height - baseline - 5)
|
|
|
|
|
|
label_y2 = max(0, y1)
|
|
|
|
|
|
label_x2 = x1 + text_width
|
|
|
|
|
|
|
|
|
|
|
|
if label_y1 < label_y2 and label_x2 > x1:
|
|
|
|
|
|
cv2.rectangle(
|
|
|
|
|
|
result_frame,
|
|
|
|
|
|
(x1, label_y1),
|
|
|
|
|
|
(label_x2, label_y2),
|
|
|
|
|
|
color,
|
|
|
|
|
|
-1
|
|
|
|
|
|
)
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 绘制标签文本
|
|
|
|
|
|
text_y = max(baseline + 5, label_y1 + text_height + baseline)
|
|
|
|
|
|
cv2.putText(
|
|
|
|
|
|
result_frame,
|
|
|
|
|
|
label,
|
|
|
|
|
|
(x1, text_y),
|
|
|
|
|
|
font,
|
|
|
|
|
|
font_scale,
|
|
|
|
|
|
(255, 255, 255),
|
|
|
|
|
|
thickness
|
|
|
|
|
|
)
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
return result_frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 简化版本,用于快速集成
|
|
|
|
|
|
def multi_model_inference_simple(
|
|
|
|
|
|
models: List[dict],
|
2025-12-16 10:08:12 +08:00
|
|
|
|
frame: np.ndarray,
|
2026-02-05 10:39:03 +08:00
|
|
|
|
confidence_threshold: float = 0.25
|
|
|
|
|
|
) -> Tuple[np.ndarray, List[dict]]:
|
2025-12-16 10:08:12 +08:00
|
|
|
|
"""
|
2026-02-05 10:39:03 +08:00
|
|
|
|
简化的多模型推理(串行,使用plot绘制)
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-05 10:39:03 +08:00
|
|
|
|
models: 已加载的模型列表
|
|
|
|
|
|
frame: 视频帧
|
|
|
|
|
|
confidence_threshold: 置信度阈值
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(绘制完成的帧, 检测结果列表)
|
|
|
|
|
|
"""
|
2026-02-05 10:39:03 +08:00
|
|
|
|
if len(models) == 0:
|
|
|
|
|
|
return frame, []
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
original_frame = frame.copy()
|
|
|
|
|
|
detections = []
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 串行推理
|
|
|
|
|
|
for model_info in models:
|
|
|
|
|
|
try:
|
|
|
|
|
|
model = model_info['model']
|
|
|
|
|
|
conf_thres = model_info.get('conf_thres', confidence_threshold)
|
|
|
|
|
|
iou_thres = model_info.get('iou_thres', 0.45)
|
|
|
|
|
|
imgsz = model_info.get('imgsz', 640)
|
|
|
|
|
|
device = model_info.get('device', 'cpu')
|
|
|
|
|
|
|
|
|
|
|
|
# 执行推理
|
|
|
|
|
|
results = model.predict(
|
|
|
|
|
|
source=original_frame,
|
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
|
conf=conf_thres,
|
|
|
|
|
|
iou=iou_thres,
|
|
|
|
|
|
verbose=False,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
max_det=300
|
2025-12-16 10:08:12 +08:00
|
|
|
|
)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
result = results[0] if len(results) > 0 else None
|
|
|
|
|
|
model_name = model_info.get('name', f"model_{model_info.get('id', 'unknown')}")
|
|
|
|
|
|
model_id = model_info.get('id', 0)
|
2025-12-11 13:41:07 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
detection = {
|
|
|
|
|
|
'model_idx': model_id,
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'raw_result': result,
|
|
|
|
|
|
'success': result is not None and result.boxes is not None and len(result.boxes) > 0
|
2025-12-16 10:08:12 +08:00
|
|
|
|
}
|
2026-02-05 10:39:03 +08:00
|
|
|
|
detections.append(detection)
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"模型推理失败: {str(e)}")
|
|
|
|
|
|
detections.append({
|
|
|
|
|
|
'model_idx': model_info.get('id', 0),
|
|
|
|
|
|
'model_name': model_info.get('name', 'unknown'),
|
|
|
|
|
|
'raw_result': None,
|
|
|
|
|
|
'success': False
|
|
|
|
|
|
})
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
# 绘制结果
|
|
|
|
|
|
plotted_frame = original_frame.copy()
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
if len(models) == 1:
|
|
|
|
|
|
# 单个模型使用plot()
|
|
|
|
|
|
detection = detections[0]
|
|
|
|
|
|
if detection['success'] and detection['raw_result'] is not None:
|
2025-12-16 10:08:12 +08:00
|
|
|
|
try:
|
2026-02-05 10:39:03 +08:00
|
|
|
|
plotted_frame_rgb = detection['raw_result'].plot(
|
|
|
|
|
|
img=original_frame,
|
|
|
|
|
|
conf=True,
|
|
|
|
|
|
labels=True,
|
|
|
|
|
|
boxes=True
|
|
|
|
|
|
)
|
|
|
|
|
|
if plotted_frame_rgb.shape[-1] == 3:
|
|
|
|
|
|
plotted_frame = cv2.cvtColor(plotted_frame_rgb, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
except:
|
|
|
|
|
|
plotted_frame = original_frame.copy()
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 多个模型使用自定义绘制
|
|
|
|
|
|
plotted_frame = plot_custom_results(original_frame, detections)
|
2025-12-16 10:08:12 +08:00
|
|
|
|
|
2026-02-05 10:39:03 +08:00
|
|
|
|
return plotted_frame, detections
|