Yolov/mandatory_model_crypto.py

870 lines
31 KiB
Python
Raw Normal View History

2025-12-12 16:04:22 +08:00
# mandatory_model_crypto.py - 强制加密模型验证和解密模块
2025-12-11 13:41:07 +08:00
import os
import tempfile
import hashlib
import pickle
import traceback
2025-12-12 16:04:22 +08:00
import time
import secrets
import string
2025-12-11 13:41:07 +08:00
from pathlib import Path
2025-12-12 16:04:22 +08:00
import base64
2025-12-11 13:41:07 +08:00
import requests
import torch
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
2025-12-12 16:04:22 +08:00
2025-12-11 13:41:07 +08:00
from log import logger
2025-12-12 16:04:22 +08:00
class MandatoryModelValidator:
"""强制加密模型验证器 - 只负责验证和解密"""
2025-12-11 13:41:07 +08:00
@staticmethod
2025-12-12 16:04:22 +08:00
def verify_model_encryption(encrypted_path):
"""验证模型是否被正确加密"""
2025-12-11 13:41:07 +08:00
try:
2025-12-12 16:04:22 +08:00
if not os.path.exists(encrypted_path):
return {'valid': False, 'error': '加密模型文件不存在'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
with open(encrypted_path, 'rb') as f:
data = pickle.load(f)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 检查必要的加密字段
required_fields = ['salt', 'data', 'encrypted', 'model_hash', 'version']
for field in required_fields:
if field not in data:
return {'valid': False, 'error': f'缺少加密字段: {field}'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
if data.get('encrypted', False) is not True:
return {'valid': False, 'error': '模型未加密'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 验证版本兼容性
if data.get('version') not in ['1.0', '2.0']:
return {'valid': False, 'error': f'不支持的加密版本: {data.get("version")}'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
return {'valid': True, 'encrypted_data': data}
2025-12-11 13:41:07 +08:00
except Exception as e:
2025-12-12 16:04:22 +08:00
return {'valid': False, 'error': f'验证失败: {str(e)}'}
2025-12-11 13:41:07 +08:00
@staticmethod
2025-12-12 16:04:22 +08:00
def is_properly_encrypted(encrypted_path):
"""检查模型是否被正确加密(简化的验证)"""
2025-12-11 13:41:07 +08:00
try:
with open(encrypted_path, 'rb') as f:
2025-12-12 16:04:22 +08:00
data = pickle.load(f)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 基本字段检查
if not isinstance(data, dict):
return False
required_fields = ['salt', 'data', 'encrypted', 'model_hash']
for field in required_fields:
if field not in data:
return False
return data.get('encrypted', False) is True
except Exception as e:
logger.warning(f"检查加密格式失败: {str(e)}")
return False
@staticmethod
def decrypt_and_verify(encrypted_path, password, verify_key=True):
"""解密模型文件并验证完整性"""
try:
# 先验证加密格式
verify_result = MandatoryModelValidator.verify_model_encryption(encrypted_path)
if not verify_result['valid']:
return {'success': False, 'error': verify_result['error']}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
encrypted_payload = verify_result['encrypted_data']
2025-12-11 13:41:07 +08:00
salt = encrypted_payload['salt']
encrypted_data = encrypted_payload['data']
2025-12-12 16:04:22 +08:00
expected_hash = encrypted_payload['model_hash']
2025-12-11 13:41:07 +08:00
# 生成解密密钥
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
2025-12-12 16:04:22 +08:00
# 解密数据
2025-12-11 13:41:07 +08:00
fernet = Fernet(key)
2025-12-12 16:04:22 +08:00
try:
decrypted_data = fernet.decrypt(encrypted_data)
except Exception as e:
if "InvalidToken" in str(e) or "Invalid signature" in str(e):
return {'success': False, 'error': '解密密钥错误'}
return {'success': False, 'error': f'解密失败: {str(e)}'}
2025-12-11 13:41:07 +08:00
# 验证模型哈希
actual_hash = hashlib.sha256(decrypted_data).hexdigest()
2025-12-12 16:04:22 +08:00
if actual_hash != expected_hash:
2025-12-11 13:41:07 +08:00
return {
'success': False,
'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...'
}
2025-12-12 16:04:22 +08:00
# 验证文件大小
original_size = encrypted_payload.get('original_size', 0)
if original_size > 0 and len(decrypted_data) != original_size:
return {
'success': False,
'error': f'文件大小不匹配: 期望{original_size}, 实际{len(decrypted_data)}'
}
2025-12-11 13:41:07 +08:00
logger.info(f"模型解密验证成功: {encrypted_path}")
2025-12-12 16:04:22 +08:00
logger.debug(f"模型哈希: {actual_hash[:16]}...")
2025-12-11 13:41:07 +08:00
return {
'success': True,
'model_hash': actual_hash,
2025-12-12 16:04:22 +08:00
'original_size': len(decrypted_data),
'decrypted_data': decrypted_data, # 返回解密数据,供后续处理
'version': encrypted_payload.get('version', '1.0')
2025-12-11 13:41:07 +08:00
}
except Exception as e:
error_msg = str(e)
if "InvalidToken" in error_msg or "Invalid signature" in error_msg:
return {'success': False, 'error': '解密密钥错误'}
return {'success': False, 'error': f'解密失败: {error_msg}'}
@staticmethod
2025-12-12 16:04:22 +08:00
def decrypt_model(encrypted_path, password, verify_key=False):
"""解密模型文件 - 兼容旧接口"""
return MandatoryModelValidator.decrypt_and_verify(
encrypted_path, password, verify_key=verify_key
)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
class ModelEncryptionService:
"""模型加密服务 - 处理密钥生成、验证和加密流程"""
2025-12-11 13:41:07 +08:00
@staticmethod
2025-12-12 16:04:22 +08:00
def generate_secure_key(length=32):
2025-12-11 13:41:07 +08:00
"""生成安全的加密密钥"""
2025-12-12 16:04:22 +08:00
# 使用加密安全的随机数生成器
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
key = ''.join(secrets.choice(alphabet) for _ in range(length))
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 计算密钥哈希
key_hash = hashlib.sha256(key.encode()).hexdigest()
2025-12-11 13:41:07 +08:00
return {
2025-12-12 16:04:22 +08:00
'key': key,
2025-12-11 13:41:07 +08:00
'key_hash': key_hash,
'short_hash': key_hash[:16]
}
2025-12-12 16:04:22 +08:00
@staticmethod
def validate_key_strength(key):
"""验证密钥强度"""
if len(key) < 16:
return False, "密钥长度至少16位"
# 检查字符种类
has_upper = any(c.isupper() for c in key)
has_lower = any(c.islower() for c in key)
has_digit = any(c.isdigit() for c in key)
has_special = any(c in "!@#$%^&*" for c in key)
if not (has_upper and has_lower and has_digit and has_special):
return False, "密钥应包含大小写字母、数字和特殊字符"
return True, "密钥强度足够"
@staticmethod
def create_key_pair():
"""创建密钥对(用于客户端-服务器通信)"""
# 生成主密钥
master_key = secrets.token_urlsafe(32)
# 生成验证令牌
verification_token = hashlib.sha256(
(master_key + str(time.time())).encode()
).hexdigest()[:20]
return {
'master_key': master_key,
'verification_token': verification_token,
'created_at': time.time()
}
class PreTaskModelValidator:
"""任务前模型验证器 - 在创建任务前验证所有模型"""
def __init__(self, config=None):
self.config = config or {}
self.encrypted_models_dir = self.config.get('encrypted_models_dir', 'encrypted_models')
os.makedirs(self.encrypted_models_dir, exist_ok=True)
# 验证状态缓存
self.validation_cache = {}
self.cache_lock = threading.Lock()
self.cache_expiry = 300 # 5分钟缓存有效期
def validate_models_before_task(self, task_config):
"""在创建任务前验证所有模型和密钥"""
try:
models_config = task_config.get('models', [])
if not models_config:
return {
'success': False,
'error': '任务配置中未找到模型列表'
}
validation_results = []
all_valid = True
for i, model_config in enumerate(models_config):
# 检查必要的配置项
required_fields = ['path', 'encryption_key']
for field in required_fields:
if field not in model_config:
return {
'success': False,
'error': f'模型 {i} 缺少必要字段: {field}'
}
model_path = model_config['path']
encryption_key = model_config['encryption_key']
# 构建完整的模型文件路径
full_model_path = self._get_full_model_path(model_path)
# 验证模型文件是否存在
if not os.path.exists(full_model_path):
return {
'success': False,
'error': f'模型文件不存在: {full_model_path}'
}
# 验证密钥强度
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
if not key_valid:
return {
'success': False,
'error': f'模型 {i} 密钥强度不足: {key_msg}'
}
# 验证密钥是否正确(尝试解密)
validator = MandatoryModelValidator()
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
validation_result = {
'model_index': i,
'model_path': model_path,
'full_path': full_model_path,
'file_exists': True,
'key_valid': decrypt_result['success'],
'model_hash': decrypt_result.get('model_hash', '')[:16] if decrypt_result['success'] else '',
'model_size': decrypt_result.get('original_size', 0) if decrypt_result['success'] else 0,
'validation_time': time.time()
}
if not decrypt_result['success']:
validation_result['error'] = decrypt_result.get('error', '验证失败')
all_valid = False
validation_results.append(validation_result)
return {
'success': all_valid,
'valid': all_valid,
'total_models': len(models_config),
'valid_models': sum(1 for r in validation_results if r['key_valid']),
'validation_results': validation_results,
'timestamp': time.time()
}
except Exception as e:
logger.error(f"预验证模型失败: {str(e)}")
return {
'success': False,
'error': f'预验证失败: {str(e)}'
}
def verify_single_model(self, model_path, encryption_key):
"""验证单个模型"""
try:
# 构建完整路径
full_model_path = self._get_full_model_path(model_path)
# 检查文件是否存在
if not os.path.exists(full_model_path):
return {
'success': False,
'error': f'模型文件不存在: {full_model_path}'
}
# 验证密钥强度
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
if not key_valid:
return {
'success': False,
'error': f'密钥强度不足: {key_msg}'
}
# 解密验证
validator = MandatoryModelValidator()
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
if decrypt_result['success']:
return {
'success': True,
'valid': True,
'model_hash': decrypt_result.get('model_hash', '')[:16],
'model_size': decrypt_result.get('original_size', 0),
'file_path': full_model_path,
'timestamp': time.time()
}
else:
return {
'success': False,
'valid': False,
'error': decrypt_result.get('error', '验证失败'),
'timestamp': time.time()
}
except Exception as e:
logger.error(f"验证单个模型失败: {str(e)}")
return {
'success': False,
'error': f'验证失败: {str(e)}'
}
def _get_full_model_path(self, model_path):
"""获取完整的模型文件路径"""
if os.path.isabs(model_path):
return model_path
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 如果路径以 encrypted_models/ 开头
if model_path.startswith('encrypted_models/'):
model_filename = os.path.basename(model_path)
return os.path.join(self.encrypted_models_dir, model_filename)
# 否则直接使用文件名
model_filename = os.path.basename(model_path)
return os.path.join(self.encrypted_models_dir, model_filename)
def get_model_info(self, model_path):
"""获取模型信息(不验证密钥)"""
try:
full_model_path = self._get_full_model_path(model_path)
if not os.path.exists(full_model_path):
return None
verify_result = MandatoryModelValidator().verify_model_encryption(full_model_path)
if verify_result['valid']:
data = verify_result['encrypted_data']
return {
'encrypted': True,
'model_hash': data.get('model_hash', '')[:16],
'original_size': data.get('original_size', 0),
'version': data.get('version', 'unknown'),
'file_path': full_model_path,
'file_size': os.path.getsize(full_model_path)
}
return None
except Exception as e:
logger.error(f"获取模型信息失败: {str(e)}")
return None
class SecureModelManager:
"""安全模型管理器 - 负责下载、验证、加载加密模型"""
2025-12-11 13:41:07 +08:00
def __init__(self, config):
self.config = config
2025-12-12 16:04:22 +08:00
self.models_dir = config.get('models_dir', 'models')
self.encrypted_models_dir = config.get('encrypted_models_dir', 'encrypted_models')
2025-12-11 13:41:07 +08:00
os.makedirs(self.models_dir, exist_ok=True)
2025-12-12 16:04:22 +08:00
os.makedirs(self.encrypted_models_dir, exist_ok=True)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 验证器实例
self.validator = MandatoryModelValidator()
self.pre_validator = PreTaskModelValidator(config)
2025-12-11 13:41:07 +08:00
# 模型缓存
self.model_cache = {}
2025-12-12 16:04:22 +08:00
self.cache_lock = threading.Lock()
self.cache_expiry = 600 # 10分钟缓存
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 验证状态记录
self.verification_status = {}
self.verification_expiry = 300 # 5分钟验证缓存
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def ensure_model_available(self, model_config):
"""确保模型可用:验证模型文件存在"""
try:
model_path = model_config.get('path')
if not model_path:
return {'available': False, 'error': '模型路径未配置'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 获取模型本地路径
local_path = self._get_local_model_path(model_path, model_config)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 检查模型文件是否存在
2025-12-11 13:41:07 +08:00
if not os.path.exists(local_path):
2025-12-12 16:04:22 +08:00
return {'available': False, 'error': f'模型文件不存在: {local_path}'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 验证加密格式
verify_result = self.validator.verify_model_encryption(local_path)
if not verify_result['valid']:
return {'available': False, 'error': f'模型加密格式无效: {verify_result.get("error", "未知错误")}'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
return {
'available': True,
'local_path': local_path,
'model_hash': verify_result['encrypted_data'].get('model_hash', '')[:16]
}
except Exception as e:
return {'available': False, 'error': f'确保模型可用失败: {str(e)}'}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def verify_model_key(self, model_config, encryption_key):
"""验证模型密钥 - 这是创建任务的前置条件"""
try:
# 确保模型文件存在
availability = self.ensure_model_available(model_config)
if not availability['available']:
return {'valid': False, 'error': availability['error']}
local_path = availability['local_path']
model_hash_short = availability.get('model_hash', 'unknown')
# 检查是否已验证过(避免重复验证)
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}"
with self.cache_lock:
if cache_key in self.verification_status:
cached_result = self.verification_status[cache_key]
if time.time() - cached_result.get('verified_at', 0) < self.verification_expiry:
if cached_result['valid']:
logger.info(f"使用缓存的验证结果: {model_hash_short}")
return cached_result
else:
# 缓存过期,删除
del self.verification_status[cache_key]
# 解密验证
decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key)
if decrypt_result['success']:
result = {
'valid': True,
'model_hash': decrypt_result.get('model_hash', '')[:16],
'original_size': decrypt_result.get('original_size', 0),
'local_path': local_path,
'verified_at': time.time()
}
# 缓存验证结果
with self.cache_lock:
self.verification_status[cache_key] = result
logger.info(f"模型密钥验证成功: {model_hash_short}")
return result
else:
result = {'valid': False, 'error': decrypt_result.get('error', '验证失败')}
logger.warning(f"模型密钥验证失败: {model_hash_short} - {result['error']}")
return result
except Exception as e:
error_msg = f"验证模型密钥失败: {str(e)}"
logger.error(error_msg)
return {'valid': False, 'error': error_msg}
def load_verified_model(self, model_config, encryption_key):
"""加载已验证的模型 - 仅在验证通过后调用"""
try:
# 先验证密钥
verify_result = self.verify_model_key(model_config, encryption_key)
if not verify_result['valid']:
raise ValueError(f"模型验证未通过: {verify_result.get('error', '未知错误')}")
local_path = verify_result['local_path']
# 检查缓存
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}"
with self.cache_lock:
if cache_key in self.model_cache:
cached_info = self.model_cache[cache_key]
if time.time() - cached_info.get('loaded_at', 0) < self.cache_expiry:
logger.info(f"使用缓存的模型: {local_path}")
return cached_info['model']
# 解密模型(使用已验证的密钥)
decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key)
2025-12-11 13:41:07 +08:00
if not decrypt_result['success']:
raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}")
2025-12-12 16:04:22 +08:00
# 保存解密数据到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
tmp.write(decrypt_result['decrypted_data'])
temp_path = tmp.name
2025-12-11 13:41:07 +08:00
# 加载YOLO模型
from ultralytics import YOLO
model = YOLO(temp_path)
# 应用设备配置
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 应用半精度配置
if model_config.get('half', False) and 'cuda' in device:
model = model.half()
2025-12-12 16:04:22 +08:00
logger.info(f"启用半精度推理: {model_config.get('path', '未知模型')}")
2025-12-11 13:41:07 +08:00
# 清理临时文件
try:
os.unlink(temp_path)
except:
pass
2025-12-12 16:04:22 +08:00
# 缓存模型
with self.cache_lock:
self.model_cache[cache_key] = {
'model': model,
'device': device,
'loaded_at': time.time(),
'original_size': decrypt_result.get('original_size', 0)
}
logger.info(f"加密模型加载成功: {model_config.get('path', '未知模型')} -> {device}")
2025-12-11 13:41:07 +08:00
return model
except Exception as e:
2025-12-12 16:04:22 +08:00
logger.error(f"加载已验证模型失败: {str(e)}")
2025-12-11 13:41:07 +08:00
logger.error(traceback.format_exc())
return None
2025-12-12 16:04:22 +08:00
def _get_local_model_path(self, model_path, model_config):
"""获取模型本地路径"""
# 如果已经是本地路径,直接返回
if os.path.exists(model_path):
return model_path
# 检查是否是加密模型
encrypted = model_config.get('encrypted', False)
if encrypted:
# 从加密模型目录查找
model_filename = os.path.basename(model_path)
if not model_filename.endswith('.enc'):
model_filename += '.enc'
return os.path.join(self.encrypted_models_dir, model_filename)
else:
# 从普通模型目录查找
model_filename = os.path.basename(model_path)
return os.path.join(self.models_dir, model_filename)
def get_model_info(self, model_config):
"""获取模型信息(不验证密钥)"""
2025-12-11 13:41:07 +08:00
try:
2025-12-12 16:04:22 +08:00
availability = self.ensure_model_available(model_config)
if not availability['available']:
return None
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
local_path = availability['local_path']
verify_result = self.validator.verify_model_encryption(local_path)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
if verify_result['valid']:
data = verify_result['encrypted_data']
return {
'encrypted': True,
'model_hash': data.get('model_hash', '')[:16],
'original_size': data.get('original_size', 0),
'version': data.get('version', 'unknown'),
'local_path': local_path,
'file_size': os.path.getsize(local_path)
}
return None
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
except Exception as e:
logger.error(f"获取模型信息失败: {str(e)}")
return None
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def clear_cache(self):
"""清空缓存"""
with self.cache_lock:
self.model_cache.clear()
self.verification_status.clear()
logger.info("模型缓存已清空")
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def get_cache_info(self):
"""获取缓存信息"""
with self.cache_lock:
return {
'model_cache_size': len(self.model_cache),
'verification_cache_size': len(self.verification_status),
'model_cache_keys': list(self.model_cache.keys()),
'verification_cache_keys': list(self.verification_status.keys())
}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
class SimpleModelEncryptor:
"""简化模型加密器 - 仅用于测试或本地转换"""
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
@staticmethod
def create_encrypted_payload(model_data, password):
"""创建加密数据包(模拟.NET服务器的加密过程"""
# 生成盐和密钥
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
fernet = Fernet(key)
# 加密数据
encrypted_data = fernet.encrypt(model_data)
model_hash = hashlib.sha256(model_data).hexdigest()
# 创建加密数据包
encrypted_payload = {
'salt': salt,
'data': encrypted_data,
'model_hash': model_hash,
'original_size': len(model_data),
'encrypted': True,
'version': '2.0',
'created_at': time.time()
}
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
return encrypted_payload, key
@staticmethod
def encrypt_model_file(input_path, output_path, password):
"""加密模型文件"""
try:
# 读取原始模型文件
with open(input_path, 'rb') as f:
model_data = f.read()
# 创建加密数据包
encrypted_payload, key = SimpleModelEncryptor.create_encrypted_payload(model_data, password)
# 保存加密文件
with open(output_path, 'wb') as f:
pickle.dump(encrypted_payload, f)
logger.info(f"模型加密成功: {input_path} -> {output_path}")
return {
'success': True,
'output_path': output_path,
'model_hash': encrypted_payload['model_hash'][:16],
'original_size': len(model_data),
'encrypted_size': os.path.getsize(output_path)
}
2025-12-11 13:41:07 +08:00
except Exception as e:
2025-12-12 16:04:22 +08:00
logger.error(f"加密模型文件失败: {str(e)}")
return {'success': False, 'error': str(e)}
@staticmethod
def verify_encryption(encrypted_path, password):
"""验证加密文件"""
return MandatoryModelValidator().decrypt_and_verify(encrypted_path, password)
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# ==================== 全局实例和接口函数 ====================
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 导入线程模块
import threading
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
# 全局实例
_model_manager = None
_pre_task_validator = None
_upload_manager = None
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def get_model_manager(config=None):
"""获取模型管理器单例"""
global _model_manager
if _model_manager is None:
if config is None:
from config import get_default_config
config = get_default_config()
_model_manager = SecureModelManager(config)
return _model_manager
2025-12-11 13:41:07 +08:00
2025-12-12 16:04:22 +08:00
def get_pre_task_validator(config=None):
"""获取预验证器单例"""
global _pre_task_validator
if _pre_task_validator is None:
if config is None:
from config import get_default_config
config = get_default_config()
_pre_task_validator = PreTaskModelValidator(config.get('upload', {}))
return _pre_task_validator
def validate_models_before_task(task_config):
"""验证任务的所有模型(外部调用接口)"""
validator = get_pre_task_validator()
return validator.validate_models_before_task(task_config)
def verify_single_model_api(model_path, encryption_key):
"""验证单个模型API接口"""
validator = get_pre_task_validator()
return validator.verify_single_model(model_path, encryption_key)
def verify_model_key_api(model_config, encryption_key):
"""验证模型密钥API接口"""
manager = get_model_manager()
return manager.verify_model_key(model_config, encryption_key)
def load_verified_model_api(model_config, encryption_key):
"""加载已验证的模型API接口"""
manager = get_model_manager()
return manager.load_verified_model(model_config, encryption_key)
def generate_secure_key_api(length=32):
"""生成安全密钥API接口"""
return ModelEncryptionService.generate_secure_key(length)
def validate_key_strength_api(key):
"""验证密钥强度API接口"""
return ModelEncryptionService.validate_key_strength(key)
def test_decryption_api(encrypted_path, password):
"""测试解密API接口"""
validator = MandatoryModelValidator()
return validator.decrypt_and_verify(encrypted_path, password)
# ==================== 向后兼容的接口 ====================
class MandatoryModelEncryptor:
"""向后兼容的加密器类"""
@staticmethod
def encrypt_model(input_path, output_path, password, require_encryption=True):
"""加密模型 - 向后兼容接口"""
return SimpleModelEncryptor.encrypt_model_file(input_path, output_path, password)
@staticmethod
def decrypt_model(encrypted_path, password, verify_key=False):
"""解密模型 - 向后兼容接口"""
validator = MandatoryModelValidator()
return validator.decrypt_model(encrypted_path, password, verify_key=verify_key)
@staticmethod
def is_properly_encrypted(encrypted_path):
"""检查是否被正确加密 - 向后兼容接口"""
return MandatoryModelValidator.is_properly_encrypted(encrypted_path)
@staticmethod
def generate_secure_key():
"""生成安全密钥 - 向后兼容接口"""
return ModelEncryptionService.generate_secure_key()
@staticmethod
def verify_model_key(encrypted_path, password):
"""验证模型密钥 - 向后兼容接口"""
validator = MandatoryModelValidator()
return validator.decrypt_and_verify(encrypted_path, password)
# ==================== 测试代码 ====================
if __name__ == "__main__":
import time
print("=== 强制加密模型验证模块测试 ===")
# 配置示例
config = {
'models_dir': './models',
'encrypted_models_dir': './encrypted_models',
'cache_verification': True
}
# 模型配置示例
model_config = {
'path': 'yolov8n.enc', # 加密模型文件路径
'device': 'cpu',
'half': False,
'encrypted': True,
'encryption_key': 'test-password-123!@#' # 测试密钥
}
# 测试密钥生成
print("\n1. 测试密钥生成:")
key_info = ModelEncryptionService.generate_secure_key()
print(f" 生成的密钥: {key_info['key']}")
print(f" 密钥哈希: {key_info['key_hash']}")
print(f" 短哈希: {key_info['short_hash']}")
# 测试密钥强度验证
print("\n2. 测试密钥强度验证:")
test_key = "WeakKey123"
valid, msg = ModelEncryptionService.validate_key_strength(test_key)
print(f" 密钥 '{test_key}': {valid} - {msg}")
test_key = "Strong@Password#2024!Complex"
valid, msg = ModelEncryptionService.validate_key_strength(test_key)
print(f" 密钥 '{test_key}': {valid} - {msg}")
# 测试模型验证器
print("\n3. 测试模型验证器:")
validator = MandatoryModelValidator()
# 测试不存在的文件
result = validator.verify_model_encryption("nonexistent.enc")
print(f" 不存在的文件验证: {result['valid']} - {result.get('error', '')}")
# 测试预验证器
print("\n4. 测试预验证器:")
pre_validator = PreTaskModelValidator(config)
task_config = {
'models': [
{
'path': 'test_model.enc',
'encryption_key': 'Test@Password#2024!Complex'
}
]
}
result = pre_validator.validate_models_before_task(task_config)
print(f" 任务验证结果: 成功={result['success']}, 错误={result.get('error', '')}")
# 测试安全模型管理器
print("\n5. 测试安全模型管理器:")
manager = SecureModelManager(config)
# 测试模型信息获取
model_info = manager.get_model_info(model_config)
if model_info:
print(f" 模型信息: 哈希={model_info['model_hash']}, 大小={model_info['original_size']}")
else:
print(f" 模型信息: 不可用")
print("\n=== 测试完成 ===")