203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
# model_crypto.py
|
||
import os
|
||
import tempfile
|
||
|
||
import requests
|
||
from cryptography.fernet import Fernet
|
||
from cryptography.hazmat.primitives import hashes
|
||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||
import base64
|
||
import pickle
|
||
from pathlib import Path
|
||
from log import logger
|
||
|
||
|
||
class ModelEncryptor:
|
||
"""模型加密/解密器"""
|
||
|
||
@staticmethod
|
||
def generate_key(password: str, salt: bytes = None):
|
||
"""生成加密密钥"""
|
||
if salt is None:
|
||
salt = os.urandom(16)
|
||
|
||
kdf = PBKDF2HMAC(
|
||
algorithm=hashes.SHA256(),
|
||
length=32,
|
||
salt=salt,
|
||
iterations=100000,
|
||
)
|
||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||
return key, salt
|
||
|
||
@staticmethod
|
||
def encrypt_model(model_path: str, output_path: str, password: str):
|
||
"""加密模型文件"""
|
||
try:
|
||
# 读取模型文件
|
||
with open(model_path, 'rb') as f:
|
||
model_data = f.read()
|
||
|
||
# 生成密钥
|
||
key, salt = ModelEncryptor.generate_key(password)
|
||
fernet = Fernet(key)
|
||
|
||
# 加密数据
|
||
encrypted_data = fernet.encrypt(model_data)
|
||
|
||
# 保存加密数据(包含salt)
|
||
encrypted_payload = {
|
||
'salt': salt,
|
||
'data': encrypted_data,
|
||
'original_size': len(model_data)
|
||
}
|
||
|
||
with open(output_path, 'wb') as f:
|
||
pickle.dump(encrypted_payload, f)
|
||
|
||
logger.info(f"模型加密成功: {model_path} -> {output_path}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型加密失败: {str(e)}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def decrypt_model(encrypted_path: str, password: str):
|
||
"""解密模型到内存"""
|
||
try:
|
||
# 读取加密文件
|
||
with open(encrypted_path, 'rb') as f:
|
||
encrypted_payload = pickle.load(f)
|
||
|
||
salt = encrypted_payload['salt']
|
||
encrypted_data = encrypted_payload['data']
|
||
|
||
# 生成密钥
|
||
key, _ = ModelEncryptor.generate_key(password, salt)
|
||
fernet = Fernet(key)
|
||
|
||
# 解密数据
|
||
decrypted_data = fernet.decrypt(encrypted_data)
|
||
|
||
# 保存到临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
||
tmp.write(decrypted_data)
|
||
temp_path = tmp.name
|
||
|
||
logger.info(f"模型解密成功: {encrypted_path}")
|
||
return temp_path
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型解密失败: {str(e)}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def is_encrypted(model_path: str):
|
||
"""检查模型是否加密"""
|
||
try:
|
||
with open(model_path, 'rb') as f:
|
||
# 尝试读取加密格式
|
||
data = pickle.load(f)
|
||
return isinstance(data, dict) and 'salt' in data and 'data' in data
|
||
except:
|
||
return False
|
||
|
||
|
||
class ModelManager:
|
||
"""模型管理器,支持加密模型加载"""
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
self.models_dir = "models"
|
||
os.makedirs(self.models_dir, exist_ok=True)
|
||
|
||
def load_model(self, model_config):
|
||
"""加载模型(支持加密)"""
|
||
model_path = model_config['path']
|
||
encrypted = model_config.get('encrypted', False)
|
||
encryption_key = model_config.get('encryption_key')
|
||
|
||
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
||
|
||
# 下载模型(如果不存在)
|
||
if not os.path.exists(local_path):
|
||
if not self.download_model(model_config):
|
||
return None
|
||
|
||
# 如果是加密模型,需要解密
|
||
if encrypted and encryption_key:
|
||
if ModelEncryptor.is_encrypted(local_path):
|
||
decrypted_path = ModelEncryptor.decrypt_model(local_path, encryption_key)
|
||
if decrypted_path:
|
||
try:
|
||
from ultralytics import YOLO
|
||
model = YOLO(decrypted_path).to(model_config['device'])
|
||
|
||
# 清理临时文件
|
||
try:
|
||
os.unlink(decrypted_path)
|
||
except:
|
||
pass
|
||
|
||
return model
|
||
except Exception as e:
|
||
logger.error(f"加载解密模型失败: {str(e)}")
|
||
return None
|
||
else:
|
||
return None
|
||
else:
|
||
logger.warning(f"模型未加密或密钥错误: {local_path}")
|
||
return None
|
||
else:
|
||
# 普通模型加载
|
||
try:
|
||
from ultralytics import YOLO
|
||
model = YOLO(local_path).to(model_config['device'])
|
||
|
||
# 应用配置
|
||
if model_config.get('half', False) and 'cuda' in model_config['device']:
|
||
model = model.half()
|
||
|
||
return model
|
||
except Exception as e:
|
||
logger.error(f"加载模型失败: {str(e)}")
|
||
return None
|
||
|
||
def download_model(self, model_config):
|
||
"""下载模型"""
|
||
try:
|
||
model_path = model_config['path']
|
||
download_url = model_config.get('download_url')
|
||
|
||
if not download_url:
|
||
logger.error(f"模型无下载地址: {model_path}")
|
||
return False
|
||
|
||
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
||
|
||
logger.info(f"下载模型: {download_url} -> {local_path}")
|
||
|
||
response = requests.get(download_url, stream=True, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
total_size = int(response.headers.get('content-length', 0))
|
||
downloaded = 0
|
||
|
||
with open(local_path, 'wb') as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
downloaded += len(chunk)
|
||
f.write(chunk)
|
||
|
||
if total_size > 0:
|
||
progress = downloaded * 100 // total_size
|
||
if progress % 10 == 0:
|
||
logger.info(f"下载进度: {progress}%")
|
||
|
||
logger.info(f"模型下载完成: {local_path} ({downloaded} 字节)")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"下载模型失败: {str(e)}")
|
||
return False
|