636 lines
25 KiB
Python
636 lines
25 KiB
Python
import datetime
|
||
import gc
|
||
import json
|
||
import os
|
||
import queue
|
||
import threading
|
||
import time
|
||
import traceback
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import paho.mqtt.client as mqtt
|
||
import requests
|
||
import torch
|
||
from ultralytics import YOLO
|
||
|
||
from _minio import MinioUploader
|
||
from ffmpegStreamer import FFmpegStreamer
|
||
from log import logger
|
||
import global_data as gd
|
||
from mapping_cn import class_mapping_cn as cmc
|
||
|
||
class DetectionThread(threading.Thread):
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
self.running = True
|
||
self.model = None
|
||
self.cap = None
|
||
self.streamer = None
|
||
self.frame_count = 0
|
||
self.frame_skip_counter = 0
|
||
self.reconnect_attempts = 0
|
||
self.last_frame_time = time.time()
|
||
self.fps = 0
|
||
self.detections_count = 0
|
||
self.stop_event = threading.Event()
|
||
self.rtmp_url = config['rtmp']['url']
|
||
self.max_reconnect_attempts = config['rtmp']['max_reconnect_attempts']
|
||
self.reconnect_delay = config['rtmp']['reconnect_delay']
|
||
self.buffer_size = config['rtmp']['buffer_size']
|
||
self.timeout_ms = config['rtmp']['timeout_ms']
|
||
self.taskname = config['task']['taskname']
|
||
self.taskid = config['task']['taskid']
|
||
self.tag = config['task']['tag']
|
||
self.aiid = config['task']['aiid']
|
||
self.last_log_time = time.time()
|
||
self.daemon = True
|
||
self.frame_skip = 0
|
||
self.target_latency = 0.05
|
||
self.max_processing_time = 0.033
|
||
self.last_processing_time = 0
|
||
self.prev_frame = None
|
||
self.prev_results = None
|
||
self.processing_times = []
|
||
self.avg_process_time = 0.033
|
||
self.last_status_update = 0
|
||
self.last_fps = 0
|
||
self.original_width = 0
|
||
self.original_height = 0
|
||
self.imgsz = config['predict']['imgsz']
|
||
self.minio_uploader = MinioUploader(config['minio'])
|
||
|
||
# 上传队列和线程
|
||
self.upload_queue = queue.Queue(maxsize=50)
|
||
self.upload_thread = None
|
||
self.upload_active = False
|
||
self.upload_interval = 2
|
||
self.last_upload_time = 0
|
||
self.res_api = config['task']['res_api']
|
||
|
||
# MQTT配置
|
||
self.mqtt_config = config.get('mqtt', {})
|
||
self.mqtt_enabled = self.mqtt_config.get('enable', False)
|
||
self.mqtt_topic = self.mqtt_config.get('topic', 'drone/data')
|
||
self.mqtt_client = None
|
||
self.mqtt_connected = False
|
||
self.latest_drone_data = None
|
||
self.mqtt_data_lock = threading.Lock()
|
||
|
||
def on_mqtt_connect(self, client, userdata, flags, rc):
|
||
if rc == 0:
|
||
client.subscribe(self.mqtt_topic)
|
||
self.mqtt_connected = True
|
||
# logger.info("MQTT连接成功")
|
||
# logger.info(f"已订阅主题: {self.mqtt_topic}")
|
||
else:
|
||
logger.error(f"MQTT连接失败,错误码: {rc}")
|
||
self.mqtt_connected = False
|
||
|
||
def on_mqtt_message(self, client, userdata, msg):
|
||
try:
|
||
drone_data = json.loads(msg.payload.decode())
|
||
with self.mqtt_data_lock:
|
||
self.latest_drone_data = drone_data
|
||
logger.debug(f"收到MQTT消息: {drone_data}")
|
||
except Exception as e:
|
||
logger.error(f"解析MQTT消息失败: {str(e)}")
|
||
|
||
def start_mqtt_client(self):
|
||
if not self.mqtt_enabled:
|
||
logger.info("MQTT功能未启用")
|
||
return False
|
||
try:
|
||
logger.info("启动MQTT客户端...")
|
||
self.mqtt_client = mqtt.Client(client_id=self.mqtt_config.get('client_id', 'yolo_detection'))
|
||
self.mqtt_client.on_connect = self.on_mqtt_connect
|
||
self.mqtt_client.on_message = self.on_mqtt_message
|
||
|
||
if 'username' in self.mqtt_config and 'password' in self.mqtt_config:
|
||
self.mqtt_client.username_pw_set(
|
||
self.mqtt_config['username'],
|
||
self.mqtt_config['password']
|
||
)
|
||
|
||
self.mqtt_client.connect(
|
||
self.mqtt_config['broker'],
|
||
self.mqtt_config.get('port', 1883),
|
||
self.mqtt_config.get('keepalive', 60)
|
||
)
|
||
|
||
self.mqtt_client.loop_start()
|
||
logger.info("MQTT客户端已启动")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"启动MQTT客户端失败: {str(e)}")
|
||
return False
|
||
|
||
def stop_mqtt_client(self):
|
||
if self.mqtt_client:
|
||
try:
|
||
self.mqtt_client.loop_stop()
|
||
self.mqtt_client.disconnect()
|
||
logger.info("MQTT客户端已停止")
|
||
except Exception as e:
|
||
logger.error(f"停止MQTT客户端失败: {str(e)}")
|
||
finally:
|
||
self.mqtt_client = None
|
||
self.mqtt_connected = False
|
||
|
||
def run(self):
|
||
global detection_active
|
||
try:
|
||
logger.info("启动优化检测线程")
|
||
self.initialize_resources()
|
||
logger.info("资源初始化完成")
|
||
|
||
if self.mqtt_enabled:
|
||
self.start_mqtt_client()
|
||
|
||
self.upload_active = True
|
||
self.upload_thread = threading.Thread(target=self._upload_worker, daemon=True)
|
||
self.upload_thread.start()
|
||
logger.info("图片上传线程已启动")
|
||
|
||
# 预热模型
|
||
logger.info("预热模型...")
|
||
dummy_input = torch.zeros(1, 3, self.imgsz, self.imgsz)
|
||
dummy_input = dummy_input.to(self.config['predict']['device'])
|
||
if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']:
|
||
dummy_input = dummy_input.half()
|
||
|
||
for _ in range(5):
|
||
self.model.predict(dummy_input)
|
||
logger.info("模型预热完成")
|
||
|
||
while self.running and not self.stop_event.is_set():
|
||
start_time = time.perf_counter()
|
||
|
||
# 优化帧读取
|
||
ret, frame = self.cap.read()
|
||
logger.debug(f'读取帧结果: {ret}')
|
||
# 处理重连
|
||
if not ret:
|
||
self.handle_reconnect()
|
||
continue
|
||
|
||
# 确保帧分辨率一致
|
||
if frame.shape[1] != self.original_width or frame.shape[0] != self.original_height:
|
||
logger.warning(
|
||
f"帧分辨率不匹配: 预期 {self.original_width}x{self.original_height}, 实际 {frame.shape[1]}x{frame.shape[0]}")
|
||
frame = cv2.resize(frame, (self.original_width, self.original_height))
|
||
|
||
# 计算FPS
|
||
current_time = time.time()
|
||
time_diff = current_time - self.last_frame_time
|
||
if time_diff > 0:
|
||
self.fps = 0.9 * self.fps + 0.1 / time_diff
|
||
if current_time - self.last_status_update > 0.5:
|
||
self.last_fps = self.fps
|
||
self.last_status_update = current_time
|
||
self.last_frame_time = current_time
|
||
|
||
# 最低延迟处理逻辑
|
||
processing_time = time.perf_counter() - start_time
|
||
if processing_time > self.target_latency:
|
||
self.frame_skip_counter += 1
|
||
self.frame_count += 1
|
||
continue
|
||
|
||
# 性能优化:跳过帧处理
|
||
if self.fps < 15 and self.frame_skip < 2:
|
||
self.frame_skip += 1
|
||
self.frame_count += 1
|
||
continue
|
||
self.frame_skip = 0
|
||
|
||
# 使用帧差分算法减少处理量
|
||
if self.prev_frame is not None:
|
||
# 确保前一帧与当前帧分辨率一致
|
||
if self.prev_frame.shape != frame.shape:
|
||
logger.warning(f"前一帧分辨率不匹配: 预期 {frame.shape}, 实际 {self.prev_frame.shape}")
|
||
self.prev_frame = cv2.resize(self.prev_frame, (frame.shape[1], frame.shape[0]))
|
||
# 使用优化的推理
|
||
results = self.model(
|
||
frame,
|
||
stream=False,
|
||
verbose=False,
|
||
conf=self.config['predict']['conf_thres'],
|
||
iou=self.config['predict']['iou_thres'],
|
||
imgsz=self.imgsz,
|
||
device=self.config['predict']['device'],
|
||
half=self.config['predict'].get('half', False)
|
||
)
|
||
self.prev_results = results
|
||
else:
|
||
results = self.model(
|
||
frame,
|
||
stream=False,
|
||
verbose=False,
|
||
conf=self.config['predict']['conf_thres'],
|
||
iou=self.config['predict']['iou_thres'],
|
||
imgsz=self.imgsz,
|
||
device=self.config['predict']['device'],
|
||
half=self.config['predict'].get('half', False)
|
||
)
|
||
self.prev_frame = frame.copy() # 使用副本避免引用问题
|
||
|
||
# 处理结果
|
||
annotated_frame, detection_data = self.process_results(frame, results)
|
||
self.detections_count = len(detection_data)
|
||
|
||
# 仅在检测到目标且满足时间间隔时添加上传任务
|
||
if len(detection_data) > 0 and (current_time - self.last_upload_time >= self.upload_interval):
|
||
try:
|
||
timestamp = int(current_time * 1000)
|
||
filename = f"DJI_{timestamp}.jpg"
|
||
self.upload_queue.put({
|
||
"image": annotated_frame.copy(),
|
||
"filename": filename,
|
||
"detection_data": detection_data,
|
||
"timestamp": current_time
|
||
})
|
||
self.last_upload_time = current_time
|
||
except Exception as e:
|
||
logger.error(f"添加上传任务失败: {e}")
|
||
|
||
# 推流处理
|
||
if self.config['push']['enable_push'] and self.streamer:
|
||
try:
|
||
self.streamer.add_frame(annotated_frame)
|
||
except queue.Full:
|
||
logger.warning("推流队列已满,丢弃帧")
|
||
# WebSocket发送优化
|
||
if current_time - self.last_log_time >= 1:
|
||
self.send_to_websocket(detection_data)
|
||
self.last_log_time = current_time
|
||
|
||
self.frame_count += 1
|
||
self.reconnect_attempts = 0
|
||
|
||
# 性能监控
|
||
elapsed = time.perf_counter() - start_time
|
||
self.processing_times.append(elapsed)
|
||
self.avg_process_time = self.avg_process_time * 0.9 + elapsed * 0.1
|
||
|
||
# 每50帧输出一次性能报告
|
||
if self.frame_count % 50 == 0:
|
||
avg_time = sum(self.processing_times) / len(self.processing_times)
|
||
logger.info(f"帧处理耗时: {avg_time * 1000:.2f}ms | 平均FPS: {1.0 / avg_time:.1f}")
|
||
self.processing_times = []
|
||
|
||
# 动态调整跳帧参数
|
||
if avg_time > 0.15:
|
||
self.config['predict']['frame_skip'] = min(4, self.config['predict']['frame_skip'] + 1)
|
||
logger.info(f"增加跳帧至 {self.config['predict']['frame_skip']}")
|
||
elif avg_time < 0.05 and self.config['predict']['frame_skip'] > 0:
|
||
self.config['predict']['frame_skip'] = max(0, self.config['predict']['frame_skip'] - 1)
|
||
logger.info(f"减少跳帧至 {self.config['predict']['frame_skip']}")
|
||
|
||
logger.info("检测线程主循环结束")
|
||
except Exception as e:
|
||
logger.error(f"检测线程异常: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
finally:
|
||
self.cleanup()
|
||
logger.info("检测线程已安全停止")
|
||
gd.set_value('detection_active', False)
|
||
|
||
def _upload_worker(self):
|
||
"""独立的图片上传工作线程"""
|
||
logger.info("上传工作线程启动")
|
||
output_dir = "output_frames"
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
while self.upload_active or not self.upload_queue.empty():
|
||
try:
|
||
task = self.upload_queue.get(timeout=1.0)
|
||
if task is None:
|
||
break
|
||
|
||
start_time = time.time()
|
||
image = task["image"]
|
||
filename = task["filename"]
|
||
detection_data = task["detection_data"]
|
||
filepath = os.path.join(output_dir, filename)
|
||
|
||
cv2.imwrite(filepath, image)
|
||
|
||
try:
|
||
foldername = self.config['task']['taskname']
|
||
object_path = f"{foldername}/{filename}"
|
||
self.minio_uploader.upload_file(filepath, object_path)
|
||
|
||
payload = {
|
||
"taskid": self.taskid,
|
||
"path": object_path,
|
||
"tag": detection_data,
|
||
"aiid": self.aiid,
|
||
}
|
||
|
||
# 添加MQTT数据
|
||
if self.mqtt_enabled and self.mqtt_connected:
|
||
with self.mqtt_data_lock:
|
||
if self.latest_drone_data:
|
||
payload["drone_info"] = self.latest_drone_data
|
||
logger.info("添加无人机数据到API调用")
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
response = requests.post(self.res_api, json=payload, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
logger.info(f"已上传帧至 MinIO: {object_path} | 耗时: {time.time() - start_time:.2f}s")
|
||
else:
|
||
logger.warning(f"API调用失败: {response.status_code} - {response.text}")
|
||
except Exception as e:
|
||
logger.error(f"上传/API调用失败: {e}")
|
||
finally:
|
||
try:
|
||
os.remove(filepath)
|
||
except:
|
||
pass
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"上传任务处理异常: {e}")
|
||
|
||
logger.info("上传工作线程已停止")
|
||
|
||
def process_results(self, frame, results):
|
||
result = results[0]
|
||
annotated_frame = []
|
||
# for result in results:
|
||
annotated_frame = result.plot(
|
||
img=frame,
|
||
line_width=self.config['predict']['line_width'],
|
||
# font_size=self.config['predict']['font_size'],
|
||
font= self.config['predict']['font'],
|
||
conf=True,
|
||
labels=True,
|
||
probs=True
|
||
)
|
||
if annotated_frame.shape != frame.shape:
|
||
logger.warning(f"渲染帧分辨率不匹配: 输入 {frame.shape}, 输出 {annotated_frame.shape}")
|
||
annotated_frame = cv2.resize(annotated_frame, (frame.shape[1], frame.shape[0]))
|
||
# break
|
||
|
||
detection_data = []
|
||
# for result in results:
|
||
# print("result",result)
|
||
boxes = result.boxes
|
||
for box in boxes:
|
||
class_id = int(box.cls)
|
||
confidence = float(box.conf)
|
||
class_name_obj = self.tag.get(str(class_id), None)
|
||
if class_name_obj is None:
|
||
class_name_obj = cmc.get(str(class_id), None)
|
||
|
||
if float(box.conf) < class_name_obj['reliability']:
|
||
continue
|
||
|
||
detection_data.append({
|
||
'class_id': class_id,
|
||
'class_name': class_name_obj['name'],
|
||
'confidence': confidence,
|
||
'box': box.xyxy[0].tolist(),
|
||
'reliability': class_name_obj['reliability']
|
||
})
|
||
return annotated_frame, detection_data
|
||
|
||
def initialize_resources(self):
|
||
"""初始化所有资源"""
|
||
# 加载模型
|
||
model_path = self.config['model']['path']
|
||
download_url = self.config['task']['api'] + model_path
|
||
_model_path = r"models/" + model_path
|
||
|
||
if not Path(os.path.join('models', model_path)).exists():
|
||
logger.info(r"模型不存在,开始下载模型")
|
||
try:
|
||
model_dir = model_path.split("\\")[0]
|
||
if ".pt" in model_dir:
|
||
model_dir = ""
|
||
model_save_path = r"models/" + model_dir
|
||
|
||
model_dir = os.path.dirname(model_save_path)
|
||
if model_dir != '' and not os.path.exists(model_save_path):
|
||
os.makedirs(model_save_path)
|
||
|
||
response = requests.get(download_url, stream=True)
|
||
response.raise_for_status()
|
||
|
||
total_size = int(response.headers.get('content-length', 0))
|
||
block_size = 1024 * 1024
|
||
downloaded = 0
|
||
progress = 0
|
||
|
||
with open(os.path.join('models', model_path), 'wb') as f:
|
||
for data in response.iter_content(block_size):
|
||
downloaded += len(data)
|
||
f.write(data)
|
||
|
||
new_progress = int(downloaded * 100 / total_size) if total_size > 0 else 0
|
||
if new_progress != progress and (new_progress % 10 == 0 or downloaded == total_size):
|
||
logger.info(f"下载进度: {new_progress}% ({downloaded}/{total_size} 字节)")
|
||
progress = new_progress
|
||
|
||
logger.info(f"模型下载成功: {model_save_path} ({downloaded} 字节)")
|
||
except Exception as e:
|
||
logger.error(f"下载模型失败: {str(e)}")
|
||
if not Path(model_path).exists():
|
||
logger.error("没有可用的模型文件,使用默认模型")
|
||
_model_path = r"models/yolov8n.pt"
|
||
|
||
logger.info(f"加载模型...{_model_path}")
|
||
self.model = YOLO(_model_path).to(self.config['predict']['device'])
|
||
|
||
if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']:
|
||
self.model = self.model.half()
|
||
logger.info("启用半精度推理")
|
||
|
||
logger.info("模型预热...")
|
||
dummy_input = torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.config['predict']['device'])
|
||
if self.config['predict'].get('half', False) and 'cuda' in self.config['predict']['device']:
|
||
dummy_input = dummy_input.half()
|
||
|
||
for _ in range(5):
|
||
self.model.predict(dummy_input)
|
||
logger.info(f"模型加载成功 | 设备: {self.config['predict']['device'].upper()}")
|
||
|
||
# 初始化RTMP连接
|
||
self.original_width, self.original_height, fps = self.initialize_rtmp()
|
||
self.fps = fps # 保存fps用于后续创建推流线程
|
||
|
||
# 初始化推流线程
|
||
self.start_streamer()
|
||
|
||
def initialize_rtmp(self):
|
||
"""初始化RTMP连接 - 使用更可靠的方法"""
|
||
logger.info(f"连接RTMP: {self.rtmp_url}")
|
||
self.cap = cv2.VideoCapture()
|
||
|
||
# 设置优化参数
|
||
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, self.buffer_size)
|
||
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'))
|
||
self.cap.set(cv2.CAP_PROP_FPS, 30)
|
||
|
||
# 启用硬件加速解码
|
||
if self.config['rtmp'].get('gpu_decode', False):
|
||
try:
|
||
self.cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
|
||
logger.info("启用硬件加速解码")
|
||
except:
|
||
logger.warning("硬件解码不可用,使用软件解码")
|
||
|
||
# 尝试连接
|
||
if not self.cap.open(self.rtmp_url, cv2.CAP_FFMPEG):
|
||
logger.error(f"连接RTMP失败: {self.rtmp_url}")
|
||
raise IOError(f"无法连接RTMP流 ({self.rtmp_url})")
|
||
|
||
# 获取视频属性
|
||
frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
fps = self.cap.get(cv2.CAP_PROP_FPS) or 30
|
||
|
||
logger.info(f"视频属性: {frame_width}x{frame_height} @ {fps}fps")
|
||
return frame_width, frame_height, fps
|
||
|
||
def handle_reconnect(self):
|
||
"""优化重连逻辑 - 更高效的恢复策略"""
|
||
self.reconnect_attempts += 1
|
||
if self.reconnect_attempts >= self.max_reconnect_attempts:
|
||
logger.error("达到最大重连次数")
|
||
self.running = False
|
||
return
|
||
|
||
delay = min(10, self.reconnect_attempts * self.reconnect_delay)
|
||
logger.warning(f"流中断,{delay}秒后重连 ({self.reconnect_attempts}/{self.max_reconnect_attempts})")
|
||
|
||
# 清理推流数据并停止推流线程
|
||
self.stop_streamer()
|
||
|
||
# 释放视频资源
|
||
if self.cap:
|
||
self.cap.release()
|
||
self.cap = None
|
||
time.sleep(delay)
|
||
|
||
# 重新连接
|
||
try:
|
||
width, height, fps = self.initialize_rtmp()
|
||
self.original_width = width
|
||
self.original_height = height
|
||
self.prev_frame = None
|
||
self.prev_results = None
|
||
|
||
# 重新创建推流线程
|
||
self.start_streamer()
|
||
|
||
logger.info("重连成功")
|
||
except Exception as e:
|
||
logger.error(f"重连异常: {str(e)}")
|
||
|
||
def send_to_websocket(self, detection_data):
|
||
"""高效发送WebSocket数据"""
|
||
try:
|
||
now = datetime.datetime.now()
|
||
time_str = now.strftime("%H:%M:%S")
|
||
simplified_data = [{
|
||
'class_id': d['class_id'],
|
||
'class_name': d['class_name'],
|
||
'confidence': round(d['confidence'], 2),
|
||
'box': [round(float(c), 1) for c in d['box']],
|
||
} for d in detection_data]
|
||
|
||
self.config['socketIO'].emit('detection_results', {
|
||
'detections': simplified_data,
|
||
'timestamp': time.time_ns() // 1000000,
|
||
'fps': round(self.last_fps, 1),
|
||
'frame_count': self.frame_count,
|
||
'taskid': self.taskid,
|
||
'time_str': time_str
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"WebSocket发送错误: {str(e)}")
|
||
|
||
def start_streamer(self):
|
||
"""启动新的推流线程"""
|
||
if self.config['push']['enable_push']:
|
||
try:
|
||
self.streamer = FFmpegStreamer(self.config, self.fps, self.original_width, self.original_height)
|
||
self.streamer.start()
|
||
logger.info("推流线程已启动")
|
||
except Exception as e:
|
||
logger.error(f"启动推流线程失败: {str(e)}")
|
||
|
||
def stop_streamer(self):
|
||
"""停止推流线程并清理资源"""
|
||
if self.streamer:
|
||
try:
|
||
self.streamer.stop()
|
||
self.streamer = None
|
||
logger.info("推流线程已停止")
|
||
except Exception as e:
|
||
logger.error(f"停止推流线程失败: {str(e)}")
|
||
|
||
def stop(self):
|
||
"""停止检测线程"""
|
||
self.stop_event.set()
|
||
self.running = False
|
||
|
||
def cleanup(self):
|
||
"""资源清理 - 保证资源释放"""
|
||
logger.info("清理资源...")
|
||
# 停止MQTT客户端
|
||
if self.mqtt_enabled:
|
||
self.stop_mqtt_client()
|
||
|
||
# 停止上传线程
|
||
self.upload_active = False
|
||
if self.upload_thread and self.upload_thread.is_alive():
|
||
try:
|
||
self.upload_queue.put(None, timeout=0.5)
|
||
except:
|
||
pass
|
||
self.upload_thread.join(2.0)
|
||
if self.upload_thread.is_alive():
|
||
logger.warning("上传线程未能正常停止")
|
||
|
||
logger.info("上传资源清理完成")
|
||
|
||
# 释放流
|
||
if self.cap:
|
||
try:
|
||
self.cap.release()
|
||
except:
|
||
pass
|
||
self.cap = None
|
||
logger.info("视频流已释放")
|
||
|
||
# 停止推流
|
||
if self.streamer:
|
||
self.streamer.stop()
|
||
self.streamer = None
|
||
logger.info("推流已停止")
|
||
|
||
# 释放模型
|
||
if self.model:
|
||
try:
|
||
if hasattr(self.model, 'predictor'):
|
||
del self.model.predictor
|
||
if hasattr(self.model, 'model'):
|
||
del self.model.model
|
||
del self.model
|
||
except:
|
||
pass
|
||
self.model = None
|
||
logger.info("模型已释放")
|
||
|
||
# 清空缓存
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
logger.info("资源清理完成")
|