312 lines
10 KiB
Python
312 lines
10 KiB
Python
import logging
|
||
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
import os
|
||
from collections import defaultdict
|
||
|
||
|
||
class OptimizedDetectionRenderer:
|
||
def __init__(self, font_path=None):
|
||
self.font_cache = {}
|
||
self.font_path = font_path
|
||
self.drawn_labels = set() # 记录已经绘制过的标签
|
||
self.detection_cache = defaultdict(list) # 缓存同一位置的检测结果
|
||
|
||
def _get_font(self, size):
|
||
"""获取字体对象,带缓存"""
|
||
if size in self.font_cache:
|
||
return self.font_cache[size]
|
||
|
||
font_paths = []
|
||
if self.font_path and os.path.exists(self.font_path):
|
||
font_paths.append(self.font_path)
|
||
|
||
# 添加常用字体路径
|
||
font_paths.extend([
|
||
"simhei.ttf",
|
||
"msyh.ttc",
|
||
"C:/Windows/Fonts/simhei.ttf",
|
||
"C:/Windows/Fonts/msyh.ttc",
|
||
"C:/Windows/Fonts/Deng.ttf",
|
||
"C:/Windows/Fonts/simsun.ttc",
|
||
"/System/Library/Fonts/PingFang.ttc", # macOS
|
||
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" # Linux
|
||
])
|
||
|
||
font = None
|
||
for path in font_paths:
|
||
if os.path.exists(path):
|
||
try:
|
||
font = ImageFont.truetype(path, size, encoding="utf-8")
|
||
print(f"加载字体: {path}")
|
||
break
|
||
except:
|
||
continue
|
||
|
||
if font is None:
|
||
try:
|
||
font = ImageFont.load_default()
|
||
except:
|
||
font = ImageFont.load_default()
|
||
|
||
self.font_cache[size] = font
|
||
return font
|
||
|
||
def compute_iou(self, box1, box2):
|
||
"""计算两个边界框的IoU"""
|
||
x1 = max(box1[0], box2[0])
|
||
y1 = max(box1[1], box2[1])
|
||
x2 = min(box1[2], box2[2])
|
||
y2 = min(box1[3], box2[3])
|
||
|
||
if x1 >= x2 or y1 >= y2:
|
||
return 0.0
|
||
|
||
intersection = (x2 - x1) * (y2 - y1)
|
||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||
|
||
return intersection / (area1 + area2 - intersection)
|
||
|
||
def filter_duplicate_detections(self, detections, iou_threshold=0.3):
|
||
"""过滤重复检测结果,保留置信度最高的"""
|
||
filtered_detections = []
|
||
|
||
# 按置信度降序排序
|
||
print(detections)
|
||
sorted_detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
|
||
|
||
for det in sorted_detections:
|
||
is_duplicate = False
|
||
|
||
for kept_det in filtered_detections:
|
||
iou = self.compute_iou(det['box'], kept_det['box'])
|
||
|
||
# 如果IoU超过阈值,认为是同一目标
|
||
if iou > iou_threshold:
|
||
is_duplicate = True
|
||
|
||
# 如果是相同类别,用置信度高的
|
||
if det['class_name'] == kept_det['class_name']:
|
||
# 如果当前检测置信度更高,替换
|
||
if det['confidence'] > kept_det['confidence']:
|
||
filtered_detections.remove(kept_det)
|
||
filtered_detections.append(det)
|
||
break
|
||
|
||
if not is_duplicate:
|
||
filtered_detections.append(det)
|
||
|
||
return filtered_detections
|
||
|
||
def draw_text_with_background(self, img, text, position, font_size=20,
|
||
text_color=(0, 255, 0), bg_color=None, padding=5):
|
||
"""绘制带背景的文本(自动调整背景大小)"""
|
||
try:
|
||
# 转换为PIL图像
|
||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
pil_img = Image.fromarray(img_rgb)
|
||
draw = ImageDraw.Draw(pil_img)
|
||
|
||
# 获取字体
|
||
font = self._get_font(font_size)
|
||
|
||
# 计算文本尺寸
|
||
if hasattr(font, 'getbbox'): # 新版本PIL
|
||
bbox = font.getbbox(text)
|
||
text_width = bbox[2] - bbox[0]
|
||
text_height = bbox[3] - bbox[1]
|
||
else: # 旧版本PIL
|
||
text_width, text_height = font.getsize(text)
|
||
|
||
# 计算背景位置
|
||
x, y = position
|
||
bg_x1 = x - padding
|
||
bg_y1 = y - padding
|
||
bg_x2 = x + text_width + padding
|
||
bg_y2 = y + text_height + padding
|
||
|
||
# 如果提供了背景颜色,绘制背景
|
||
if bg_color:
|
||
# 将BGR转换为RGB
|
||
rgb_bg_color = bg_color[::-1]
|
||
draw.rectangle([bg_x1, bg_y1, bg_x2, bg_y2], fill=rgb_bg_color)
|
||
|
||
# 绘制文本
|
||
rgb_text_color = text_color[::-1]
|
||
draw.text(position, text, font=font, fill=rgb_text_color)
|
||
|
||
# 转换回OpenCV格式
|
||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||
|
||
except Exception as e:
|
||
print(f"文本渲染失败: {e},使用后备方案")
|
||
# 后备方案
|
||
if bg_color:
|
||
cv2.rectangle(img, (position[0] - padding, position[1] - padding),
|
||
(position[0] + len(text) * 10 + padding, position[1] + font_size + padding),
|
||
bg_color, cv2.FILLED)
|
||
cv2.putText(img, text, position, cv2.FONT_HERSHEY_SIMPLEX,
|
||
font_size / 30, text_color, 2)
|
||
return img
|
||
|
||
def draw_detection(self, frame, detection_info, model_config,
|
||
show_model_name=True, show_confidence=True):
|
||
"""在帧上绘制检测结果"""
|
||
x1, y1, x2, y2 = map(int, detection_info['box'])
|
||
color = tuple(detection_info['color'])
|
||
line_width = model_config.get('line_width', 2)
|
||
|
||
# 绘制边界框
|
||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, line_width)
|
||
|
||
# 准备标签文本
|
||
label_parts = []
|
||
|
||
if show_model_name and 'model_name' in detection_info:
|
||
label_parts.append(detection_info['model_name'])
|
||
|
||
if 'class_name' in detection_info:
|
||
label_parts.append(detection_info['class_name'])
|
||
|
||
if show_confidence and 'confidence' in detection_info:
|
||
label_parts.append(f"{detection_info['confidence']:.2f}")
|
||
|
||
label = ":".join(label_parts)
|
||
|
||
# 生成标签的唯一标识
|
||
label_key = f"{detection_info['class_name']}_{x1}_{y1}"
|
||
|
||
# 检查是否已经绘制过相同类别的标签(在一定区域内)
|
||
label_drawn = False
|
||
for drawn_label in self.drawn_labels:
|
||
drawn_class, drawn_x, drawn_y = drawn_label.split('_')
|
||
drawn_x, drawn_y = int(drawn_x), int(drawn_y)
|
||
|
||
# 计算距离,如果很近且类别相同,认为已经绘制过
|
||
distance = np.sqrt((x1 - drawn_x) ** 2 + (y1 - drawn_y) ** 2)
|
||
if distance < 50 and detection_info['class_name'] == drawn_class:
|
||
label_drawn = True
|
||
break
|
||
|
||
if not label_drawn:
|
||
# 计算标签位置(放在框的上方,如果上方空间不够则放在下方)
|
||
label_y = y1 - 20
|
||
if label_y < 20: # 如果上方空间不够
|
||
label_y = y2 + 20
|
||
|
||
# 绘制带背景的标签
|
||
frame = self.draw_text_with_background(
|
||
frame, label,
|
||
(x1, label_y),
|
||
font_size=model_config.get('font_size', 20),
|
||
text_color=(255, 255, 255),
|
||
bg_color=color,
|
||
padding=3
|
||
)
|
||
|
||
# 记录已绘制的标签
|
||
self.drawn_labels.add(label_key)
|
||
|
||
return frame
|
||
|
||
def draw_all_detections(self, frame, all_detections, model_configs,
|
||
enable_nms=True, show_model_name=True):
|
||
"""绘制所有检测结果(主入口函数)"""
|
||
# 重置已绘制标签记录
|
||
self.drawn_labels.clear()
|
||
logging.info(f'帧绘制:{model_configs}')
|
||
if not all_detections:
|
||
return frame
|
||
|
||
# 如果需要,过滤重复检测
|
||
if enable_nms:
|
||
filtered_detections = self.filter_duplicate_detections(all_detections)
|
||
else:
|
||
filtered_detections = all_detections
|
||
|
||
# 按置信度排序,先绘制置信度低的,再绘制置信度高的
|
||
sorted_detections = sorted(filtered_detections, key=lambda x: x.get('confidence', 0))
|
||
|
||
# 绘制每个检测结果
|
||
for detection in sorted_detections:
|
||
model_id = detection.get('model_id')
|
||
model_config = model_configs.get(model_id, {})
|
||
frame = self.draw_detection(frame, detection, model_config, show_model_name)
|
||
|
||
return frame
|
||
|
||
def put_text_simple(self, img, text, position, font_size=20, color=(0, 255, 0)):
|
||
"""简化版文本绘制函数"""
|
||
return self.draw_text_with_background(
|
||
img, text, position, font_size, color, None, 0
|
||
)
|
||
|
||
|
||
# 使用示例
|
||
def main():
|
||
# 初始化渲染器
|
||
renderer = OptimizedDetectionRenderer()
|
||
|
||
# 模拟多模型检测结果
|
||
detections = [
|
||
{
|
||
'model_id': 'yolov8n',
|
||
'model_name': 'YOLOv8',
|
||
'class_id': 0,
|
||
'class_name': 'person',
|
||
'confidence': 0.85,
|
||
'box': [100, 100, 200, 300],
|
||
'reliability': 0.9,
|
||
'color': (0, 255, 0)
|
||
},
|
||
{
|
||
'model_id': 'yolov8s',
|
||
'model_name': 'YOLOv8s',
|
||
'class_id': 0,
|
||
'class_name': 'person',
|
||
'confidence': 0.75,
|
||
'box': [110, 110, 210, 310], # 与第一个重叠
|
||
'reliability': 0.8,
|
||
'color': (0, 0, 255)
|
||
},
|
||
{
|
||
'model_id': 'yolov8n',
|
||
'model_name': 'YOLOv8',
|
||
'class_id': 2,
|
||
'class_name': 'car',
|
||
'confidence': 0.95,
|
||
'box': [300, 150, 450, 250],
|
||
'reliability': 0.95,
|
||
'color': (255, 0, 0)
|
||
}
|
||
]
|
||
|
||
# 模型配置
|
||
model_configs = {
|
||
'yolov8n': {'line_width': 2, 'font_size': 20},
|
||
'yolov8s': {'line_width': 2, 'font_size': 18}
|
||
}
|
||
|
||
# 读取测试图像
|
||
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||
|
||
# 绘制所有检测结果
|
||
frame = renderer.draw_all_detections(
|
||
frame,
|
||
detections,
|
||
model_configs,
|
||
enable_nms=True, # 启用NMS去重
|
||
show_model_name=True
|
||
)
|
||
|
||
# 显示结果
|
||
cv2.imshow('Detections', frame)
|
||
cv2.waitKey(0)
|
||
cv2.destroyAllWindows()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |