Yolov/model_crypto.py

203 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# model_crypto.py
import os
import tempfile
import requests
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
import pickle
from pathlib import Path
from log import logger
class ModelEncryptor:
"""模型加密/解密器"""
@staticmethod
def generate_key(password: str, salt: bytes = None):
"""生成加密密钥"""
if salt is None:
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
return key, salt
@staticmethod
def encrypt_model(model_path: str, output_path: str, password: str):
"""加密模型文件"""
try:
# 读取模型文件
with open(model_path, 'rb') as f:
model_data = f.read()
# 生成密钥
key, salt = ModelEncryptor.generate_key(password)
fernet = Fernet(key)
# 加密数据
encrypted_data = fernet.encrypt(model_data)
# 保存加密数据包含salt
encrypted_payload = {
'salt': salt,
'data': encrypted_data,
'original_size': len(model_data)
}
with open(output_path, 'wb') as f:
pickle.dump(encrypted_payload, f)
logger.info(f"模型加密成功: {model_path} -> {output_path}")
return True
except Exception as e:
logger.error(f"模型加密失败: {str(e)}")
return False
@staticmethod
def decrypt_model(encrypted_path: str, password: str):
"""解密模型到内存"""
try:
# 读取加密文件
with open(encrypted_path, 'rb') as f:
encrypted_payload = pickle.load(f)
salt = encrypted_payload['salt']
encrypted_data = encrypted_payload['data']
# 生成密钥
key, _ = ModelEncryptor.generate_key(password, salt)
fernet = Fernet(key)
# 解密数据
decrypted_data = fernet.decrypt(encrypted_data)
# 保存到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
tmp.write(decrypted_data)
temp_path = tmp.name
logger.info(f"模型解密成功: {encrypted_path}")
return temp_path
except Exception as e:
logger.error(f"模型解密失败: {str(e)}")
return None
@staticmethod
def is_encrypted(model_path: str):
"""检查模型是否加密"""
try:
with open(model_path, 'rb') as f:
# 尝试读取加密格式
data = pickle.load(f)
return isinstance(data, dict) and 'salt' in data and 'data' in data
except:
return False
class ModelManager:
"""模型管理器,支持加密模型加载"""
def __init__(self, config):
self.config = config
self.models_dir = "models"
os.makedirs(self.models_dir, exist_ok=True)
def load_model(self, model_config):
"""加载模型(支持加密)"""
model_path = model_config['path']
encrypted = model_config.get('encrypted', False)
encryption_key = model_config.get('encryption_key')
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
# 下载模型(如果不存在)
if not os.path.exists(local_path):
if not self.download_model(model_config):
return None
# 如果是加密模型,需要解密
if encrypted and encryption_key:
if ModelEncryptor.is_encrypted(local_path):
decrypted_path = ModelEncryptor.decrypt_model(local_path, encryption_key)
if decrypted_path:
try:
from ultralytics import YOLO
model = YOLO(decrypted_path).to(model_config['device'])
# 清理临时文件
try:
os.unlink(decrypted_path)
except:
pass
return model
except Exception as e:
logger.error(f"加载解密模型失败: {str(e)}")
return None
else:
return None
else:
logger.warning(f"模型未加密或密钥错误: {local_path}")
return None
else:
# 普通模型加载
try:
from ultralytics import YOLO
model = YOLO(local_path).to(model_config['device'])
# 应用配置
if model_config.get('half', False) and 'cuda' in model_config['device']:
model = model.half()
return model
except Exception as e:
logger.error(f"加载模型失败: {str(e)}")
return None
def download_model(self, model_config):
"""下载模型"""
try:
model_path = model_config['path']
download_url = model_config.get('download_url')
if not download_url:
logger.error(f"模型无下载地址: {model_path}")
return False
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
logger.info(f"下载模型: {download_url} -> {local_path}")
response = requests.get(download_url, stream=True, timeout=30)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
downloaded += len(chunk)
f.write(chunk)
if total_size > 0:
progress = downloaded * 100 // total_size
if progress % 10 == 0:
logger.info(f"下载进度: {progress}%")
logger.info(f"模型下载完成: {local_path} ({downloaded} 字节)")
return True
except Exception as e:
logger.error(f"下载模型失败: {str(e)}")
return False