324 lines
11 KiB
Python
324 lines
11 KiB
Python
# mandatory_model_crypto.py
|
|
import os
|
|
import tempfile
|
|
import hashlib
|
|
import pickle
|
|
import traceback
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
import torch
|
|
from cryptography.fernet import Fernet
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
import base64
|
|
from log import logger
|
|
|
|
|
|
class MandatoryModelEncryptor:
|
|
"""强制模型加密器 - 所有模型必须加密"""
|
|
|
|
@staticmethod
|
|
def encrypt_model(model_path, output_path, password, require_encryption=True):
|
|
"""加密模型文件 - 强制模式"""
|
|
try:
|
|
# 读取模型文件
|
|
with open(model_path, 'rb') as f:
|
|
model_data = f.read()
|
|
|
|
# 计算模型哈希
|
|
model_hash = hashlib.sha256(model_data).hexdigest()
|
|
|
|
# 生成加密密钥
|
|
salt = os.urandom(16)
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=salt,
|
|
iterations=100000,
|
|
)
|
|
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
|
fernet = Fernet(key)
|
|
|
|
# 加密数据
|
|
encrypted_data = fernet.encrypt(model_data)
|
|
|
|
# 保存加密数据
|
|
encrypted_payload = {
|
|
'salt': salt,
|
|
'data': encrypted_data,
|
|
'model_hash': model_hash,
|
|
'original_size': len(model_data),
|
|
'encrypted': True,
|
|
'version': '1.0'
|
|
}
|
|
|
|
with open(output_path, 'wb') as f:
|
|
pickle.dump(encrypted_payload, f)
|
|
|
|
logger.info(f"模型强制加密成功: {model_path} -> {output_path}")
|
|
logger.info(f"模型哈希: {model_hash[:16]}...")
|
|
|
|
# 返回密钥信息(用于验证)
|
|
return {
|
|
'success': True,
|
|
'model_hash': model_hash,
|
|
'key_hash': hashlib.sha256(key).hexdigest()[:16],
|
|
'output_path': output_path
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"模型强制加密失败: {str(e)}")
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
@staticmethod
|
|
def decrypt_model(encrypted_path, password, verify_key=True):
|
|
"""解密模型文件 - 带密钥验证"""
|
|
try:
|
|
if not os.path.exists(encrypted_path):
|
|
return {'success': False, 'error': '加密模型文件不存在'}
|
|
|
|
# 读取加密文件
|
|
with open(encrypted_path, 'rb') as f:
|
|
encrypted_payload = pickle.load(f)
|
|
|
|
# 验证加密格式
|
|
if not encrypted_payload.get('encrypted', False):
|
|
return {'success': False, 'error': '模型未加密'}
|
|
|
|
salt = encrypted_payload['salt']
|
|
encrypted_data = encrypted_payload['data']
|
|
expected_hash = encrypted_payload.get('model_hash', '')
|
|
|
|
# 生成解密密钥
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=salt,
|
|
iterations=100000,
|
|
)
|
|
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
|
|
|
# 验证密钥(可选)
|
|
if verify_key:
|
|
key_hash = hashlib.sha256(key).hexdigest()[:16]
|
|
logger.debug(f"解密密钥哈希: {key_hash}")
|
|
|
|
fernet = Fernet(key)
|
|
|
|
# 解密数据
|
|
decrypted_data = fernet.decrypt(encrypted_data)
|
|
|
|
# 验证模型哈希
|
|
actual_hash = hashlib.sha256(decrypted_data).hexdigest()
|
|
if expected_hash and actual_hash != expected_hash:
|
|
return {
|
|
'success': False,
|
|
'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...'
|
|
}
|
|
|
|
# 保存到临时文件
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
|
tmp.write(decrypted_data)
|
|
temp_path = tmp.name
|
|
|
|
logger.info(f"模型解密验证成功: {encrypted_path}")
|
|
|
|
return {
|
|
'success': True,
|
|
'temp_path': temp_path,
|
|
'model_hash': actual_hash,
|
|
'original_size': len(decrypted_data)
|
|
}
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
if "InvalidToken" in error_msg or "Invalid signature" in error_msg:
|
|
return {'success': False, 'error': '解密密钥错误'}
|
|
return {'success': False, 'error': f'解密失败: {error_msg}'}
|
|
|
|
@staticmethod
|
|
def is_properly_encrypted(model_path):
|
|
"""检查模型是否被正确加密"""
|
|
try:
|
|
with open(model_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
# 检查必要的加密字段
|
|
required_fields = ['salt', 'data', 'encrypted', 'model_hash']
|
|
for field in required_fields:
|
|
if field not in data:
|
|
return False
|
|
|
|
return data.get('encrypted', False) is True
|
|
|
|
except:
|
|
return False
|
|
|
|
@staticmethod
|
|
def generate_secure_key():
|
|
"""生成安全的加密密钥"""
|
|
# 生成随机密钥
|
|
key = Fernet.generate_key()
|
|
|
|
# 生成密钥指纹
|
|
key_hash = hashlib.sha256(key).hexdigest()
|
|
|
|
return {
|
|
'key': key.decode('utf-8'),
|
|
'key_hash': key_hash,
|
|
'short_hash': key_hash[:16]
|
|
}
|
|
|
|
|
|
class MandatoryModelManager:
|
|
"""强制加密模型管理器"""
|
|
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.models_dir = "encrypted_models"
|
|
os.makedirs(self.models_dir, exist_ok=True)
|
|
|
|
# 加载加密器
|
|
self.encryptor = MandatoryModelEncryptor()
|
|
|
|
# 模型缓存
|
|
self.model_cache = {}
|
|
|
|
def load_encrypted_model(self, model_config):
|
|
"""加载加密模型 - 必须提供密钥"""
|
|
try:
|
|
model_path = model_config['path']
|
|
encryption_key = model_config.get('encryption_key')
|
|
|
|
# 必须提供密钥
|
|
if not encryption_key:
|
|
raise ValueError(f"模型 {model_path} 必须提供加密密钥")
|
|
|
|
# 构建本地路径
|
|
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
|
|
|
|
# 检查本地文件是否存在
|
|
if not os.path.exists(local_path):
|
|
# 尝试下载(如果提供下载地址)
|
|
if not self.download_encrypted_model(model_config, local_path):
|
|
raise FileNotFoundError(f"加密模型文件不存在且无法下载: {local_path}")
|
|
|
|
# 验证是否为正确加密的模型
|
|
if not self.encryptor.is_properly_encrypted(local_path):
|
|
raise ValueError(f"模型文件未正确加密: {local_path}")
|
|
|
|
# 解密模型
|
|
decrypt_result = self.encryptor.decrypt_model(local_path, encryption_key)
|
|
|
|
if not decrypt_result['success']:
|
|
raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}")
|
|
|
|
temp_path = decrypt_result['temp_path']
|
|
|
|
# 加载YOLO模型
|
|
from ultralytics import YOLO
|
|
model = YOLO(temp_path)
|
|
|
|
# 应用设备配置
|
|
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
model = model.to(device)
|
|
|
|
# 应用半精度配置
|
|
if model_config.get('half', False) and 'cuda' in device:
|
|
model = model.half()
|
|
logger.info(f"启用半精度推理: {model_path}")
|
|
|
|
# 清理临时文件
|
|
try:
|
|
os.unlink(temp_path)
|
|
except:
|
|
pass
|
|
|
|
# 记录模型信息
|
|
model_hash = decrypt_result.get('model_hash', 'unknown')[:16]
|
|
logger.info(f"加密模型加载成功: {model_path} -> {device} [哈希: {model_hash}...]")
|
|
|
|
return model
|
|
|
|
except Exception as e:
|
|
logger.error(f"加载加密模型失败 {model_config.get('path')}: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
return None
|
|
|
|
def download_encrypted_model(self, model_config, save_path):
|
|
"""下载加密模型文件"""
|
|
try:
|
|
download_url = model_config.get('download_url')
|
|
|
|
if not download_url:
|
|
logger.error(f"加密模型无下载地址: {model_config['path']}")
|
|
return False
|
|
|
|
logger.info(f"下载加密模型: {download_url} -> {save_path}")
|
|
|
|
response = requests.get(download_url, stream=True, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
downloaded = 0
|
|
|
|
with open(save_path, 'wb') as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
downloaded += len(chunk)
|
|
f.write(chunk)
|
|
|
|
if total_size > 0:
|
|
progress = (downloaded * 100) // total_size
|
|
if progress % 25 == 0:
|
|
logger.info(f"下载进度: {progress}%")
|
|
|
|
logger.info(f"加密模型下载完成: {save_path} ({downloaded} 字节)")
|
|
|
|
# 验证下载的文件是否正确加密
|
|
if not self.encryptor.is_properly_encrypted(save_path):
|
|
logger.error(f"下载的文件不是正确加密的模型: {save_path}")
|
|
os.remove(save_path)
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"下载加密模型失败: {str(e)}")
|
|
return False
|
|
|
|
def encrypt_existing_model(self, model_path, output_path, password):
|
|
"""加密现有模型文件"""
|
|
return self.encryptor.encrypt_model(model_path, output_path, password)
|
|
|
|
def verify_model_key(self, model_path, encryption_key):
|
|
"""验证模型密钥是否正确"""
|
|
try:
|
|
if not os.path.exists(model_path):
|
|
return {'valid': False, 'error': '模型文件不存在'}
|
|
|
|
if not self.encryptor.is_properly_encrypted(model_path):
|
|
return {'valid': False, 'error': '模型文件未正确加密'}
|
|
|
|
# 尝试解密(不保存文件)
|
|
result = self.encryptor.decrypt_model(model_path, encryption_key)
|
|
|
|
if result['success']:
|
|
# 清理临时文件
|
|
if 'temp_path' in result and os.path.exists(result['temp_path']):
|
|
try:
|
|
os.unlink(result['temp_path'])
|
|
except:
|
|
pass
|
|
|
|
return {
|
|
'valid': True,
|
|
'model_hash': result.get('model_hash', '')[:16],
|
|
'original_size': result.get('original_size', 0)
|
|
}
|
|
else:
|
|
return {'valid': False, 'error': result.get('error', '解密失败')}
|
|
|
|
except Exception as e:
|
|
return {'valid': False, 'error': str(e)} |