203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
|
|
# model_crypto.py
|
|||
|
|
import os
|
|||
|
|
import tempfile
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
from cryptography.fernet import Fernet
|
|||
|
|
from cryptography.hazmat.primitives import hashes
|
|||
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|||
|
|
import base64
|
|||
|
|
import pickle
|
|||
|
|
from pathlib import Path
|
|||
|
|
from log import logger
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelEncryptor:
|
|||
|
|
"""模型加密/解密器"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def generate_key(password: str, salt: bytes = None):
|
|||
|
|
"""生成加密密钥"""
|
|||
|
|
if salt is None:
|
|||
|
|
salt = os.urandom(16)
|
|||
|
|
|
|||
|
|
kdf = PBKDF2HMAC(
|
|||
|
|
algorithm=hashes.SHA256(),
|
|||
|
|
length=32,
|
|||
|
|
salt=salt,
|
|||
|
|
iterations=100000,
|
|||
|
|
)
|
|||
|
|
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
|||
|
|
return key, salt
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def encrypt_model(model_path: str, output_path: str, password: str):
|
|||
|
|
"""加密模型文件"""
|
|||
|
|
try:
|
|||
|
|
# 读取模型文件
|
|||
|
|
with open(model_path, 'rb') as f:
|
|||
|
|
model_data = f.read()
|
|||
|
|
|
|||
|
|
# 生成密钥
|
|||
|
|
key, salt = ModelEncryptor.generate_key(password)
|
|||
|
|
fernet = Fernet(key)
|
|||
|
|
|
|||
|
|
# 加密数据
|
|||
|
|
encrypted_data = fernet.encrypt(model_data)
|
|||
|
|
|
|||
|
|
# 保存加密数据(包含salt)
|
|||
|
|
encrypted_payload = {
|
|||
|
|
'salt': salt,
|
|||
|
|
'data': encrypted_data,
|
|||
|
|
'original_size': len(model_data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(output_path, 'wb') as f:
|
|||
|
|
pickle.dump(encrypted_payload, f)
|
|||
|
|
|
|||
|
|
logger.info(f"模型加密成功: {model_path} -> {output_path}")
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"模型加密失败: {str(e)}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def decrypt_model(encrypted_path: str, password: str):
|
|||
|
|
"""解密模型到内存"""
|
|||
|
|
try:
|
|||
|
|
# 读取加密文件
|
|||
|
|
with open(encrypted_path, 'rb') as f:
|
|||
|
|
encrypted_payload = pickle.load(f)
|
|||
|
|
|
|||
|
|
salt = encrypted_payload['salt']
|
|||
|
|
encrypted_data = encrypted_payload['data']
|
|||
|
|
|
|||
|
|
# 生成密钥
|
|||
|
|
key, _ = ModelEncryptor.generate_key(password, salt)
|
|||
|
|
fernet = Fernet(key)
|
|||
|
|
|
|||
|
|
# 解密数据
|
|||
|
|
decrypted_data = fernet.decrypt(encrypted_data)
|
|||
|
|
|
|||
|
|
# 保存到临时文件
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
|||
|
|
tmp.write(decrypted_data)
|
|||
|
|
temp_path = tmp.name
|
|||
|
|
|
|||
|
|
logger.info(f"模型解密成功: {encrypted_path}")
|
|||
|
|
return temp_path
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"模型解密失败: {str(e)}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def is_encrypted(model_path: str):
|
|||
|
|
"""检查模型是否加密"""
|
|||
|
|
try:
|
|||
|
|
with open(model_path, 'rb') as f:
|
|||
|
|
# 尝试读取加密格式
|
|||
|
|
data = pickle.load(f)
|
|||
|
|
return isinstance(data, dict) and 'salt' in data and 'data' in data
|
|||
|
|
except:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelManager:
|
|||
|
|
"""模型管理器,支持加密模型加载"""
|
|||
|
|
|
|||
|
|
def __init__(self, config):
|
|||
|
|
self.config = config
|
|||
|
|
self.models_dir = "models"
|
|||
|
|
os.makedirs(self.models_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
def load_model(self, model_config):
|
|||
|
|
"""加载模型(支持加密)"""
|
|||
|
|
model_path = model_config['path']
|
|||
|
|
encrypted = model_config.get('encrypted', False)
|
|||
|
|
encryption_key = model_config.get('encryption_key')
|
|||
|
|
|
|||
|
|
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
|||
|
|
|
|||
|
|
# 下载模型(如果不存在)
|
|||
|
|
if not os.path.exists(local_path):
|
|||
|
|
if not self.download_model(model_config):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 如果是加密模型,需要解密
|
|||
|
|
if encrypted and encryption_key:
|
|||
|
|
if ModelEncryptor.is_encrypted(local_path):
|
|||
|
|
decrypted_path = ModelEncryptor.decrypt_model(local_path, encryption_key)
|
|||
|
|
if decrypted_path:
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
model = YOLO(decrypted_path).to(model_config['device'])
|
|||
|
|
|
|||
|
|
# 清理临时文件
|
|||
|
|
try:
|
|||
|
|
os.unlink(decrypted_path)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
return model
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"加载解密模型失败: {str(e)}")
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
logger.warning(f"模型未加密或密钥错误: {local_path}")
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
# 普通模型加载
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
model = YOLO(local_path).to(model_config['device'])
|
|||
|
|
|
|||
|
|
# 应用配置
|
|||
|
|
if model_config.get('half', False) and 'cuda' in model_config['device']:
|
|||
|
|
model = model.half()
|
|||
|
|
|
|||
|
|
return model
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"加载模型失败: {str(e)}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def download_model(self, model_config):
|
|||
|
|
"""下载模型"""
|
|||
|
|
try:
|
|||
|
|
model_path = model_config['path']
|
|||
|
|
download_url = model_config.get('download_url')
|
|||
|
|
|
|||
|
|
if not download_url:
|
|||
|
|
logger.error(f"模型无下载地址: {model_path}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
|||
|
|
|
|||
|
|
logger.info(f"下载模型: {download_url} -> {local_path}")
|
|||
|
|
|
|||
|
|
response = requests.get(download_url, stream=True, timeout=30)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
|
|||
|
|
total_size = int(response.headers.get('content-length', 0))
|
|||
|
|
downloaded = 0
|
|||
|
|
|
|||
|
|
with open(local_path, 'wb') as f:
|
|||
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|||
|
|
downloaded += len(chunk)
|
|||
|
|
f.write(chunk)
|
|||
|
|
|
|||
|
|
if total_size > 0:
|
|||
|
|
progress = downloaded * 100 // total_size
|
|||
|
|
if progress % 10 == 0:
|
|||
|
|
logger.info(f"下载进度: {progress}%")
|
|||
|
|
|
|||
|
|
logger.info(f"模型下载完成: {local_path} ({downloaded} 字节)")
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"下载模型失败: {str(e)}")
|
|||
|
|
return False
|