Compare commits
3 Commits
1b66ec5dbc
...
4dd77023a2
| Author | SHA1 | Date |
|---|---|---|
|
|
4dd77023a2 | |
|
|
8349890942 | |
|
|
9754cd45eb |
|
|
@ -13,3 +13,5 @@ yolo_detection.log
|
|||
*.enc
|
||||
*.jpg
|
||||
*.pyc
|
||||
encrypted_models
|
||||
temp_uploads
|
||||
Binary file not shown.
|
|
@ -19,7 +19,7 @@ from ultralytics import YOLO
|
|||
from _minio import MinioUploader
|
||||
from log import logger
|
||||
from global_data import gd
|
||||
from detection_render import OptimizedDetectionRenderer
|
||||
from detection_render import multi_model_inference
|
||||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,6 @@ class DetectionThread(threading.Thread):
|
|||
self.reconnect_attempts = 0
|
||||
self.last_frame_time = time.time()
|
||||
self.fps = 0
|
||||
self.detections_count = 0
|
||||
self.stop_event = threading.Event()
|
||||
self.last_log_time = time.time()
|
||||
self.daemon = True
|
||||
|
|
@ -298,7 +297,6 @@ class DetectionThread(threading.Thread):
|
|||
self.key_verification_results = {}
|
||||
|
||||
# 绘制结果
|
||||
self.renderer = OptimizedDetectionRenderer()
|
||||
logger.info(f"检测线程初始化完成: {self.taskname}")
|
||||
|
||||
def _is_windows(self):
|
||||
|
|
@ -344,7 +342,6 @@ class DetectionThread(threading.Thread):
|
|||
|
||||
# 准备标签
|
||||
tags = model_config.get('tags', {})
|
||||
|
||||
# 存储模型信息
|
||||
model_info = {
|
||||
'model': model,
|
||||
|
|
@ -715,96 +712,10 @@ class DetectionThread(threading.Thread):
|
|||
|
||||
logger.info("所有模型预热完成")
|
||||
|
||||
def multi_model_inference(self, frame):
|
||||
def _multi_model_inference(self, frame):
|
||||
"""多模型推理(每个模型独立标签和置信度)"""
|
||||
all_detections = {}
|
||||
annotated_frame = frame.copy()
|
||||
model_config_list = {}
|
||||
model_detections_list = []
|
||||
for model_info in self.models:
|
||||
model = model_info['model']
|
||||
model_config = model_info['config']
|
||||
model_tags = model_info['tags'] # 模型独立标签
|
||||
model_name = model_info['name']
|
||||
model_id = model_info['id']
|
||||
model_config_list[model_id] = model_config
|
||||
try:
|
||||
# 使用模型特定的参数
|
||||
conf_thres = model_config.get('conf_thres', 0.25)
|
||||
iou_thres = model_config.get('iou_thres', 0.45)
|
||||
imgsz = model_config.get('imgsz', 640)
|
||||
device = model_config.get('device', 'cpu')
|
||||
half = model_config.get('half', False) and 'cuda' in device
|
||||
|
||||
with torch.no_grad():
|
||||
results = model.predict(
|
||||
frame,
|
||||
stream=False,
|
||||
verbose=False,
|
||||
conf=conf_thres,
|
||||
iou=iou_thres,
|
||||
imgsz=imgsz,
|
||||
device=device,
|
||||
half=half
|
||||
)
|
||||
|
||||
# 处理结果
|
||||
result = results[0]
|
||||
detections = []
|
||||
|
||||
if result.boxes is not None:
|
||||
for box in result.boxes:
|
||||
class_id = int(box.cls)
|
||||
confidence = float(box.conf)
|
||||
|
||||
# 从模型特定标签中获取类别信息
|
||||
class_key = str(class_id)
|
||||
class_info = model_tags.get(class_key, None)
|
||||
|
||||
if class_info is None:
|
||||
continue
|
||||
|
||||
select = class_info.get('select', False)
|
||||
if not select:
|
||||
continue
|
||||
|
||||
# 使用标签中的reliability作为置信度阈值
|
||||
tag_reliability = class_info.get('reliability', conf_thres)
|
||||
if confidence < tag_reliability:
|
||||
continue
|
||||
|
||||
if class_info is None:
|
||||
continue
|
||||
color = class_info.get('color',None)
|
||||
if color is None:
|
||||
color = [255, 255, 0]
|
||||
detection_info = {
|
||||
'model_id': model_id,
|
||||
'model_name': model_name,
|
||||
'class_id': class_id,
|
||||
'class_name': class_info.get('name', f'class_{class_id}'),
|
||||
'confidence': confidence,
|
||||
'box': box.xyxy[0].tolist(),
|
||||
'reliability': tag_reliability,
|
||||
'color': color
|
||||
}
|
||||
|
||||
detections.append(detection_info)
|
||||
model_detections_list.append(detection_info)
|
||||
|
||||
all_detections[model_name] = detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型 {model_name} 推理失败: {str(e)}")
|
||||
all_detections[model_name] = []
|
||||
_frame = self.renderer.draw_all_detections(
|
||||
annotated_frame,
|
||||
model_detections_list,
|
||||
model_config_list,
|
||||
enable_nms=True, # 启用NMS去重
|
||||
show_model_name=True
|
||||
)
|
||||
return _frame, all_detections
|
||||
frame_drawn, detections = multi_model_inference(self.models, frame)
|
||||
return frame_drawn, detections
|
||||
|
||||
def should_skip_frame(self, start_time):
|
||||
"""判断是否应该跳过当前帧"""
|
||||
|
|
@ -1035,22 +946,25 @@ class DetectionThread(threading.Thread):
|
|||
time_str = now.strftime("%H:%M:%S")
|
||||
|
||||
# 合并所有模型的检测结果
|
||||
combined_detections = []
|
||||
for model_name, detections in all_detections.items():
|
||||
combined_detections.extend(detections)
|
||||
|
||||
# 简化数据
|
||||
simplified_data = []
|
||||
for d in combined_detections:
|
||||
simplified_data.append({
|
||||
'model_id': d['model_id'],
|
||||
'model_name': d['model_name'],
|
||||
'class_id': d['class_id'],
|
||||
'class_name': d['class_name'],
|
||||
'confidence': round(d['confidence'], 2),
|
||||
'box': [round(float(c), 1) for c in d['box']],
|
||||
'color': d['color']
|
||||
})
|
||||
all_detections_send = []
|
||||
for det in all_detections:
|
||||
print(f" 模型 {det.model_name}: 检测到 {len(det.boxes)} 个目标")
|
||||
det_detections_res = []
|
||||
for box, conf, cls_id, cls_name in zip(det.boxes, det.confidences,
|
||||
det.class_ids, det.class_names):
|
||||
det_detections_res.append(
|
||||
{
|
||||
'class_id': cls_id,
|
||||
'class_name': cls_name,
|
||||
'box': box,
|
||||
'conf': conf,
|
||||
}
|
||||
)
|
||||
det_detection = {
|
||||
'count': len(det.boxes),
|
||||
'detections': det_detections_res,
|
||||
}
|
||||
all_detections_send.append(det_detection)
|
||||
|
||||
# 添加流稳定性信息
|
||||
stream_info = {
|
||||
|
|
@ -1062,7 +976,7 @@ class DetectionThread(threading.Thread):
|
|||
|
||||
self.config['socketIO'].emit('detection_results', {
|
||||
'task_id': getattr(self, 'task_id', 'unknown'),
|
||||
'detections': simplified_data,
|
||||
'detections': all_detections_send,
|
||||
'timestamp': time.time_ns() // 1000000,
|
||||
'fps': round(self.last_fps, 1),
|
||||
'frame_count': self.frame_count,
|
||||
|
|
@ -1227,23 +1141,21 @@ class DetectionThread(threading.Thread):
|
|||
if self.should_skip_frame(start_time):
|
||||
self.frame_count += 1
|
||||
continue
|
||||
|
||||
# 多模型推理
|
||||
annotated_frame, all_detections = self.multi_model_inference(frame)
|
||||
self.detections_count = sum(len(dets) for dets in all_detections.values())
|
||||
|
||||
annotated_frame, model_detections = self._multi_model_inference(frame)
|
||||
# 推流处理(Windows优化)
|
||||
if self.enable_push:
|
||||
if not self.push_frame_to_task_streamer(annotated_frame):
|
||||
# 推流失败,但继续处理,避免影响检测
|
||||
pass
|
||||
|
||||
# 上传处理
|
||||
self.handle_upload(annotated_frame, all_detections, current_time)
|
||||
|
||||
# WebSocket发送
|
||||
if current_time - self.last_log_time >= 1:
|
||||
self.send_to_websocket(all_detections)
|
||||
# # WebSocket发送
|
||||
start = time.time()
|
||||
self.send_to_websocket(model_detections)
|
||||
logger.info(f'startTime:{start},endTime:{time.time()},时间差:{time.time()-start}')
|
||||
# # 上传处理
|
||||
# self.handle_upload(annotated_frame, model_detections, current_time)
|
||||
self.last_log_time = current_time
|
||||
|
||||
self.frame_count += 1
|
||||
|
|
|
|||
|
|
@ -1,312 +1,539 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import concurrent.futures
|
||||
from typing import List, Dict, Tuple, Any, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
class OptimizedDetectionRenderer:
|
||||
def __init__(self, font_path=None):
|
||||
self.font_cache = {}
|
||||
self.font_path = font_path
|
||||
self.drawn_labels = set() # 记录已经绘制过的标签
|
||||
self.detection_cache = defaultdict(list) # 缓存同一位置的检测结果
|
||||
|
||||
def _get_font(self, size):
|
||||
"""获取字体对象,带缓存"""
|
||||
if size in self.font_cache:
|
||||
return self.font_cache[size]
|
||||
|
||||
font_paths = []
|
||||
if self.font_path and os.path.exists(self.font_path):
|
||||
font_paths.append(self.font_path)
|
||||
|
||||
# 添加常用字体路径
|
||||
font_paths.extend([
|
||||
"simhei.ttf",
|
||||
"msyh.ttc",
|
||||
"C:/Windows/Fonts/simhei.ttf",
|
||||
"C:/Windows/Fonts/msyh.ttc",
|
||||
"C:/Windows/Fonts/Deng.ttf",
|
||||
"C:/Windows/Fonts/simsun.ttc",
|
||||
"/System/Library/Fonts/PingFang.ttc", # macOS
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" # Linux
|
||||
])
|
||||
|
||||
# 尝试使用PIL绘制中文
|
||||
font = None
|
||||
for path in font_paths:
|
||||
if os.path.exists(path):
|
||||
FONT_PATHS = [
|
||||
"simhei.ttf", # 黑体
|
||||
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", # 文泉驿微米黑
|
||||
"C:/Windows/Fonts/simhei.ttf", # Windows黑体
|
||||
"/System/Library/Fonts/PingFang.ttc" # macOS苹方
|
||||
]
|
||||
|
||||
for path in FONT_PATHS:
|
||||
try:
|
||||
font = ImageFont.truetype(path, size, encoding="utf-8")
|
||||
print(f"加载字体: {path}")
|
||||
font = ImageFont.truetype(path, 20)
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if font is None:
|
||||
try:
|
||||
font = ImageFont.load_default()
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
self.font_cache[size] = font
|
||||
return font
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型配置数据类"""
|
||||
model: Any
|
||||
config: Any
|
||||
tags: Optional[Dict[str, Dict]] = None
|
||||
name: str = "未命名"
|
||||
id: Any = None
|
||||
device: str = "cuda:0"
|
||||
conf_thres: float = 0.25
|
||||
iou_thres: float = 0.45
|
||||
half: bool = False
|
||||
key_valid: Any = None
|
||||
model_hash: Any = None
|
||||
imgsz: Any = 640
|
||||
|
||||
def compute_iou(self, box1, box2):
|
||||
"""计算两个边界框的IoU"""
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2 - x1) * (y2 - y1)
|
||||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
|
||||
return intersection / (area1 + area2 - intersection)
|
||||
|
||||
def filter_duplicate_detections(self, detections, iou_threshold=0.3):
|
||||
"""过滤重复检测结果,保留置信度最高的"""
|
||||
filtered_detections = []
|
||||
|
||||
# 按置信度降序排序
|
||||
print(detections)
|
||||
sorted_detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
|
||||
|
||||
for det in sorted_detections:
|
||||
is_duplicate = False
|
||||
|
||||
for kept_det in filtered_detections:
|
||||
iou = self.compute_iou(det['box'], kept_det['box'])
|
||||
|
||||
# 如果IoU超过阈值,认为是同一目标
|
||||
if iou > iou_threshold:
|
||||
is_duplicate = True
|
||||
|
||||
# 如果是相同类别,用置信度高的
|
||||
if det['class_name'] == kept_det['class_name']:
|
||||
# 如果当前检测置信度更高,替换
|
||||
if det['confidence'] > kept_det['confidence']:
|
||||
filtered_detections.remove(kept_det)
|
||||
filtered_detections.append(det)
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
filtered_detections.append(det)
|
||||
|
||||
return filtered_detections
|
||||
|
||||
def draw_text_with_background(self, img, text, position, font_size=20,
|
||||
text_color=(0, 255, 0), bg_color=None, padding=5):
|
||||
"""绘制带背景的文本(自动调整背景大小)"""
|
||||
try:
|
||||
# 转换为PIL图像
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(img_rgb)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 获取字体
|
||||
font = self._get_font(font_size)
|
||||
|
||||
# 计算文本尺寸
|
||||
if hasattr(font, 'getbbox'): # 新版本PIL
|
||||
bbox = font.getbbox(text)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
else: # 旧版本PIL
|
||||
text_width, text_height = font.getsize(text)
|
||||
|
||||
# 计算背景位置
|
||||
x, y = position
|
||||
bg_x1 = x - padding
|
||||
bg_y1 = y - padding
|
||||
bg_x2 = x + text_width + padding
|
||||
bg_y2 = y + text_height + padding
|
||||
|
||||
# 如果提供了背景颜色,绘制背景
|
||||
if bg_color:
|
||||
# 将BGR转换为RGB
|
||||
rgb_bg_color = bg_color[::-1]
|
||||
draw.rectangle([bg_x1, bg_y1, bg_x2, bg_y2], fill=rgb_bg_color)
|
||||
|
||||
# 绘制文本
|
||||
rgb_text_color = text_color[::-1]
|
||||
draw.text(position, text, font=font, fill=rgb_text_color)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
except Exception as e:
|
||||
print(f"文本渲染失败: {e},使用后备方案")
|
||||
# 后备方案
|
||||
if bg_color:
|
||||
cv2.rectangle(img, (position[0] - padding, position[1] - padding),
|
||||
(position[0] + len(text) * 10 + padding, position[1] + font_size + padding),
|
||||
bg_color, cv2.FILLED)
|
||||
cv2.putText(img, text, position, cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_size / 30, text_color, 2)
|
||||
return img
|
||||
|
||||
def draw_detection(self, frame, detection_info, model_config,
|
||||
show_model_name=True, show_confidence=True):
|
||||
"""在帧上绘制检测结果"""
|
||||
x1, y1, x2, y2 = map(int, detection_info['box'])
|
||||
color = tuple(detection_info['color'])
|
||||
line_width = model_config.get('line_width', 2)
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, line_width)
|
||||
|
||||
# 准备标签文本
|
||||
label_parts = []
|
||||
|
||||
if show_model_name and 'model_name' in detection_info:
|
||||
label_parts.append(detection_info['model_name'])
|
||||
|
||||
if 'class_name' in detection_info:
|
||||
label_parts.append(detection_info['class_name'])
|
||||
|
||||
if show_confidence and 'confidence' in detection_info:
|
||||
label_parts.append(f"{detection_info['confidence']:.2f}")
|
||||
|
||||
label = ":".join(label_parts)
|
||||
|
||||
# 生成标签的唯一标识
|
||||
label_key = f"{detection_info['class_name']}_{x1}_{y1}"
|
||||
|
||||
# 检查是否已经绘制过相同类别的标签(在一定区域内)
|
||||
label_drawn = False
|
||||
for drawn_label in self.drawn_labels:
|
||||
drawn_class, drawn_x, drawn_y = drawn_label.split('_')
|
||||
drawn_x, drawn_y = int(drawn_x), int(drawn_y)
|
||||
|
||||
# 计算距离,如果很近且类别相同,认为已经绘制过
|
||||
distance = np.sqrt((x1 - drawn_x) ** 2 + (y1 - drawn_y) ** 2)
|
||||
if distance < 50 and detection_info['class_name'] == drawn_class:
|
||||
label_drawn = True
|
||||
break
|
||||
|
||||
if not label_drawn:
|
||||
# 计算标签位置(放在框的上方,如果上方空间不够则放在下方)
|
||||
label_y = y1 - 20
|
||||
if label_y < 20: # 如果上方空间不够
|
||||
label_y = y2 + 20
|
||||
|
||||
# 绘制带背景的标签
|
||||
frame = self.draw_text_with_background(
|
||||
frame, label,
|
||||
(x1, label_y),
|
||||
font_size=model_config.get('font_size', 20),
|
||||
text_color=(255, 255, 255),
|
||||
bg_color=color,
|
||||
padding=3
|
||||
)
|
||||
|
||||
# 记录已绘制的标签
|
||||
self.drawn_labels.add(label_key)
|
||||
|
||||
return frame
|
||||
|
||||
def draw_all_detections(self, frame, all_detections, model_configs,
|
||||
enable_nms=True, show_model_name=True):
|
||||
"""绘制所有检测结果(主入口函数)"""
|
||||
# 重置已绘制标签记录
|
||||
self.drawn_labels.clear()
|
||||
logging.info(f'帧绘制:{model_configs}')
|
||||
if not all_detections:
|
||||
return frame
|
||||
|
||||
# 如果需要,过滤重复检测
|
||||
if enable_nms:
|
||||
filtered_detections = self.filter_duplicate_detections(all_detections)
|
||||
else:
|
||||
filtered_detections = all_detections
|
||||
|
||||
# 按置信度排序,先绘制置信度低的,再绘制置信度高的
|
||||
sorted_detections = sorted(filtered_detections, key=lambda x: x.get('confidence', 0))
|
||||
|
||||
# 绘制每个检测结果
|
||||
for detection in sorted_detections:
|
||||
model_id = detection.get('model_id')
|
||||
model_config = model_configs.get(model_id, {})
|
||||
frame = self.draw_detection(frame, detection, model_config, show_model_name)
|
||||
|
||||
return frame
|
||||
|
||||
def put_text_simple(self, img, text, position, font_size=20, color=(0, 255, 0)):
|
||||
"""简化版文本绘制函数"""
|
||||
return self.draw_text_with_background(
|
||||
img, text, position, font_size, color, None, 0
|
||||
)
|
||||
def __post_init__(self):
|
||||
"""初始化后处理"""
|
||||
if self.tags is None:
|
||||
self.tags = {}
|
||||
|
||||
|
||||
# 使用示例
|
||||
def main():
|
||||
# 初始化渲染器
|
||||
renderer = OptimizedDetectionRenderer()
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""检测结果数据类"""
|
||||
model_idx: int
|
||||
model_name: str
|
||||
boxes: np.ndarray
|
||||
confidences: np.ndarray
|
||||
class_ids: np.ndarray
|
||||
class_names: List[str]
|
||||
tags: Dict[str, Dict]
|
||||
raw_result: Any = None # 新增:保存原始结果对象
|
||||
|
||||
# 模拟多模型检测结果
|
||||
detections = [
|
||||
{
|
||||
'model_id': 'yolov8n',
|
||||
'model_name': 'YOLOv8',
|
||||
'class_id': 0,
|
||||
'class_name': 'person',
|
||||
'confidence': 0.85,
|
||||
'box': [100, 100, 200, 300],
|
||||
'reliability': 0.9,
|
||||
'color': (0, 255, 0)
|
||||
},
|
||||
{
|
||||
'model_id': 'yolov8s',
|
||||
'model_name': 'YOLOv8s',
|
||||
'class_id': 0,
|
||||
'class_name': 'person',
|
||||
'confidence': 0.75,
|
||||
'box': [110, 110, 210, 310], # 与第一个重叠
|
||||
'reliability': 0.8,
|
||||
'color': (0, 0, 255)
|
||||
},
|
||||
{
|
||||
'model_id': 'yolov8n',
|
||||
'model_name': 'YOLOv8',
|
||||
'class_id': 2,
|
||||
'class_name': 'car',
|
||||
'confidence': 0.95,
|
||||
'box': [300, 150, 450, 250],
|
||||
'reliability': 0.95,
|
||||
'color': (255, 0, 0)
|
||||
}
|
||||
|
||||
class DetectionVisualizer:
|
||||
"""检测结果可视化器"""
|
||||
|
||||
# 预定义模型颜色
|
||||
MODEL_COLORS = [
|
||||
(0, 255, 0), # 绿色
|
||||
(255, 0, 0), # 蓝色
|
||||
(0, 0, 255), # 红色
|
||||
(255, 255, 0), # 青色
|
||||
(255, 0, 255), # 紫色
|
||||
(0, 255, 255), # 黄色
|
||||
(128, 0, 128), # 深紫色
|
||||
(0, 128, 128), # 橄榄色
|
||||
]
|
||||
|
||||
# 模型配置
|
||||
model_configs = {
|
||||
'yolov8n': {'line_width': 2, 'font_size': 20},
|
||||
'yolov8s': {'line_width': 2, 'font_size': 18}
|
||||
}
|
||||
def __init__(self, use_pil: bool = True):
|
||||
"""
|
||||
初始化可视化器
|
||||
|
||||
# 读取测试图像
|
||||
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
Args:
|
||||
use_pil: 是否使用PIL绘制(支持中文)
|
||||
"""
|
||||
self.use_pil = use_pil
|
||||
self.font = font
|
||||
|
||||
# 绘制所有检测结果
|
||||
frame = renderer.draw_all_detections(
|
||||
frame,
|
||||
detections,
|
||||
model_configs,
|
||||
enable_nms=True, # 启用NMS去重
|
||||
show_model_name=True
|
||||
def should_draw_detection(self, class_id: int, confidence: float,
|
||||
tags: Dict[str, Dict]) -> Tuple[bool, Optional[Tuple[int, int, int]]]:
|
||||
"""
|
||||
判断是否应该绘制检测框
|
||||
|
||||
Args:
|
||||
class_id: 类别ID
|
||||
confidence: 置信度
|
||||
tags: 标签配置字典
|
||||
|
||||
Returns:
|
||||
(是否绘制, 颜色)
|
||||
"""
|
||||
class_id_str = str(class_id)
|
||||
|
||||
# 如果标签配置为空,默认不绘制任何标签(因为需求是只绘制配置的标签)
|
||||
if not tags:
|
||||
return False, None
|
||||
|
||||
# 如果标签不在配置中,不绘制
|
||||
tag_config = tags.get(class_id_str)
|
||||
if not tag_config:
|
||||
return False, None
|
||||
|
||||
# 检查select标记
|
||||
if not tag_config.get('select', True):
|
||||
return False, None
|
||||
|
||||
# 检查置信度阈值
|
||||
reliability = tag_config.get('reliability', 0)
|
||||
if confidence < reliability:
|
||||
return False, None
|
||||
|
||||
# 获取自定义颜色
|
||||
color = tag_config.get('color')
|
||||
if color and isinstance(color, (list, tuple)) and len(color) >= 3:
|
||||
return True, tuple(color[:3])
|
||||
|
||||
return True, None
|
||||
|
||||
def draw_with_pil(self, frame: np.ndarray, detections: List[Detection],
|
||||
confidence_threshold: float) -> np.ndarray:
|
||||
"""使用PIL绘制检测结果(支持中文)"""
|
||||
# 转换到PIL格式
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
for det in detections:
|
||||
model_idx = det.model_idx
|
||||
boxes = det.boxes
|
||||
confidences = det.confidences
|
||||
class_names = det.class_names
|
||||
class_ids = det.class_ids
|
||||
tags = det.tags
|
||||
|
||||
for i, (box, conf, cls_id, cls_name) in enumerate(
|
||||
zip(boxes, confidences, class_ids, class_names)):
|
||||
|
||||
if conf < confidence_threshold:
|
||||
continue
|
||||
|
||||
should_draw, custom_color = self.should_draw_detection(
|
||||
cls_id, conf, tags)
|
||||
if not should_draw:
|
||||
continue
|
||||
|
||||
# 使用自定义颜色或模型颜色
|
||||
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
|
||||
|
||||
x1, y1, x2, y2 = map(int, box[:4])
|
||||
|
||||
# 标签文本
|
||||
label = f"{cls_name} {conf:.2f}"
|
||||
|
||||
# 绘制矩形框
|
||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
||||
|
||||
# 绘制标签背景
|
||||
if self.font:
|
||||
try:
|
||||
text_bbox = draw.textbbox((x1, y1 - 25), label, font=self.font)
|
||||
draw.rectangle(text_bbox, fill=color)
|
||||
draw.text((x1, y1 - 25), label, fill=(255, 255, 255), font=self.font)
|
||||
except:
|
||||
# 字体失败,回退到OpenCV
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(frame, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
else:
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(frame, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
def draw_with_opencv(self, frame: np.ndarray, detections: List[Detection],
|
||||
confidence_threshold: float) -> np.ndarray:
|
||||
"""使用OpenCV绘制检测结果"""
|
||||
frame_drawn = frame.copy()
|
||||
|
||||
for det in detections:
|
||||
model_idx = det.model_idx
|
||||
model_name = det.model_name
|
||||
boxes = det.boxes
|
||||
confidences = det.confidences
|
||||
class_names = det.class_names
|
||||
class_ids = det.class_ids
|
||||
tags = det.tags
|
||||
|
||||
for i, (box, conf, cls_id, cls_name) in enumerate(
|
||||
zip(boxes, confidences, class_ids, class_names)):
|
||||
|
||||
if conf < confidence_threshold:
|
||||
continue
|
||||
|
||||
should_draw, custom_color = self.should_draw_detection(
|
||||
cls_id, conf, tags)
|
||||
if not should_draw:
|
||||
continue
|
||||
|
||||
# 使用自定义颜色或模型颜色
|
||||
color = custom_color or self.MODEL_COLORS[model_idx % len(self.MODEL_COLORS)]
|
||||
|
||||
x1, y1, x2, y2 = map(int, box[:4])
|
||||
|
||||
# 绘制矩形框
|
||||
cv2.rectangle(frame_drawn, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# 绘制标签
|
||||
label = f"{model_name}: {cls_name} {conf:.2f}"
|
||||
|
||||
# 计算文本大小
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2
|
||||
)
|
||||
|
||||
# 显示结果
|
||||
cv2.imshow('Detections', frame)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
# 绘制标签背景
|
||||
cv2.rectangle(frame_drawn,
|
||||
(x1, y1 - text_height - 10),
|
||||
(x1 + text_width, y1),
|
||||
color, -1)
|
||||
|
||||
# 绘制文本
|
||||
cv2.putText(frame_drawn, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||||
|
||||
return frame_drawn
|
||||
|
||||
def draw(self, frame: np.ndarray, detections: List[Detection],
|
||||
confidence_threshold: float) -> np.ndarray:
|
||||
"""绘制检测结果"""
|
||||
if self.use_pil:
|
||||
try:
|
||||
return self.draw_with_pil(frame, detections, confidence_threshold)
|
||||
except Exception as e:
|
||||
print(f"PIL绘制失败,使用OpenCV: {e}")
|
||||
return self.draw_with_opencv(frame, detections, confidence_threshold)
|
||||
else:
|
||||
return self.draw_with_opencv(frame, detections, confidence_threshold)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
class YOLOModelWrapper:
|
||||
"""YOLO模型包装器"""
|
||||
|
||||
@staticmethod
|
||||
def infer_yolov8(model: Any, frame: np.ndarray, conf_thres: float,
|
||||
iou_thres: float, imgsz: int, half: bool, device: str):
|
||||
"""YOLOv8推理"""
|
||||
with torch.no_grad():
|
||||
results = model(frame, conf=conf_thres, verbose=False,
|
||||
iou=iou_thres, imgsz=imgsz, half=half, device=device)
|
||||
|
||||
if not results or len(results) == 0:
|
||||
return np.array([]), np.array([]), np.array([]), [], None
|
||||
|
||||
result = results[0]
|
||||
if not hasattr(result, 'boxes'):
|
||||
return np.array([]), np.array([]), np.array([]), [], None
|
||||
|
||||
boxes = result.boxes.xyxy.cpu().numpy()
|
||||
confidences = result.boxes.conf.cpu().numpy()
|
||||
class_ids = result.boxes.cls.cpu().numpy().astype(int)
|
||||
class_names = [result.names.get(cid, str(cid)) for cid in class_ids]
|
||||
|
||||
return boxes, confidences, class_ids, class_names, result
|
||||
|
||||
@staticmethod
|
||||
def infer_yolov5(model: Any, frame: np.ndarray):
|
||||
"""YOLOv5推理"""
|
||||
results = model(frame)
|
||||
|
||||
if not hasattr(results, 'xyxy'):
|
||||
return np.array([]), np.array([]), np.array([]), [], None
|
||||
|
||||
detections = results.xyxy[0].cpu().numpy()
|
||||
if len(detections) == 0:
|
||||
return np.array([]), np.array([]), np.array([]), [], None
|
||||
|
||||
boxes = detections[:, :4]
|
||||
confidences = detections[:, 4]
|
||||
class_ids = detections[:, 5].astype(int)
|
||||
|
||||
if hasattr(results, 'names'):
|
||||
class_names = [results.names.get(cid, str(cid)) for cid in class_ids]
|
||||
else:
|
||||
class_names = [str(cid) for cid in class_ids]
|
||||
|
||||
return boxes, confidences, class_ids, class_names, results
|
||||
|
||||
|
||||
def multi_model_inference(
|
||||
models_config: List[Union[Dict, ModelConfig]],
|
||||
frame: np.ndarray,
|
||||
confidence_threshold: float = 0.25,
|
||||
parallel: bool = True,
|
||||
use_pil: bool = True,
|
||||
use_plot_for_single: bool = True # 新增参数:是否对单个模型使用plot绘制
|
||||
) -> Tuple[np.ndarray, List[Detection]]:
|
||||
"""
|
||||
多模型并行推理
|
||||
|
||||
Args:
|
||||
models_config: 模型配置列表
|
||||
frame: 视频帧 (BGR格式)
|
||||
confidence_threshold: 全局置信度阈值
|
||||
parallel: 是否并行推理
|
||||
use_pil: 是否使用PIL绘制
|
||||
use_plot_for_single: 当只有一个模型时,是否使用result.plot()绘制
|
||||
|
||||
Returns:
|
||||
(绘制完成的帧, 检测结果列表)
|
||||
"""
|
||||
# 转换为ModelConfig对象
|
||||
configs = []
|
||||
for cfg in models_config:
|
||||
if isinstance(cfg, dict):
|
||||
configs.append(ModelConfig(**cfg))
|
||||
else:
|
||||
configs.append(cfg)
|
||||
|
||||
def single_model_inference(model_cfg: ModelConfig, model_idx: int) -> Detection:
|
||||
"""单个模型的推理函数"""
|
||||
try:
|
||||
model = model_cfg.model
|
||||
|
||||
# 根据模型类型进行推理
|
||||
if hasattr(model, 'predict'): # YOLOv8
|
||||
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov8(
|
||||
model, frame, model_cfg.conf_thres, model_cfg.iou_thres,
|
||||
model_cfg.imgsz, model_cfg.half, model_cfg.device
|
||||
)
|
||||
elif hasattr(model, '__call__'): # YOLOv5
|
||||
boxes, confidences, class_ids, class_names, raw_result = YOLOModelWrapper.infer_yolov5(model, frame)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {type(model)}")
|
||||
|
||||
return Detection(
|
||||
model_idx=model_idx,
|
||||
model_name=model_cfg.name,
|
||||
boxes=boxes,
|
||||
confidences=confidences,
|
||||
class_ids=class_ids,
|
||||
class_names=class_names,
|
||||
tags=model_cfg.tags,
|
||||
raw_result=raw_result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型 {model_idx} ({model_cfg.name}) 推理失败: {e}")
|
||||
return Detection(
|
||||
model_idx=model_idx,
|
||||
model_name=model_cfg.name,
|
||||
boxes=np.array([]),
|
||||
confidences=np.array([]),
|
||||
class_ids=np.array([]),
|
||||
class_names=[],
|
||||
tags=model_cfg.tags,
|
||||
raw_result=None
|
||||
)
|
||||
|
||||
# 并行推理
|
||||
if parallel and len(configs) > 1:
|
||||
detections = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(configs)) as executor:
|
||||
# 提交任务
|
||||
future_to_idx = {
|
||||
executor.submit(single_model_inference, cfg, idx): idx
|
||||
for idx, cfg in enumerate(configs)
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
try:
|
||||
detection = future.result(timeout=2.0)
|
||||
detections.append(detection)
|
||||
except concurrent.futures.TimeoutError:
|
||||
idx = future_to_idx[future]
|
||||
print(f"模型 {idx} 推理超时")
|
||||
# 创建空的检测结果
|
||||
detections.append(Detection(
|
||||
model_idx=idx,
|
||||
model_name=configs[idx].name,
|
||||
boxes=np.array([]),
|
||||
confidences=np.array([]),
|
||||
class_ids=np.array([]),
|
||||
class_names=[],
|
||||
tags=configs[idx].tags,
|
||||
raw_result=None
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"模型推理异常: {e}")
|
||||
else:
|
||||
# 顺序推理
|
||||
detections = [single_model_inference(cfg, idx)
|
||||
for idx, cfg in enumerate(configs)]
|
||||
|
||||
# 按照模型索引排序
|
||||
detections.sort(key=lambda x: x.model_idx)
|
||||
|
||||
# 新增:如果只有一个模型且配置使用plot绘制,则使用result.plot()
|
||||
if len(configs) == 1 and use_plot_for_single:
|
||||
single_detection = detections[0]
|
||||
if single_detection.raw_result is not None:
|
||||
try:
|
||||
# 检查是否有plot方法
|
||||
if hasattr(single_detection.raw_result, 'plot'):
|
||||
# 使用plot方法绘制结果
|
||||
frame_drawn = single_detection.raw_result.plot()
|
||||
# 确保返回的是numpy数组
|
||||
if not isinstance(frame_drawn, np.ndarray):
|
||||
frame_drawn = np.array(frame_drawn)
|
||||
# plot方法通常返回RGB图像,转换为BGR
|
||||
if len(frame_drawn.shape) == 3 and frame_drawn.shape[2] == 3:
|
||||
frame_drawn = cv2.cvtColor(frame_drawn, cv2.COLOR_RGB2BGR)
|
||||
print(f"使用 {single_detection.model_name} 的 plot() 方法绘制结果")
|
||||
return frame_drawn, detections
|
||||
else:
|
||||
print(f"模型 {single_detection.model_name} 的结果对象没有 plot() 方法,使用自定义绘制")
|
||||
except Exception as e:
|
||||
print(f"使用 plot() 方法绘制失败: {e},回退到自定义绘制")
|
||||
|
||||
# 绘制结果
|
||||
visualizer = DetectionVisualizer(use_pil=use_pil)
|
||||
frame_drawn = visualizer.draw(frame.copy(), detections, confidence_threshold)
|
||||
|
||||
return frame_drawn, detections
|
||||
|
||||
|
||||
# 兼容旧接口
|
||||
def multi_model_inference_legacy(_models: List[Dict], frame: np.ndarray,
|
||||
confidence_threshold: float = 0.25,
|
||||
parallel: bool = True,
|
||||
use_plot_for_single: bool = True) -> Tuple[np.ndarray, List[Dict]]:
|
||||
"""旧接口兼容函数"""
|
||||
frame_drawn, detections = multi_model_inference(
|
||||
_models, frame, confidence_threshold, parallel, use_plot_for_single=use_plot_for_single
|
||||
)
|
||||
|
||||
# 转换为旧格式
|
||||
old_detections = []
|
||||
for det in detections:
|
||||
old_detections.append({
|
||||
'model_idx': det.model_idx,
|
||||
'model_name': det.model_name,
|
||||
'boxes': det.boxes,
|
||||
'confidences': det.confidences,
|
||||
'class_ids': det.class_ids,
|
||||
'class_names': det.class_names,
|
||||
'tags': det.tags
|
||||
})
|
||||
|
||||
return frame_drawn, old_detections
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from ultralytics import YOLO
|
||||
|
||||
print('加载模型中')
|
||||
model_paths = [
|
||||
r"F:\PyModelScope\Yolov\models\yolov8m.pt",
|
||||
r"F:\PyModelScope\Yolov\models\car.pt"
|
||||
]
|
||||
|
||||
print("预热模型...")
|
||||
model_list = []
|
||||
for i, path in enumerate(model_paths):
|
||||
model = YOLO(path)
|
||||
model.to('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
model.eval()
|
||||
|
||||
# 创建模型配置
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
name=f'model_{i}',
|
||||
device='cuda:0',
|
||||
conf_thres=0.45,
|
||||
iou_thres=0.45,
|
||||
half=False,
|
||||
imgsz=1920,
|
||||
tags={
|
||||
"0": {"name": "汽车", "reliability": 0.4, "select": True, "color": [0, 255, 0]},
|
||||
"1": {"name": "行人", "reliability": 0.3, "select": True, "color": [255, 0, 0]},
|
||||
"2": {"name": "自行车", "reliability": 0.5, "select": False, "color": [0, 0, 255]}
|
||||
},
|
||||
config=None
|
||||
)
|
||||
model_list.append(model_config)
|
||||
|
||||
print("模型预热完成")
|
||||
image_path = r"F:\PyModelScope\Yolov\images\444.png"
|
||||
frame = cv2.imread(image_path)
|
||||
|
||||
# 测试单个模型的情况 - 使用plot绘制
|
||||
print("\n=== 测试单个模型 (使用plot绘制) ===")
|
||||
single_model_list = [model_list[0]] # 只使用第一个模型
|
||||
frame_drawn_single, detections_single = multi_model_inference(
|
||||
single_model_list, frame, use_plot_for_single=True
|
||||
)
|
||||
print(f"单个模型检测结果: {len(detections_single[0].boxes)} 个目标")
|
||||
cv2.imwrite("uploads/result_single_plot.jpg", frame_drawn_single)
|
||||
print("结果已保存到 uploads/result_single_plot.jpg")
|
||||
|
||||
# 测试单个模型的情况 - 强制使用自定义绘制
|
||||
print("\n=== 测试单个模型 (强制使用自定义绘制) ===")
|
||||
frame_drawn_single_custom, detections_single_custom = multi_model_inference(
|
||||
single_model_list, frame, use_plot_for_single=False
|
||||
)
|
||||
print(f"单个模型自定义绘制结果: {len(detections_single_custom[0].boxes)} 个目标")
|
||||
cv2.imwrite("uploads/result_single_custom.jpg", frame_drawn_single_custom)
|
||||
print("结果已保存到 uploads/result_single_custom.jpg")
|
||||
|
||||
# 测试多个模型的情况
|
||||
print("\n=== 测试多个模型 (使用自定义绘制) ===")
|
||||
frame_drawn_multi, detections_multi = multi_model_inference(
|
||||
model_list, frame, use_plot_for_single=True # 即使设为True,多个模型也会使用自定义绘制
|
||||
)
|
||||
print(f"多个模型检测结果:")
|
||||
for det in detections_multi:
|
||||
print(f" 模型 {det.model_name}: 检测到 {len(det.boxes)} 个目标")
|
||||
for box, conf, cls_id, cls_name in zip(det.boxes, det.confidences,
|
||||
det.class_ids, det.class_names):
|
||||
print(f"box:{box},conf:{conf},cls_id:{cls_id},cls_name:{cls_name}")
|
||||
if conf >= 0.25: # 全局阈值
|
||||
should_draw, color = DetectionVisualizer().should_draw_detection(
|
||||
cls_id, conf, det.tags)
|
||||
status = "绘制" if should_draw else "不绘制"
|
||||
print(f" {cls_name} (置信度: {conf:.2f}): {status}")
|
||||
|
||||
cv2.imwrite("uploads/result_multi.jpg", frame_drawn_multi)
|
||||
print("结果已保存到 uploads/result_multi.jpg")
|
||||
Binary file not shown.
2
log.py
2
log.py
|
|
@ -9,7 +9,7 @@ import sys
|
|||
|
||||
def setup_logger():
|
||||
"""优化日志系统 - 减少磁盘I/O"""
|
||||
logger = logging.getLogger("YOLOv8 Optimized")
|
||||
logger = logging.getLogger("YOLOv Optimized")
|
||||
logger.setLevel(logging.DEBUG) # 改为DEBUG级别以查看更多信息
|
||||
|
||||
# 清除现有处理器
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 273 KiB |
|
|
@ -3,7 +3,7 @@
|
|||
<head>
|
||||
<title>模型文件上传</title>
|
||||
<meta charset="UTF-8">
|
||||
<script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/axios/0.21.1/axios.min.js"></script>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
|
|
|
|||
8644
yolo_detection.log
8644
yolo_detection.log
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue