333 lines
11 KiB
Python
333 lines
11 KiB
Python
import numpy as np
|
||
from typing import List, Tuple
|
||
import cv2
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
|
||
def multi_model_inference(
|
||
models: List[dict],
|
||
frame: np.ndarray,
|
||
confidence_threshold: float = 0.25,
|
||
parallel: bool = True,
|
||
use_plot_for_single: bool = True
|
||
) -> Tuple[np.ndarray, List[dict]]:
|
||
"""
|
||
多模型并行推理
|
||
|
||
Args:
|
||
models: 已加载的模型列表,每个元素是包含'model'键的字典
|
||
frame: 视频帧 (BGR格式)
|
||
confidence_threshold: 全局置信度阈值
|
||
parallel: 是否并行推理
|
||
use_plot_for_single: 当只有一个模型时,是否使用result.plot()绘制
|
||
|
||
Returns:
|
||
(绘制完成的帧, 检测结果列表)
|
||
"""
|
||
if len(models) == 0:
|
||
return frame, []
|
||
|
||
original_frame = frame.copy()
|
||
|
||
def inference_single_model(model_info: dict, img: np.ndarray) -> dict:
|
||
"""单模型推理"""
|
||
try:
|
||
model = model_info['model']
|
||
# 使用模型配置中的参数,如果不存在则使用默认值
|
||
conf_thres = model_info.get('conf_thres', confidence_threshold)
|
||
iou_thres = model_info.get('iou_thres', 0.45)
|
||
imgsz = model_info.get('imgsz', 640)
|
||
device = model_info.get('device', 'cpu')
|
||
|
||
# 执行推理
|
||
results = model.predict(
|
||
source=img,
|
||
imgsz=imgsz,
|
||
conf=conf_thres,
|
||
iou=iou_thres,
|
||
verbose=False,
|
||
device=device,
|
||
max_det=300
|
||
)
|
||
|
||
result = results[0] if len(results) > 0 else None
|
||
model_name = model_info.get('name', f"model_{model_info.get('id', 'unknown')}")
|
||
model_id = model_info.get('id', 0)
|
||
|
||
if result and result.boxes is not None and len(result.boxes):
|
||
# 提取检测信息
|
||
boxes_xyxy = result.boxes.xyxy.cpu().numpy()
|
||
confidences = result.boxes.conf.cpu().numpy()
|
||
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
||
class_names = [result.names[int(cls_id)] for cls_id in class_ids]
|
||
|
||
return {
|
||
'model_idx': model_id,
|
||
'model_name': model_name,
|
||
'boxes': boxes_xyxy,
|
||
'confidences': confidences,
|
||
'class_ids': class_ids,
|
||
'class_names': class_names,
|
||
'tags': model_info.get('tags', []),
|
||
'raw_result': result,
|
||
'success': True
|
||
}
|
||
else:
|
||
return {
|
||
'model_idx': model_id,
|
||
'model_name': model_name,
|
||
'boxes': np.array([]),
|
||
'confidences': np.array([]),
|
||
'class_ids': np.array([]),
|
||
'class_names': [],
|
||
'tags': model_info.get('tags', []),
|
||
'raw_result': result,
|
||
'success': False
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"模型 {model_info.get('name', 'unknown')} 推理失败: {str(e)}")
|
||
model_id = model_info.get('id', 0)
|
||
model_name = model_info.get('name', f"model_{model_id}")
|
||
return {
|
||
'model_idx': model_id,
|
||
'model_name': model_name,
|
||
'boxes': np.array([]),
|
||
'confidences': np.array([]),
|
||
'class_ids': np.array([]),
|
||
'class_names': [],
|
||
'tags': model_info.get('tags', []),
|
||
'raw_result': None,
|
||
'success': False
|
||
}
|
||
|
||
# 执行推理
|
||
detections = []
|
||
|
||
if parallel and len(models) > 1:
|
||
# 并行推理
|
||
max_workers = min(len(models), 4)
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 提交推理任务
|
||
future_to_model = {}
|
||
for model_info in models:
|
||
future = executor.submit(inference_single_model, model_info, original_frame)
|
||
future_to_model[future] = model_info
|
||
|
||
# 收集结果
|
||
for future in as_completed(future_to_model.keys()):
|
||
try:
|
||
result = future.result(timeout=30.0)
|
||
detections.append(result)
|
||
except Exception as e:
|
||
model_info = future_to_model[future]
|
||
model_name = model_info.get('name', 'unknown')
|
||
print(f"并行推理任务失败 - 模型 {model_name}: {str(e)}")
|
||
else:
|
||
# 串行推理
|
||
for model_info in models:
|
||
result = inference_single_model(model_info, original_frame)
|
||
detections.append(result)
|
||
|
||
# 结果绘制
|
||
plotted_frame = original_frame.copy()
|
||
|
||
if len(detections) == 1 and use_plot_for_single:
|
||
# 单个模型直接使用YOLO的plot()方法
|
||
detection = detections[0]
|
||
if detection['success'] and detection['raw_result'] is not None:
|
||
try:
|
||
# 使用YOLO自带的plot方法
|
||
plotted_frame_rgb = detection['raw_result'].plot(
|
||
img=original_frame,
|
||
conf=True,
|
||
labels=True,
|
||
boxes=True,
|
||
line_width=2
|
||
)
|
||
|
||
# 确保返回的是BGR格式
|
||
if plotted_frame_rgb.shape[-1] == 3:
|
||
plotted_frame = cv2.cvtColor(plotted_frame_rgb, cv2.COLOR_RGB2BGR)
|
||
else:
|
||
plotted_frame = plotted_frame_rgb
|
||
|
||
except Exception as e:
|
||
print(f"使用YOLO plot()绘制失败: {str(e)}")
|
||
# 单个模型时不使用自定义绘制,直接返回原图
|
||
plotted_frame = original_frame.copy()
|
||
else:
|
||
plotted_frame = original_frame.copy()
|
||
|
||
elif len(detections) > 0:
|
||
# 多个模型使用自定义绘制
|
||
plotted_frame = plot_custom_results(original_frame, detections)
|
||
else:
|
||
plotted_frame = original_frame.copy()
|
||
|
||
return plotted_frame, detections
|
||
|
||
|
||
def plot_custom_results(frame: np.ndarray, detections: List[dict]) -> np.ndarray:
|
||
"""自定义绘制多个模型的结果"""
|
||
result_frame = frame.copy()
|
||
|
||
# 预定义的颜色列表
|
||
colors = [
|
||
(0, 255, 0), # 绿色
|
||
(255, 0, 0), # 蓝色
|
||
(0, 0, 255), # 红色
|
||
(255, 255, 0), # 青色
|
||
(255, 0, 255), # 洋红
|
||
(0, 255, 255), # 黄色
|
||
(128, 0, 128), # 紫色
|
||
(0, 128, 128), # 橄榄色
|
||
]
|
||
|
||
for idx, detection in enumerate(detections):
|
||
if not detection['success'] or len(detection['boxes']) == 0:
|
||
continue
|
||
|
||
color = colors[idx % len(colors)]
|
||
model_name = detection['model_name']
|
||
|
||
# 绘制每个检测框
|
||
for box, conf, cls_id, cls_name in zip(
|
||
detection['boxes'],
|
||
detection['confidences'],
|
||
detection['class_ids'],
|
||
detection['class_names']
|
||
):
|
||
x1, y1, x2, y2 = map(int, box)
|
||
|
||
# 绘制边界框
|
||
cv2.rectangle(result_frame, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 准备标签
|
||
label = f"{model_name}:{cls_name} {conf:.2f}"
|
||
|
||
# 计算文本尺寸
|
||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||
font_scale = 0.5
|
||
thickness = 2
|
||
|
||
(text_width, text_height), baseline = cv2.getTextSize(
|
||
label, font, font_scale, thickness
|
||
)
|
||
|
||
# 绘制标签背景
|
||
label_y1 = max(0, y1 - text_height - baseline - 5)
|
||
label_y2 = max(0, y1)
|
||
label_x2 = x1 + text_width
|
||
|
||
if label_y1 < label_y2 and label_x2 > x1:
|
||
cv2.rectangle(
|
||
result_frame,
|
||
(x1, label_y1),
|
||
(label_x2, label_y2),
|
||
color,
|
||
-1
|
||
)
|
||
|
||
# 绘制标签文本
|
||
text_y = max(baseline + 5, label_y1 + text_height + baseline)
|
||
cv2.putText(
|
||
result_frame,
|
||
label,
|
||
(x1, text_y),
|
||
font,
|
||
font_scale,
|
||
(255, 255, 255),
|
||
thickness
|
||
)
|
||
|
||
return result_frame
|
||
|
||
|
||
# 简化版本,用于快速集成
|
||
def multi_model_inference_simple(
|
||
models: List[dict],
|
||
frame: np.ndarray,
|
||
confidence_threshold: float = 0.25
|
||
) -> Tuple[np.ndarray, List[dict]]:
|
||
"""
|
||
简化的多模型推理(串行,使用plot绘制)
|
||
|
||
Args:
|
||
models: 已加载的模型列表
|
||
frame: 视频帧
|
||
confidence_threshold: 置信度阈值
|
||
|
||
Returns:
|
||
(绘制完成的帧, 检测结果列表)
|
||
"""
|
||
if len(models) == 0:
|
||
return frame, []
|
||
|
||
original_frame = frame.copy()
|
||
detections = []
|
||
|
||
# 串行推理
|
||
for model_info in models:
|
||
try:
|
||
model = model_info['model']
|
||
conf_thres = model_info.get('conf_thres', confidence_threshold)
|
||
iou_thres = model_info.get('iou_thres', 0.45)
|
||
imgsz = model_info.get('imgsz', 640)
|
||
device = model_info.get('device', 'cpu')
|
||
|
||
# 执行推理
|
||
results = model.predict(
|
||
source=original_frame,
|
||
imgsz=imgsz,
|
||
conf=conf_thres,
|
||
iou=iou_thres,
|
||
verbose=False,
|
||
device=device,
|
||
max_det=300
|
||
)
|
||
|
||
result = results[0] if len(results) > 0 else None
|
||
model_name = model_info.get('name', f"model_{model_info.get('id', 'unknown')}")
|
||
model_id = model_info.get('id', 0)
|
||
|
||
detection = {
|
||
'model_idx': model_id,
|
||
'model_name': model_name,
|
||
'raw_result': result,
|
||
'success': result is not None and result.boxes is not None and len(result.boxes) > 0
|
||
}
|
||
detections.append(detection)
|
||
|
||
except Exception as e:
|
||
print(f"模型推理失败: {str(e)}")
|
||
detections.append({
|
||
'model_idx': model_info.get('id', 0),
|
||
'model_name': model_info.get('name', 'unknown'),
|
||
'raw_result': None,
|
||
'success': False
|
||
})
|
||
|
||
# 绘制结果
|
||
plotted_frame = original_frame.copy()
|
||
|
||
if len(models) == 1:
|
||
# 单个模型使用plot()
|
||
detection = detections[0]
|
||
if detection['success'] and detection['raw_result'] is not None:
|
||
try:
|
||
plotted_frame_rgb = detection['raw_result'].plot(
|
||
img=original_frame,
|
||
conf=True,
|
||
labels=True,
|
||
boxes=True
|
||
)
|
||
if plotted_frame_rgb.shape[-1] == 3:
|
||
plotted_frame = cv2.cvtColor(plotted_frame_rgb, cv2.COLOR_RGB2BGR)
|
||
except:
|
||
plotted_frame = original_frame.copy()
|
||
else:
|
||
# 多个模型使用自定义绘制
|
||
plotted_frame = plot_custom_results(original_frame, detections)
|
||
|
||
return plotted_frame, detections |