Yolov/mandatory_model_crypto.py

324 lines
11 KiB
Python

# mandatory_model_crypto.py
import os
import tempfile
import hashlib
import pickle
import traceback
from pathlib import Path
import requests
import torch
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
from log import logger
class MandatoryModelEncryptor:
"""强制模型加密器 - 所有模型必须加密"""
@staticmethod
def encrypt_model(model_path, output_path, password, require_encryption=True):
"""加密模型文件 - 强制模式"""
try:
# 读取模型文件
with open(model_path, 'rb') as f:
model_data = f.read()
# 计算模型哈希
model_hash = hashlib.sha256(model_data).hexdigest()
# 生成加密密钥
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
fernet = Fernet(key)
# 加密数据
encrypted_data = fernet.encrypt(model_data)
# 保存加密数据
encrypted_payload = {
'salt': salt,
'data': encrypted_data,
'model_hash': model_hash,
'original_size': len(model_data),
'encrypted': True,
'version': '1.0'
}
with open(output_path, 'wb') as f:
pickle.dump(encrypted_payload, f)
logger.info(f"模型强制加密成功: {model_path} -> {output_path}")
logger.info(f"模型哈希: {model_hash[:16]}...")
# 返回密钥信息(用于验证)
return {
'success': True,
'model_hash': model_hash,
'key_hash': hashlib.sha256(key).hexdigest()[:16],
'output_path': output_path
}
except Exception as e:
logger.error(f"模型强制加密失败: {str(e)}")
return {'success': False, 'error': str(e)}
@staticmethod
def decrypt_model(encrypted_path, password, verify_key=True):
"""解密模型文件 - 带密钥验证"""
try:
if not os.path.exists(encrypted_path):
return {'success': False, 'error': '加密模型文件不存在'}
# 读取加密文件
with open(encrypted_path, 'rb') as f:
encrypted_payload = pickle.load(f)
# 验证加密格式
if not encrypted_payload.get('encrypted', False):
return {'success': False, 'error': '模型未加密'}
salt = encrypted_payload['salt']
encrypted_data = encrypted_payload['data']
expected_hash = encrypted_payload.get('model_hash', '')
# 生成解密密钥
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
# 验证密钥(可选)
if verify_key:
key_hash = hashlib.sha256(key).hexdigest()[:16]
logger.debug(f"解密密钥哈希: {key_hash}")
fernet = Fernet(key)
# 解密数据
decrypted_data = fernet.decrypt(encrypted_data)
# 验证模型哈希
actual_hash = hashlib.sha256(decrypted_data).hexdigest()
if expected_hash and actual_hash != expected_hash:
return {
'success': False,
'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...'
}
# 保存到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
tmp.write(decrypted_data)
temp_path = tmp.name
logger.info(f"模型解密验证成功: {encrypted_path}")
return {
'success': True,
'temp_path': temp_path,
'model_hash': actual_hash,
'original_size': len(decrypted_data)
}
except Exception as e:
error_msg = str(e)
if "InvalidToken" in error_msg or "Invalid signature" in error_msg:
return {'success': False, 'error': '解密密钥错误'}
return {'success': False, 'error': f'解密失败: {error_msg}'}
@staticmethod
def is_properly_encrypted(model_path):
"""检查模型是否被正确加密"""
try:
with open(model_path, 'rb') as f:
data = pickle.load(f)
# 检查必要的加密字段
required_fields = ['salt', 'data', 'encrypted', 'model_hash']
for field in required_fields:
if field not in data:
return False
return data.get('encrypted', False) is True
except:
return False
@staticmethod
def generate_secure_key():
"""生成安全的加密密钥"""
# 生成随机密钥
key = Fernet.generate_key()
# 生成密钥指纹
key_hash = hashlib.sha256(key).hexdigest()
return {
'key': key.decode('utf-8'),
'key_hash': key_hash,
'short_hash': key_hash[:16]
}
class MandatoryModelManager:
"""强制加密模型管理器"""
def __init__(self, config):
self.config = config
self.models_dir = "encrypted_models"
os.makedirs(self.models_dir, exist_ok=True)
# 加载加密器
self.encryptor = MandatoryModelEncryptor()
# 模型缓存
self.model_cache = {}
def load_encrypted_model(self, model_config):
"""加载加密模型 - 必须提供密钥"""
try:
model_path = model_config['path']
encryption_key = model_config.get('encryption_key')
# 必须提供密钥
if not encryption_key:
raise ValueError(f"模型 {model_path} 必须提供加密密钥")
# 构建本地路径
local_path = os.path.join(self.models_dir, os.path.basename(model_path))
# 检查本地文件是否存在
if not os.path.exists(local_path):
# 尝试下载(如果提供下载地址)
if not self.download_encrypted_model(model_config, local_path):
raise FileNotFoundError(f"加密模型文件不存在且无法下载: {local_path}")
# 验证是否为正确加密的模型
if not self.encryptor.is_properly_encrypted(local_path):
raise ValueError(f"模型文件未正确加密: {local_path}")
# 解密模型
decrypt_result = self.encryptor.decrypt_model(local_path, encryption_key)
if not decrypt_result['success']:
raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}")
temp_path = decrypt_result['temp_path']
# 加载YOLO模型
from ultralytics import YOLO
model = YOLO(temp_path)
# 应用设备配置
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 应用半精度配置
if model_config.get('half', False) and 'cuda' in device:
model = model.half()
logger.info(f"启用半精度推理: {model_path}")
# 清理临时文件
try:
os.unlink(temp_path)
except:
pass
# 记录模型信息
model_hash = decrypt_result.get('model_hash', 'unknown')[:16]
logger.info(f"加密模型加载成功: {model_path} -> {device} [哈希: {model_hash}...]")
return model
except Exception as e:
logger.error(f"加载加密模型失败 {model_config.get('path')}: {str(e)}")
logger.error(traceback.format_exc())
return None
def download_encrypted_model(self, model_config, save_path):
"""下载加密模型文件"""
try:
download_url = model_config.get('download_url')
if not download_url:
logger.error(f"加密模型无下载地址: {model_config['path']}")
return False
logger.info(f"下载加密模型: {download_url} -> {save_path}")
response = requests.get(download_url, stream=True, timeout=30)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
downloaded += len(chunk)
f.write(chunk)
if total_size > 0:
progress = (downloaded * 100) // total_size
if progress % 25 == 0:
logger.info(f"下载进度: {progress}%")
logger.info(f"加密模型下载完成: {save_path} ({downloaded} 字节)")
# 验证下载的文件是否正确加密
if not self.encryptor.is_properly_encrypted(save_path):
logger.error(f"下载的文件不是正确加密的模型: {save_path}")
os.remove(save_path)
return False
return True
except Exception as e:
logger.error(f"下载加密模型失败: {str(e)}")
return False
def encrypt_existing_model(self, model_path, output_path, password):
"""加密现有模型文件"""
return self.encryptor.encrypt_model(model_path, output_path, password)
def verify_model_key(self, model_path, encryption_key):
"""验证模型密钥是否正确"""
try:
if not os.path.exists(model_path):
return {'valid': False, 'error': '模型文件不存在'}
if not self.encryptor.is_properly_encrypted(model_path):
return {'valid': False, 'error': '模型文件未正确加密'}
# 尝试解密(不保存文件)
result = self.encryptor.decrypt_model(model_path, encryption_key)
if result['success']:
# 清理临时文件
if 'temp_path' in result and os.path.exists(result['temp_path']):
try:
os.unlink(result['temp_path'])
except:
pass
return {
'valid': True,
'model_hash': result.get('model_hash', '')[:16],
'original_size': result.get('original_size', 0)
}
else:
return {'valid': False, 'error': result.get('error', '解密失败')}
except Exception as e:
return {'valid': False, 'error': str(e)}