'11'
parent
0b51193b11
commit
a97071369c
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
82
config.py
82
config.py
|
|
@ -1,65 +1,57 @@
|
|||
# config.py
|
||||
import torch
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def generate_model_key():
|
||||
"""生成模型加密密钥"""
|
||||
return Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# config.py
|
||||
def get_default_config():
|
||||
return {
|
||||
'rtmp': {
|
||||
# 'url': "rtmp://123.132.248.154:6009/live/14",
|
||||
'url': "rtmp://localhost:1935/live/14",
|
||||
'max_reconnect_attempts': 20, # 增加重连次数
|
||||
'reconnect_delay': 1, # 减少初始延迟
|
||||
'max_reconnect_attempts': 20,
|
||||
'reconnect_delay': 1,
|
||||
'buffer_size': 1,
|
||||
'timeout_ms': 5000,
|
||||
'gpu_decode': True # 启用硬件解码
|
||||
'gpu_decode': True
|
||||
},
|
||||
'push': {
|
||||
'enable_push': True,
|
||||
# 'url': 'rtmp://123.132.248.154:6009/live/11',
|
||||
'url': 'rtmp://localhost:1935/live/13',
|
||||
'format': 'flv',
|
||||
'video_codec': 'h264_nvenc' if torch.cuda.is_available() else 'libx264', # 使用硬件编码
|
||||
'video_codec': 'libx264',
|
||||
'pixel_format': 'bgr24',
|
||||
'preset': 'p1' if torch.cuda.is_available() else 'ultrafast', # NVIDIA专用预设
|
||||
'framerate': 30,
|
||||
'gpu_acceleration': True, # 启用硬件加速
|
||||
'tune': 'll', # 低延迟模式
|
||||
'zerolatency': 1, # 零延迟
|
||||
'delay': 0, # 无延迟
|
||||
'rc': 'cbr_ld_hq', # 恒定码率低延迟高质量
|
||||
'bufsize': '500k' # 减少缓冲区大小
|
||||
},
|
||||
'model': {
|
||||
'path': 'yolo11x.pt', # 保留大模型
|
||||
'download_url': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8x.pt'
|
||||
},
|
||||
'predict': {
|
||||
'conf_thres': 0.25, # 提高置信度阈值
|
||||
'iou_thres': 0.45,
|
||||
'imgsz': 1280, # 初始推理尺寸
|
||||
'line_width': 1,
|
||||
'font_size': 18,
|
||||
'font':"E:\\tsgz\\fonts\\PingFangSC-Medium.ttf",
|
||||
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu',
|
||||
'frame_skip': 1, # 初始跳帧值
|
||||
'half': True # 启用半精度推理
|
||||
'preset': 'veryfast',
|
||||
'framerate': 25,
|
||||
'gpu_acceleration': False,
|
||||
'tune': 'zerolatency',
|
||||
'crf': 28,
|
||||
'bitrate': '1500k',
|
||||
'bufsize': '3000k'
|
||||
},
|
||||
'models': [ # 只保留多模型配置
|
||||
# 默认模型配置可以在创建任务时被覆盖
|
||||
],
|
||||
'task': {
|
||||
'taskname': '', # 文件夹名称
|
||||
'taskname': '',
|
||||
'taskid': '',
|
||||
'tag': {},
|
||||
'aiid': '',
|
||||
'res_api': 'http://123.132.248.154:6033/api/DaHuaAi/AddImg',
|
||||
'api': 'http://123.132.248.154:6033/'
|
||||
},
|
||||
'mqtt': {
|
||||
'enable': True, # 是否启用MQTT
|
||||
'broker': '175.27.168.120', # MQTT代理地址
|
||||
'port': 6011, # MQTT端口
|
||||
'topic': 'thing/product/1581F8HGX254V00A0BUY/osd', # 订阅的主题
|
||||
'client_id': 'yolo_detection_client', # 客户端ID------自己生成个
|
||||
'username': 'sdhc', # 用户名(可选)
|
||||
'password': None, # 密码(可选)
|
||||
'keepalive': 60 # 保活时间
|
||||
'enable': True,
|
||||
'broker': '175.27.168.120',
|
||||
'port': 6011,
|
||||
'topic': 'thing/product/1581F8HGX254V00A0BUY/osd',
|
||||
'client_id': 'yolo_detection_client',
|
||||
'username': 'sdhc',
|
||||
'password': None,
|
||||
'keepalive': 60
|
||||
},
|
||||
'minio': {
|
||||
"UseSSL": False,
|
||||
|
|
@ -67,5 +59,15 @@ def get_default_config():
|
|||
"AccessKey": "minioadmin",
|
||||
"SecretKey": "minioadmin",
|
||||
"BucketName": "test"
|
||||
}
|
||||
},
|
||||
'resource_limits': {
|
||||
'max_cpu_percent': 80,
|
||||
'max_memory_percent': 80,
|
||||
'max_gpu_memory_percent': 80,
|
||||
'max_concurrent_tasks': 5,
|
||||
'min_concurrent_tasks': 1,
|
||||
'check_interval': 5,
|
||||
'adjust_threshold': 5
|
||||
},
|
||||
'model_path': 'models'
|
||||
}
|
||||
|
|
|
|||
1735
detectionThread.py
1735
detectionThread.py
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,310 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class OptimizedDetectionRenderer:
|
||||
def __init__(self, font_path=None):
|
||||
self.font_cache = {}
|
||||
self.font_path = font_path
|
||||
self.drawn_labels = set() # 记录已经绘制过的标签
|
||||
self.detection_cache = defaultdict(list) # 缓存同一位置的检测结果
|
||||
|
||||
def _get_font(self, size):
|
||||
"""获取字体对象,带缓存"""
|
||||
if size in self.font_cache:
|
||||
return self.font_cache[size]
|
||||
|
||||
font_paths = []
|
||||
if self.font_path and os.path.exists(self.font_path):
|
||||
font_paths.append(self.font_path)
|
||||
|
||||
# 添加常用字体路径
|
||||
font_paths.extend([
|
||||
"simhei.ttf",
|
||||
"msyh.ttc",
|
||||
"C:/Windows/Fonts/simhei.ttf",
|
||||
"C:/Windows/Fonts/msyh.ttc",
|
||||
"C:/Windows/Fonts/Deng.ttf",
|
||||
"C:/Windows/Fonts/simsun.ttc",
|
||||
"/System/Library/Fonts/PingFang.ttc", # macOS
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" # Linux
|
||||
])
|
||||
|
||||
font = None
|
||||
for path in font_paths:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
font = ImageFont.truetype(path, size, encoding="utf-8")
|
||||
print(f"加载字体: {path}")
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if font is None:
|
||||
try:
|
||||
font = ImageFont.load_default()
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
self.font_cache[size] = font
|
||||
return font
|
||||
|
||||
def compute_iou(self, box1, box2):
|
||||
"""计算两个边界框的IoU"""
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2 - x1) * (y2 - y1)
|
||||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
|
||||
return intersection / (area1 + area2 - intersection)
|
||||
|
||||
def filter_duplicate_detections(self, detections, iou_threshold=0.3):
|
||||
"""过滤重复检测结果,保留置信度最高的"""
|
||||
filtered_detections = []
|
||||
|
||||
# 按置信度降序排序
|
||||
print(detections)
|
||||
sorted_detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
|
||||
|
||||
for det in sorted_detections:
|
||||
is_duplicate = False
|
||||
|
||||
for kept_det in filtered_detections:
|
||||
iou = self.compute_iou(det['box'], kept_det['box'])
|
||||
|
||||
# 如果IoU超过阈值,认为是同一目标
|
||||
if iou > iou_threshold:
|
||||
is_duplicate = True
|
||||
|
||||
# 如果是相同类别,用置信度高的
|
||||
if det['class_name'] == kept_det['class_name']:
|
||||
# 如果当前检测置信度更高,替换
|
||||
if det['confidence'] > kept_det['confidence']:
|
||||
filtered_detections.remove(kept_det)
|
||||
filtered_detections.append(det)
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
filtered_detections.append(det)
|
||||
|
||||
return filtered_detections
|
||||
|
||||
def draw_text_with_background(self, img, text, position, font_size=20,
|
||||
text_color=(0, 255, 0), bg_color=None, padding=5):
|
||||
"""绘制带背景的文本(自动调整背景大小)"""
|
||||
try:
|
||||
# 转换为PIL图像
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(img_rgb)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 获取字体
|
||||
font = self._get_font(font_size)
|
||||
|
||||
# 计算文本尺寸
|
||||
if hasattr(font, 'getbbox'): # 新版本PIL
|
||||
bbox = font.getbbox(text)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
else: # 旧版本PIL
|
||||
text_width, text_height = font.getsize(text)
|
||||
|
||||
# 计算背景位置
|
||||
x, y = position
|
||||
bg_x1 = x - padding
|
||||
bg_y1 = y - padding
|
||||
bg_x2 = x + text_width + padding
|
||||
bg_y2 = y + text_height + padding
|
||||
|
||||
# 如果提供了背景颜色,绘制背景
|
||||
if bg_color:
|
||||
# 将BGR转换为RGB
|
||||
rgb_bg_color = bg_color[::-1]
|
||||
draw.rectangle([bg_x1, bg_y1, bg_x2, bg_y2], fill=rgb_bg_color)
|
||||
|
||||
# 绘制文本
|
||||
rgb_text_color = text_color[::-1]
|
||||
draw.text(position, text, font=font, fill=rgb_text_color)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
except Exception as e:
|
||||
print(f"文本渲染失败: {e},使用后备方案")
|
||||
# 后备方案
|
||||
if bg_color:
|
||||
cv2.rectangle(img, (position[0] - padding, position[1] - padding),
|
||||
(position[0] + len(text) * 10 + padding, position[1] + font_size + padding),
|
||||
bg_color, cv2.FILLED)
|
||||
cv2.putText(img, text, position, cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_size / 30, text_color, 2)
|
||||
return img
|
||||
|
||||
def draw_detection(self, frame, detection_info, model_config,
|
||||
show_model_name=True, show_confidence=True):
|
||||
"""在帧上绘制检测结果"""
|
||||
x1, y1, x2, y2 = map(int, detection_info['box'])
|
||||
color = tuple(detection_info['color'])
|
||||
line_width = model_config.get('line_width', 2)
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, line_width)
|
||||
|
||||
# 准备标签文本
|
||||
label_parts = []
|
||||
|
||||
if show_model_name and 'model_name' in detection_info:
|
||||
label_parts.append(detection_info['model_name'])
|
||||
|
||||
if 'class_name' in detection_info:
|
||||
label_parts.append(detection_info['class_name'])
|
||||
|
||||
if show_confidence and 'confidence' in detection_info:
|
||||
label_parts.append(f"{detection_info['confidence']:.2f}")
|
||||
|
||||
label = ":".join(label_parts)
|
||||
|
||||
# 生成标签的唯一标识
|
||||
label_key = f"{detection_info['class_name']}_{x1}_{y1}"
|
||||
|
||||
# 检查是否已经绘制过相同类别的标签(在一定区域内)
|
||||
label_drawn = False
|
||||
for drawn_label in self.drawn_labels:
|
||||
drawn_class, drawn_x, drawn_y = drawn_label.split('_')
|
||||
drawn_x, drawn_y = int(drawn_x), int(drawn_y)
|
||||
|
||||
# 计算距离,如果很近且类别相同,认为已经绘制过
|
||||
distance = np.sqrt((x1 - drawn_x) ** 2 + (y1 - drawn_y) ** 2)
|
||||
if distance < 50 and detection_info['class_name'] == drawn_class:
|
||||
label_drawn = True
|
||||
break
|
||||
|
||||
if not label_drawn:
|
||||
# 计算标签位置(放在框的上方,如果上方空间不够则放在下方)
|
||||
label_y = y1 - 20
|
||||
if label_y < 20: # 如果上方空间不够
|
||||
label_y = y2 + 20
|
||||
|
||||
# 绘制带背景的标签
|
||||
frame = self.draw_text_with_background(
|
||||
frame, label,
|
||||
(x1, label_y),
|
||||
font_size=model_config.get('font_size', 20),
|
||||
text_color=(255, 255, 255),
|
||||
bg_color=color,
|
||||
padding=3
|
||||
)
|
||||
|
||||
# 记录已绘制的标签
|
||||
self.drawn_labels.add(label_key)
|
||||
|
||||
return frame
|
||||
|
||||
def draw_all_detections(self, frame, all_detections, model_configs,
|
||||
enable_nms=True, show_model_name=True):
|
||||
"""绘制所有检测结果(主入口函数)"""
|
||||
# 重置已绘制标签记录
|
||||
self.drawn_labels.clear()
|
||||
|
||||
if not all_detections:
|
||||
return frame
|
||||
|
||||
# 如果需要,过滤重复检测
|
||||
if enable_nms:
|
||||
filtered_detections = self.filter_duplicate_detections(all_detections)
|
||||
else:
|
||||
filtered_detections = all_detections
|
||||
|
||||
# 按置信度排序,先绘制置信度低的,再绘制置信度高的
|
||||
sorted_detections = sorted(filtered_detections, key=lambda x: x.get('confidence', 0))
|
||||
|
||||
# 绘制每个检测结果
|
||||
for detection in sorted_detections:
|
||||
model_id = detection.get('model_id')
|
||||
model_config = model_configs.get(model_id, {})
|
||||
frame = self.draw_detection(frame, detection, model_config, show_model_name)
|
||||
|
||||
return frame
|
||||
|
||||
def put_text_simple(self, img, text, position, font_size=20, color=(0, 255, 0)):
|
||||
"""简化版文本绘制函数"""
|
||||
return self.draw_text_with_background(
|
||||
img, text, position, font_size, color, None, 0
|
||||
)
|
||||
|
||||
|
||||
# 使用示例
|
||||
def main():
|
||||
# 初始化渲染器
|
||||
renderer = OptimizedDetectionRenderer()
|
||||
|
||||
# 模拟多模型检测结果
|
||||
detections = [
|
||||
{
|
||||
'model_id': 'yolov8n',
|
||||
'model_name': 'YOLOv8',
|
||||
'class_id': 0,
|
||||
'class_name': 'person',
|
||||
'confidence': 0.85,
|
||||
'box': [100, 100, 200, 300],
|
||||
'reliability': 0.9,
|
||||
'color': (0, 255, 0)
|
||||
},
|
||||
{
|
||||
'model_id': 'yolov8s',
|
||||
'model_name': 'YOLOv8s',
|
||||
'class_id': 0,
|
||||
'class_name': 'person',
|
||||
'confidence': 0.75,
|
||||
'box': [110, 110, 210, 310], # 与第一个重叠
|
||||
'reliability': 0.8,
|
||||
'color': (0, 0, 255)
|
||||
},
|
||||
{
|
||||
'model_id': 'yolov8n',
|
||||
'model_name': 'YOLOv8',
|
||||
'class_id': 2,
|
||||
'class_name': 'car',
|
||||
'confidence': 0.95,
|
||||
'box': [300, 150, 450, 250],
|
||||
'reliability': 0.95,
|
||||
'color': (255, 0, 0)
|
||||
}
|
||||
]
|
||||
|
||||
# 模型配置
|
||||
model_configs = {
|
||||
'yolov8n': {'line_width': 2, 'font_size': 20},
|
||||
'yolov8s': {'line_width': 2, 'font_size': 18}
|
||||
}
|
||||
|
||||
# 读取测试图像
|
||||
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
# 绘制所有检测结果
|
||||
frame = renderer.draw_all_detections(
|
||||
frame,
|
||||
detections,
|
||||
model_configs,
|
||||
enable_nms=True, # 启用NMS去重
|
||||
show_model_name=True
|
||||
)
|
||||
|
||||
# 显示结果
|
||||
cv2.imshow('Detections', frame)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 39 KiB |
|
|
@ -0,0 +1,8 @@
|
|||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||||
|
||||
encryptor = MandatoryModelEncryptor()
|
||||
result = encryptor.encrypt_model(
|
||||
model_path="E:\DC\Yolov\models\yolov8n.pt",
|
||||
output_path="models/yolov8n_encrypted.pt",
|
||||
password="12345678"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,16 +1,51 @@
|
|||
# globalvar.py
|
||||
|
||||
def _init():
|
||||
global _global_dict
|
||||
_global_dict = {}
|
||||
# global_data.py
|
||||
import threading
|
||||
|
||||
|
||||
def set_value(name, value):
|
||||
_global_dict[name] = value
|
||||
class GlobalData:
|
||||
"""全局数据管理类,支持多任务"""
|
||||
|
||||
def __init__(self):
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
"""初始化全局字典和锁"""
|
||||
self._global_dict = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set_value(self, name, value):
|
||||
"""设置全局值"""
|
||||
with self._lock:
|
||||
self._global_dict[name] = value
|
||||
|
||||
def get_value(self, name, defValue=None):
|
||||
"""获取全局值"""
|
||||
with self._lock:
|
||||
try:
|
||||
return self._global_dict[name]
|
||||
except KeyError:
|
||||
return defValue
|
||||
|
||||
def get_or_create_dict(self, name):
|
||||
"""获取或创建字典"""
|
||||
with self._lock:
|
||||
if name not in self._global_dict:
|
||||
self._global_dict[name] = {}
|
||||
return self._global_dict[name]
|
||||
|
||||
def remove_value(self, name):
|
||||
"""移除全局值"""
|
||||
with self._lock:
|
||||
if name in self._global_dict:
|
||||
del self._global_dict[name]
|
||||
|
||||
def clear_all_tasks(self):
|
||||
"""清除所有任务"""
|
||||
with self._lock:
|
||||
for key in list(self._global_dict.keys()):
|
||||
if key.startswith('task_'):
|
||||
del self._global_dict[key]
|
||||
|
||||
|
||||
def get_value(name, defValue=None):
|
||||
try:
|
||||
return _global_dict[name]
|
||||
except KeyError:
|
||||
return defValue
|
||||
# 创建全局实例
|
||||
gd = GlobalData()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,149 @@
|
|||
# key_manager.py
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from log import logger
|
||||
from global_data import gd
|
||||
|
||||
|
||||
class EncryptionKeyManager:
|
||||
"""加密密钥管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.keys_store = {} # task_id -> model_keys
|
||||
self.key_history = {} # 密钥使用历史
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def register_task_keys(self, task_id, model_configs):
|
||||
"""注册任务的加密密钥"""
|
||||
with self.lock:
|
||||
if task_id not in self.keys_store:
|
||||
self.keys_store[task_id] = {}
|
||||
|
||||
key_info = {
|
||||
'task_id': task_id,
|
||||
'models': [],
|
||||
'registered_at': datetime.now().isoformat(),
|
||||
'key_count': 0
|
||||
}
|
||||
|
||||
for i, model_cfg in enumerate(model_configs):
|
||||
encryption_key = model_cfg.get('encryption_key')
|
||||
if encryption_key:
|
||||
# 计算密钥哈希(不存储原始密钥)
|
||||
key_hash = hashlib.sha256(encryption_key.encode()).hexdigest()
|
||||
|
||||
model_key_info = {
|
||||
'model_index': i,
|
||||
'model_path': model_cfg.get('path', 'unknown'),
|
||||
'key_hash': key_hash,
|
||||
'short_hash': key_hash[:16],
|
||||
'key_provided': True
|
||||
}
|
||||
|
||||
self.keys_store[task_id][f'model_{i}'] = model_key_info
|
||||
key_info['models'].append(model_key_info)
|
||||
key_info['key_count'] += 1
|
||||
|
||||
# 记录密钥使用历史
|
||||
history_key = f"{task_id}_{key_hash[:8]}"
|
||||
self.key_history[history_key] = {
|
||||
'task_id': task_id,
|
||||
'model_index': i,
|
||||
'key_hash': key_hash,
|
||||
'used_at': datetime.now().isoformat(),
|
||||
'model_path': model_cfg.get('path', 'unknown')
|
||||
}
|
||||
|
||||
logger.info(f"注册任务 {task_id} 的 {key_info['key_count']} 个加密密钥")
|
||||
return key_info
|
||||
|
||||
def validate_task_keys(self, task_id):
|
||||
"""验证任务的加密密钥"""
|
||||
with self.lock:
|
||||
if task_id not in self.keys_store:
|
||||
return {
|
||||
'valid': False,
|
||||
'error': '任务未注册密钥',
|
||||
'key_count': 0
|
||||
}
|
||||
|
||||
key_info = self.keys_store[task_id]
|
||||
valid_keys = len(key_info)
|
||||
|
||||
return {
|
||||
'valid': True,
|
||||
'key_count': valid_keys,
|
||||
'models': list(key_info.keys()),
|
||||
'last_updated': self.get_last_key_update(task_id)
|
||||
}
|
||||
|
||||
def get_last_key_update(self, task_id):
|
||||
"""获取密钥最后更新时间"""
|
||||
if task_id not in self.keys_store:
|
||||
return None
|
||||
|
||||
# 从历史记录中查找
|
||||
for history in self.key_history.values():
|
||||
if history['task_id'] == task_id:
|
||||
return history['used_at']
|
||||
|
||||
return None
|
||||
|
||||
def cleanup_task_keys(self, task_id):
|
||||
"""清理任务的加密密钥"""
|
||||
with self.lock:
|
||||
if task_id in self.keys_store:
|
||||
key_count = len(self.keys_store[task_id])
|
||||
del self.keys_store[task_id]
|
||||
logger.info(f"清理任务 {task_id} 的 {key_count} 个加密密钥")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_key_statistics(self):
|
||||
"""获取密钥统计信息"""
|
||||
with self.lock:
|
||||
total_tasks = len(self.keys_store)
|
||||
total_keys = sum(len(keys) for keys in self.keys_store.values())
|
||||
total_history = len(self.key_history)
|
||||
|
||||
return {
|
||||
'total_tasks': total_tasks,
|
||||
'total_keys': total_keys,
|
||||
'total_history': total_history,
|
||||
'active_tasks': list(self.keys_store.keys())
|
||||
}
|
||||
|
||||
def verify_key_for_model(self, task_id, model_index, provided_key):
|
||||
"""验证特定模型的密钥"""
|
||||
with self.lock:
|
||||
if task_id not in self.keys_store:
|
||||
return {'valid': False, 'error': '任务未注册'}
|
||||
|
||||
model_key = f'model_{model_index}'
|
||||
if model_key not in self.keys_store[task_id]:
|
||||
return {'valid': False, 'error': '模型未注册密钥'}
|
||||
|
||||
# 计算提供的密钥哈希
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
stored_hash = self.keys_store[task_id][model_key]['key_hash']
|
||||
|
||||
if provided_hash == stored_hash:
|
||||
return {
|
||||
'valid': True,
|
||||
'key_hash': stored_hash,
|
||||
'short_hash': stored_hash[:16]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'valid': False,
|
||||
'error': '密钥不匹配',
|
||||
'provided_hash': provided_hash[:16],
|
||||
'stored_hash': stored_hash[:16]
|
||||
}
|
||||
|
||||
|
||||
# 全局密钥管理器实例
|
||||
key_manager = EncryptionKeyManager()
|
||||
29
log.py
29
log.py
|
|
@ -1,34 +1,42 @@
|
|||
# log.py - 增强日志输出
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import sys
|
||||
|
||||
|
||||
# 配置日志处理
|
||||
# 在 log.py 中添加更详细的日志格式
|
||||
|
||||
def setup_logger():
|
||||
"""优化日志系统 - 减少磁盘I/O"""
|
||||
logger = logging.getLogger("YOLOv8 Optimized")
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.setLevel(logging.DEBUG) # 改为DEBUG级别以查看更多信息
|
||||
|
||||
# 清除现有处理器
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# 优化控制台日志格式
|
||||
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] [%(module)s:%(lineno)d] %(message)s'
|
||||
)
|
||||
|
||||
# 控制台处理器 - 设置较低的延迟
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# 文件处理器 - 限制日志大小和使用异步写入
|
||||
log_file = 'yolo_detection.log'
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=5 * 1024 * 1024, # 减小为5MB
|
||||
backupCount=3,
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding='utf-8',
|
||||
delay=False # 禁用延迟打开文件
|
||||
delay=False
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
|
@ -38,6 +46,13 @@ def setup_logger():
|
|||
logging.getLogger("engineio").setLevel(logging.WARNING)
|
||||
logging.getLogger("socketio").setLevel(logging.WARNING)
|
||||
|
||||
# 设置其他库的日志级别
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
logging.getLogger("requests").setLevel(logging.WARNING)
|
||||
|
||||
# 设置任务管理器的日志级别
|
||||
logging.getLogger("task_manager").setLevel(logging.DEBUG)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
|
|
|
|||
59
main.py
59
main.py
|
|
@ -1,34 +1,59 @@
|
|||
# main.py
|
||||
# 主程序入口
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import torch
|
||||
from log import logger
|
||||
|
||||
import global_data as gd
|
||||
from global_data import gd
|
||||
from server import socketio, app
|
||||
from task_manager import task_manager
|
||||
|
||||
# 初始化全局数据
|
||||
gd._init()
|
||||
|
||||
gd.set_value('detection_thread', None)
|
||||
gd.set_value('detection_active', False)
|
||||
gd.set_value('stop_event', threading.Event())
|
||||
gd.set_value('mqtt_client', None)
|
||||
# 设置全局变量
|
||||
gd.set_value('task_manager', task_manager) # 直接在这里初始化
|
||||
gd.set_value('latest_drone_data', None)
|
||||
gd.set_value('mqtt_data_lock', threading.Lock())
|
||||
|
||||
# 初始化任务管理器
|
||||
logger.info("任务管理器初始化...")
|
||||
logger.info(f"任务管理器实例: {task_manager}")
|
||||
|
||||
# main.py 修改部分
|
||||
|
||||
# main.py 修改部分
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("启动优化版YOLOv8服务")
|
||||
# 优化2: 使用最新版本YOLO和PyTorch特性
|
||||
logger.info("启动多任务版YOLOv8服务")
|
||||
logger.info(f"PyTorch版本: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}")
|
||||
|
||||
# 初始化任务推流管理器
|
||||
try:
|
||||
from task_stream_manager import task_stream_manager
|
||||
|
||||
task_stream_manager.start_health_monitor()
|
||||
logger.info("任务推流管理器初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化推流管理器失败: {str(e)}")
|
||||
|
||||
|
||||
# 退出服务
|
||||
def graceful_exit(signum, frame):
|
||||
logger.info("收到退出信号,停止服务...")
|
||||
detection_active = gd.get_value('detection_active')
|
||||
detection_thread = gd.get_value('detection_thread')
|
||||
if detection_active and detection_thread:
|
||||
detection_thread.stop()
|
||||
detection_thread.join(5.0)
|
||||
logger.info("收到退出信号,停止所有服务...")
|
||||
|
||||
# 停止所有任务推流
|
||||
try:
|
||||
from task_stream_manager import task_stream_manager
|
||||
task_stream_manager.cleanup_all()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 停止所有任务
|
||||
if task_manager:
|
||||
task_manager.cleanup_all_tasks()
|
||||
|
||||
logger.info("服务已退出")
|
||||
os._exit(0)
|
||||
|
||||
|
|
@ -36,10 +61,10 @@ if __name__ == '__main__':
|
|||
signal.signal(signal.SIGINT, graceful_exit)
|
||||
signal.signal(signal.SIGTERM, graceful_exit)
|
||||
|
||||
# 启动服务 - 禁用开发服务器调试
|
||||
# 启动服务
|
||||
socketio.run(app,
|
||||
host='0.0.0.0',
|
||||
port=9309,
|
||||
debug=True, # 禁用调试模式
|
||||
debug=True,
|
||||
use_reloader=False,
|
||||
allow_unsafe_werkzeug=True)
|
||||
allow_unsafe_werkzeug=True)
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
# mandatory_model_crypto.py
|
||||
import os
|
||||
import tempfile
|
||||
import hashlib
|
||||
import pickle
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
import base64
|
||||
from log import logger
|
||||
|
||||
|
||||
class MandatoryModelEncryptor:
|
||||
"""强制模型加密器 - 所有模型必须加密"""
|
||||
|
||||
@staticmethod
|
||||
def encrypt_model(model_path, output_path, password, require_encryption=True):
|
||||
"""加密模型文件 - 强制模式"""
|
||||
try:
|
||||
# 读取模型文件
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = f.read()
|
||||
|
||||
# 计算模型哈希
|
||||
model_hash = hashlib.sha256(model_data).hexdigest()
|
||||
|
||||
# 生成加密密钥
|
||||
salt = os.urandom(16)
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||||
fernet = Fernet(key)
|
||||
|
||||
# 加密数据
|
||||
encrypted_data = fernet.encrypt(model_data)
|
||||
|
||||
# 保存加密数据
|
||||
encrypted_payload = {
|
||||
'salt': salt,
|
||||
'data': encrypted_data,
|
||||
'model_hash': model_hash,
|
||||
'original_size': len(model_data),
|
||||
'encrypted': True,
|
||||
'version': '1.0'
|
||||
}
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
pickle.dump(encrypted_payload, f)
|
||||
|
||||
logger.info(f"模型强制加密成功: {model_path} -> {output_path}")
|
||||
logger.info(f"模型哈希: {model_hash[:16]}...")
|
||||
|
||||
# 返回密钥信息(用于验证)
|
||||
return {
|
||||
'success': True,
|
||||
'model_hash': model_hash,
|
||||
'key_hash': hashlib.sha256(key).hexdigest()[:16],
|
||||
'output_path': output_path
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型强制加密失败: {str(e)}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
@staticmethod
|
||||
def decrypt_model(encrypted_path, password, verify_key=True):
|
||||
"""解密模型文件 - 带密钥验证"""
|
||||
try:
|
||||
if not os.path.exists(encrypted_path):
|
||||
return {'success': False, 'error': '加密模型文件不存在'}
|
||||
|
||||
# 读取加密文件
|
||||
with open(encrypted_path, 'rb') as f:
|
||||
encrypted_payload = pickle.load(f)
|
||||
|
||||
# 验证加密格式
|
||||
if not encrypted_payload.get('encrypted', False):
|
||||
return {'success': False, 'error': '模型未加密'}
|
||||
|
||||
salt = encrypted_payload['salt']
|
||||
encrypted_data = encrypted_payload['data']
|
||||
expected_hash = encrypted_payload.get('model_hash', '')
|
||||
|
||||
# 生成解密密钥
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||||
|
||||
# 验证密钥(可选)
|
||||
if verify_key:
|
||||
key_hash = hashlib.sha256(key).hexdigest()[:16]
|
||||
logger.debug(f"解密密钥哈希: {key_hash}")
|
||||
|
||||
fernet = Fernet(key)
|
||||
|
||||
# 解密数据
|
||||
decrypted_data = fernet.decrypt(encrypted_data)
|
||||
|
||||
# 验证模型哈希
|
||||
actual_hash = hashlib.sha256(decrypted_data).hexdigest()
|
||||
if expected_hash and actual_hash != expected_hash:
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...'
|
||||
}
|
||||
|
||||
# 保存到临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
||||
tmp.write(decrypted_data)
|
||||
temp_path = tmp.name
|
||||
|
||||
logger.info(f"模型解密验证成功: {encrypted_path}")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'temp_path': temp_path,
|
||||
'model_hash': actual_hash,
|
||||
'original_size': len(decrypted_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "InvalidToken" in error_msg or "Invalid signature" in error_msg:
|
||||
return {'success': False, 'error': '解密密钥错误'}
|
||||
return {'success': False, 'error': f'解密失败: {error_msg}'}
|
||||
|
||||
@staticmethod
|
||||
def is_properly_encrypted(model_path):
|
||||
"""检查模型是否被正确加密"""
|
||||
try:
|
||||
with open(model_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
# 检查必要的加密字段
|
||||
required_fields = ['salt', 'data', 'encrypted', 'model_hash']
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
return False
|
||||
|
||||
return data.get('encrypted', False) is True
|
||||
|
||||
except:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def generate_secure_key():
|
||||
"""生成安全的加密密钥"""
|
||||
# 生成随机密钥
|
||||
key = Fernet.generate_key()
|
||||
|
||||
# 生成密钥指纹
|
||||
key_hash = hashlib.sha256(key).hexdigest()
|
||||
|
||||
return {
|
||||
'key': key.decode('utf-8'),
|
||||
'key_hash': key_hash,
|
||||
'short_hash': key_hash[:16]
|
||||
}
|
||||
|
||||
|
||||
class MandatoryModelManager:
|
||||
"""强制加密模型管理器"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.models_dir = "encrypted_models"
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
|
||||
# 加载加密器
|
||||
self.encryptor = MandatoryModelEncryptor()
|
||||
|
||||
# 模型缓存
|
||||
self.model_cache = {}
|
||||
|
||||
def load_encrypted_model(self, model_config):
|
||||
"""加载加密模型 - 必须提供密钥"""
|
||||
try:
|
||||
model_path = model_config['path']
|
||||
encryption_key = model_config.get('encryption_key')
|
||||
|
||||
# 必须提供密钥
|
||||
if not encryption_key:
|
||||
raise ValueError(f"模型 {model_path} 必须提供加密密钥")
|
||||
|
||||
# 构建本地路径
|
||||
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
||||
|
||||
# 检查本地文件是否存在
|
||||
if not os.path.exists(local_path):
|
||||
# 尝试下载(如果提供下载地址)
|
||||
if not self.download_encrypted_model(model_config, local_path):
|
||||
raise FileNotFoundError(f"加密模型文件不存在且无法下载: {local_path}")
|
||||
|
||||
# 验证是否为正确加密的模型
|
||||
if not self.encryptor.is_properly_encrypted(local_path):
|
||||
raise ValueError(f"模型文件未正确加密: {local_path}")
|
||||
|
||||
# 解密模型
|
||||
decrypt_result = self.encryptor.decrypt_model(local_path, encryption_key)
|
||||
|
||||
if not decrypt_result['success']:
|
||||
raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}")
|
||||
|
||||
temp_path = decrypt_result['temp_path']
|
||||
|
||||
# 加载YOLO模型
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(temp_path)
|
||||
|
||||
# 应用设备配置
|
||||
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
|
||||
# 应用半精度配置
|
||||
if model_config.get('half', False) and 'cuda' in device:
|
||||
model = model.half()
|
||||
logger.info(f"启用半精度推理: {model_path}")
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 记录模型信息
|
||||
model_hash = decrypt_result.get('model_hash', 'unknown')[:16]
|
||||
logger.info(f"加密模型加载成功: {model_path} -> {device} [哈希: {model_hash}...]")
|
||||
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载加密模型失败 {model_config.get('path')}: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def download_encrypted_model(self, model_config, save_path):
|
||||
"""下载加密模型文件"""
|
||||
try:
|
||||
download_url = model_config.get('download_url')
|
||||
|
||||
if not download_url:
|
||||
logger.error(f"加密模型无下载地址: {model_config['path']}")
|
||||
return False
|
||||
|
||||
logger.info(f"下载加密模型: {download_url} -> {save_path}")
|
||||
|
||||
response = requests.get(download_url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
downloaded = 0
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
downloaded += len(chunk)
|
||||
f.write(chunk)
|
||||
|
||||
if total_size > 0:
|
||||
progress = (downloaded * 100) // total_size
|
||||
if progress % 25 == 0:
|
||||
logger.info(f"下载进度: {progress}%")
|
||||
|
||||
logger.info(f"加密模型下载完成: {save_path} ({downloaded} 字节)")
|
||||
|
||||
# 验证下载的文件是否正确加密
|
||||
if not self.encryptor.is_properly_encrypted(save_path):
|
||||
logger.error(f"下载的文件不是正确加密的模型: {save_path}")
|
||||
os.remove(save_path)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载加密模型失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def encrypt_existing_model(self, model_path, output_path, password):
|
||||
"""加密现有模型文件"""
|
||||
return self.encryptor.encrypt_model(model_path, output_path, password)
|
||||
|
||||
def verify_model_key(self, model_path, encryption_key):
|
||||
"""验证模型密钥是否正确"""
|
||||
try:
|
||||
if not os.path.exists(model_path):
|
||||
return {'valid': False, 'error': '模型文件不存在'}
|
||||
|
||||
if not self.encryptor.is_properly_encrypted(model_path):
|
||||
return {'valid': False, 'error': '模型文件未正确加密'}
|
||||
|
||||
# 尝试解密(不保存文件)
|
||||
result = self.encryptor.decrypt_model(model_path, encryption_key)
|
||||
|
||||
if result['success']:
|
||||
# 清理临时文件
|
||||
if 'temp_path' in result and os.path.exists(result['temp_path']):
|
||||
try:
|
||||
os.unlink(result['temp_path'])
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
'valid': True,
|
||||
'model_hash': result.get('model_hash', '')[:16],
|
||||
'original_size': result.get('original_size', 0)
|
||||
}
|
||||
else:
|
||||
return {'valid': False, 'error': result.get('error', '解密失败')}
|
||||
|
||||
except Exception as e:
|
||||
return {'valid': False, 'error': str(e)}
|
||||
217
mapping_cn.py
217
mapping_cn.py
|
|
@ -1,25 +1,194 @@
|
|||
# 优化3: 使用更紧凑的类别映射
|
||||
class_mapping_cn = dict(
|
||||
{
|
||||
'0': {
|
||||
"name": "汽车",
|
||||
"reliability": 0.5
|
||||
},
|
||||
'1': {
|
||||
"name": "卡车",
|
||||
"reliability": 0.5
|
||||
},
|
||||
'2': {
|
||||
"name": "公交车",
|
||||
"reliability": 0.5
|
||||
},
|
||||
'3': {
|
||||
"name": "商用车",
|
||||
"reliability": 0.5
|
||||
},
|
||||
'4': {
|
||||
"name": "大货车",
|
||||
"reliability": 0.5
|
||||
}
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
|
||||
|
||||
class ChineseTextRenderer:
|
||||
def __init__(self, font_path=None):
|
||||
self.font_cache = {}
|
||||
self.font_path = font_path
|
||||
|
||||
def put_text(self, img, text, position, font_size=20, color=(0, 255, 0)):
|
||||
"""安全的中文文本绘制函数"""
|
||||
# 输入验证
|
||||
if img is None:
|
||||
raise ValueError("输入图像不能为None")
|
||||
|
||||
if not isinstance(img, np.ndarray):
|
||||
raise TypeError(f"输入图像必须是numpy数组,实际是{type(img)}")
|
||||
|
||||
if len(img.shape) != 3 or img.shape[2] != 3:
|
||||
# 尝试转换
|
||||
if len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif len(img.shape) == 3 and img.shape[2] == 4:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
else:
|
||||
raise ValueError(f"不支持的图像格式,shape={img.shape}")
|
||||
|
||||
try:
|
||||
# BGR转RGB
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(img_rgb)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 获取字体
|
||||
font = self._get_font(font_size)
|
||||
|
||||
# 颜色转换:BGR到RGB
|
||||
rgb_color = color[::-1]
|
||||
|
||||
# 绘制文本
|
||||
draw.text(position, text, font=font, fill=rgb_color)
|
||||
|
||||
# RGB转BGR
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
except Exception as e:
|
||||
print(f"中文渲染失败: {e},使用英文后备")
|
||||
# 后备方案:使用OpenCV绘制英文
|
||||
cv2.putText(img, text, position, cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_size / 30, color, 2)
|
||||
return img
|
||||
|
||||
def _get_font(self, size):
|
||||
"""获取字体对象,带缓存"""
|
||||
if size in self.font_cache:
|
||||
return self.font_cache[size]
|
||||
|
||||
# 查找字体文件
|
||||
font_paths = []
|
||||
if self.font_path and os.path.exists(self.font_path):
|
||||
font_paths.append(self.font_path)
|
||||
|
||||
# 添加常用字体路径
|
||||
font_paths.extend([
|
||||
"simhei.ttf",
|
||||
"msyh.ttc",
|
||||
"C:/Windows/Fonts/simhei.ttf",
|
||||
"C:/Windows/Fonts/msyh.ttc",
|
||||
"C:/Windows/Fonts/Deng.ttf",
|
||||
"C:/Windows/Fonts/simsun.ttc",
|
||||
])
|
||||
|
||||
# 遍历所有可能的字体路径
|
||||
font = None
|
||||
for path in font_paths:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
font = ImageFont.truetype(path, size, encoding="utf-8")
|
||||
print(f"加载字体: {path}")
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# 如果找不到字体,尝试创建默认字体
|
||||
if font is None:
|
||||
try:
|
||||
font = ImageFont.load_default()
|
||||
# 调整默认字体大小
|
||||
font = ImageFont.truetype("arial.ttf", size)
|
||||
except:
|
||||
# 最后手段:使用PIL的默认字体
|
||||
font = ImageFont.load_default()
|
||||
|
||||
self.font_cache[size] = font
|
||||
return font
|
||||
|
||||
|
||||
def ensure_image_valid(img, default_size=(640, 480)):
|
||||
"""确保图像有效,如果无效则创建默认图像"""
|
||||
if img is None or not isinstance(img, np.ndarray) or img.size == 0:
|
||||
print("创建默认图像...")
|
||||
img = np.zeros((default_size[0], default_size[1], 3), dtype=np.uint8)
|
||||
img.fill(50)
|
||||
# 添加一些参考线
|
||||
h, w = img.shape[:2]
|
||||
cv2.line(img, (0, 0), (w, h), (0, 255, 0), 1)
|
||||
cv2.line(img, (w, 0), (0, h), (0, 255, 0), 1)
|
||||
cv2.circle(img, (w // 2, h // 2), min(w, h) // 4, (255, 0, 0), 2)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# 使用示例
|
||||
def run_yolo_with_chinese():
|
||||
# 初始化中文渲染器
|
||||
renderer = ChineseTextRenderer()
|
||||
|
||||
# 模拟YOLO检测结果
|
||||
# 注意:这里假设你已经有检测结果
|
||||
detections = [
|
||||
{"bbox": [100, 100, 200, 300], "conf": 0.95, "class": "person"},
|
||||
{"bbox": [300, 150, 450, 350], "conf": 0.88, "class": "car"},
|
||||
]
|
||||
|
||||
# 中英文类别映射
|
||||
class_map = {
|
||||
"person": "人",
|
||||
"car": "汽车",
|
||||
"bicycle": "自行车",
|
||||
"dog": "狗",
|
||||
"cat": "猫",
|
||||
"chair": "椅子",
|
||||
"bottle": "瓶子"
|
||||
}
|
||||
)
|
||||
|
||||
# 读取图像或创建默认图像
|
||||
img_path = "test.jpg"
|
||||
if os.path.exists(img_path):
|
||||
img = cv2.imread(img_path)
|
||||
else:
|
||||
# 用户选择图像文件
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
|
||||
root = tk.Tk()
|
||||
root.withdraw()
|
||||
img_path = filedialog.askopenfilename(
|
||||
title="选择图像文件",
|
||||
filetypes=[("Image files", "*.jpg *.jpeg *.png *.bmp *.tiff")]
|
||||
)
|
||||
|
||||
if img_path:
|
||||
img = cv2.imread(img_path)
|
||||
else:
|
||||
print("未选择文件,创建测试图像")
|
||||
img = None
|
||||
|
||||
# 确保图像有效
|
||||
img = ensure_image_valid(img)
|
||||
|
||||
# 绘制检测结果
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
conf = det["conf"]
|
||||
cls_en = det["class"]
|
||||
cls_cn = class_map.get(cls_en, cls_en)
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# 绘制中文标签
|
||||
label = f"{cls_cn}: {conf:.2f}"
|
||||
img = renderer.put_text(img, label, (x1, max(y1 - 20, 10)),
|
||||
font_size=15, color=(0, 255, 0))
|
||||
|
||||
# 添加标题
|
||||
img = renderer.put_text(img, "YOLO检测结果", (10, 30),
|
||||
font_size=25, color=(255, 255, 0))
|
||||
|
||||
# 显示结果
|
||||
cv2.imshow("YOLO Detection with Chinese", img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 保存结果
|
||||
output_path = "detection_result.jpg"
|
||||
cv2.imwrite(output_path, img)
|
||||
print(f"结果已保存: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_yolo_with_chinese()
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
# model_crypto.py
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import requests
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
import base64
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from log import logger
|
||||
|
||||
|
||||
class ModelEncryptor:
|
||||
"""模型加密/解密器"""
|
||||
|
||||
@staticmethod
|
||||
def generate_key(password: str, salt: bytes = None):
|
||||
"""生成加密密钥"""
|
||||
if salt is None:
|
||||
salt = os.urandom(16)
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||||
return key, salt
|
||||
|
||||
@staticmethod
|
||||
def encrypt_model(model_path: str, output_path: str, password: str):
|
||||
"""加密模型文件"""
|
||||
try:
|
||||
# 读取模型文件
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = f.read()
|
||||
|
||||
# 生成密钥
|
||||
key, salt = ModelEncryptor.generate_key(password)
|
||||
fernet = Fernet(key)
|
||||
|
||||
# 加密数据
|
||||
encrypted_data = fernet.encrypt(model_data)
|
||||
|
||||
# 保存加密数据(包含salt)
|
||||
encrypted_payload = {
|
||||
'salt': salt,
|
||||
'data': encrypted_data,
|
||||
'original_size': len(model_data)
|
||||
}
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
pickle.dump(encrypted_payload, f)
|
||||
|
||||
logger.info(f"模型加密成功: {model_path} -> {output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型加密失败: {str(e)}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def decrypt_model(encrypted_path: str, password: str):
|
||||
"""解密模型到内存"""
|
||||
try:
|
||||
# 读取加密文件
|
||||
with open(encrypted_path, 'rb') as f:
|
||||
encrypted_payload = pickle.load(f)
|
||||
|
||||
salt = encrypted_payload['salt']
|
||||
encrypted_data = encrypted_payload['data']
|
||||
|
||||
# 生成密钥
|
||||
key, _ = ModelEncryptor.generate_key(password, salt)
|
||||
fernet = Fernet(key)
|
||||
|
||||
# 解密数据
|
||||
decrypted_data = fernet.decrypt(encrypted_data)
|
||||
|
||||
# 保存到临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
||||
tmp.write(decrypted_data)
|
||||
temp_path = tmp.name
|
||||
|
||||
logger.info(f"模型解密成功: {encrypted_path}")
|
||||
return temp_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型解密失败: {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_encrypted(model_path: str):
|
||||
"""检查模型是否加密"""
|
||||
try:
|
||||
with open(model_path, 'rb') as f:
|
||||
# 尝试读取加密格式
|
||||
data = pickle.load(f)
|
||||
return isinstance(data, dict) and 'salt' in data and 'data' in data
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器,支持加密模型加载"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.models_dir = "models"
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
|
||||
def load_model(self, model_config):
|
||||
"""加载模型(支持加密)"""
|
||||
model_path = model_config['path']
|
||||
encrypted = model_config.get('encrypted', False)
|
||||
encryption_key = model_config.get('encryption_key')
|
||||
|
||||
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
||||
|
||||
# 下载模型(如果不存在)
|
||||
if not os.path.exists(local_path):
|
||||
if not self.download_model(model_config):
|
||||
return None
|
||||
|
||||
# 如果是加密模型,需要解密
|
||||
if encrypted and encryption_key:
|
||||
if ModelEncryptor.is_encrypted(local_path):
|
||||
decrypted_path = ModelEncryptor.decrypt_model(local_path, encryption_key)
|
||||
if decrypted_path:
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(decrypted_path).to(model_config['device'])
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.unlink(decrypted_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"加载解密模型失败: {str(e)}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"模型未加密或密钥错误: {local_path}")
|
||||
return None
|
||||
else:
|
||||
# 普通模型加载
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(local_path).to(model_config['device'])
|
||||
|
||||
# 应用配置
|
||||
if model_config.get('half', False) and 'cuda' in model_config['device']:
|
||||
model = model.half()
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def download_model(self, model_config):
|
||||
"""下载模型"""
|
||||
try:
|
||||
model_path = model_config['path']
|
||||
download_url = model_config.get('download_url')
|
||||
|
||||
if not download_url:
|
||||
logger.error(f"模型无下载地址: {model_path}")
|
||||
return False
|
||||
|
||||
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
||||
|
||||
logger.info(f"下载模型: {download_url} -> {local_path}")
|
||||
|
||||
response = requests.get(download_url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
downloaded = 0
|
||||
|
||||
with open(local_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
f.write(chunk)
|
||||
|
||||
if total_size > 0:
|
||||
progress = downloaded * 100 // total_size
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"下载进度: {progress}%")
|
||||
|
||||
logger.info(f"模型下载完成: {local_path} ({downloaded} 字节)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载模型失败: {str(e)}")
|
||||
return False
|
||||
BIN
models/car.pt
BIN
models/car.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 273 KiB |
|
|
@ -0,0 +1,134 @@
|
|||
# platform_utils.py
|
||||
"""
|
||||
跨平台工具模块,处理不同操作系统的兼容性问题
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import select
|
||||
from log import logger
|
||||
|
||||
|
||||
def is_windows():
|
||||
"""检查是否是Windows系统"""
|
||||
return os.name == 'nt' or sys.platform.startswith('win')
|
||||
|
||||
|
||||
def is_linux():
|
||||
"""检查是否是Linux系统"""
|
||||
return os.name == 'posix' and sys.platform.startswith('linux')
|
||||
|
||||
|
||||
def is_macos():
|
||||
"""检查是否是macOS系统"""
|
||||
return os.name == 'posix' and sys.platform.startswith('darwin')
|
||||
|
||||
|
||||
def set_non_blocking(fd):
|
||||
"""设置文件描述符为非阻塞模式(跨平台)"""
|
||||
try:
|
||||
if is_windows():
|
||||
# Windows系统:使用msvcrt或io模块
|
||||
import msvcrt
|
||||
import io
|
||||
if hasattr(fd, 'fileno'):
|
||||
# 对于文件对象
|
||||
handle = msvcrt.get_osfhandle(fd.fileno())
|
||||
# Windows上的非阻塞设置更复杂,这里使用io模块
|
||||
# 实际上,对于Windows上的管道,通常使用异步I/O
|
||||
# 这里我们采用简化方案:使用超时读取
|
||||
pass
|
||||
else:
|
||||
# Unix系统:使用fcntl
|
||||
import fcntl
|
||||
import os as unix_os
|
||||
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, fl | unix_os.O_NONBLOCK)
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning("fcntl模块在Windows上不可用")
|
||||
except Exception as e:
|
||||
logger.warning(f"设置非阻塞模式失败: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def read_with_timeout(fd, timeout=0.5):
|
||||
"""
|
||||
带超时的读取(跨平台)
|
||||
返回:(has_data, data) 或 (False, None)
|
||||
"""
|
||||
if is_windows():
|
||||
# Windows: 使用select检查socket,但对于普通文件/管道可能不支持
|
||||
# 这里我们使用简单的轮询
|
||||
import time
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
data = fd.read(1) # 尝试读取1个字节
|
||||
if data:
|
||||
# 读取剩余部分
|
||||
remaining = fd.read()
|
||||
if remaining:
|
||||
data += remaining
|
||||
return True, data
|
||||
except (IOError, OSError) as e:
|
||||
if "句柄无效" in str(e) or "bad file descriptor" in str(e):
|
||||
return False, None
|
||||
# 其他错误,继续等待
|
||||
pass
|
||||
time.sleep(0.01)
|
||||
return False, None
|
||||
else:
|
||||
# Unix: 使用select
|
||||
try:
|
||||
ready, _, _ = select.select([fd], [], [], timeout)
|
||||
if ready:
|
||||
data = fd.read()
|
||||
return True, data
|
||||
except (ValueError, OSError):
|
||||
# 可能文件描述符无效
|
||||
pass
|
||||
return False, None
|
||||
|
||||
|
||||
def safe_readline(fd, timeout=1.0):
|
||||
"""安全读取一行(跨平台)"""
|
||||
if is_windows():
|
||||
# Windows: 使用简单的读取,带超时
|
||||
import time
|
||||
line = b""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# 尝试读取一个字节
|
||||
char = fd.read(1)
|
||||
if char:
|
||||
line += char
|
||||
if char == b'\n':
|
||||
return line.decode('utf-8', errors='ignore').strip()
|
||||
else:
|
||||
# 没有数据
|
||||
if line:
|
||||
# 有部分数据但没换行符
|
||||
return line.decode('utf-8', errors='ignore').strip()
|
||||
return None
|
||||
except (IOError, OSError) as e:
|
||||
if "句柄无效" in str(e) or "bad file descriptor" in str(e):
|
||||
return None
|
||||
time.sleep(0.01)
|
||||
|
||||
# 超时
|
||||
if line:
|
||||
return line.decode('utf-8', errors='ignore').strip()
|
||||
return None
|
||||
else:
|
||||
# Unix: 使用普通readline
|
||||
try:
|
||||
line = fd.readline()
|
||||
if line:
|
||||
return line.decode('utf-8', errors='ignore').strip()
|
||||
except (IOError, OSError):
|
||||
pass
|
||||
return None
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
# resource_monitor.py
|
||||
import psutil
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
|
||||
from log import logger
|
||||
from global_data import gd
|
||||
|
||||
try:
|
||||
import GPUtil
|
||||
|
||||
GPU_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_AVAILABLE = False
|
||||
logger.warning("GPUtil 未安装,GPU监控不可用")
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
"""系统资源监控器"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.running = True
|
||||
self.monitor_thread = None
|
||||
self.resource_history = []
|
||||
self.max_history = 100
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# 资源限制
|
||||
self.resource_limits = config.get('resource_limits', {
|
||||
'max_cpu_percent': 80,
|
||||
'max_memory_percent': 80,
|
||||
'max_gpu_memory_percent': 80,
|
||||
'max_concurrent_tasks': 5,
|
||||
'min_concurrent_tasks': 1
|
||||
})
|
||||
|
||||
# 动态调整参数
|
||||
self.current_max_tasks = self.resource_limits['max_concurrent_tasks']
|
||||
self.adjustment_factor = 1.0
|
||||
self.last_adjustment = time.time()
|
||||
|
||||
# GPU信息
|
||||
self.gpu_info = None
|
||||
if GPU_AVAILABLE and torch.cuda.is_available():
|
||||
self.init_gpu_monitor()
|
||||
|
||||
def init_gpu_monitor(self):
|
||||
"""初始化GPU监控"""
|
||||
try:
|
||||
gpus = GPUtil.getGPUs()
|
||||
self.gpu_info = []
|
||||
for gpu in gpus:
|
||||
self.gpu_info.append({
|
||||
'id': gpu.id,
|
||||
'name': gpu.name,
|
||||
'memory_total': gpu.memoryTotal,
|
||||
'driver_version': torch.version.cuda if torch.cuda.is_available() else 'Unknown'
|
||||
})
|
||||
logger.info(f"GPU监控已初始化: {len(gpus)}个GPU")
|
||||
except Exception as e:
|
||||
logger.error(f"GPU监控初始化失败: {str(e)}")
|
||||
self.gpu_info = None
|
||||
|
||||
def get_system_resources(self):
|
||||
"""获取系统资源使用情况"""
|
||||
resources = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'cpu_percent': psutil.cpu_percent(interval=0.1),
|
||||
'memory_percent': psutil.virtual_memory().percent,
|
||||
'memory_used': psutil.virtual_memory().used // (1024 * 1024), # MB
|
||||
'memory_total': psutil.virtual_memory().total // (1024 * 1024), # MB
|
||||
'disk_percent': psutil.disk_usage('/').percent,
|
||||
'network_io': psutil.net_io_counters()._asdict(),
|
||||
'process_count': len(psutil.pids()),
|
||||
}
|
||||
|
||||
# GPU信息
|
||||
if self.gpu_info is not None:
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_data = []
|
||||
for i, gpu in enumerate(gpus):
|
||||
gpu_data.append({
|
||||
'id': gpu.id,
|
||||
'name': gpu.name,
|
||||
'load': gpu.load * 100,
|
||||
'memory_used': gpu.memoryUsed,
|
||||
'memory_total': gpu.memoryTotal,
|
||||
'memory_percent': (gpu.memoryUsed / gpu.memoryTotal) * 100,
|
||||
'temperature': gpu.temperature,
|
||||
'driver_version': self.gpu_info[i]['driver_version'] if i < len(self.gpu_info) else 'Unknown'
|
||||
})
|
||||
resources['gpus'] = gpu_data
|
||||
|
||||
return resources
|
||||
|
||||
def check_resource_limits(self, resources):
|
||||
"""检查资源是否超过限制"""
|
||||
violations = []
|
||||
|
||||
# CPU检查
|
||||
if resources['cpu_percent'] > self.resource_limits['max_cpu_percent']:
|
||||
violations.append(f"CPU使用率过高: {resources['cpu_percent']:.1f}%")
|
||||
|
||||
# 内存检查
|
||||
if resources['memory_percent'] > self.resource_limits['max_memory_percent']:
|
||||
violations.append(f"内存使用率过高: {resources['memory_percent']:.1f}%")
|
||||
|
||||
# GPU检查
|
||||
if 'gpus' in resources:
|
||||
for gpu in resources['gpus']:
|
||||
if gpu['memory_percent'] > self.resource_limits['max_gpu_memory_percent']:
|
||||
violations.append(f"GPU{gpu['id']}内存使用率过高: {gpu['memory_percent']:.1f}%")
|
||||
|
||||
return violations
|
||||
|
||||
def adjust_concurrent_tasks(self, resources):
|
||||
"""根据资源使用动态调整并发任务数"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查调整间隔
|
||||
if current_time - self.last_adjustment < self.resource_limits['check_interval']:
|
||||
return self.current_max_tasks
|
||||
|
||||
violations = self.check_resource_limits(resources)
|
||||
|
||||
if violations:
|
||||
# 资源使用过高,减少并发任务
|
||||
if self.current_max_tasks > self.resource_limits['min_concurrent_tasks']:
|
||||
self.current_max_tasks -= 1
|
||||
logger.warning(f"资源使用过高,减少并发任务数至: {self.current_max_tasks}")
|
||||
logger.warning(f"违规项: {', '.join(violations)}")
|
||||
else:
|
||||
# 资源使用正常,尝试增加并发任务
|
||||
safety_margin = 0.8 # 安全边际
|
||||
cpu_headroom = (self.resource_limits['max_cpu_percent'] - resources['cpu_percent']) / 100
|
||||
memory_headroom = (self.resource_limits['max_memory_percent'] - resources['memory_percent']) / 100
|
||||
|
||||
# 考虑GPU内存
|
||||
gpu_headroom = 1.0
|
||||
if 'gpus' in resources:
|
||||
gpu_headrooms = []
|
||||
for gpu in resources['gpus']:
|
||||
gpu_headrooms.append((self.resource_limits['max_gpu_memory_percent'] - gpu['memory_percent']) / 100)
|
||||
gpu_headroom = min(gpu_headrooms) if gpu_headrooms else 1.0
|
||||
|
||||
# 计算可用资源比例
|
||||
available_resources = min(cpu_headroom, memory_headroom, gpu_headroom)
|
||||
|
||||
# 根据可用资源调整任务数
|
||||
if available_resources > 0.3: # 有30%以上余量
|
||||
if self.current_max_tasks < self.resource_limits['max_concurrent_tasks']:
|
||||
self.current_max_tasks += 1
|
||||
logger.info(f"资源充足,增加并发任务数至: {self.current_max_tasks}")
|
||||
elif available_resources < 0.1: # 余量不足10%
|
||||
if self.current_max_tasks > self.resource_limits['min_concurrent_tasks']:
|
||||
self.current_max_tasks -= 1
|
||||
logger.warning(f"资源紧张,减少并发任务数至: {self.current_max_tasks}")
|
||||
|
||||
self.last_adjustment = current_time
|
||||
return self.current_max_tasks
|
||||
|
||||
def monitor_loop(self):
|
||||
"""监控循环"""
|
||||
logger.info("资源监控线程启动")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 获取资源使用情况
|
||||
resources = self.get_system_resources()
|
||||
|
||||
with self.lock:
|
||||
# 保存历史记录
|
||||
self.resource_history.append(resources)
|
||||
if len(self.resource_history) > self.max_history:
|
||||
self.resource_history.pop(0)
|
||||
|
||||
# 动态调整并发任务数
|
||||
self.adjust_concurrent_tasks(resources)
|
||||
|
||||
# 更新全局数据
|
||||
gd.set_value('system_resources', resources)
|
||||
gd.set_value('max_concurrent_tasks', self.current_max_tasks)
|
||||
|
||||
# 记录资源使用情况(每分钟一次)
|
||||
if len(self.resource_history) % 12 == 0: # 5秒 * 12 = 60秒
|
||||
self.log_resource_summary(resources)
|
||||
|
||||
time.sleep(5) # 5秒检查一次
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"资源监控异常: {str(e)}")
|
||||
time.sleep(10)
|
||||
|
||||
logger.info("资源监控线程停止")
|
||||
|
||||
def log_resource_summary(self, resources):
|
||||
"""记录资源使用摘要"""
|
||||
summary = [
|
||||
f"CPU: {resources['cpu_percent']:.1f}%",
|
||||
f"内存: {resources['memory_percent']:.1f}% ({resources['memory_used']}/{resources['memory_total']}MB)",
|
||||
]
|
||||
|
||||
if 'gpus' in resources:
|
||||
for gpu in resources['gpus']:
|
||||
summary.append(f"GPU{gpu['id']}: {gpu['load']:.1f}%负载, {gpu['memory_percent']:.1f}%内存")
|
||||
|
||||
summary.append(f"并发任务限制: {self.current_max_tasks}")
|
||||
|
||||
logger.info("资源使用摘要: " + " | ".join(summary))
|
||||
|
||||
def start(self):
|
||||
"""启动监控"""
|
||||
self.running = True
|
||||
self.monitor_thread = threading.Thread(target=self.monitor_loop, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""停止监控"""
|
||||
self.running = False
|
||||
if self.monitor_thread and self.monitor_thread.is_alive():
|
||||
self.monitor_thread.join(5.0)
|
||||
|
||||
def get_resource_history(self, count=10):
|
||||
"""获取最近资源历史"""
|
||||
with self.lock:
|
||||
return self.resource_history[-count:] if self.resource_history else []
|
||||
|
||||
def get_current_resources(self):
|
||||
"""获取当前资源使用情况"""
|
||||
with self.lock:
|
||||
return self.resource_history[-1] if self.resource_history else None
|
||||
|
||||
|
||||
# 全局资源监控器实例
|
||||
resource_monitor = None
|
||||
|
||||
|
||||
def init_resource_monitor(config):
|
||||
"""初始化资源监控器"""
|
||||
global resource_monitor
|
||||
resource_monitor = ResourceMonitor(config)
|
||||
resource_monitor.start()
|
||||
return resource_monitor
|
||||
795
server.py
795
server.py
|
|
@ -1,127 +1,531 @@
|
|||
# server.py
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from flask import Flask, jsonify, request, render_template
|
||||
from flask_socketio import SocketIO
|
||||
from flask_cors import CORS
|
||||
from config import get_default_config
|
||||
from detectionThread import DetectionThread
|
||||
import global_data as gd
|
||||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||||
from task_manager import task_manager # 导入任务管理器
|
||||
from global_data import gd
|
||||
from log import logger
|
||||
from mapping_cn import class_mapping_cn
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# Flask初始化
|
||||
app = Flask(__name__, static_url_path='/static')
|
||||
CORS(app)
|
||||
socketio = SocketIO(app,
|
||||
cors_allowed_origins="*",
|
||||
async_mode='threading',
|
||||
allow_unsafe_werkzeug=True,
|
||||
max_http_buffer_size=5 * 1024 * 1024) # 增加WebSocket缓冲区
|
||||
max_http_buffer_size=5 * 1024 * 1024)
|
||||
|
||||
_initialized = False
|
||||
|
||||
|
||||
@app.before_request
|
||||
def initialize_once():
|
||||
global _initialized
|
||||
if not _initialized:
|
||||
with app.app_context():
|
||||
gd.set_value('task_manager', task_manager)
|
||||
logger.info("任务管理器初始化完成")
|
||||
_initialized = True
|
||||
|
||||
|
||||
# ======================= Flask路由 =======================
|
||||
@app.route('/', methods=['GET'])
|
||||
def main():
|
||||
@app.route('/')
|
||||
def task_management():
|
||||
"""任务管理页面"""
|
||||
return render_template("task_management.html")
|
||||
|
||||
|
||||
@app.route('/video_player')
|
||||
def video_player():
|
||||
"""视频播放页面"""
|
||||
return render_template("flv2.html")
|
||||
|
||||
|
||||
@app.route('/start_detection', methods=['POST'])
|
||||
def start_detection():
|
||||
detection_active = gd.get_value('detection_active')
|
||||
if detection_active:
|
||||
return jsonify({"status": "error", "message": "检测已在运行"}), 400
|
||||
@app.route('/api/tasks/create', methods=['POST'])
|
||||
def create_task():
|
||||
"""创建新任务 - 强制模型加密和密钥验证"""
|
||||
try:
|
||||
config = get_default_config()
|
||||
config['socketIO'] = socketio
|
||||
|
||||
config = get_default_config()
|
||||
config['socketIO'] = socketio
|
||||
if not request.json:
|
||||
return jsonify({"status": "error", "message": "请求体不能为空"}), 400
|
||||
|
||||
# 配置更新逻辑
|
||||
if request.json:
|
||||
# 更新RTMP地址
|
||||
if 'rtmp_url' in request.json:
|
||||
config['rtmp']['url'] = request.json['rtmp_url']
|
||||
# 更新推流地址
|
||||
if 'push_url' in request.json and request.json['push_url'] is not None:
|
||||
config['push']['url'] = request.json['push_url']
|
||||
# minio文件夹名称
|
||||
if 'taskname' in request.json:
|
||||
config['task']['taskname'] = request.json['taskname']
|
||||
# 标签
|
||||
data = request.json
|
||||
logger.info(f"收到创建任务请求: {data.get('taskname', '未命名')}")
|
||||
|
||||
if 'tag' in request.json and request.json['tag'] is not {}:
|
||||
config['task']['tag'] = request.json['tag']
|
||||
# 验证必须的参数
|
||||
if 'rtmp_url' not in data:
|
||||
return jsonify({"status": "error", "message": "必须提供rtmp_url"}), 400
|
||||
|
||||
if 'models' not in data or not isinstance(data['models'], list):
|
||||
return jsonify({"status": "error", "message": "必须提供models列表"}), 400
|
||||
|
||||
# 检查模型数量
|
||||
if len(data['models']) == 0:
|
||||
return jsonify({"status": "error", "message": "models列表不能为空"}), 400
|
||||
|
||||
# 更新配置
|
||||
config['rtmp']['url'] = data['rtmp_url']
|
||||
|
||||
if 'push_url' in data and data['push_url']:
|
||||
config['push']['url'] = data['push_url']
|
||||
|
||||
if 'taskname' in data:
|
||||
config['task']['taskname'] = data['taskname']
|
||||
else:
|
||||
config['task']['tag'] = class_mapping_cn
|
||||
config['task']['taskname'] = f"task_{int(time.time())}"
|
||||
|
||||
if 'taskid' in request.json:
|
||||
config['task']['taskid'] = request.json['taskid']
|
||||
# 性能参数调整
|
||||
if 'imgsz' in request.json:
|
||||
config['predict']['imgsz'] = max(128, min(1920, request.json['imgsz']))
|
||||
if 'frame_skip' in request.json:
|
||||
config['predict']['frame_skip'] = request.json['frame_skip']
|
||||
if 'model_name' in request.json:
|
||||
config['model']['path'] = request.json['model_name']
|
||||
if 'AlgoId' in request.json:
|
||||
config['task']['aiid'] = request.json['AlgoId']
|
||||
if 'device' in request.json:
|
||||
if request.json['device'] == "cuda:0" or "cpu":
|
||||
config['predict']['device'] = request.json['device']
|
||||
else:
|
||||
config['predict']['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
# 创建并启动线程
|
||||
detection_thread = DetectionThread(config)
|
||||
gd.set_value('detection_thread', detection_thread)
|
||||
detection_thread.start()
|
||||
gd.set_value('detection_active', True)
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "目标检测已启动"
|
||||
})
|
||||
if 'AlgoId' in data:
|
||||
config['task']['aiid'] = data['AlgoId']
|
||||
|
||||
# 处理多模型配置 - 强制加密验证
|
||||
config['models'] = []
|
||||
encryption_checker = MandatoryModelEncryptor()
|
||||
|
||||
for i, model_data in enumerate(data['models']):
|
||||
# 必须提供加密密钥
|
||||
encryption_key = model_data.get('encryption_key')
|
||||
if not encryption_key:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"模型 {i} ({model_data.get('path', 'unknown')}) 必须提供encryption_key"
|
||||
}), 400
|
||||
|
||||
model_path = model_data.get('path', f'model_{i}.pt')
|
||||
model_name = os.path.basename(model_path).split('.')[0]
|
||||
|
||||
# 检查模型文件是否加密(如果是本地文件)
|
||||
local_model_path = os.path.join(os.path.basename(model_path))
|
||||
# 如果本地文件存在,验证加密格式
|
||||
if os.path.exists(local_model_path):
|
||||
if not encryption_checker.is_properly_encrypted(local_model_path):
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"模型 {i} ({model_name}) 未正确加密"
|
||||
}), 400
|
||||
|
||||
# 构建模型配置
|
||||
model_config = {
|
||||
'path': model_path,
|
||||
'encryption_key': encryption_key, # 必须提供
|
||||
'encrypted': True, # 强制加密
|
||||
'tags': model_data.get('tags', {}),
|
||||
'conf_thres': float(model_data.get('conf_thres', 0.25)),
|
||||
'iou_thres': float(model_data.get('iou_thres', 0.45)),
|
||||
'imgsz': max(128, min(1920, int(model_data.get('imgsz', 640)))),
|
||||
'color': model_data.get('color'),
|
||||
'line_width': int(model_data.get('line_width', 1)),
|
||||
'device': model_data.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu'),
|
||||
'half': model_data.get('half', True),
|
||||
'enabled': model_data.get('enabled', True),
|
||||
'download_url': model_data.get('download_url') # 可选的下载地址
|
||||
}
|
||||
|
||||
config['models'].append(model_config)
|
||||
logger.info(f"添加加密模型 {i}: {model_name}")
|
||||
|
||||
# 创建任务
|
||||
logger.info(f"开始创建任务,包含 {len(config['models'])} 个加密模型...")
|
||||
|
||||
try:
|
||||
task_id = task_manager.create_task(config, socketio)
|
||||
logger.info(f"任务创建成功,ID: {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"任务创建失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"任务创建失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
# 启动任务
|
||||
logger.info(f"启动任务 {task_id}...")
|
||||
success = task_manager.start_task(task_id)
|
||||
|
||||
if success:
|
||||
logger.info(f"任务启动成功: {task_id}")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "任务创建并启动成功",
|
||||
"task_id": task_id,
|
||||
"models_count": len(config['models']),
|
||||
"encryption_required": True
|
||||
})
|
||||
else:
|
||||
logger.error(f"任务启动失败: {task_id}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "任务创建成功但启动失败",
|
||||
"task_id": task_id
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建任务失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"创建任务失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/stop_detection', methods=['POST'])
|
||||
def stop_detection():
|
||||
detection_active = gd.get_value('detection_active')
|
||||
detection_thread = gd.get_value('detection_thread')
|
||||
if not detection_active or not detection_thread:
|
||||
return jsonify({"status": "error", "message": "检测未运行"}), 400
|
||||
@app.route('/api/system/resources', methods=['GET'])
|
||||
def get_system_resources():
|
||||
"""获取系统资源使用情况"""
|
||||
try:
|
||||
resources = gd.get_value('system_resources')
|
||||
max_tasks = gd.get_value('max_concurrent_tasks', 5)
|
||||
|
||||
# 停止线程
|
||||
detection_thread.stop()
|
||||
if not resources:
|
||||
# 实时获取资源
|
||||
import psutil
|
||||
resources = {
|
||||
'cpu_percent': psutil.cpu_percent(),
|
||||
'memory_percent': psutil.virtual_memory().percent,
|
||||
'memory_used': psutil.virtual_memory().used // (1024 * 1024),
|
||||
'memory_total': psutil.virtual_memory().total // (1024 * 1024),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 等待不超过3秒
|
||||
detection_thread.join(3.0)
|
||||
# 获取任务统计
|
||||
active_tasks = task_manager.get_active_tasks_count()
|
||||
total_tasks = len(task_manager.tasks)
|
||||
|
||||
if detection_thread.is_alive():
|
||||
logger.warning("检测线程未在规定时间停止")
|
||||
else:
|
||||
logger.info("检测线程已停止")
|
||||
gd.set_value('detection_active', False)
|
||||
gd.set_value('detection_thread', None)
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "目标检测已停止"
|
||||
})
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": {
|
||||
"resources": resources,
|
||||
"tasks": {
|
||||
"active": active_tasks,
|
||||
"total": total_tasks,
|
||||
"max_concurrent": max_tasks
|
||||
}
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统资源失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取系统资源失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/status', methods=['GET'])
|
||||
def get_status():
|
||||
detection_active = gd.get_value('detection_active')
|
||||
detection_thread = gd.get_value('detection_thread')
|
||||
if detection_active and detection_thread:
|
||||
status = {
|
||||
"active": True,
|
||||
"fps": round(detection_thread.last_fps, 1), # 使用稳定FPS值
|
||||
"frame_count": detection_thread.frame_count,
|
||||
"detections_count": detection_thread.detections_count,
|
||||
"rtmp_url": detection_thread.rtmp_url,
|
||||
"reconnect_attempts": detection_thread.reconnect_attempts
|
||||
@app.route('/api/models/encrypt', methods=['POST'])
|
||||
def encrypt_model():
|
||||
"""加密模型文件 - 强制加密"""
|
||||
try:
|
||||
config = get_default_config()
|
||||
data = request.json
|
||||
model_path = data.get('model_path')
|
||||
output_path = data.get('output_path')
|
||||
password = data.get('password')
|
||||
download_url = data.get('download_url')
|
||||
if not all([model_path, output_path, password, download_url]):
|
||||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||||
local_path = os.path.join(config['model_path'], model_path)
|
||||
output_path = os.path.join(config['model_path'], output_path)
|
||||
# 验证输入文件是否存在
|
||||
if not os.path.exists(model_path):
|
||||
from model_crypto import ModelManager
|
||||
model_d = ModelManager(data)
|
||||
down_status = model_d.download_model({"path": model_path, "download_url": download_url})
|
||||
if not down_status:
|
||||
return jsonify({"status": "error", "message": f"模型文件不存在: {model_path}"}), 400
|
||||
|
||||
# 使用强制加密器
|
||||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||||
encryptor = MandatoryModelEncryptor()
|
||||
|
||||
result = encryptor.encrypt_model(local_path, output_path, password, require_encryption=True)
|
||||
|
||||
if result['success']:
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": f"模型加密成功: {output_path}",
|
||||
"data": {
|
||||
"model_hash": result.get('model_hash'),
|
||||
"key_hash": result.get('key_hash'),
|
||||
"output_path": result.get('output_path')
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"模型加密失败: {result.get('error', '未知错误')}"
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密模型失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"加密模型失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/models/verify_key', methods=['POST'])
|
||||
def verify_model_key():
|
||||
"""验证模型密钥"""
|
||||
try:
|
||||
data = request.json
|
||||
model_path = data.get('model_path')
|
||||
encryption_key = data.get('encryption_key')
|
||||
|
||||
if not all([model_path, encryption_key]):
|
||||
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
||||
|
||||
# 检查模型文件是否存在
|
||||
full_path = os.path.join('encrypted_models', os.path.basename(model_path))
|
||||
if not os.path.exists(full_path):
|
||||
return jsonify({"status": "error", "message": f"模型文件不存在: {full_path}"}), 400
|
||||
|
||||
# 使用强制加密器验证密钥
|
||||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||||
encryptor = MandatoryModelEncryptor()
|
||||
|
||||
# 检查是否为正确加密的模型
|
||||
if not encryptor.is_properly_encrypted(full_path):
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "模型文件未正确加密",
|
||||
"valid": False
|
||||
}), 400
|
||||
|
||||
# 验证密钥
|
||||
verify_result = encryptor.decrypt_model(full_path, encryption_key, verify_key=True)
|
||||
|
||||
if verify_result['success']:
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "密钥验证成功",
|
||||
"data": {
|
||||
"valid": True,
|
||||
"model_hash": verify_result.get('model_hash', '')[:16],
|
||||
"model_size": verify_result.get('original_size', 0)
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"密钥验证失败: {verify_result.get('error', '未知错误')}",
|
||||
"data": {
|
||||
"valid": False,
|
||||
"error": verify_result.get('error', '未知错误')
|
||||
}
|
||||
}), 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证密钥失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"验证密钥失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/models/generate_key', methods=['POST'])
|
||||
def generate_secure_encryption_key():
|
||||
"""生成安全的加密密钥"""
|
||||
try:
|
||||
from mandatory_model_crypto import MandatoryModelEncryptor
|
||||
|
||||
# 生成密钥
|
||||
encryptor = MandatoryModelEncryptor()
|
||||
key_info = encryptor.generate_secure_key()
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "加密密钥生成成功",
|
||||
"data": {
|
||||
"key": key_info['key'],
|
||||
"key_hash": key_info['key_hash'],
|
||||
"short_hash": key_info['short_hash'],
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成加密密钥失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"生成加密密钥失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
# server.py 中的 stop_detection 路由修改
|
||||
|
||||
@app.route('/api/tasks/<task_id>/stop', methods=['POST'])
|
||||
def stop_task(task_id):
|
||||
"""停止指定任务"""
|
||||
try:
|
||||
logger.info(f"接收到停止任务请求: {task_id}")
|
||||
success = task_manager.stop_task(task_id)
|
||||
if success:
|
||||
logger.info(f"任务停止成功: {task_id}")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": f"任务停止成功: {task_id}"
|
||||
})
|
||||
else:
|
||||
logger.warning(f"停止任务失败: {task_id}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"停止任务失败: {task_id}"
|
||||
}), 500
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务异常: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"停止任务异常: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tasks/<task_id>/status', methods=['GET'])
|
||||
def get_task_status(task_id):
|
||||
"""获取任务状态(仅支持多模型)"""
|
||||
status = task_manager.get_task_status(task_id)
|
||||
if status:
|
||||
# 增强返回信息,包含模型详情
|
||||
enhanced_status = {
|
||||
'task_id': status['task_id'],
|
||||
'status': status['status'],
|
||||
'config': status['config'],
|
||||
'models': status.get('models', []), # 直接返回模型列表
|
||||
'stats': status['stats'],
|
||||
'created_at': status['created_at']
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
status['gpu_memory'] = torch.cuda.memory_allocated() // (1024 * 1024)
|
||||
return jsonify(status)
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": enhanced_status
|
||||
})
|
||||
else:
|
||||
return jsonify({"active": False})
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"任务不存在: {task_id}"
|
||||
}), 404
|
||||
|
||||
|
||||
# WebSocket事件
|
||||
@app.route('/api/tasks', methods=['GET'])
|
||||
def get_all_tasks():
|
||||
"""获取所有任务(仅支持多模型)"""
|
||||
tasks = task_manager.get_all_tasks()
|
||||
|
||||
# 增强任务信息
|
||||
enhanced_tasks = []
|
||||
for task in tasks:
|
||||
enhanced_task = {
|
||||
'task_id': task['task_id'],
|
||||
'status': task['status'],
|
||||
'config': {
|
||||
'rtmp_url': task['config']['rtmp_url'],
|
||||
'taskname': task['config']['taskname'],
|
||||
'push_url': task['config'].get('push_url', ''),
|
||||
'enable_push': task['config'].get('enable_push', False)
|
||||
},
|
||||
'models': task.get('models', []), # 直接返回模型列表
|
||||
'stats': task['stats'],
|
||||
'created_at': task['created_at']
|
||||
}
|
||||
enhanced_tasks.append(enhanced_task)
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": {
|
||||
"tasks": enhanced_tasks,
|
||||
"total": len(tasks),
|
||||
"active": task_manager.get_active_tasks_count(),
|
||||
"models_count": sum(len(t.get('models', [])) for t in enhanced_tasks)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/tasks/<task_id>/cleanup', methods=['POST'])
|
||||
def cleanup_task(task_id):
|
||||
"""清理任务资源"""
|
||||
success = task_manager.cleanup_task(task_id)
|
||||
if success:
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": f"任务资源已清理: {task_id}"
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"清理任务失败: {task_id}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tasks/cleanup_all', methods=['POST'])
|
||||
def cleanup_all_tasks():
|
||||
"""清理所有任务"""
|
||||
task_manager.cleanup_all_tasks()
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "所有任务资源已清理"
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/system/status', methods=['GET'])
|
||||
def get_system_status():
|
||||
"""获取系统状态"""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
system_info = {
|
||||
"cpu_percent": psutil.cpu_percent(),
|
||||
"memory_percent": psutil.virtual_memory().percent,
|
||||
"disk_percent": psutil.disk_usage('/').percent,
|
||||
"active_tasks": task_manager.get_active_tasks_count(),
|
||||
"total_tasks": len(task_manager.tasks)
|
||||
}
|
||||
|
||||
# GPU信息(如果可用)
|
||||
try:
|
||||
import GPUtil
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
gpu_info.append({
|
||||
"id": gpu.id,
|
||||
"name": gpu.name,
|
||||
"load": gpu.load * 100,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
"temperature": gpu.temperature
|
||||
})
|
||||
system_info["gpus"] = gpu_info
|
||||
except ImportError:
|
||||
# 如果没有安装GPUtil,尝试使用torch获取GPU信息
|
||||
if torch.cuda.is_available():
|
||||
gpu_info = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_info.append({
|
||||
"id": i,
|
||||
"name": torch.cuda.get_device_name(i),
|
||||
"memory_used": torch.cuda.memory_allocated(i) / 1024 ** 2,
|
||||
"memory_total": torch.cuda.get_device_properties(i).total_memory / 1024 ** 2
|
||||
})
|
||||
system_info["gpus"] = gpu_info
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": system_info
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统状态失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取系统状态失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
# WebSocket事件 - 按任务ID发送
|
||||
@socketio.on('connect')
|
||||
def handle_connect():
|
||||
logger.info(f"Socket客户端已连接: {request.sid}")
|
||||
|
|
@ -130,3 +534,214 @@ def handle_connect():
|
|||
@socketio.on('disconnect')
|
||||
def handle_disconnect():
|
||||
logger.info(f"Socket客户端断开: {request.sid}")
|
||||
|
||||
|
||||
@socketio.on('subscribe_task')
|
||||
def handle_subscribe_task(data):
|
||||
"""订阅特定任务的WebSocket消息"""
|
||||
task_id = data.get('task_id')
|
||||
if task_id:
|
||||
# 这里可以记录客户端订阅关系
|
||||
logger.info(f"客户端 {request.sid} 订阅任务: {task_id}")
|
||||
return {"status": "subscribed", "task_id": task_id}
|
||||
|
||||
return {"status": "error", "message": "需要提供task_id"}
|
||||
|
||||
|
||||
@app.route('/api/system/resource_limits', methods=['GET', 'POST'])
|
||||
def manage_resource_limits():
|
||||
"""获取或设置资源限制"""
|
||||
if request.method == 'GET':
|
||||
# 获取当前资源限制
|
||||
resource_monitor = gd.get_value('resource_monitor')
|
||||
if resource_monitor:
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": resource_monitor.resource_limits
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "资源监控器未初始化"
|
||||
}), 500
|
||||
|
||||
elif request.method == 'POST':
|
||||
# 更新资源限制
|
||||
try:
|
||||
data = request.json
|
||||
resource_monitor = gd.get_value('resource_monitor')
|
||||
|
||||
if resource_monitor and data:
|
||||
# 更新限制
|
||||
resource_monitor.resource_limits.update(data)
|
||||
|
||||
# 更新任务管理器中的限制
|
||||
task_manager.resource_limits.update(data)
|
||||
|
||||
logger.info(f"更新资源限制: {data}")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "资源限制更新成功"
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": "资源监控器未初始化或请求数据无效"
|
||||
}), 400
|
||||
except Exception as e:
|
||||
logger.error(f"更新资源限制失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"更新失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tasks/<task_id>/models', methods=['GET'])
|
||||
def get_task_models(task_id):
|
||||
"""获取任务中的模型配置"""
|
||||
try:
|
||||
task = task_manager.get_task_status(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"任务不存在: {task_id}"
|
||||
}), 404
|
||||
|
||||
# 获取任务中的模型配置
|
||||
models_info = []
|
||||
if task_id in task_manager.tasks:
|
||||
task_info = task_manager.tasks[task_id]
|
||||
config = task_info.get('config', {})
|
||||
models_config = config.get('models', [])
|
||||
|
||||
for i, model_config in enumerate(models_config):
|
||||
models_info.append({
|
||||
'id': i,
|
||||
'name': os.path.basename(model_config.get('path', '')).split('.')[0],
|
||||
'path': model_config.get('path'),
|
||||
'conf_thres': model_config.get('conf_thres'),
|
||||
'tags': model_config.get('tags', {}),
|
||||
'color': model_config.get('color'),
|
||||
'enabled': model_config.get('enabled', True)
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": {
|
||||
"task_id": task_id,
|
||||
"models": models_info
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务模型失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
|
||||
# server.py 添加以下路由
|
||||
|
||||
@app.route('/api/tasks/<task_id>/stream/status', methods=['GET'])
|
||||
def get_task_stream_status(task_id):
|
||||
"""获取任务推流状态"""
|
||||
try:
|
||||
from task_stream_manager import task_stream_manager
|
||||
|
||||
# 获取任务状态
|
||||
task_status = task_manager.get_task_status(task_id)
|
||||
if not task_status:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"任务不存在: {task_id}"
|
||||
}), 404
|
||||
|
||||
# 获取推流信息
|
||||
stream_info = task_stream_manager.get_all_task_streams_info().get(task_id, {})
|
||||
|
||||
# 合并信息
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"task_status": task_status['status'],
|
||||
"stream_enabled": task_status['config'].get('enable_push', False),
|
||||
"stream_info": stream_info,
|
||||
"push_url": task_status['config'].get('push_url', '')
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务推流状态失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/tasks/<task_id>/stream/restart', methods=['POST'])
|
||||
def restart_task_stream(task_id):
|
||||
"""重启任务推流"""
|
||||
try:
|
||||
from task_stream_manager import task_stream_manager
|
||||
|
||||
# 检查任务是否存在
|
||||
if task_id not in task_manager.tasks:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"任务不存在: {task_id}"
|
||||
}), 404
|
||||
|
||||
# 重启推流
|
||||
success = task_stream_manager._restart_task_streamer(task_id)
|
||||
|
||||
if success:
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": f"任务推流重启成功: {task_id}"
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"重启失败: {task_id}"
|
||||
}), 500
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务推流失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"重启失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/system/streams/info', methods=['GET'])
|
||||
def get_all_streams_info():
|
||||
"""获取所有任务推流信息"""
|
||||
try:
|
||||
from task_stream_manager import task_stream_manager
|
||||
|
||||
streams_info = task_stream_manager.get_all_task_streams_info()
|
||||
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"data": {
|
||||
"total_streams": len(streams_info),
|
||||
"active_streams": sum(1 for info in streams_info.values() if info.get('running', False)),
|
||||
"streams": streams_info
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有推流信息失败: {str(e)}")
|
||||
return jsonify({
|
||||
"status": "error",
|
||||
"message": f"获取失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
# 初始化函数,可在主程序中调用
|
||||
def init_app():
|
||||
"""初始化应用程序"""
|
||||
with app.app_context():
|
||||
gd.set_value('task_manager', task_manager)
|
||||
logger.info("任务管理器初始化完成")
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -0,0 +1,384 @@
|
|||
# task_manager.py
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime
|
||||
from log import logger
|
||||
from global_data import gd
|
||||
from resource_monitor import resource_monitor
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""任务管理器,支持多任务并行"""
|
||||
|
||||
def __init__(self, resource_config=None):
|
||||
self.tasks = {}
|
||||
self.task_lock = threading.Lock()
|
||||
|
||||
# 资源监控配置
|
||||
if resource_config:
|
||||
self.resource_limits = resource_config.get('resource_limits', {
|
||||
'max_concurrent_tasks': 5,
|
||||
'min_concurrent_tasks': 1
|
||||
})
|
||||
else:
|
||||
self.resource_limits = {
|
||||
'max_concurrent_tasks': 5,
|
||||
'min_concurrent_tasks': 1
|
||||
}
|
||||
|
||||
# 当前最大并发数(由资源监控动态调整)
|
||||
self.current_max_tasks = self.resource_limits['max_concurrent_tasks']
|
||||
|
||||
# 初始化资源监控
|
||||
self.init_resource_monitor(resource_config)
|
||||
|
||||
def init_resource_monitor(self, config):
|
||||
"""初始化资源监控"""
|
||||
try:
|
||||
from resource_monitor import init_resource_monitor
|
||||
if config:
|
||||
init_resource_monitor(config)
|
||||
logger.info("资源监控器初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化资源监控器失败: {str(e)}")
|
||||
|
||||
def get_current_max_tasks(self):
|
||||
"""获取当前最大并发任务数(考虑资源限制)"""
|
||||
try:
|
||||
# 从资源监控器获取当前限制
|
||||
current_limit = gd.get_value('max_concurrent_tasks', self.resource_limits['max_concurrent_tasks'])
|
||||
|
||||
# 确保在最小和最大范围内
|
||||
self.current_max_tasks = max(
|
||||
self.resource_limits['min_concurrent_tasks'],
|
||||
min(current_limit, self.resource_limits['max_concurrent_tasks'])
|
||||
)
|
||||
|
||||
return self.current_max_tasks
|
||||
except:
|
||||
return self.resource_limits['max_concurrent_tasks']
|
||||
|
||||
def can_create_task(self):
|
||||
"""检查是否可以创建新任务(考虑资源限制)"""
|
||||
try:
|
||||
# 获取当前活动任务数
|
||||
active_count = self.get_active_tasks_count()
|
||||
|
||||
# 获取动态调整的最大任务数
|
||||
max_tasks = self.get_current_max_tasks()
|
||||
|
||||
logger.info(f"当前活动任务: {active_count}/{max_tasks}")
|
||||
|
||||
# 如果已经达到最大限制,检查是否可以放宽
|
||||
if active_count >= max_tasks:
|
||||
# 获取详细的资源使用情况
|
||||
resources = gd.get_value('system_resources')
|
||||
if resources:
|
||||
cpu_usage = resources.get('cpu_percent', 0)
|
||||
memory_usage = resources.get('memory_percent', 0)
|
||||
gpu_usage = 0
|
||||
|
||||
# 如果有GPU,获取GPU使用率
|
||||
if 'gpus' in resources and resources['gpus']:
|
||||
gpu_usage = max(gpu['load'] for gpu in resources['gpus'])
|
||||
|
||||
# 资源使用阈值
|
||||
cpu_threshold = 70 # 70%
|
||||
memory_threshold = 75 # 75%
|
||||
gpu_threshold = 80 # 80%
|
||||
|
||||
# 如果所有资源都低于阈值,可以放宽限制
|
||||
if (cpu_usage < cpu_threshold and
|
||||
memory_usage < memory_threshold and
|
||||
gpu_usage < gpu_threshold):
|
||||
logger.info(
|
||||
f"资源充足,允许创建额外任务 (CPU: {cpu_usage:.1f}%, 内存: {memory_usage:.1f}%, GPU: {gpu_usage:.1f}%)")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"资源紧张,拒绝创建新任务 (CPU: {cpu_usage:.1f}%, 内存: {memory_usage:.1f}%, GPU: {gpu_usage:.1f}%)")
|
||||
return False
|
||||
else:
|
||||
# 没有资源数据时保守处理
|
||||
logger.warning("无法获取资源数据,保守拒绝创建新任务")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查任务创建条件失败: {str(e)}")
|
||||
return False # 出错时保守拒绝
|
||||
|
||||
def create_task(self, config, socketio):
|
||||
"""创建新任务 - 强制加密验证"""
|
||||
try:
|
||||
with self.task_lock:
|
||||
# 检查资源限制
|
||||
if not self.can_create_task():
|
||||
raise Exception(f"达到资源限制,当前最大并发任务数: {self.get_current_max_tasks()}")
|
||||
|
||||
# 生成唯一任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 验证配置格式 - 必须是多模型
|
||||
if 'models' not in config or not isinstance(config['models'], list):
|
||||
raise Exception("配置必须包含models列表")
|
||||
|
||||
# 验证所有模型都有加密密钥
|
||||
for i, model_cfg in enumerate(config['models']):
|
||||
if not model_cfg.get('encryption_key'):
|
||||
raise Exception(f"模型 {i} ({model_cfg.get('path', 'unknown')}) 必须提供加密密钥")
|
||||
|
||||
# 准备任务信息
|
||||
task_info = {
|
||||
'task_id': task_id,
|
||||
'config': config.copy(),
|
||||
'status': 'creating',
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'thread': None,
|
||||
'socketio': socketio,
|
||||
'stats': {
|
||||
'total_frames': 0,
|
||||
'detections': 0,
|
||||
'avg_fps': 0,
|
||||
'start_time': time.time(),
|
||||
'models_loaded': len(config['models']),
|
||||
'encrypted_models': len(config['models']), # 所有模型都加密
|
||||
'key_validation_required': True
|
||||
},
|
||||
'key_validation': {} # 存储密钥验证结果
|
||||
}
|
||||
|
||||
# 更新配置中的任务ID
|
||||
task_info['config']['task']['taskid'] = task_id
|
||||
|
||||
# 存储任务信息
|
||||
self.tasks[task_id] = task_info
|
||||
gd.get_or_create_dict('tasks')[task_id] = task_info
|
||||
|
||||
logger.info(f"创建加密任务成功: {task_id}, 加密模型数: {len(config['models'])}")
|
||||
return task_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建任务失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def start_task(self, task_id):
|
||||
"""启动任务"""
|
||||
try:
|
||||
if task_id not in self.tasks:
|
||||
raise Exception(f"任务不存在: {task_id}")
|
||||
|
||||
task_info = self.tasks[task_id]
|
||||
|
||||
# 检查任务是否已在运行
|
||||
if task_info['status'] in ['running', 'starting']:
|
||||
logger.warning(f"任务已在运行或启动中: {task_id}")
|
||||
return False
|
||||
|
||||
# 创建并启动检测线程
|
||||
try:
|
||||
from detectionThread import DetectionThread
|
||||
|
||||
# 配置socketio
|
||||
task_info['config']['socketIO'] = task_info['socketio']
|
||||
|
||||
# 创建检测线程
|
||||
detection_thread = DetectionThread(task_info['config'])
|
||||
detection_thread.task_id = task_id # 设置任务ID
|
||||
|
||||
# 存储线程引用
|
||||
task_info['thread'] = detection_thread
|
||||
task_info['status'] = 'starting'
|
||||
|
||||
# 启动线程
|
||||
detection_thread.start()
|
||||
|
||||
# 更新任务状态为运行中(线程会自动更新)
|
||||
logger.info(f"任务启动成功: {task_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动任务线程失败: {task_id}, 错误: {str(e)}")
|
||||
task_info['status'] = 'failed'
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动任务失败: {task_id}, 错误: {str(e)}")
|
||||
return False
|
||||
|
||||
def stop_task(self, task_id, force=False):
|
||||
"""停止任务"""
|
||||
try:
|
||||
with self.task_lock:
|
||||
if task_id not in self.tasks:
|
||||
logger.warning(f"任务不存在: {task_id}")
|
||||
return False
|
||||
|
||||
task_info = self.tasks[task_id]
|
||||
|
||||
if task_info['status'] not in ['running', 'starting']:
|
||||
logger.info(f"任务未运行: {task_id}")
|
||||
return True
|
||||
|
||||
try:
|
||||
# 停止线程
|
||||
thread = task_info['thread']
|
||||
if thread and thread.is_alive():
|
||||
logger.info(f"正在停止任务: {task_id}")
|
||||
|
||||
# 调用线程的停止方法
|
||||
if hasattr(thread, 'stop'):
|
||||
thread.stop()
|
||||
|
||||
# 等待线程停止
|
||||
if not force:
|
||||
thread.join(5.0) # 等待5秒
|
||||
|
||||
# 强制停止
|
||||
if thread.is_alive():
|
||||
if force:
|
||||
logger.warning(f"强制停止任务: {task_id}")
|
||||
# 这里可以添加强制停止逻辑
|
||||
else:
|
||||
logger.warning(f"任务停止超时: {task_id}")
|
||||
|
||||
task_info['status'] = 'stopped'
|
||||
logger.info(f"任务已停止: {task_id}")
|
||||
else:
|
||||
task_info['status'] = 'stopped'
|
||||
logger.info(f"任务线程不存在或未运行: {task_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务失败: {task_id}, 错误: {str(e)}")
|
||||
task_info['status'] = 'error'
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务异常: {task_id}, 错误: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_task_status(self, task_id):
|
||||
"""获取任务状态 - 包含加密信息"""
|
||||
if task_id not in self.tasks:
|
||||
return None
|
||||
|
||||
task_info = self.tasks[task_id]
|
||||
|
||||
# 构建返回数据
|
||||
result = {
|
||||
'task_id': task_id,
|
||||
'status': task_info['status'],
|
||||
'config': {
|
||||
'rtmp_url': task_info['config']['rtmp']['url'],
|
||||
'push_url': task_info['config'].get('push', {}).get('url', ''),
|
||||
'taskname': task_info['config']['task']['taskname'],
|
||||
'enable_push': task_info['config'].get('push', {}).get('enable_push', False)
|
||||
},
|
||||
'models': [],
|
||||
'stats': task_info['stats'],
|
||||
'encryption_info': {
|
||||
'required': True,
|
||||
'models_count': task_info['stats']['encrypted_models'],
|
||||
'key_validation_required': task_info['stats']['key_validation_required']
|
||||
},
|
||||
'created_at': task_info['created_at']
|
||||
}
|
||||
|
||||
# 获取模型信息(不包含密钥)
|
||||
if 'models' in task_info['config'] and isinstance(task_info['config']['models'], list):
|
||||
for i, model_cfg in enumerate(task_info['config']['models']):
|
||||
model_info = {
|
||||
'id': i,
|
||||
'name': os.path.basename(model_cfg.get('path', 'unknown')).split('.')[0],
|
||||
'path': model_cfg.get('path', 'unknown'),
|
||||
'enabled': model_cfg.get('enabled', True),
|
||||
'color': model_cfg.get('color'),
|
||||
'conf_thres': model_cfg.get('conf_thres', 0.25),
|
||||
'encrypted': True,
|
||||
'key_provided': bool(model_cfg.get('encryption_key'))
|
||||
}
|
||||
result['models'].append(model_info)
|
||||
|
||||
return result
|
||||
|
||||
def get_all_tasks(self):
|
||||
"""获取所有任务信息"""
|
||||
result = []
|
||||
for task_id in self.tasks:
|
||||
task_status = self.get_task_status(task_id)
|
||||
if task_status:
|
||||
result.append(task_status)
|
||||
return result
|
||||
|
||||
def get_active_tasks_count(self):
|
||||
"""获取活动任务数量"""
|
||||
try:
|
||||
count = 0
|
||||
for task_id in self.tasks:
|
||||
task_status = self.tasks[task_id].get('status', '')
|
||||
if task_status in ['running', 'starting', 'creating']:
|
||||
count += 1
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"获取活动任务数失败: {str(e)}")
|
||||
return 0
|
||||
|
||||
def cleanup_task(self, task_id):
|
||||
"""清理任务资源"""
|
||||
if task_id in self.tasks:
|
||||
# 停止任务(如果正在运行)
|
||||
if self.tasks[task_id]['status'] in ['running', 'starting']:
|
||||
self.stop_task(task_id, force=True)
|
||||
|
||||
# 清理线程资源
|
||||
task_info = self.tasks[task_id]
|
||||
if task_info['thread']:
|
||||
try:
|
||||
# 确保线程已停止
|
||||
if task_info['thread'].is_alive():
|
||||
task_info['thread'].join(1.0)
|
||||
except:
|
||||
pass
|
||||
task_info['thread'] = None
|
||||
|
||||
# 从全局数据中移除
|
||||
tasks_dict = gd.get_or_create_dict('tasks')
|
||||
if task_id in tasks_dict:
|
||||
del tasks_dict[task_id]
|
||||
|
||||
# 从本地字典中移除
|
||||
del self.tasks[task_id]
|
||||
|
||||
logger.info(f"任务资源已清理: {task_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_task_status(self, task_id, status):
|
||||
"""更新任务状态"""
|
||||
with self.task_lock:
|
||||
if task_id in self.tasks:
|
||||
old_status = self.tasks[task_id].get('status', 'unknown')
|
||||
self.tasks[task_id]['status'] = status
|
||||
logger.info(f"更新任务状态: {task_id} {old_status} -> {status}")
|
||||
|
||||
# 同步到全局数据
|
||||
tasks_dict = gd.get_or_create_dict('tasks')
|
||||
if task_id in tasks_dict:
|
||||
tasks_dict[task_id]['status'] = status
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup_all_tasks(self):
|
||||
"""清理所有任务"""
|
||||
task_ids = list(self.tasks.keys())
|
||||
for task_id in task_ids:
|
||||
self.cleanup_task(task_id)
|
||||
logger.info(f"已清理所有任务,共{len(task_ids)}个")
|
||||
|
||||
|
||||
# 创建全局任务管理器实例
|
||||
task_manager = TaskManager()
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
# task_stream_manager.py
|
||||
import threading
|
||||
import time
|
||||
from log import logger
|
||||
from ffmpegStreamer import FFmpegStreamer
|
||||
|
||||
|
||||
class TaskStreamManager:
|
||||
"""任务独立的推流管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.task_streams = {} # task_id -> streamer
|
||||
self.lock = threading.Lock()
|
||||
self.health_check_thread = None
|
||||
self.running = True
|
||||
|
||||
def create_streamer_for_task(self, task_id, config, fps, width, height):
|
||||
"""为任务创建独立的推流器"""
|
||||
with self.lock:
|
||||
if task_id in self.task_streams:
|
||||
logger.warning(f"任务 {task_id} 已有推流器,先清理")
|
||||
self.stop_task_streamer(task_id)
|
||||
|
||||
try:
|
||||
streamer = FFmpegStreamer(config, fps, width, height)
|
||||
streamer.task_id = task_id # 标记属于哪个任务
|
||||
self.task_streams[task_id] = {
|
||||
'streamer': streamer,
|
||||
'config': config,
|
||||
'fps': fps,
|
||||
'width': width,
|
||||
'height': height,
|
||||
'created_at': time.time(),
|
||||
'last_active': time.time(),
|
||||
'frame_count': 0,
|
||||
'status': 'initializing'
|
||||
}
|
||||
|
||||
# 启动推流
|
||||
streamer.start()
|
||||
|
||||
# 等待初始化完成
|
||||
time.sleep(0.5)
|
||||
|
||||
if streamer.running:
|
||||
self.task_streams[task_id]['status'] = 'running'
|
||||
logger.info(f"任务 {task_id} 推流器创建成功")
|
||||
return streamer
|
||||
else:
|
||||
logger.error(f"任务 {task_id} 推流器启动失败")
|
||||
del self.task_streams[task_id]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建任务推流器失败 {task_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_task_streamer(self, task_id):
|
||||
"""获取任务的推流器"""
|
||||
with self.lock:
|
||||
task_info = self.task_streams.get(task_id)
|
||||
return task_info['streamer'] if task_info else None
|
||||
|
||||
def push_frame(self, task_id, frame):
|
||||
"""为指定任务推送帧"""
|
||||
with self.lock:
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info:
|
||||
logger.warning(f"任务 {task_id} 没有推流器")
|
||||
return False
|
||||
|
||||
streamer = task_info['streamer']
|
||||
if not streamer or not streamer.running:
|
||||
logger.warning(f"任务 {task_id} 推流器未运行")
|
||||
# 尝试重启
|
||||
self._restart_task_streamer(task_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
success = streamer.add_frame(frame)
|
||||
if success:
|
||||
task_info['frame_count'] += 1
|
||||
task_info['last_active'] = time.time()
|
||||
task_info['status'] = 'active'
|
||||
else:
|
||||
task_info['status'] = 'error'
|
||||
# 连续错误处理
|
||||
if self._check_streamer_health(task_id) == 'unhealthy':
|
||||
self._restart_task_streamer(task_id)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推流失败 {task_id}: {str(e)}")
|
||||
task_info['status'] = 'error'
|
||||
return False
|
||||
|
||||
def _check_streamer_health(self, task_id):
|
||||
"""检查推流器健康状态"""
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info:
|
||||
return 'not_found'
|
||||
|
||||
streamer = task_info['streamer']
|
||||
if not streamer:
|
||||
return 'unhealthy'
|
||||
|
||||
# 检查进程状态
|
||||
if hasattr(streamer, 'process_failed') and streamer.process_failed():
|
||||
return 'unhealthy'
|
||||
|
||||
# 检查是否长时间无活动
|
||||
if time.time() - task_info['last_active'] > 10: # 10秒无活动
|
||||
return 'inactive'
|
||||
|
||||
# 检查帧率是否异常
|
||||
if task_info['frame_count'] > 0:
|
||||
elapsed = time.time() - task_info['created_at']
|
||||
actual_fps = task_info['frame_count'] / elapsed
|
||||
if actual_fps < task_info['fps'] * 0.1: # 低于目标帧率10%
|
||||
return 'low_fps'
|
||||
|
||||
return 'healthy'
|
||||
|
||||
def _restart_task_streamer(self, task_id):
|
||||
"""重启任务推流器"""
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info:
|
||||
return False
|
||||
|
||||
logger.info(f"重启任务推流器: {task_id}")
|
||||
|
||||
try:
|
||||
# 停止旧推流器
|
||||
old_streamer = task_info['streamer']
|
||||
if old_streamer:
|
||||
old_streamer.stop()
|
||||
|
||||
# 创建新推流器
|
||||
streamer = FFmpegStreamer(
|
||||
task_info['config'],
|
||||
task_info['fps'],
|
||||
task_info['width'],
|
||||
task_info['height']
|
||||
)
|
||||
streamer.task_id = task_id
|
||||
streamer.start()
|
||||
|
||||
# 更新信息
|
||||
task_info['streamer'] = streamer
|
||||
task_info['created_at'] = time.time()
|
||||
task_info['last_active'] = time.time()
|
||||
task_info['frame_count'] = 0
|
||||
task_info['status'] = 'restarted'
|
||||
|
||||
logger.info(f"任务推流器重启成功: {task_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重启任务推流器失败 {task_id}: {str(e)}")
|
||||
task_info['status'] = 'failed'
|
||||
return False
|
||||
|
||||
def stop_task_streamer(self, task_id):
|
||||
"""停止任务推流器"""
|
||||
with self.lock:
|
||||
if task_id in self.task_streams:
|
||||
try:
|
||||
task_info = self.task_streams[task_id]
|
||||
streamer = task_info['streamer']
|
||||
if streamer:
|
||||
streamer.stop()
|
||||
|
||||
del self.task_streams[task_id]
|
||||
logger.info(f"任务推流器停止成功: {task_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务推流器失败 {task_id}: {str(e)}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_all_task_streams_info(self):
|
||||
"""获取所有任务推流器信息"""
|
||||
with self.lock:
|
||||
info = {}
|
||||
for task_id, task_info in self.task_streams.items():
|
||||
streamer = task_info['streamer']
|
||||
info[task_id] = {
|
||||
'status': task_info['status'],
|
||||
'fps': task_info['fps'],
|
||||
'resolution': f"{task_info['width']}x{task_info['height']}",
|
||||
'frame_count': task_info['frame_count'],
|
||||
'last_active': time.time() - task_info['last_active'],
|
||||
'running': streamer.running if streamer else False,
|
||||
'health': self._check_streamer_health(task_id)
|
||||
}
|
||||
return info
|
||||
|
||||
def start_health_monitor(self):
|
||||
"""启动健康监控线程"""
|
||||
|
||||
def health_monitor_loop():
|
||||
logger.info("任务推流健康监控启动")
|
||||
while self.running:
|
||||
try:
|
||||
with self.lock:
|
||||
task_ids = list(self.task_streams.keys())
|
||||
|
||||
for task_id in task_ids:
|
||||
health = self._check_streamer_health(task_id)
|
||||
if health in ['unhealthy', 'inactive']:
|
||||
logger.warning(f"任务 {task_id} 推流器健康状态异常: {health}")
|
||||
self._restart_task_streamer(task_id)
|
||||
|
||||
time.sleep(5) # 每5秒检查一次
|
||||
except Exception as e:
|
||||
logger.error(f"健康监控异常: {str(e)}")
|
||||
time.sleep(10)
|
||||
|
||||
self.health_check_thread = threading.Thread(
|
||||
target=health_monitor_loop,
|
||||
daemon=True,
|
||||
name="TaskStreamHealthMonitor"
|
||||
)
|
||||
self.health_check_thread.start()
|
||||
|
||||
def stop_health_monitor(self):
|
||||
"""停止健康监控"""
|
||||
self.running = False
|
||||
if self.health_check_thread and self.health_check_thread.is_alive():
|
||||
self.health_check_thread.join(3.0)
|
||||
|
||||
def cleanup_all(self):
|
||||
"""清理所有推流器"""
|
||||
logger.info("清理所有任务推流器")
|
||||
with self.lock:
|
||||
task_ids = list(self.task_streams.keys())
|
||||
for task_id in task_ids:
|
||||
self.stop_task_streamer(task_id)
|
||||
self.stop_health_monitor()
|
||||
|
||||
|
||||
# 全局推流管理器实例
|
||||
task_stream_manager = TaskStreamManager()
|
||||
|
|
@ -0,0 +1,490 @@
|
|||
# task_stream_manager_windows.py
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import psutil
|
||||
from log import logger
|
||||
from ffmpegStreamer import FFmpegStreamer
|
||||
|
||||
|
||||
class WindowsTaskStreamManager:
|
||||
"""Windows系统专用任务推流管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.task_streams = {}
|
||||
self.lock = threading.RLock()
|
||||
self.health_check_thread = None
|
||||
self.running = True
|
||||
self.streaming_check_interval = 3
|
||||
self.ffmpeg_timeout = 20
|
||||
|
||||
# Windows特定配置
|
||||
self.windows_ffmpeg_path = self._find_windows_ffmpeg()
|
||||
self.windows_ffmpeg_args = [
|
||||
'-loglevel', 'verbose',
|
||||
'-hide_banner',
|
||||
'-f', 'rawvideo',
|
||||
'-vcodec', 'rawvideo',
|
||||
'-pix_fmt', 'bgr24',
|
||||
'-s', '{}x{}',
|
||||
'-r', '{}',
|
||||
'-i', '-',
|
||||
'-c:v', 'libx264',
|
||||
'-preset', 'ultrafast',
|
||||
'-tune', 'zerolatency',
|
||||
'-f', 'flv',
|
||||
'-g', '10',
|
||||
'-bf', '0',
|
||||
'-max_delay', '0',
|
||||
'-flags', '+global_header',
|
||||
'-rtbufsize', '100M',
|
||||
'-b:v', '2000k',
|
||||
'-bufsize', '2000k',
|
||||
'{}'
|
||||
]
|
||||
self.last_output_time = 0
|
||||
self.last_restart_time = time.time()
|
||||
|
||||
def _find_windows_ffmpeg(self):
|
||||
"""查找Windows上的ffmpeg路径"""
|
||||
# 常见ffmpeg安装位置
|
||||
possible_paths = [
|
||||
r'C:\ffmpeg\bin\ffmpeg.exe',
|
||||
r'C:\Program Files\ffmpeg\bin\ffmpeg.exe',
|
||||
r'C:\Program Files (x86)\ffmpeg\bin\ffmpeg.exe',
|
||||
os.path.join(os.getcwd(), 'ffmpeg.exe'),
|
||||
'ffmpeg' # 如果在PATH中
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path) or os.system(f'where {path} >nul 2>nul') == 0:
|
||||
logger.info(f"找到FFmpeg: {path}")
|
||||
return path
|
||||
|
||||
logger.warning("未找到FFmpeg,将使用系统PATH中的ffmpeg")
|
||||
return 'ffmpeg'
|
||||
|
||||
def create_streamer_for_task(self, task_id, config, fps, width, height):
|
||||
"""为任务创建Windows专用的推流器"""
|
||||
with self.lock:
|
||||
if task_id in self.task_streams:
|
||||
logger.warning(f"任务 {task_id} 已有推流器,先清理")
|
||||
self.stop_task_streamer(task_id)
|
||||
|
||||
try:
|
||||
# 创建Windows优化配置
|
||||
win_config = self._create_windows_config(config, width, height, fps)
|
||||
|
||||
# 创建推流器
|
||||
streamer = FFmpegStreamer(win_config, fps, width, height)
|
||||
streamer.task_id = task_id
|
||||
|
||||
# 覆盖build_ffmpeg_command方法为Windows专用版本
|
||||
streamer.build_ffmpeg_command = self._build_windows_ffmpeg_command(
|
||||
streamer, win_config, width, height, fps
|
||||
)
|
||||
|
||||
# 存储任务信息
|
||||
self.task_streams[task_id] = {
|
||||
'streamer': streamer,
|
||||
'config': win_config,
|
||||
'fps': fps,
|
||||
'width': width,
|
||||
'height': height,
|
||||
'created_at': time.time(),
|
||||
'last_active': time.time(),
|
||||
'frame_count': 0,
|
||||
'status': 'initializing',
|
||||
'last_ffmpeg_output': '',
|
||||
'output_lines': []
|
||||
}
|
||||
|
||||
# 启动推流
|
||||
streamer.start()
|
||||
|
||||
# Windows上需要额外时间启动
|
||||
time.sleep(1)
|
||||
|
||||
if streamer.running and streamer.process:
|
||||
self.task_streams[task_id]['status'] = 'running'
|
||||
logger.info(f"Windows任务 {task_id} 推流器创建成功")
|
||||
|
||||
# 启动输出监控线程
|
||||
self._start_output_monitor(task_id)
|
||||
|
||||
return streamer
|
||||
else:
|
||||
logger.error(f"Windows任务 {task_id} 推流器启动失败")
|
||||
self.stop_task_streamer(task_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Windows创建任务推流器失败 {task_id}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _create_windows_config(self, config, width, height, fps):
|
||||
"""创建Windows专用的推流配置"""
|
||||
win_config = config.copy()
|
||||
push_config = config.get('push', {}).copy()
|
||||
|
||||
# Windows优化参数
|
||||
push_config.update({
|
||||
'video_codec': 'libx264', # Windows上软件编码更稳定
|
||||
'gpu_acceleration': False, # Windows上硬件加速问题多
|
||||
'preset': 'ultrafast',
|
||||
'tune': 'zerolatency',
|
||||
'pixel_format': 'bgr24',
|
||||
'format': 'flv',
|
||||
'crf': 23,
|
||||
'bitrate': '2000k',
|
||||
'bufsize': '2000k',
|
||||
'framerate': fps,
|
||||
'extra_args': [
|
||||
'-max_delay', '0',
|
||||
'-flags', '+global_header',
|
||||
'-rtbufsize', '100M',
|
||||
'-g', '10',
|
||||
'-bf', '0'
|
||||
]
|
||||
})
|
||||
|
||||
win_config['push'] = push_config
|
||||
return win_config
|
||||
|
||||
def _build_windows_ffmpeg_command(self, streamer, config, width, height, fps):
|
||||
"""Windows专用的FFmpeg命令构建"""
|
||||
|
||||
def build_command():
|
||||
# 基础命令
|
||||
cmd = [
|
||||
self.windows_ffmpeg_path,
|
||||
'-y',
|
||||
'-loglevel', 'verbose', # 详细日志便于调试
|
||||
'-hide_banner',
|
||||
'-f', 'rawvideo',
|
||||
'-vcodec', 'rawvideo',
|
||||
'-pix_fmt', config['push'].get('pixel_format', 'bgr24'),
|
||||
'-s', f'{width}x{height}',
|
||||
'-r', str(fps),
|
||||
'-i', '-',
|
||||
'-c:v', config['push']['video_codec'],
|
||||
'-preset', config['push']['preset'],
|
||||
'-tune', config['push']['tune'],
|
||||
'-f', config['push']['format'],
|
||||
'-g', '10',
|
||||
'-bf', '0',
|
||||
'-max_delay', '0',
|
||||
'-flags', '+global_header',
|
||||
'-rtbufsize', '100M',
|
||||
'-b:v', config['push']['bitrate'],
|
||||
'-bufsize', config['push']['bufsize'],
|
||||
]
|
||||
|
||||
# 添加额外参数
|
||||
if 'extra_args' in config['push']:
|
||||
cmd.extend(config['push']['extra_args'])
|
||||
|
||||
# 添加输出URL
|
||||
cmd.append(config['push']['url'])
|
||||
|
||||
logger.info(f"Windows FFmpeg命令: {' '.join(cmd)}")
|
||||
return cmd
|
||||
|
||||
return build_command
|
||||
|
||||
def _start_output_monitor(self, task_id):
|
||||
"""启动FFmpeg输出监控线程"""
|
||||
|
||||
def monitor():
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info or not task_info.get('streamer'):
|
||||
return
|
||||
|
||||
streamer = task_info['streamer']
|
||||
while self.running and task_id in self.task_streams:
|
||||
try:
|
||||
if streamer.process and streamer.process.stderr:
|
||||
try:
|
||||
line = streamer.process.stderr.readline()
|
||||
if line:
|
||||
line = line.decode('utf-8', errors='ignore').strip()
|
||||
if line:
|
||||
task_info['last_ffmpeg_output'] = line
|
||||
task_info['output_lines'].append(line)
|
||||
|
||||
# # 只保留最近100行
|
||||
# if len(task_info['output_lines']) > 100:
|
||||
# task_info['output_lines'].pop(0)
|
||||
|
||||
# 解析关键信息
|
||||
self._parse_ffmpeg_output(task_id, line)
|
||||
self.last_output_time = time.time()
|
||||
|
||||
# 输出关键错误
|
||||
if any(keyword in line.lower() for keyword in
|
||||
['error', 'failed', 'invalid', 'cannot']):
|
||||
logger.error(f"FFmpeg[{task_id}]: {line}")
|
||||
elif 'frame=' in line and 'fps=' in line:
|
||||
logger.debug(f"FFmpeg[{task_id}]: {line}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 检查是否长时间无输出
|
||||
if time.time() - self.last_output_time > self.ffmpeg_timeout:
|
||||
logger.warning(f"FFmpeg[{task_id}] 长时间无输出,可能已崩溃")
|
||||
self._safe_restart_streamer(task_id)
|
||||
break
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
|
||||
threading.Thread(target=monitor, daemon=True, name=f"FFmpegMonitor-{task_id}").start()
|
||||
|
||||
def _parse_ffmpeg_output(self, task_id, line):
|
||||
"""解析FFmpeg输出,提取关键信息"""
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info:
|
||||
return
|
||||
|
||||
# 提取帧率信息
|
||||
if 'fps=' in line:
|
||||
import re
|
||||
match = re.search(r'fps=\s*(\d+)', line)
|
||||
if match:
|
||||
fps = int(match.group(1))
|
||||
task_info['actual_fps'] = fps
|
||||
|
||||
# 检查关键错误
|
||||
error_keywords = [
|
||||
'connection refused',
|
||||
'cannot open',
|
||||
'invalid data',
|
||||
'broken pipe',
|
||||
'timed out',
|
||||
'access denied'
|
||||
]
|
||||
|
||||
for keyword in error_keywords:
|
||||
if keyword in line.lower():
|
||||
logger.error(f"FFmpeg[{task_id}] 关键错误: {line}")
|
||||
self._safe_restart_streamer(task_id)
|
||||
break
|
||||
|
||||
def _safe_restart_streamer(self, task_id):
|
||||
"""安全重启推流器"""
|
||||
with self.lock:
|
||||
if task_id not in self.task_streams:
|
||||
return
|
||||
|
||||
task_info = self.task_streams[task_id]
|
||||
|
||||
# 防止频繁重启
|
||||
if time.time() - self.last_restart_time < 10:
|
||||
logger.debug(f"跳过频繁重启: {task_id}")
|
||||
return
|
||||
|
||||
self.last_restart_time = time.time()
|
||||
logger.info(f"安全重启推流器: {task_id}")
|
||||
|
||||
try:
|
||||
# 停止旧推流器
|
||||
old_streamer = task_info['streamer']
|
||||
if old_streamer:
|
||||
old_streamer.stop()
|
||||
|
||||
# 短暂延迟
|
||||
time.sleep(1)
|
||||
|
||||
# 创建新推流器
|
||||
new_streamer = self.create_streamer_for_task(
|
||||
task_id,
|
||||
task_info['config'],
|
||||
task_info['fps'],
|
||||
task_info['width'],
|
||||
task_info['height']
|
||||
)
|
||||
|
||||
if new_streamer:
|
||||
logger.info(f"推流器重启成功: {task_id}")
|
||||
else:
|
||||
logger.error(f"推流器重启失败: {task_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重启推流器异常 {task_id}: {str(e)}")
|
||||
|
||||
def push_frame(self, task_id, frame):
|
||||
"""为指定任务推送帧(Windows优化版本)"""
|
||||
with self.lock:
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info:
|
||||
logger.warning(f"任务 {task_id} 没有推流器")
|
||||
return False
|
||||
|
||||
streamer = task_info['streamer']
|
||||
if not streamer or not streamer.running:
|
||||
logger.warning(f"任务 {task_id} 推流器未运行")
|
||||
self._safe_restart_streamer(task_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Windows上需要更严格的检查
|
||||
# if streamer.process and streamer.process.poll() is not None:
|
||||
# logger.warning(f"FFmpeg进程已退出: {task_id}")
|
||||
# self._safe_restart_streamer(task_id)
|
||||
# return False
|
||||
|
||||
# 推流帧
|
||||
success = streamer.add_frame(frame)
|
||||
|
||||
if success:
|
||||
task_info['frame_count'] += 1
|
||||
task_info['last_active'] = time.time()
|
||||
task_info['status'] = 'active'
|
||||
|
||||
# 每100帧记录一次
|
||||
if task_info['frame_count'] % 100 == 0:
|
||||
logger.info(f"任务 {task_id} 已推流 {task_info['frame_count']} 帧")
|
||||
else:
|
||||
task_info['status'] = 'error'
|
||||
logger.warning(f"任务 {task_id} 推流失败")
|
||||
|
||||
# 连续失败处理
|
||||
fail_count = task_info.get('consecutive_failures', 0) + 1
|
||||
task_info['consecutive_failures'] = fail_count
|
||||
|
||||
if fail_count >= 10:
|
||||
logger.error(f"任务 {task_id} 连续失败 {fail_count} 次,尝试重启")
|
||||
self._safe_restart_streamer(task_id)
|
||||
task_info['consecutive_failures'] = 0
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推流失败 {task_id}: {str(e)}")
|
||||
task_info['status'] = 'error'
|
||||
return False
|
||||
|
||||
def stop_task_streamer(self, task_id):
|
||||
"""停止任务推流器(Windows专用)"""
|
||||
with self.lock:
|
||||
if task_id in self.task_streams:
|
||||
try:
|
||||
task_info = self.task_streams[task_id]
|
||||
streamer = task_info['streamer']
|
||||
|
||||
if streamer:
|
||||
# Windows上需要强制终止
|
||||
try:
|
||||
if streamer.process:
|
||||
import signal
|
||||
try:
|
||||
streamer.process.terminate()
|
||||
streamer.process.wait(timeout=3)
|
||||
except:
|
||||
streamer.process.kill()
|
||||
except:
|
||||
pass
|
||||
|
||||
streamer.stop()
|
||||
|
||||
del self.task_streams[task_id]
|
||||
logger.info(f"Windows任务推流器停止成功: {task_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停止任务推流器失败 {task_id}: {str(e)}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_task_stream_info(self, task_id):
|
||||
"""获取任务推流信息"""
|
||||
with self.lock:
|
||||
if task_id not in self.task_streams:
|
||||
return None
|
||||
|
||||
task_info = self.task_streams[task_id]
|
||||
streamer = task_info['streamer']
|
||||
|
||||
info = {
|
||||
'status': task_info['status'],
|
||||
'fps': task_info['fps'],
|
||||
'resolution': f"{task_info['width']}x{task_info['height']}",
|
||||
'frame_count': task_info['frame_count'],
|
||||
'last_active': time.time() - task_info['last_active'],
|
||||
'running': streamer.running if streamer else False,
|
||||
'process_alive': streamer.process.poll() is None if streamer and streamer.process else False,
|
||||
'output_lines': task_info['output_lines'][-10:], # 最近10行输出
|
||||
'last_ffmpeg_output': task_info.get('last_ffmpeg_output', '')
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def get_all_task_streams_info(self):
|
||||
"""获取所有任务推流器信息"""
|
||||
with self.lock:
|
||||
info = {}
|
||||
for task_id in self.task_streams:
|
||||
info[task_id] = self.get_task_stream_info(task_id)
|
||||
return info
|
||||
|
||||
def start_health_monitor(self):
|
||||
"""启动Windows健康监控线程"""
|
||||
|
||||
def health_monitor_loop():
|
||||
logger.info("Windows推流健康监控启动")
|
||||
while self.running:
|
||||
try:
|
||||
with self.lock:
|
||||
task_ids = list(self.task_streams.keys())
|
||||
|
||||
for task_id in task_ids:
|
||||
task_info = self.task_streams.get(task_id)
|
||||
if not task_info:
|
||||
continue
|
||||
|
||||
# 检查是否长时间无活动
|
||||
inactive_time = time.time() - task_info['last_active']
|
||||
if inactive_time > 30 and task_info['frame_count'] > 0:
|
||||
logger.warning(f"任务 {task_id} 已 {inactive_time:.0f} 秒无活动")
|
||||
self._safe_restart_streamer(task_id)
|
||||
|
||||
# 检查FFmpeg进程状态
|
||||
streamer = task_info['streamer']
|
||||
if streamer and hasattr(streamer, 'process'):
|
||||
if streamer.process.poll() is not None:
|
||||
logger.warning(f"任务 {task_id} FFmpeg进程已退出")
|
||||
self._safe_restart_streamer(task_id)
|
||||
|
||||
time.sleep(self.streaming_check_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"健康监控异常: {str(e)}")
|
||||
time.sleep(10)
|
||||
|
||||
self.health_check_thread = threading.Thread(
|
||||
target=health_monitor_loop,
|
||||
daemon=True,
|
||||
name="WindowsTaskStreamHealthMonitor"
|
||||
)
|
||||
self.health_check_thread.start()
|
||||
|
||||
def stop_health_monitor(self):
|
||||
"""停止健康监控"""
|
||||
self.running = False
|
||||
if self.health_check_thread and self.health_check_thread.is_alive():
|
||||
self.health_check_thread.join(3.0)
|
||||
|
||||
def cleanup_all(self):
|
||||
"""清理所有推流器"""
|
||||
logger.info("清理所有Windows任务推流器")
|
||||
with self.lock:
|
||||
task_ids = list(self.task_streams.keys())
|
||||
for task_id in task_ids:
|
||||
self.stop_task_streamer(task_id)
|
||||
self.stop_health_monitor()
|
||||
|
||||
|
||||
# Windows专用全局推流管理器实例
|
||||
windows_task_stream_manager = WindowsTaskStreamManager()
|
||||
1810
templates/flv2.html
1810
templates/flv2.html
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,171 @@
|
|||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
def test_create_multi_model_task():
|
||||
"""测试创建多模型任务"""
|
||||
url = "http://localhost:9309/api/tasks/create"
|
||||
|
||||
task_data = {
|
||||
"rtmp_url": "rtmp://175.27.168.120:6019/live/8UUXN5400A079H",
|
||||
"push_url": "rtmp://localhost:1935/live/15",
|
||||
"taskname": "多模型道路监控",
|
||||
"models": [
|
||||
{
|
||||
"path": "yolov8n_encrypted.pt",
|
||||
"encryption_key":"123456",
|
||||
"tags": {
|
||||
"0": {"name": "汽车", "reliability": 0.4},
|
||||
"1": {"name": "行人", "reliability": 0.3},
|
||||
"2": {"name": "自行车", "reliability": 0.35}
|
||||
},
|
||||
"conf_thres": 0.3,
|
||||
"imgsz": 640,
|
||||
"color": [0, 255, 0], # 绿色
|
||||
"device": "cuda:0",
|
||||
"line_width": 2,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"path": "yolov8n_encrypted.pt",
|
||||
"encryption_key": "123456",
|
||||
"tags": {
|
||||
"0": {"name": "轿车", "reliability": 0.5},
|
||||
"1": {"name": "卡车", "reliability": 0.6},
|
||||
"2": {"name": "公交车", "reliability": 0.55}
|
||||
},
|
||||
"conf_thres": 0.35,
|
||||
"imgsz": 1280,
|
||||
"color": [255, 0, 0], # 红色
|
||||
"device": "cuda:0",
|
||||
"line_width": 2,
|
||||
"enabled": True
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
print("发送多模型任务创建请求...")
|
||||
response = requests.post(url, json=task_data, timeout=30)
|
||||
print(f"响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"响应内容: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
||||
|
||||
if result.get('status') == 'success':
|
||||
task_id = result.get('task_id')
|
||||
print(f"\n✅ 多模型任务创建成功!")
|
||||
print(f"任务ID: {task_id}")
|
||||
print(f"任务名称: {task_data['taskname']}")
|
||||
print(f"加载模型数: {len(task_data['models'])}")
|
||||
|
||||
# 等待任务启动
|
||||
time.sleep(2)
|
||||
|
||||
# 测试获取任务详情
|
||||
status_url = f"http://localhost:9309/api/tasks/{task_id}/status"
|
||||
status_response = requests.get(status_url)
|
||||
if status_response.status_code == 200:
|
||||
status_data = status_response.json()
|
||||
print(f"\n任务状态: {status_data.get('data', {}).get('status')}")
|
||||
print(
|
||||
f"模型信息: {json.dumps(status_data.get('data', {}).get('models'), indent=2, ensure_ascii=False)}")
|
||||
|
||||
# 验证模型数量
|
||||
models = status_data.get('data', {}).get('models', [])
|
||||
if len(models) == 2:
|
||||
print("✅ 模型数量正确")
|
||||
else:
|
||||
print(f"❌ 模型数量错误: 期望2,实际{len(models)}")
|
||||
else:
|
||||
print(f"❌ 获取任务状态失败: {status_response.text}")
|
||||
else:
|
||||
print(f"❌ 请求失败: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
|
||||
|
||||
def test_get_all_tasks():
|
||||
"""测试获取所有任务"""
|
||||
url = "http://localhost:9309/api/tasks"
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"\n获取所有任务成功:")
|
||||
print(f"总任务数: {result.get('data', {}).get('total', 0)}")
|
||||
print(f"活动任务数: {result.get('data', {}).get('active', 0)}")
|
||||
print(f"总模型数: {result.get('data', {}).get('models_count', 0)}")
|
||||
|
||||
tasks = result.get('data', {}).get('tasks', [])
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\n任务 {i + 1}:")
|
||||
print(f" ID: {task.get('task_id')}")
|
||||
print(f" 名称: {task.get('config', {}).get('taskname')}")
|
||||
print(f" 状态: {task.get('status')}")
|
||||
print(f" 模型数: {len(task.get('models', []))}")
|
||||
|
||||
# 显示每个模型的详细信息
|
||||
for model in task.get('models', []):
|
||||
print(f" - {model.get('name')}: 阈值={model.get('conf_thres')}, 颜色={model.get('color')}")
|
||||
else:
|
||||
print(f"❌ 获取任务列表失败: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
|
||||
|
||||
def test_error_cases():
|
||||
"""测试错误情况"""
|
||||
url = "http://localhost:9309/api/tasks/create"
|
||||
|
||||
# 测试1: 缺少rtmp_url
|
||||
print("\n测试1: 缺少rtmp_url")
|
||||
task_data = {
|
||||
"taskname": "错误测试",
|
||||
"models": [{"path": "yolov8n.pt"}]
|
||||
}
|
||||
response = requests.post(url, json=task_data)
|
||||
print(f"响应: {response.status_code} - {response.text}")
|
||||
|
||||
# 测试2: 缺少models
|
||||
print("\n测试2: 缺少models")
|
||||
task_data = {
|
||||
"rtmp_url": "rtmp://localhost:1935/live/14",
|
||||
"taskname": "错误测试"
|
||||
}
|
||||
response = requests.post(url, json=task_data)
|
||||
print(f"响应: {response.status_code} - {response.text}")
|
||||
|
||||
# 测试3: models不是列表
|
||||
print("\n测试3: models不是列表")
|
||||
task_data = {
|
||||
"rtmp_url": "rtmp://localhost:1935/live/14",
|
||||
"taskname": "错误测试",
|
||||
"models": {"path": "yolov8n.pt"} # 应该是列表
|
||||
}
|
||||
response = requests.post(url, json=task_data)
|
||||
print(f"响应: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 50)
|
||||
print("多模型专用测试脚本")
|
||||
print("=" * 50)
|
||||
|
||||
# 测试正常情况
|
||||
test_create_multi_model_task()
|
||||
|
||||
# 测试获取所有任务
|
||||
test_get_all_tasks()
|
||||
|
||||
# 测试错误情况
|
||||
test_error_cases()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("测试完成!")
|
||||
print("=" * 50)
|
||||
Binary file not shown.
|
|
@ -0,0 +1,181 @@
|
|||
# windows_stream_diagnose.py
|
||||
import time
|
||||
import json
|
||||
from log import logger
|
||||
from windows_utils import WindowsSystemUtils
|
||||
|
||||
|
||||
class WindowsStreamDiagnoser:
|
||||
"""Windows推流诊断工具"""
|
||||
|
||||
def __init__(self, task_id=None):
|
||||
self.task_id = task_id
|
||||
|
||||
def run_full_diagnosis(self):
|
||||
"""运行完整诊断"""
|
||||
logger.info("开始Windows推流诊断...")
|
||||
|
||||
results = {
|
||||
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'system': {},
|
||||
'ffmpeg': {},
|
||||
'network': {},
|
||||
'recommendations': []
|
||||
}
|
||||
|
||||
# 1. 系统检查
|
||||
results['system'] = self.check_system()
|
||||
|
||||
# 2. FFmpeg检查
|
||||
results['ffmpeg'] = self.check_ffmpeg()
|
||||
|
||||
# 3. 网络检查
|
||||
results['network'] = self.check_network()
|
||||
|
||||
# 4. 生成建议
|
||||
results['recommendations'] = self.generate_recommendations(results)
|
||||
|
||||
# 输出报告
|
||||
self.print_report(results)
|
||||
|
||||
return results
|
||||
|
||||
def check_system(self):
|
||||
"""检查系统"""
|
||||
info = WindowsSystemUtils.get_windows_version()
|
||||
resources = WindowsSystemUtils.get_system_resources()
|
||||
|
||||
return {
|
||||
'os_info': info,
|
||||
'resources': resources,
|
||||
'issues': []
|
||||
}
|
||||
|
||||
def check_ffmpeg(self):
|
||||
"""检查FFmpeg"""
|
||||
ffmpeg_info = WindowsSystemUtils.check_ffmpeg_installation()
|
||||
|
||||
issues = []
|
||||
if not ffmpeg_info['installed']:
|
||||
issues.append("FFmpeg未安装")
|
||||
|
||||
return {
|
||||
'info': ffmpeg_info,
|
||||
'issues': issues
|
||||
}
|
||||
|
||||
def check_network(self):
|
||||
"""检查网络"""
|
||||
# 检查常见RTMP端口
|
||||
common_ports = [1935, 1936, 8080]
|
||||
results = {}
|
||||
|
||||
for port in common_ports:
|
||||
test_url = f"rtmp://localhost:{port}/live/test"
|
||||
result = WindowsSystemUtils.check_rtmp_server_accessibility(test_url)
|
||||
results[f'port_{port}'] = result
|
||||
|
||||
return results
|
||||
|
||||
def generate_recommendations(self, results):
|
||||
"""生成建议"""
|
||||
recommendations = []
|
||||
|
||||
# FFmpeg相关
|
||||
if not results['ffmpeg']['info']['installed']:
|
||||
recommendations.append({
|
||||
'priority': 'critical',
|
||||
'action': '安装FFmpeg',
|
||||
'details': '从 https://ffmpeg.org/download.html 下载并添加到PATH'
|
||||
})
|
||||
|
||||
# 系统资源相关
|
||||
resources = results['system']['resources']
|
||||
if resources['cpu_percent'] > 80:
|
||||
recommendations.append({
|
||||
'priority': 'high',
|
||||
'action': '降低CPU使用率',
|
||||
'details': '关闭不必要的程序,降低推理帧率'
|
||||
})
|
||||
|
||||
if resources['memory_percent'] > 85:
|
||||
recommendations.append({
|
||||
'priority': 'high',
|
||||
'action': '增加可用内存',
|
||||
'details': '关闭内存占用大的程序,考虑增加物理内存'
|
||||
})
|
||||
|
||||
# 网络相关
|
||||
network_results = results['network']
|
||||
all_inaccessible = all(not r.get('accessible', False) for r in network_results.values())
|
||||
if all_inaccessible:
|
||||
recommendations.append({
|
||||
'priority': 'high',
|
||||
'action': '配置防火墙规则',
|
||||
'details': '运行: netsh advfirewall firewall add rule name="RTMP" dir=in action=allow protocol=TCP localport=1935'
|
||||
})
|
||||
|
||||
# 通用优化
|
||||
recommendations.append({
|
||||
'priority': 'medium',
|
||||
'action': '使用软件编码',
|
||||
'details': 'Windows上使用libx264而非硬件编码更稳定'
|
||||
})
|
||||
|
||||
recommendations.append({
|
||||
'priority': 'low',
|
||||
'action': '降低推流质量',
|
||||
'details': '尝试降低分辨率和码率: 640x480 @ 1000kbps'
|
||||
})
|
||||
|
||||
return recommendations
|
||||
|
||||
def print_report(self, results):
|
||||
"""打印诊断报告"""
|
||||
print("\n" + "=" * 80)
|
||||
print("Windows推流诊断报告")
|
||||
print("=" * 80)
|
||||
|
||||
# 系统信息
|
||||
print("\n[系统信息]")
|
||||
print(f"操作系统: {results['system']['os_info'].get('caption', 'Unknown')}")
|
||||
print(f"架构: {results['system']['os_info'].get('architecture', 'Unknown')}")
|
||||
print(f"CPU使用率: {results['system']['resources']['cpu_percent']}%")
|
||||
print(f"内存使用率: {results['system']['resources']['memory_percent']}%")
|
||||
|
||||
# FFmpeg信息
|
||||
print("\n[FFmpeg信息]")
|
||||
ffmpeg_info = results['ffmpeg']['info']
|
||||
print(f"已安装: {'是' if ffmpeg_info['installed'] else '否'}")
|
||||
if ffmpeg_info['installed']:
|
||||
print(f"版本: {ffmpeg_info['version']}")
|
||||
|
||||
# 网络信息
|
||||
print("\n[网络信息]")
|
||||
for port, result in results['network'].items():
|
||||
status = "✓ 可达" if result.get('accessible') else "✗ 不可达"
|
||||
print(f"端口 {port}: {status}")
|
||||
if not result.get('accessible') and result.get('error'):
|
||||
print(f" 错误: {result['error']}")
|
||||
|
||||
# 建议
|
||||
print("\n[修复建议]")
|
||||
recommendations = sorted(results['recommendations'],
|
||||
key=lambda x: {'critical': 0, 'high': 1, 'medium': 2, 'low': 3}[x['priority']])
|
||||
|
||||
for rec in recommendations:
|
||||
priority_icon = {'critical': '🔴', 'high': '🟠', 'medium': '🟡', 'low': '🟢'}[rec['priority']]
|
||||
print(f"{priority_icon} [{rec['priority'].upper()}] {rec['action']}")
|
||||
print(f" {rec['details']}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
def diagnose_streaming():
|
||||
"""诊断推流问题"""
|
||||
diagnoser = WindowsStreamDiagnoser()
|
||||
return diagnoser.run_full_diagnosis()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
diagnose_streaming()
|
||||
|
|
@ -0,0 +1,682 @@
|
|||
# windows_utils.py
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
import ctypes
|
||||
import traceback
|
||||
from log import logger
|
||||
|
||||
|
||||
class WindowsSystemUtils:
|
||||
"""Windows系统工具类"""
|
||||
|
||||
@staticmethod
|
||||
def is_windows():
|
||||
"""检查是否是Windows系统"""
|
||||
return os.name == 'nt' or sys.platform.startswith('win')
|
||||
|
||||
@staticmethod
|
||||
def get_windows_version():
|
||||
"""获取Windows版本信息"""
|
||||
if not WindowsSystemUtils.is_windows():
|
||||
return None
|
||||
|
||||
version_info = {
|
||||
'system': platform.system(),
|
||||
'release': platform.release(),
|
||||
'version': platform.version(),
|
||||
'machine': platform.machine(),
|
||||
'processor': platform.processor(),
|
||||
'architecture': platform.architecture()[0]
|
||||
}
|
||||
|
||||
# 获取详细版本信息
|
||||
try:
|
||||
# 使用ctypes获取Windows版本
|
||||
class OSVERSIONINFOEX(ctypes.Structure):
|
||||
_fields_ = [
|
||||
('dwOSVersionInfoSize', ctypes.c_ulong),
|
||||
('dwMajorVersion', ctypes.c_ulong),
|
||||
('dwMinorVersion', ctypes.c_ulong),
|
||||
('dwBuildNumber', ctypes.c_ulong),
|
||||
('dwPlatformId', ctypes.c_ulong),
|
||||
('szCSDVersion', ctypes.c_wchar * 128),
|
||||
('wServicePackMajor', ctypes.c_ushort),
|
||||
('wServicePackMinor', ctypes.c_ushort),
|
||||
('wSuiteMask', ctypes.c_ushort),
|
||||
('wProductType', ctypes.c_byte),
|
||||
('wReserved', ctypes.c_byte)
|
||||
]
|
||||
|
||||
os_version = OSVERSIONINFOEX()
|
||||
os_version.dwOSVersionInfoSize = ctypes.sizeof(OSVERSIONINFOEX)
|
||||
|
||||
if ctypes.windll.Ntdll.RtlGetVersion(ctypes.byref(os_version)) == 0:
|
||||
version_info['major_version'] = os_version.dwMajorVersion
|
||||
version_info['minor_version'] = os_version.dwMinorVersion
|
||||
version_info['build_number'] = os_version.dwBuildNumber
|
||||
version_info['service_pack'] = os_version.wServicePackMajor
|
||||
|
||||
# 转换为可读版本名称
|
||||
if version_info['major_version'] == 10:
|
||||
version_info['name'] = f"Windows 10/11 (Build {version_info['build_number']})"
|
||||
elif version_info['major_version'] == 6:
|
||||
if version_info['minor_version'] == 3:
|
||||
version_info['name'] = "Windows 8.1"
|
||||
elif version_info['minor_version'] == 2:
|
||||
version_info['name'] = "Windows 8"
|
||||
elif version_info['minor_version'] == 1:
|
||||
version_info['name'] = "Windows 7"
|
||||
elif version_info['minor_version'] == 0:
|
||||
version_info['name'] = "Windows Vista"
|
||||
elif version_info['major_version'] == 5:
|
||||
if version_info['minor_version'] == 2:
|
||||
version_info['name'] = "Windows XP Professional x64"
|
||||
elif version_info['minor_version'] == 1:
|
||||
version_info['name'] = "Windows XP"
|
||||
except Exception as e:
|
||||
logger.warning(f"获取Windows版本信息失败: {str(e)}")
|
||||
|
||||
return version_info
|
||||
|
||||
@staticmethod
|
||||
def check_ffmpeg_installation():
|
||||
"""检查FFmpeg安装"""
|
||||
try:
|
||||
# Windows上使用where命令查找ffmpeg
|
||||
result = subprocess.run(
|
||||
'where ffmpeg',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
ffmpeg_path = result.stdout.strip().split('\n')[0]
|
||||
|
||||
# 获取版本信息
|
||||
version_result = subprocess.run(
|
||||
['ffmpeg', '-version'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
version_info = {}
|
||||
if version_result.returncode == 0:
|
||||
lines = version_result.stdout.split('\n')
|
||||
for line in lines:
|
||||
if 'version' in line.lower():
|
||||
version_info['version'] = line.strip()
|
||||
break
|
||||
|
||||
return {
|
||||
'installed': True,
|
||||
'path': ffmpeg_path,
|
||||
'version': version_info.get('version', 'unknown'),
|
||||
'details': lines[0] if lines else 'unknown'
|
||||
}
|
||||
except subprocess.TimeoutExpired:
|
||||
return {'installed': False, 'error': 'Timeout checking ffmpeg'}
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"检查FFmpeg失败: {str(e)}")
|
||||
|
||||
return {'installed': False, 'version': 'not found'}
|
||||
|
||||
@staticmethod
|
||||
def check_rtmp_server_accessibility(rtmp_url):
|
||||
"""检查RTMP服务器可达性"""
|
||||
if not rtmp_url.startswith('rtmp://'):
|
||||
return {'accessible': False, 'error': 'Invalid RTMP URL format'}
|
||||
|
||||
try:
|
||||
# 提取主机和端口
|
||||
parts = rtmp_url.replace('rtmp://', '').split('/')[0].split(':')
|
||||
host = parts[0]
|
||||
port = int(parts[1]) if len(parts) > 1 else 1935
|
||||
|
||||
# Windows网络连接检查
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(5)
|
||||
|
||||
result = sock.connect_ex((host, port))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
return {'accessible': True, 'host': host, 'port': port}
|
||||
else:
|
||||
# Windows错误代码含义
|
||||
error_map = {
|
||||
10061: 'Connection refused (服务器拒绝连接)',
|
||||
10060: 'Connection timed out (连接超时)',
|
||||
10013: 'Permission denied (权限被拒绝)',
|
||||
10048: 'Address already in use (地址已被使用)',
|
||||
10049: 'Cannot assign requested address (无法分配请求的地址)',
|
||||
10050: 'Network is down (网络断开)',
|
||||
10051: 'Network is unreachable (网络不可达)',
|
||||
}
|
||||
|
||||
error_msg = error_map.get(result, f'Unknown error code: {result}')
|
||||
|
||||
return {
|
||||
'accessible': False,
|
||||
'host': host,
|
||||
'port': port,
|
||||
'error_code': result,
|
||||
'error': error_msg
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {'accessible': False, 'error': str(e)}
|
||||
|
||||
@staticmethod
|
||||
def optimize_windows_for_streaming():
|
||||
"""优化Windows系统设置以支持推流"""
|
||||
optimizations = {}
|
||||
|
||||
# 1. 检查电源设置
|
||||
try:
|
||||
# 使用powercfg检查当前电源方案
|
||||
result = subprocess.run(
|
||||
'powercfg /getactivescheme',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
optimizations['power_scheme'] = result.stdout.strip()
|
||||
else:
|
||||
optimizations['power_scheme'] = 'Cannot check power scheme'
|
||||
|
||||
# 建议使用高性能电源方案
|
||||
optimizations['power_recommendation'] = (
|
||||
'For best streaming performance, use High Performance power plan. '
|
||||
'Run: powercfg /setactive 8c5e7fda-e8bf-4a96-9a85-a6e23a8c635c'
|
||||
)
|
||||
except Exception as e:
|
||||
optimizations['power_scheme'] = f'Error checking power: {str(e)}'
|
||||
|
||||
# 2. 防火墙建议
|
||||
optimizations['firewall'] = 'Consider adding firewall rules for RTMP ports (1935, 1936)'
|
||||
|
||||
# 3. 网络优化建议
|
||||
optimizations['network_optimizations'] = [
|
||||
'Increase TCP buffer size: netsh int tcp set global autotuninglevel=normal',
|
||||
'Disable TCP auto-tuning (if unstable): netsh int tcp set global autotuninglevel=disabled',
|
||||
'Enable TCP fast open: netsh int tcp set global fastopen=enabled',
|
||||
'Set TCP keepalive: netsh int tcp set global keepalivetime=30000'
|
||||
]
|
||||
|
||||
# 4. 显卡设置建议
|
||||
optimizations['gpu_recommendations'] = [
|
||||
'Update graphics drivers to latest version',
|
||||
'In NVIDIA Control Panel: Set Power Management Mode to "Prefer Maximum Performance"',
|
||||
'In Windows Graphics Settings: Add ffmpeg.exe and set to "High Performance"'
|
||||
]
|
||||
|
||||
return optimizations
|
||||
|
||||
@staticmethod
|
||||
def create_windows_firewall_rule(port, name="RTMP Streaming"):
|
||||
"""创建Windows防火墙规则"""
|
||||
commands = [
|
||||
f'netsh advfirewall firewall add rule name="{name}" dir=in action=allow protocol=TCP localport={port}',
|
||||
f'netsh advfirewall firewall add rule name="{name} UDP" dir=in action=allow protocol=UDP localport={port}'
|
||||
]
|
||||
|
||||
results = []
|
||||
for cmd in commands:
|
||||
try:
|
||||
# Windows上以管理员权限运行
|
||||
import ctypes
|
||||
is_admin = ctypes.windll.shell32.IsUserAnAdmin() != 0
|
||||
if not is_admin:
|
||||
results.append({
|
||||
'command': cmd,
|
||||
'success': False,
|
||||
'error': '需要管理员权限运行此命令'
|
||||
})
|
||||
continue
|
||||
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
results.append({
|
||||
'command': cmd,
|
||||
'success': result.returncode == 0,
|
||||
'output': result.stdout,
|
||||
'error': result.stderr
|
||||
})
|
||||
except Exception as e:
|
||||
results.append({
|
||||
'command': cmd,
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_system_resources():
|
||||
"""获取Windows系统资源 - 修复版"""
|
||||
import psutil
|
||||
|
||||
resources = {
|
||||
'cpu_percent': psutil.cpu_percent(interval=0.1),
|
||||
'cpu_count': psutil.cpu_count(),
|
||||
'cpu_freq': getattr(psutil.cpu_freq(), 'current', 0) if psutil.cpu_freq() else 0,
|
||||
'memory_percent': psutil.virtual_memory().percent,
|
||||
'memory_used_gb': psutil.virtual_memory().used / (1024 ** 3),
|
||||
'memory_total_gb': psutil.virtual_memory().total / (1024 ** 3),
|
||||
'process_count': len(psutil.pids()),
|
||||
}
|
||||
|
||||
try:
|
||||
# 磁盘IO(可能在某些系统上不可用)
|
||||
disk_io = psutil.disk_io_counters()
|
||||
if disk_io:
|
||||
resources['disk_io'] = {
|
||||
'read_bytes': disk_io.read_bytes,
|
||||
'write_bytes': disk_io.write_bytes,
|
||||
'read_count': disk_io.read_count,
|
||||
'write_count': disk_io.write_count
|
||||
}
|
||||
except:
|
||||
resources['disk_io'] = {'available': False}
|
||||
|
||||
try:
|
||||
# 网络IO
|
||||
net_io = psutil.net_io_counters()
|
||||
resources['network_io'] = {
|
||||
'bytes_sent': net_io.bytes_sent,
|
||||
'bytes_recv': net_io.bytes_recv,
|
||||
'packets_sent': net_io.packets_sent,
|
||||
'packets_recv': net_io.packets_recv
|
||||
}
|
||||
except:
|
||||
resources['network_io'] = {'available': False}
|
||||
|
||||
# GPU信息 - 多种方法获取
|
||||
resources['gpu_info'] = WindowsSystemUtils._get_gpu_info()
|
||||
|
||||
return resources
|
||||
|
||||
@staticmethod
|
||||
def _get_gpu_info():
|
||||
"""获取GPU信息 - 支持多种方法"""
|
||||
gpu_info = {
|
||||
'method': 'none',
|
||||
'gpus': []
|
||||
}
|
||||
# 方法2: 使用torch获取GPU信息 (如果可用)
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_data = {
|
||||
'name': torch.cuda.get_device_name(i),
|
||||
'total_memory_mb': torch.cuda.get_device_properties(i).total_memory / (1024 ** 2),
|
||||
'method': 'torch'
|
||||
}
|
||||
gpu_info['gpus'].append(gpu_data)
|
||||
|
||||
if gpu_info['gpus']:
|
||||
gpu_info['method'] = 'torch'
|
||||
return gpu_info
|
||||
except ImportError:
|
||||
logger.debug("torch未安装,跳过PyTorch GPU检测")
|
||||
except Exception as e:
|
||||
logger.debug(f"PyTorch GPU检测失败: {str(e)}")
|
||||
|
||||
# 方法3: 使用dxdiag命令获取GPU信息 (Windows原生)
|
||||
try:
|
||||
# 运行dxdiag并保存到临时文件
|
||||
import tempfile
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
temp_file.close()
|
||||
|
||||
result = subprocess.run(
|
||||
f'dxdiag /t {temp_file.name}',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
with open(temp_file.name, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
dxdiag_content = f.read()
|
||||
|
||||
# 解析dxdiag输出
|
||||
import re
|
||||
|
||||
# 查找显示设备部分
|
||||
display_sections = re.split(r'Display Devices\n-+', dxdiag_content)
|
||||
if len(display_sections) > 1:
|
||||
for section in display_sections[1:]:
|
||||
# 提取显卡名称
|
||||
name_match = re.search(r'Card name:\s*(.+)', section)
|
||||
# 提取显存
|
||||
memory_match = re.search(r'Display Memory:\s*(\d+)', section)
|
||||
# 提取驱动版本
|
||||
driver_match = re.search(r'Driver Version:\s*(.+)', section)
|
||||
|
||||
gpu_data = {
|
||||
'name': name_match.group(1).strip() if name_match else 'Unknown GPU',
|
||||
'display_memory_mb': int(memory_match.group(1)) if memory_match else 0,
|
||||
'driver_version': driver_match.group(1).strip() if driver_match else 'Unknown',
|
||||
'method': 'dxdiag'
|
||||
}
|
||||
|
||||
gpu_info['gpus'].append(gpu_data)
|
||||
|
||||
if gpu_info['gpus']:
|
||||
gpu_info['method'] = 'dxdiag'
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.unlink(temp_file.name)
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"dxdiag GPU检测失败: {str(e)}")
|
||||
|
||||
# 方法4: 使用Windows注册表 (高级方法)
|
||||
try:
|
||||
import winreg
|
||||
|
||||
# 打开显卡注册表键
|
||||
reg_path = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}"
|
||||
|
||||
with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, reg_path) as key:
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
subkey_name = winreg.EnumKey(key, i)
|
||||
subkey_path = f"{reg_path}\\{subkey_name}"
|
||||
|
||||
with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, subkey_path) as subkey:
|
||||
try:
|
||||
driver_desc = winreg.QueryValueEx(subkey, "DriverDesc")[0]
|
||||
# 跳过基本显示适配器
|
||||
if "basic display" not in driver_desc.lower():
|
||||
gpu_data = {
|
||||
'name': driver_desc,
|
||||
'method': 'registry'
|
||||
}
|
||||
gpu_info['gpus'].append(gpu_data)
|
||||
except WindowsError:
|
||||
pass
|
||||
i += 1
|
||||
except WindowsError:
|
||||
break
|
||||
|
||||
if gpu_info['gpus']:
|
||||
gpu_info['method'] = 'registry'
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"注册表GPU检测失败: {str(e)}")
|
||||
|
||||
return gpu_info
|
||||
|
||||
@staticmethod
|
||||
def test_ffmpeg_streaming():
|
||||
"""测试FFmpeg推流功能"""
|
||||
test_results = []
|
||||
|
||||
# 测试1: 基本的FFmpeg功能
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['ffmpeg', '-version'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
test_results.append({
|
||||
'test': 'ffmpeg_version',
|
||||
'success': result.returncode == 0,
|
||||
'output': result.stdout[:100] if result.stdout else 'No output'
|
||||
})
|
||||
except Exception as e:
|
||||
test_results.append({
|
||||
'test': 'ffmpeg_version',
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
# 测试2: 编码器支持
|
||||
|
||||
# 测试3: 生成测试视频并推流到null (本地测试)
|
||||
try:
|
||||
# 创建一个简单的测试命令
|
||||
test_command = [
|
||||
'ffmpeg',
|
||||
'-f', 'lavfi',
|
||||
'-i', 'testsrc=duration=2:size=640x480:rate=30',
|
||||
'-c:v', 'libx264',
|
||||
'-t', '1', # 只运行1秒
|
||||
'-f', 'null', # 输出到null
|
||||
'-'
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
test_command,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
test_results.append({
|
||||
'test': 'ffmpeg_basic_encode',
|
||||
'success': result.returncode == 0,
|
||||
'output': 'Success' if result.returncode == 0 else result.stderr[:200]
|
||||
})
|
||||
except Exception as e:
|
||||
test_results.append({
|
||||
'test': 'ffmpeg_basic_encode',
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return test_results
|
||||
|
||||
|
||||
# 系统检测和配置
|
||||
def detect_and_configure_windows():
|
||||
"""检测并配置Windows系统"""
|
||||
if not WindowsSystemUtils.is_windows():
|
||||
logger.info("非Windows系统,跳过Windows特定配置")
|
||||
return None
|
||||
|
||||
logger.info("检测到Windows系统,进行系统检测和配置...")
|
||||
|
||||
config_result = {
|
||||
'system_info': {},
|
||||
'ffmpeg_info': {},
|
||||
'gpu_info': {},
|
||||
'optimizations': {},
|
||||
'resources': {},
|
||||
'ffmpeg_tests': [],
|
||||
'status': 'unknown',
|
||||
'issues': [],
|
||||
'recommendations': []
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 获取系统信息
|
||||
system_info = WindowsSystemUtils.get_windows_version()
|
||||
config_result['system_info'] = system_info
|
||||
logger.info(f"Windows系统: {system_info.get('name', 'Unknown')}")
|
||||
|
||||
# 2. 检查FFmpeg
|
||||
ffmpeg_info = WindowsSystemUtils.check_ffmpeg_installation()
|
||||
config_result['ffmpeg_info'] = ffmpeg_info
|
||||
|
||||
if not ffmpeg_info['installed']:
|
||||
config_result['issues'].append("FFmpeg未安装")
|
||||
config_result['recommendations'].append({
|
||||
'priority': 'critical',
|
||||
'action': '安装FFmpeg',
|
||||
'details': '从 https://github.com/BtbN/FFmpeg-Builds/releases 下载 ffmpeg-master-latest-win64-gpl.zip,解压并添加bin目录到PATH'
|
||||
})
|
||||
logger.error("FFmpeg未安装!")
|
||||
else:
|
||||
logger.info(f"FFmpeg版本: {ffmpeg_info['version']}")
|
||||
logger.info(f"FFmpeg路径: {ffmpeg_info.get('path', 'unknown')}")
|
||||
|
||||
# 运行FFmpeg测试
|
||||
logger.info("运行FFmpeg功能测试...")
|
||||
ffmpeg_tests = WindowsSystemUtils.test_ffmpeg_streaming()
|
||||
config_result['ffmpeg_tests'] = ffmpeg_tests
|
||||
|
||||
# 检查测试结果
|
||||
failed_tests = [t for t in ffmpeg_tests if not t.get('success', False)]
|
||||
if failed_tests:
|
||||
config_result['issues'].append(f"FFmpeg测试失败 ({len(failed_tests)}个)")
|
||||
for test in failed_tests:
|
||||
logger.warning(f"FFmpeg测试失败: {test.get('test')}")
|
||||
else:
|
||||
logger.info(f"FFmpeg功能测试成功!")
|
||||
# 3. 获取GPU信息
|
||||
gpu_info = WindowsSystemUtils._get_gpu_info()
|
||||
config_result['gpu_info'] = gpu_info
|
||||
|
||||
if gpu_info['gpus']:
|
||||
logger.info(f"检测到 {len(gpu_info['gpus'])} 个GPU:")
|
||||
for i, gpu in enumerate(gpu_info['gpus']):
|
||||
logger.info(f" GPU{i}: {gpu.get('name', 'Unknown')}")
|
||||
if gpu.get('adapter_ram_mb', 0) > 0:
|
||||
logger.info(f" 显存: {gpu['adapter_ram_mb']} MB")
|
||||
else:
|
||||
logger.warning("未检测到GPU信息")
|
||||
config_result['issues'].append("未检测到GPU信息")
|
||||
|
||||
# 4. 系统优化建议
|
||||
optimizations = WindowsSystemUtils.optimize_windows_for_streaming()
|
||||
config_result['optimizations'] = optimizations
|
||||
|
||||
# 输出优化建议
|
||||
logger.info("系统优化建议:")
|
||||
for key, value in optimizations.items():
|
||||
if isinstance(value, list):
|
||||
logger.info(f" {key}:")
|
||||
for item in value:
|
||||
logger.info(f" - {item}")
|
||||
else:
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# 5. 获取系统资源
|
||||
resources = WindowsSystemUtils.get_system_resources()
|
||||
config_result['resources'] = resources
|
||||
|
||||
logger.info(f"系统资源: CPU {resources['cpu_percent']}%, 内存 {resources['memory_percent']}%")
|
||||
|
||||
# 检查资源问题
|
||||
if resources['cpu_percent'] > 85:
|
||||
config_result['issues'].append(f"CPU使用率过高: {resources['cpu_percent']}%")
|
||||
config_result['recommendations'].append({
|
||||
'priority': 'high',
|
||||
'action': '降低CPU使用率',
|
||||
'details': '关闭不必要的程序,减少后台进程'
|
||||
})
|
||||
|
||||
if resources['memory_percent'] > 90:
|
||||
config_result['issues'].append(f"内存使用率过高: {resources['memory_percent']}%")
|
||||
config_result['recommendations'].append({
|
||||
'priority': 'high',
|
||||
'action': '释放内存',
|
||||
'details': '关闭内存占用大的程序,重启系统'
|
||||
})
|
||||
|
||||
# 6. 生成网络测试建议
|
||||
config_result['recommendations'].append({
|
||||
'priority': 'medium',
|
||||
'action': '测试RTMP服务器连接',
|
||||
'details': '运行: python -c "import socket; sock=socket.socket(); sock.settimeout(5); print(\'OK\' if sock.connect_ex((\'your-server\', 1935))==0 else \'FAIL\')"'
|
||||
})
|
||||
|
||||
# 7. 推流配置建议
|
||||
config_result['recommendations'].append({
|
||||
'priority': 'medium',
|
||||
'action': '使用软件编码',
|
||||
'details': 'Windows上建议使用libx264而非硬件编码,更稳定'
|
||||
})
|
||||
|
||||
config_result['recommendations'].append({
|
||||
'priority': 'low',
|
||||
'action': '调整推流参数',
|
||||
'details': '尝试: -preset ultrafast -tune zerolatency -b:v 1500k -maxrate 1500k -bufsize 3000k'
|
||||
})
|
||||
|
||||
# 设置状态
|
||||
if config_result['issues']:
|
||||
config_result['status'] = 'warning'
|
||||
logger.warning(f"发现 {len(config_result['issues'])} 个问题")
|
||||
else:
|
||||
config_result['status'] = 'ready'
|
||||
logger.info("Windows系统检测完成,状态正常")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Windows系统检测异常: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
config_result['status'] = 'error'
|
||||
config_result['error'] = str(e)
|
||||
|
||||
return config_result
|
||||
|
||||
|
||||
# Windows快速诊断函数
|
||||
def quick_windows_diagnosis():
|
||||
"""快速Windows诊断"""
|
||||
logger.info("开始Windows快速诊断...")
|
||||
|
||||
results = {
|
||||
'ffmpeg': WindowsSystemUtils.check_ffmpeg_installation(),
|
||||
'system': WindowsSystemUtils.get_windows_version(),
|
||||
'gpu': WindowsSystemUtils._get_gpu_info(),
|
||||
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
# 快速资源检查
|
||||
try:
|
||||
import psutil
|
||||
results['resources'] = {
|
||||
'cpu': psutil.cpu_percent(),
|
||||
'memory': psutil.virtual_memory().percent
|
||||
}
|
||||
except:
|
||||
results['resources'] = {'error': 'Failed to get resources'}
|
||||
|
||||
# 输出摘要
|
||||
print("\n" + "=" * 60)
|
||||
print("Windows快速诊断结果")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\n系统: {results['system'].get('name', 'Unknown')}")
|
||||
print(f"时间: {results['timestamp']}")
|
||||
|
||||
if results['ffmpeg']['installed']:
|
||||
print(f"✓ FFmpeg: {results['ffmpeg']['version']}")
|
||||
else:
|
||||
print("✗ FFmpeg: 未安装")
|
||||
|
||||
gpu_count = len(results['gpu'].get('gpus', []))
|
||||
if gpu_count > 0:
|
||||
print(f"✓ GPU: {gpu_count}个检测到")
|
||||
for i, gpu in enumerate(results['gpu']['gpus'][:2]): # 只显示前2个
|
||||
print(f" {gpu.get('name', 'Unknown GPU')}")
|
||||
else:
|
||||
print("⚠ GPU: 未检测到")
|
||||
|
||||
if 'resources' in results and 'error' not in results['resources']:
|
||||
print(f"系统负载: CPU {results['resources']['cpu']}%, 内存 {results['resources']['memory']}%")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
return results
|
||||
110594
yolo_detection.log
110594
yolo_detection.log
File diff suppressed because it is too large
Load Diff
155563
yolo_detection.log.1
155563
yolo_detection.log.1
File diff suppressed because one or more lines are too long
191109
yolo_detection.log.2
191109
yolo_detection.log.2
File diff suppressed because it is too large
Load Diff
153074
yolo_detection.log.3
153074
yolo_detection.log.3
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue