870 lines
31 KiB
Python
870 lines
31 KiB
Python
# mandatory_model_crypto.py - 强制加密模型验证和解密模块
|
||
import os
|
||
import tempfile
|
||
import hashlib
|
||
import pickle
|
||
import traceback
|
||
import time
|
||
import secrets
|
||
import string
|
||
from pathlib import Path
|
||
import base64
|
||
|
||
import requests
|
||
import torch
|
||
from cryptography.fernet import Fernet
|
||
from cryptography.hazmat.primitives import hashes
|
||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||
|
||
from log import logger
|
||
|
||
|
||
class MandatoryModelValidator:
|
||
"""强制加密模型验证器 - 只负责验证和解密"""
|
||
|
||
@staticmethod
|
||
def verify_model_encryption(encrypted_path):
|
||
"""验证模型是否被正确加密"""
|
||
try:
|
||
if not os.path.exists(encrypted_path):
|
||
return {'valid': False, 'error': '加密模型文件不存在'}
|
||
|
||
with open(encrypted_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
|
||
# 检查必要的加密字段
|
||
required_fields = ['salt', 'data', 'encrypted', 'model_hash', 'version']
|
||
for field in required_fields:
|
||
if field not in data:
|
||
return {'valid': False, 'error': f'缺少加密字段: {field}'}
|
||
|
||
if data.get('encrypted', False) is not True:
|
||
return {'valid': False, 'error': '模型未加密'}
|
||
|
||
# 验证版本兼容性
|
||
if data.get('version') not in ['1.0', '2.0']:
|
||
return {'valid': False, 'error': f'不支持的加密版本: {data.get("version")}'}
|
||
|
||
return {'valid': True, 'encrypted_data': data}
|
||
|
||
except Exception as e:
|
||
return {'valid': False, 'error': f'验证失败: {str(e)}'}
|
||
|
||
@staticmethod
|
||
def is_properly_encrypted(encrypted_path):
|
||
"""检查模型是否被正确加密(简化的验证)"""
|
||
try:
|
||
with open(encrypted_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
|
||
# 基本字段检查
|
||
if not isinstance(data, dict):
|
||
return False
|
||
|
||
required_fields = ['salt', 'data', 'encrypted', 'model_hash']
|
||
for field in required_fields:
|
||
if field not in data:
|
||
return False
|
||
|
||
return data.get('encrypted', False) is True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"检查加密格式失败: {str(e)}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def decrypt_and_verify(encrypted_path, password, verify_key=True):
|
||
"""解密模型文件并验证完整性"""
|
||
try:
|
||
# 先验证加密格式
|
||
verify_result = MandatoryModelValidator.verify_model_encryption(encrypted_path)
|
||
if not verify_result['valid']:
|
||
return {'success': False, 'error': verify_result['error']}
|
||
|
||
encrypted_payload = verify_result['encrypted_data']
|
||
salt = encrypted_payload['salt']
|
||
encrypted_data = encrypted_payload['data']
|
||
expected_hash = encrypted_payload['model_hash']
|
||
|
||
# 生成解密密钥
|
||
kdf = PBKDF2HMAC(
|
||
algorithm=hashes.SHA256(),
|
||
length=32,
|
||
salt=salt,
|
||
iterations=100000,
|
||
)
|
||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||
|
||
# 解密数据
|
||
fernet = Fernet(key)
|
||
|
||
try:
|
||
decrypted_data = fernet.decrypt(encrypted_data)
|
||
except Exception as e:
|
||
if "InvalidToken" in str(e) or "Invalid signature" in str(e):
|
||
return {'success': False, 'error': '解密密钥错误'}
|
||
return {'success': False, 'error': f'解密失败: {str(e)}'}
|
||
|
||
# 验证模型哈希
|
||
actual_hash = hashlib.sha256(decrypted_data).hexdigest()
|
||
if actual_hash != expected_hash:
|
||
return {
|
||
'success': False,
|
||
'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...'
|
||
}
|
||
|
||
# 验证文件大小
|
||
original_size = encrypted_payload.get('original_size', 0)
|
||
if original_size > 0 and len(decrypted_data) != original_size:
|
||
return {
|
||
'success': False,
|
||
'error': f'文件大小不匹配: 期望{original_size}, 实际{len(decrypted_data)}'
|
||
}
|
||
|
||
logger.info(f"模型解密验证成功: {encrypted_path}")
|
||
logger.debug(f"模型哈希: {actual_hash[:16]}...")
|
||
|
||
return {
|
||
'success': True,
|
||
'model_hash': actual_hash,
|
||
'original_size': len(decrypted_data),
|
||
'decrypted_data': decrypted_data, # 返回解密数据,供后续处理
|
||
'version': encrypted_payload.get('version', '1.0')
|
||
}
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
if "InvalidToken" in error_msg or "Invalid signature" in error_msg:
|
||
return {'success': False, 'error': '解密密钥错误'}
|
||
return {'success': False, 'error': f'解密失败: {error_msg}'}
|
||
|
||
@staticmethod
|
||
def decrypt_model(encrypted_path, password, verify_key=False):
|
||
"""解密模型文件 - 兼容旧接口"""
|
||
return MandatoryModelValidator.decrypt_and_verify(
|
||
encrypted_path, password, verify_key=verify_key
|
||
)
|
||
|
||
|
||
class ModelEncryptionService:
|
||
"""模型加密服务 - 处理密钥生成、验证和加密流程"""
|
||
|
||
@staticmethod
|
||
def generate_secure_key(length=32):
|
||
"""生成安全的加密密钥"""
|
||
# 使用加密安全的随机数生成器
|
||
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
|
||
key = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||
|
||
# 计算密钥哈希
|
||
key_hash = hashlib.sha256(key.encode()).hexdigest()
|
||
|
||
return {
|
||
'key': key,
|
||
'key_hash': key_hash,
|
||
'short_hash': key_hash[:16]
|
||
}
|
||
|
||
@staticmethod
|
||
def validate_key_strength(key):
|
||
"""验证密钥强度"""
|
||
if len(key) < 16:
|
||
return False, "密钥长度至少16位"
|
||
|
||
# 检查字符种类
|
||
has_upper = any(c.isupper() for c in key)
|
||
has_lower = any(c.islower() for c in key)
|
||
has_digit = any(c.isdigit() for c in key)
|
||
has_special = any(c in "!@#$%^&*" for c in key)
|
||
|
||
if not (has_upper and has_lower and has_digit and has_special):
|
||
return False, "密钥应包含大小写字母、数字和特殊字符"
|
||
|
||
return True, "密钥强度足够"
|
||
|
||
@staticmethod
|
||
def create_key_pair():
|
||
"""创建密钥对(用于客户端-服务器通信)"""
|
||
# 生成主密钥
|
||
master_key = secrets.token_urlsafe(32)
|
||
|
||
# 生成验证令牌
|
||
verification_token = hashlib.sha256(
|
||
(master_key + str(time.time())).encode()
|
||
).hexdigest()[:20]
|
||
|
||
return {
|
||
'master_key': master_key,
|
||
'verification_token': verification_token,
|
||
'created_at': time.time()
|
||
}
|
||
|
||
|
||
class PreTaskModelValidator:
|
||
"""任务前模型验证器 - 在创建任务前验证所有模型"""
|
||
|
||
def __init__(self, config=None):
|
||
self.config = config or {}
|
||
self.encrypted_models_dir = self.config.get('encrypted_models_dir', 'encrypted_models')
|
||
os.makedirs(self.encrypted_models_dir, exist_ok=True)
|
||
|
||
# 验证状态缓存
|
||
self.validation_cache = {}
|
||
self.cache_lock = threading.Lock()
|
||
self.cache_expiry = 300 # 5分钟缓存有效期
|
||
|
||
def validate_models_before_task(self, task_config):
|
||
"""在创建任务前验证所有模型和密钥"""
|
||
try:
|
||
models_config = task_config.get('models', [])
|
||
if not models_config:
|
||
return {
|
||
'success': False,
|
||
'error': '任务配置中未找到模型列表'
|
||
}
|
||
|
||
validation_results = []
|
||
all_valid = True
|
||
|
||
for i, model_config in enumerate(models_config):
|
||
# 检查必要的配置项
|
||
required_fields = ['path', 'encryption_key']
|
||
for field in required_fields:
|
||
if field not in model_config:
|
||
return {
|
||
'success': False,
|
||
'error': f'模型 {i} 缺少必要字段: {field}'
|
||
}
|
||
|
||
model_path = model_config['path']
|
||
encryption_key = model_config['encryption_key']
|
||
|
||
# 构建完整的模型文件路径
|
||
full_model_path = self._get_full_model_path(model_path)
|
||
|
||
# 验证模型文件是否存在
|
||
if not os.path.exists(full_model_path):
|
||
return {
|
||
'success': False,
|
||
'error': f'模型文件不存在: {full_model_path}'
|
||
}
|
||
|
||
# 验证密钥强度
|
||
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
|
||
if not key_valid:
|
||
return {
|
||
'success': False,
|
||
'error': f'模型 {i} 密钥强度不足: {key_msg}'
|
||
}
|
||
|
||
# 验证密钥是否正确(尝试解密)
|
||
validator = MandatoryModelValidator()
|
||
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
|
||
|
||
validation_result = {
|
||
'model_index': i,
|
||
'model_path': model_path,
|
||
'full_path': full_model_path,
|
||
'file_exists': True,
|
||
'key_valid': decrypt_result['success'],
|
||
'model_hash': decrypt_result.get('model_hash', '')[:16] if decrypt_result['success'] else '',
|
||
'model_size': decrypt_result.get('original_size', 0) if decrypt_result['success'] else 0,
|
||
'validation_time': time.time()
|
||
}
|
||
|
||
if not decrypt_result['success']:
|
||
validation_result['error'] = decrypt_result.get('error', '验证失败')
|
||
all_valid = False
|
||
|
||
validation_results.append(validation_result)
|
||
|
||
return {
|
||
'success': all_valid,
|
||
'valid': all_valid,
|
||
'total_models': len(models_config),
|
||
'valid_models': sum(1 for r in validation_results if r['key_valid']),
|
||
'validation_results': validation_results,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"预验证模型失败: {str(e)}")
|
||
return {
|
||
'success': False,
|
||
'error': f'预验证失败: {str(e)}'
|
||
}
|
||
|
||
def verify_single_model(self, model_path, encryption_key):
|
||
"""验证单个模型"""
|
||
try:
|
||
# 构建完整路径
|
||
full_model_path = self._get_full_model_path(model_path)
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(full_model_path):
|
||
return {
|
||
'success': False,
|
||
'error': f'模型文件不存在: {full_model_path}'
|
||
}
|
||
|
||
# 验证密钥强度
|
||
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
|
||
if not key_valid:
|
||
return {
|
||
'success': False,
|
||
'error': f'密钥强度不足: {key_msg}'
|
||
}
|
||
|
||
# 解密验证
|
||
validator = MandatoryModelValidator()
|
||
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
|
||
|
||
if decrypt_result['success']:
|
||
return {
|
||
'success': True,
|
||
'valid': True,
|
||
'model_hash': decrypt_result.get('model_hash', '')[:16],
|
||
'model_size': decrypt_result.get('original_size', 0),
|
||
'file_path': full_model_path,
|
||
'timestamp': time.time()
|
||
}
|
||
else:
|
||
return {
|
||
'success': False,
|
||
'valid': False,
|
||
'error': decrypt_result.get('error', '验证失败'),
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"验证单个模型失败: {str(e)}")
|
||
return {
|
||
'success': False,
|
||
'error': f'验证失败: {str(e)}'
|
||
}
|
||
|
||
def _get_full_model_path(self, model_path):
|
||
"""获取完整的模型文件路径"""
|
||
if os.path.isabs(model_path):
|
||
return model_path
|
||
|
||
# 如果路径以 encrypted_models/ 开头
|
||
if model_path.startswith('encrypted_models/'):
|
||
model_filename = os.path.basename(model_path)
|
||
return os.path.join(self.encrypted_models_dir, model_filename)
|
||
|
||
# 否则直接使用文件名
|
||
model_filename = os.path.basename(model_path)
|
||
return os.path.join(self.encrypted_models_dir, model_filename)
|
||
|
||
def get_model_info(self, model_path):
|
||
"""获取模型信息(不验证密钥)"""
|
||
try:
|
||
full_model_path = self._get_full_model_path(model_path)
|
||
|
||
if not os.path.exists(full_model_path):
|
||
return None
|
||
|
||
verify_result = MandatoryModelValidator().verify_model_encryption(full_model_path)
|
||
|
||
if verify_result['valid']:
|
||
data = verify_result['encrypted_data']
|
||
return {
|
||
'encrypted': True,
|
||
'model_hash': data.get('model_hash', '')[:16],
|
||
'original_size': data.get('original_size', 0),
|
||
'version': data.get('version', 'unknown'),
|
||
'file_path': full_model_path,
|
||
'file_size': os.path.getsize(full_model_path)
|
||
}
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模型信息失败: {str(e)}")
|
||
return None
|
||
|
||
|
||
class SecureModelManager:
|
||
"""安全模型管理器 - 负责下载、验证、加载加密模型"""
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
self.models_dir = config.get('models_dir', 'models')
|
||
self.encrypted_models_dir = config.get('encrypted_models_dir', 'encrypted_models')
|
||
os.makedirs(self.models_dir, exist_ok=True)
|
||
os.makedirs(self.encrypted_models_dir, exist_ok=True)
|
||
|
||
# 验证器实例
|
||
self.validator = MandatoryModelValidator()
|
||
self.pre_validator = PreTaskModelValidator(config)
|
||
|
||
# 模型缓存
|
||
self.model_cache = {}
|
||
self.cache_lock = threading.Lock()
|
||
self.cache_expiry = 600 # 10分钟缓存
|
||
|
||
# 验证状态记录
|
||
self.verification_status = {}
|
||
self.verification_expiry = 300 # 5分钟验证缓存
|
||
|
||
def ensure_model_available(self, model_config):
|
||
"""确保模型可用:验证模型文件存在"""
|
||
try:
|
||
model_path = model_config.get('path')
|
||
if not model_path:
|
||
return {'available': False, 'error': '模型路径未配置'}
|
||
|
||
# 获取模型本地路径
|
||
local_path = self._get_local_model_path(model_path, model_config)
|
||
|
||
# 检查模型文件是否存在
|
||
if not os.path.exists(local_path):
|
||
return {'available': False, 'error': f'模型文件不存在: {local_path}'}
|
||
|
||
# 验证加密格式
|
||
verify_result = self.validator.verify_model_encryption(local_path)
|
||
if not verify_result['valid']:
|
||
return {'available': False, 'error': f'模型加密格式无效: {verify_result.get("error", "未知错误")}'}
|
||
|
||
return {
|
||
'available': True,
|
||
'local_path': local_path,
|
||
'model_hash': verify_result['encrypted_data'].get('model_hash', '')[:16]
|
||
}
|
||
|
||
except Exception as e:
|
||
return {'available': False, 'error': f'确保模型可用失败: {str(e)}'}
|
||
|
||
def verify_model_key(self, model_config, encryption_key):
|
||
"""验证模型密钥 - 这是创建任务的前置条件"""
|
||
try:
|
||
# 确保模型文件存在
|
||
availability = self.ensure_model_available(model_config)
|
||
if not availability['available']:
|
||
return {'valid': False, 'error': availability['error']}
|
||
|
||
local_path = availability['local_path']
|
||
model_hash_short = availability.get('model_hash', 'unknown')
|
||
|
||
# 检查是否已验证过(避免重复验证)
|
||
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}"
|
||
with self.cache_lock:
|
||
if cache_key in self.verification_status:
|
||
cached_result = self.verification_status[cache_key]
|
||
if time.time() - cached_result.get('verified_at', 0) < self.verification_expiry:
|
||
if cached_result['valid']:
|
||
logger.info(f"使用缓存的验证结果: {model_hash_short}")
|
||
return cached_result
|
||
else:
|
||
# 缓存过期,删除
|
||
del self.verification_status[cache_key]
|
||
|
||
# 解密验证
|
||
decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key)
|
||
|
||
if decrypt_result['success']:
|
||
result = {
|
||
'valid': True,
|
||
'model_hash': decrypt_result.get('model_hash', '')[:16],
|
||
'original_size': decrypt_result.get('original_size', 0),
|
||
'local_path': local_path,
|
||
'verified_at': time.time()
|
||
}
|
||
|
||
# 缓存验证结果
|
||
with self.cache_lock:
|
||
self.verification_status[cache_key] = result
|
||
|
||
logger.info(f"模型密钥验证成功: {model_hash_short}")
|
||
return result
|
||
else:
|
||
result = {'valid': False, 'error': decrypt_result.get('error', '验证失败')}
|
||
logger.warning(f"模型密钥验证失败: {model_hash_short} - {result['error']}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
error_msg = f"验证模型密钥失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
return {'valid': False, 'error': error_msg}
|
||
|
||
def load_verified_model(self, model_config, encryption_key):
|
||
"""加载已验证的模型 - 仅在验证通过后调用"""
|
||
try:
|
||
# 先验证密钥
|
||
verify_result = self.verify_model_key(model_config, encryption_key)
|
||
if not verify_result['valid']:
|
||
raise ValueError(f"模型验证未通过: {verify_result.get('error', '未知错误')}")
|
||
|
||
local_path = verify_result['local_path']
|
||
|
||
# 检查缓存
|
||
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}"
|
||
with self.cache_lock:
|
||
if cache_key in self.model_cache:
|
||
cached_info = self.model_cache[cache_key]
|
||
if time.time() - cached_info.get('loaded_at', 0) < self.cache_expiry:
|
||
logger.info(f"使用缓存的模型: {local_path}")
|
||
return cached_info['model']
|
||
|
||
# 解密模型(使用已验证的密钥)
|
||
decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key)
|
||
if not decrypt_result['success']:
|
||
raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}")
|
||
|
||
# 保存解密数据到临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
|
||
tmp.write(decrypt_result['decrypted_data'])
|
||
temp_path = tmp.name
|
||
|
||
# 加载YOLO模型
|
||
from ultralytics import YOLO
|
||
model = YOLO(temp_path)
|
||
|
||
# 应用设备配置
|
||
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
|
||
model = model.to(device)
|
||
|
||
# 应用半精度配置
|
||
if model_config.get('half', False) and 'cuda' in device:
|
||
model = model.half()
|
||
logger.info(f"启用半精度推理: {model_config.get('path', '未知模型')}")
|
||
|
||
# 清理临时文件
|
||
try:
|
||
os.unlink(temp_path)
|
||
except:
|
||
pass
|
||
|
||
# 缓存模型
|
||
with self.cache_lock:
|
||
self.model_cache[cache_key] = {
|
||
'model': model,
|
||
'device': device,
|
||
'loaded_at': time.time(),
|
||
'original_size': decrypt_result.get('original_size', 0)
|
||
}
|
||
|
||
logger.info(f"加密模型加载成功: {model_config.get('path', '未知模型')} -> {device}")
|
||
|
||
return model
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载已验证模型失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return None
|
||
|
||
def _get_local_model_path(self, model_path, model_config):
|
||
"""获取模型本地路径"""
|
||
# 如果已经是本地路径,直接返回
|
||
if os.path.exists(model_path):
|
||
return model_path
|
||
|
||
# 检查是否是加密模型
|
||
encrypted = model_config.get('encrypted', False)
|
||
|
||
if encrypted:
|
||
# 从加密模型目录查找
|
||
model_filename = os.path.basename(model_path)
|
||
if not model_filename.endswith('.enc'):
|
||
model_filename += '.enc'
|
||
return os.path.join(self.encrypted_models_dir, model_filename)
|
||
else:
|
||
# 从普通模型目录查找
|
||
model_filename = os.path.basename(model_path)
|
||
return os.path.join(self.models_dir, model_filename)
|
||
|
||
def get_model_info(self, model_config):
|
||
"""获取模型信息(不验证密钥)"""
|
||
try:
|
||
availability = self.ensure_model_available(model_config)
|
||
if not availability['available']:
|
||
return None
|
||
|
||
local_path = availability['local_path']
|
||
verify_result = self.validator.verify_model_encryption(local_path)
|
||
|
||
if verify_result['valid']:
|
||
data = verify_result['encrypted_data']
|
||
return {
|
||
'encrypted': True,
|
||
'model_hash': data.get('model_hash', '')[:16],
|
||
'original_size': data.get('original_size', 0),
|
||
'version': data.get('version', 'unknown'),
|
||
'local_path': local_path,
|
||
'file_size': os.path.getsize(local_path)
|
||
}
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模型信息失败: {str(e)}")
|
||
return None
|
||
|
||
def clear_cache(self):
|
||
"""清空缓存"""
|
||
with self.cache_lock:
|
||
self.model_cache.clear()
|
||
self.verification_status.clear()
|
||
logger.info("模型缓存已清空")
|
||
|
||
def get_cache_info(self):
|
||
"""获取缓存信息"""
|
||
with self.cache_lock:
|
||
return {
|
||
'model_cache_size': len(self.model_cache),
|
||
'verification_cache_size': len(self.verification_status),
|
||
'model_cache_keys': list(self.model_cache.keys()),
|
||
'verification_cache_keys': list(self.verification_status.keys())
|
||
}
|
||
|
||
|
||
class SimpleModelEncryptor:
|
||
"""简化模型加密器 - 仅用于测试或本地转换"""
|
||
|
||
@staticmethod
|
||
def create_encrypted_payload(model_data, password):
|
||
"""创建加密数据包(模拟.NET服务器的加密过程)"""
|
||
# 生成盐和密钥
|
||
salt = os.urandom(16)
|
||
kdf = PBKDF2HMAC(
|
||
algorithm=hashes.SHA256(),
|
||
length=32,
|
||
salt=salt,
|
||
iterations=100000,
|
||
)
|
||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||
fernet = Fernet(key)
|
||
|
||
# 加密数据
|
||
encrypted_data = fernet.encrypt(model_data)
|
||
model_hash = hashlib.sha256(model_data).hexdigest()
|
||
|
||
# 创建加密数据包
|
||
encrypted_payload = {
|
||
'salt': salt,
|
||
'data': encrypted_data,
|
||
'model_hash': model_hash,
|
||
'original_size': len(model_data),
|
||
'encrypted': True,
|
||
'version': '2.0',
|
||
'created_at': time.time()
|
||
}
|
||
|
||
return encrypted_payload, key
|
||
|
||
@staticmethod
|
||
def encrypt_model_file(input_path, output_path, password):
|
||
"""加密模型文件"""
|
||
try:
|
||
# 读取原始模型文件
|
||
with open(input_path, 'rb') as f:
|
||
model_data = f.read()
|
||
|
||
# 创建加密数据包
|
||
encrypted_payload, key = SimpleModelEncryptor.create_encrypted_payload(model_data, password)
|
||
|
||
# 保存加密文件
|
||
with open(output_path, 'wb') as f:
|
||
pickle.dump(encrypted_payload, f)
|
||
|
||
logger.info(f"模型加密成功: {input_path} -> {output_path}")
|
||
|
||
return {
|
||
'success': True,
|
||
'output_path': output_path,
|
||
'model_hash': encrypted_payload['model_hash'][:16],
|
||
'original_size': len(model_data),
|
||
'encrypted_size': os.path.getsize(output_path)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"加密模型文件失败: {str(e)}")
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
@staticmethod
|
||
def verify_encryption(encrypted_path, password):
|
||
"""验证加密文件"""
|
||
return MandatoryModelValidator().decrypt_and_verify(encrypted_path, password)
|
||
|
||
|
||
# ==================== 全局实例和接口函数 ====================
|
||
|
||
# 导入线程模块
|
||
import threading
|
||
|
||
# 全局实例
|
||
_model_manager = None
|
||
_pre_task_validator = None
|
||
_upload_manager = None
|
||
|
||
|
||
def get_model_manager(config=None):
|
||
"""获取模型管理器单例"""
|
||
global _model_manager
|
||
if _model_manager is None:
|
||
if config is None:
|
||
from config import get_default_config
|
||
config = get_default_config()
|
||
_model_manager = SecureModelManager(config)
|
||
return _model_manager
|
||
|
||
|
||
def get_pre_task_validator(config=None):
|
||
"""获取预验证器单例"""
|
||
global _pre_task_validator
|
||
if _pre_task_validator is None:
|
||
if config is None:
|
||
from config import get_default_config
|
||
config = get_default_config()
|
||
_pre_task_validator = PreTaskModelValidator(config.get('upload', {}))
|
||
return _pre_task_validator
|
||
|
||
|
||
def validate_models_before_task(task_config):
|
||
"""验证任务的所有模型(外部调用接口)"""
|
||
validator = get_pre_task_validator()
|
||
return validator.validate_models_before_task(task_config)
|
||
|
||
|
||
def verify_single_model_api(model_path, encryption_key):
|
||
"""验证单个模型(API接口)"""
|
||
validator = get_pre_task_validator()
|
||
return validator.verify_single_model(model_path, encryption_key)
|
||
|
||
|
||
def verify_model_key_api(model_config, encryption_key):
|
||
"""验证模型密钥(API接口)"""
|
||
manager = get_model_manager()
|
||
return manager.verify_model_key(model_config, encryption_key)
|
||
|
||
|
||
def load_verified_model_api(model_config, encryption_key):
|
||
"""加载已验证的模型(API接口)"""
|
||
manager = get_model_manager()
|
||
return manager.load_verified_model(model_config, encryption_key)
|
||
|
||
|
||
def generate_secure_key_api(length=32):
|
||
"""生成安全密钥(API接口)"""
|
||
return ModelEncryptionService.generate_secure_key(length)
|
||
|
||
|
||
def validate_key_strength_api(key):
|
||
"""验证密钥强度(API接口)"""
|
||
return ModelEncryptionService.validate_key_strength(key)
|
||
|
||
|
||
def test_decryption_api(encrypted_path, password):
|
||
"""测试解密(API接口)"""
|
||
validator = MandatoryModelValidator()
|
||
return validator.decrypt_and_verify(encrypted_path, password)
|
||
|
||
|
||
# ==================== 向后兼容的接口 ====================
|
||
|
||
class MandatoryModelEncryptor:
|
||
"""向后兼容的加密器类"""
|
||
|
||
@staticmethod
|
||
def encrypt_model(input_path, output_path, password, require_encryption=True):
|
||
"""加密模型 - 向后兼容接口"""
|
||
return SimpleModelEncryptor.encrypt_model_file(input_path, output_path, password)
|
||
|
||
@staticmethod
|
||
def decrypt_model(encrypted_path, password, verify_key=False):
|
||
"""解密模型 - 向后兼容接口"""
|
||
validator = MandatoryModelValidator()
|
||
return validator.decrypt_model(encrypted_path, password, verify_key=verify_key)
|
||
|
||
@staticmethod
|
||
def is_properly_encrypted(encrypted_path):
|
||
"""检查是否被正确加密 - 向后兼容接口"""
|
||
return MandatoryModelValidator.is_properly_encrypted(encrypted_path)
|
||
|
||
@staticmethod
|
||
def generate_secure_key():
|
||
"""生成安全密钥 - 向后兼容接口"""
|
||
return ModelEncryptionService.generate_secure_key()
|
||
|
||
@staticmethod
|
||
def verify_model_key(encrypted_path, password):
|
||
"""验证模型密钥 - 向后兼容接口"""
|
||
validator = MandatoryModelValidator()
|
||
return validator.decrypt_and_verify(encrypted_path, password)
|
||
|
||
|
||
# ==================== 测试代码 ====================
|
||
|
||
if __name__ == "__main__":
|
||
import time
|
||
|
||
print("=== 强制加密模型验证模块测试 ===")
|
||
|
||
# 配置示例
|
||
config = {
|
||
'models_dir': './models',
|
||
'encrypted_models_dir': './encrypted_models',
|
||
'cache_verification': True
|
||
}
|
||
|
||
# 模型配置示例
|
||
model_config = {
|
||
'path': 'yolov8n.enc', # 加密模型文件路径
|
||
'device': 'cpu',
|
||
'half': False,
|
||
'encrypted': True,
|
||
'encryption_key': 'test-password-123!@#' # 测试密钥
|
||
}
|
||
|
||
# 测试密钥生成
|
||
print("\n1. 测试密钥生成:")
|
||
key_info = ModelEncryptionService.generate_secure_key()
|
||
print(f" 生成的密钥: {key_info['key']}")
|
||
print(f" 密钥哈希: {key_info['key_hash']}")
|
||
print(f" 短哈希: {key_info['short_hash']}")
|
||
|
||
# 测试密钥强度验证
|
||
print("\n2. 测试密钥强度验证:")
|
||
test_key = "WeakKey123"
|
||
valid, msg = ModelEncryptionService.validate_key_strength(test_key)
|
||
print(f" 密钥 '{test_key}': {valid} - {msg}")
|
||
|
||
test_key = "Strong@Password#2024!Complex"
|
||
valid, msg = ModelEncryptionService.validate_key_strength(test_key)
|
||
print(f" 密钥 '{test_key}': {valid} - {msg}")
|
||
|
||
# 测试模型验证器
|
||
print("\n3. 测试模型验证器:")
|
||
validator = MandatoryModelValidator()
|
||
|
||
# 测试不存在的文件
|
||
result = validator.verify_model_encryption("nonexistent.enc")
|
||
print(f" 不存在的文件验证: {result['valid']} - {result.get('error', '')}")
|
||
|
||
# 测试预验证器
|
||
print("\n4. 测试预验证器:")
|
||
pre_validator = PreTaskModelValidator(config)
|
||
|
||
task_config = {
|
||
'models': [
|
||
{
|
||
'path': 'test_model.enc',
|
||
'encryption_key': 'Test@Password#2024!Complex'
|
||
}
|
||
]
|
||
}
|
||
|
||
result = pre_validator.validate_models_before_task(task_config)
|
||
print(f" 任务验证结果: 成功={result['success']}, 错误={result.get('error', '无')}")
|
||
|
||
# 测试安全模型管理器
|
||
print("\n5. 测试安全模型管理器:")
|
||
manager = SecureModelManager(config)
|
||
|
||
# 测试模型信息获取
|
||
model_info = manager.get_model_info(model_config)
|
||
if model_info:
|
||
print(f" 模型信息: 哈希={model_info['model_hash']}, 大小={model_info['original_size']}")
|
||
else:
|
||
print(f" 模型信息: 不可用")
|
||
|
||
print("\n=== 测试完成 ===") |