# mandatory_model_crypto.py - 强制加密模型验证和解密模块 import os import tempfile import hashlib import pickle import traceback import time import secrets import string from pathlib import Path import base64 import requests import torch from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from log import logger class MandatoryModelValidator: """强制加密模型验证器 - 只负责验证和解密""" @staticmethod def verify_model_encryption(encrypted_path): """验证模型是否被正确加密""" try: if not os.path.exists(encrypted_path): return {'valid': False, 'error': '加密模型文件不存在'} with open(encrypted_path, 'rb') as f: data = pickle.load(f) # 检查必要的加密字段 required_fields = ['salt', 'data', 'encrypted', 'model_hash', 'version'] for field in required_fields: if field not in data: return {'valid': False, 'error': f'缺少加密字段: {field}'} if data.get('encrypted', False) is not True: return {'valid': False, 'error': '模型未加密'} # 验证版本兼容性 if data.get('version') not in ['1.0', '2.0']: return {'valid': False, 'error': f'不支持的加密版本: {data.get("version")}'} return {'valid': True, 'encrypted_data': data} except Exception as e: return {'valid': False, 'error': f'验证失败: {str(e)}'} @staticmethod def is_properly_encrypted(encrypted_path): """检查模型是否被正确加密(简化的验证)""" try: with open(encrypted_path, 'rb') as f: data = pickle.load(f) # 基本字段检查 if not isinstance(data, dict): return False required_fields = ['salt', 'data', 'encrypted', 'model_hash'] for field in required_fields: if field not in data: return False return data.get('encrypted', False) is True except Exception as e: logger.warning(f"检查加密格式失败: {str(e)}") return False @staticmethod def decrypt_and_verify(encrypted_path, password, verify_key=True): """解密模型文件并验证完整性""" try: # 先验证加密格式 verify_result = MandatoryModelValidator.verify_model_encryption(encrypted_path) if not verify_result['valid']: return {'success': False, 'error': verify_result['error']} encrypted_payload = verify_result['encrypted_data'] salt = encrypted_payload['salt'] encrypted_data = encrypted_payload['data'] expected_hash = encrypted_payload['model_hash'] # 生成解密密钥 kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(password.encode())) # 解密数据 fernet = Fernet(key) try: decrypted_data = fernet.decrypt(encrypted_data) except Exception as e: if "InvalidToken" in str(e) or "Invalid signature" in str(e): return {'success': False, 'error': '解密密钥错误'} return {'success': False, 'error': f'解密失败: {str(e)}'} # 验证模型哈希 actual_hash = hashlib.sha256(decrypted_data).hexdigest() if actual_hash != expected_hash: return { 'success': False, 'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...' } # 验证文件大小 original_size = encrypted_payload.get('original_size', 0) if original_size > 0 and len(decrypted_data) != original_size: return { 'success': False, 'error': f'文件大小不匹配: 期望{original_size}, 实际{len(decrypted_data)}' } logger.info(f"模型解密验证成功: {encrypted_path}") logger.debug(f"模型哈希: {actual_hash[:16]}...") return { 'success': True, 'model_hash': actual_hash, 'original_size': len(decrypted_data), 'decrypted_data': decrypted_data, # 返回解密数据,供后续处理 'version': encrypted_payload.get('version', '1.0') } except Exception as e: error_msg = str(e) if "InvalidToken" in error_msg or "Invalid signature" in error_msg: return {'success': False, 'error': '解密密钥错误'} return {'success': False, 'error': f'解密失败: {error_msg}'} @staticmethod def decrypt_model(encrypted_path, password, verify_key=False): """解密模型文件 - 兼容旧接口""" return MandatoryModelValidator.decrypt_and_verify( encrypted_path, password, verify_key=verify_key ) class ModelEncryptionService: """模型加密服务 - 处理密钥生成、验证和加密流程""" @staticmethod def generate_secure_key(length=32): """生成安全的加密密钥""" # 使用加密安全的随机数生成器 alphabet = string.ascii_letters + string.digits + "!@#$%^&*" key = ''.join(secrets.choice(alphabet) for _ in range(length)) # 计算密钥哈希 key_hash = hashlib.sha256(key.encode()).hexdigest() return { 'key': key, 'key_hash': key_hash, 'short_hash': key_hash[:16] } @staticmethod def validate_key_strength(key): """验证密钥强度""" if len(key) < 16: return False, "密钥长度至少16位" # 检查字符种类 has_upper = any(c.isupper() for c in key) has_lower = any(c.islower() for c in key) has_digit = any(c.isdigit() for c in key) has_special = any(c in "!@#$%^&*" for c in key) if not (has_upper and has_lower and has_digit and has_special): return False, "密钥应包含大小写字母、数字和特殊字符" return True, "密钥强度足够" @staticmethod def create_key_pair(): """创建密钥对(用于客户端-服务器通信)""" # 生成主密钥 master_key = secrets.token_urlsafe(32) # 生成验证令牌 verification_token = hashlib.sha256( (master_key + str(time.time())).encode() ).hexdigest()[:20] return { 'master_key': master_key, 'verification_token': verification_token, 'created_at': time.time() } class PreTaskModelValidator: """任务前模型验证器 - 在创建任务前验证所有模型""" def __init__(self, config=None): self.config = config or {} self.encrypted_models_dir = self.config.get('encrypted_models_dir', 'encrypted_models') os.makedirs(self.encrypted_models_dir, exist_ok=True) # 验证状态缓存 self.validation_cache = {} self.cache_lock = threading.Lock() self.cache_expiry = 300 # 5分钟缓存有效期 def validate_models_before_task(self, task_config): """在创建任务前验证所有模型和密钥""" try: models_config = task_config.get('models', []) if not models_config: return { 'success': False, 'error': '任务配置中未找到模型列表' } validation_results = [] all_valid = True for i, model_config in enumerate(models_config): # 检查必要的配置项 required_fields = ['path', 'encryption_key'] for field in required_fields: if field not in model_config: return { 'success': False, 'error': f'模型 {i} 缺少必要字段: {field}' } model_path = model_config['path'] encryption_key = model_config['encryption_key'] # 构建完整的模型文件路径 full_model_path = self._get_full_model_path(model_path) # 验证模型文件是否存在 if not os.path.exists(full_model_path): return { 'success': False, 'error': f'模型文件不存在: {full_model_path}' } # 验证密钥强度 key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key) if not key_valid: return { 'success': False, 'error': f'模型 {i} 密钥强度不足: {key_msg}' } # 验证密钥是否正确(尝试解密) validator = MandatoryModelValidator() decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key) validation_result = { 'model_index': i, 'model_path': model_path, 'full_path': full_model_path, 'file_exists': True, 'key_valid': decrypt_result['success'], 'model_hash': decrypt_result.get('model_hash', '')[:16] if decrypt_result['success'] else '', 'model_size': decrypt_result.get('original_size', 0) if decrypt_result['success'] else 0, 'validation_time': time.time() } if not decrypt_result['success']: validation_result['error'] = decrypt_result.get('error', '验证失败') all_valid = False validation_results.append(validation_result) return { 'success': all_valid, 'valid': all_valid, 'total_models': len(models_config), 'valid_models': sum(1 for r in validation_results if r['key_valid']), 'validation_results': validation_results, 'timestamp': time.time() } except Exception as e: logger.error(f"预验证模型失败: {str(e)}") return { 'success': False, 'error': f'预验证失败: {str(e)}' } def verify_single_model(self, model_path, encryption_key): """验证单个模型""" try: # 构建完整路径 full_model_path = self._get_full_model_path(model_path) # 检查文件是否存在 if not os.path.exists(full_model_path): return { 'success': False, 'error': f'模型文件不存在: {full_model_path}' } # 验证密钥强度 key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key) if not key_valid: return { 'success': False, 'error': f'密钥强度不足: {key_msg}' } # 解密验证 validator = MandatoryModelValidator() decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key) if decrypt_result['success']: return { 'success': True, 'valid': True, 'model_hash': decrypt_result.get('model_hash', '')[:16], 'model_size': decrypt_result.get('original_size', 0), 'file_path': full_model_path, 'timestamp': time.time() } else: return { 'success': False, 'valid': False, 'error': decrypt_result.get('error', '验证失败'), 'timestamp': time.time() } except Exception as e: logger.error(f"验证单个模型失败: {str(e)}") return { 'success': False, 'error': f'验证失败: {str(e)}' } def _get_full_model_path(self, model_path): """获取完整的模型文件路径""" if os.path.isabs(model_path): return model_path # 如果路径以 encrypted_models/ 开头 if model_path.startswith('encrypted_models/'): model_filename = os.path.basename(model_path) return os.path.join(self.encrypted_models_dir, model_filename) # 否则直接使用文件名 model_filename = os.path.basename(model_path) return os.path.join(self.encrypted_models_dir, model_filename) def get_model_info(self, model_path): """获取模型信息(不验证密钥)""" try: full_model_path = self._get_full_model_path(model_path) if not os.path.exists(full_model_path): return None verify_result = MandatoryModelValidator().verify_model_encryption(full_model_path) if verify_result['valid']: data = verify_result['encrypted_data'] return { 'encrypted': True, 'model_hash': data.get('model_hash', '')[:16], 'original_size': data.get('original_size', 0), 'version': data.get('version', 'unknown'), 'file_path': full_model_path, 'file_size': os.path.getsize(full_model_path) } return None except Exception as e: logger.error(f"获取模型信息失败: {str(e)}") return None class SecureModelManager: """安全模型管理器 - 负责下载、验证、加载加密模型""" def __init__(self, config): self.config = config self.models_dir = config.get('models_dir', 'models') self.encrypted_models_dir = config.get('encrypted_models_dir', 'encrypted_models') os.makedirs(self.models_dir, exist_ok=True) os.makedirs(self.encrypted_models_dir, exist_ok=True) # 验证器实例 self.validator = MandatoryModelValidator() self.pre_validator = PreTaskModelValidator(config) # 模型缓存 self.model_cache = {} self.cache_lock = threading.Lock() self.cache_expiry = 600 # 10分钟缓存 # 验证状态记录 self.verification_status = {} self.verification_expiry = 300 # 5分钟验证缓存 def ensure_model_available(self, model_config): """确保模型可用:验证模型文件存在""" try: model_path = model_config.get('path') if not model_path: return {'available': False, 'error': '模型路径未配置'} # 获取模型本地路径 local_path = self._get_local_model_path(model_path, model_config) # 检查模型文件是否存在 if not os.path.exists(local_path): return {'available': False, 'error': f'模型文件不存在: {local_path}'} # 验证加密格式 verify_result = self.validator.verify_model_encryption(local_path) if not verify_result['valid']: return {'available': False, 'error': f'模型加密格式无效: {verify_result.get("error", "未知错误")}'} return { 'available': True, 'local_path': local_path, 'model_hash': verify_result['encrypted_data'].get('model_hash', '')[:16] } except Exception as e: return {'available': False, 'error': f'确保模型可用失败: {str(e)}'} def verify_model_key(self, model_config, encryption_key): """验证模型密钥 - 这是创建任务的前置条件""" try: # 确保模型文件存在 availability = self.ensure_model_available(model_config) if not availability['available']: return {'valid': False, 'error': availability['error']} local_path = availability['local_path'] model_hash_short = availability.get('model_hash', 'unknown') # 检查是否已验证过(避免重复验证) cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}" with self.cache_lock: if cache_key in self.verification_status: cached_result = self.verification_status[cache_key] if time.time() - cached_result.get('verified_at', 0) < self.verification_expiry: if cached_result['valid']: logger.info(f"使用缓存的验证结果: {model_hash_short}") return cached_result else: # 缓存过期,删除 del self.verification_status[cache_key] # 解密验证 decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key) if decrypt_result['success']: result = { 'valid': True, 'model_hash': decrypt_result.get('model_hash', '')[:16], 'original_size': decrypt_result.get('original_size', 0), 'local_path': local_path, 'verified_at': time.time() } # 缓存验证结果 with self.cache_lock: self.verification_status[cache_key] = result logger.info(f"模型密钥验证成功: {model_hash_short}") return result else: result = {'valid': False, 'error': decrypt_result.get('error', '验证失败')} logger.warning(f"模型密钥验证失败: {model_hash_short} - {result['error']}") return result except Exception as e: error_msg = f"验证模型密钥失败: {str(e)}" logger.error(error_msg) return {'valid': False, 'error': error_msg} def load_verified_model(self, model_config, encryption_key): """加载已验证的模型 - 仅在验证通过后调用""" try: # 先验证密钥 verify_result = self.verify_model_key(model_config, encryption_key) if not verify_result['valid']: raise ValueError(f"模型验证未通过: {verify_result.get('error', '未知错误')}") local_path = verify_result['local_path'] # 检查缓存 cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}" with self.cache_lock: if cache_key in self.model_cache: cached_info = self.model_cache[cache_key] if time.time() - cached_info.get('loaded_at', 0) < self.cache_expiry: logger.info(f"使用缓存的模型: {local_path}") return cached_info['model'] # 解密模型(使用已验证的密钥) decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key) if not decrypt_result['success']: raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}") # 保存解密数据到临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp: tmp.write(decrypt_result['decrypted_data']) temp_path = tmp.name # 加载YOLO模型 from ultralytics import YOLO model = YOLO(temp_path) # 应用设备配置 device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu') model = model.to(device) # 应用半精度配置 if model_config.get('half', False) and 'cuda' in device: model = model.half() logger.info(f"启用半精度推理: {model_config.get('path', '未知模型')}") # 清理临时文件 try: os.unlink(temp_path) except: pass # 缓存模型 with self.cache_lock: self.model_cache[cache_key] = { 'model': model, 'device': device, 'loaded_at': time.time(), 'original_size': decrypt_result.get('original_size', 0) } logger.info(f"加密模型加载成功: {model_config.get('path', '未知模型')} -> {device}") return model except Exception as e: logger.error(f"加载已验证模型失败: {str(e)}") logger.error(traceback.format_exc()) return None def _get_local_model_path(self, model_path, model_config): """获取模型本地路径""" # 如果已经是本地路径,直接返回 if os.path.exists(model_path): return model_path # 检查是否是加密模型 encrypted = model_config.get('encrypted', False) if encrypted: # 从加密模型目录查找 model_filename = os.path.basename(model_path) if not model_filename.endswith('.enc'): model_filename += '.enc' return os.path.join(self.encrypted_models_dir, model_filename) else: # 从普通模型目录查找 model_filename = os.path.basename(model_path) return os.path.join(self.models_dir, model_filename) def get_model_info(self, model_config): """获取模型信息(不验证密钥)""" try: availability = self.ensure_model_available(model_config) if not availability['available']: return None local_path = availability['local_path'] verify_result = self.validator.verify_model_encryption(local_path) if verify_result['valid']: data = verify_result['encrypted_data'] return { 'encrypted': True, 'model_hash': data.get('model_hash', '')[:16], 'original_size': data.get('original_size', 0), 'version': data.get('version', 'unknown'), 'local_path': local_path, 'file_size': os.path.getsize(local_path) } return None except Exception as e: logger.error(f"获取模型信息失败: {str(e)}") return None def clear_cache(self): """清空缓存""" with self.cache_lock: self.model_cache.clear() self.verification_status.clear() logger.info("模型缓存已清空") def get_cache_info(self): """获取缓存信息""" with self.cache_lock: return { 'model_cache_size': len(self.model_cache), 'verification_cache_size': len(self.verification_status), 'model_cache_keys': list(self.model_cache.keys()), 'verification_cache_keys': list(self.verification_status.keys()) } class SimpleModelEncryptor: """简化模型加密器 - 仅用于测试或本地转换""" @staticmethod def create_encrypted_payload(model_data, password): """创建加密数据包(模拟.NET服务器的加密过程)""" # 生成盐和密钥 salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(password.encode())) fernet = Fernet(key) # 加密数据 encrypted_data = fernet.encrypt(model_data) model_hash = hashlib.sha256(model_data).hexdigest() # 创建加密数据包 encrypted_payload = { 'salt': salt, 'data': encrypted_data, 'model_hash': model_hash, 'original_size': len(model_data), 'encrypted': True, 'version': '2.0', 'created_at': time.time() } return encrypted_payload, key @staticmethod def encrypt_model_file(input_path, output_path, password): """加密模型文件""" try: # 读取原始模型文件 with open(input_path, 'rb') as f: model_data = f.read() # 创建加密数据包 encrypted_payload, key = SimpleModelEncryptor.create_encrypted_payload(model_data, password) # 保存加密文件 with open(output_path, 'wb') as f: pickle.dump(encrypted_payload, f) logger.info(f"模型加密成功: {input_path} -> {output_path}") return { 'success': True, 'output_path': output_path, 'model_hash': encrypted_payload['model_hash'][:16], 'original_size': len(model_data), 'encrypted_size': os.path.getsize(output_path) } except Exception as e: logger.error(f"加密模型文件失败: {str(e)}") return {'success': False, 'error': str(e)} @staticmethod def verify_encryption(encrypted_path, password): """验证加密文件""" return MandatoryModelValidator().decrypt_and_verify(encrypted_path, password) # ==================== 全局实例和接口函数 ==================== # 导入线程模块 import threading # 全局实例 _model_manager = None _pre_task_validator = None _upload_manager = None def get_model_manager(config=None): """获取模型管理器单例""" global _model_manager if _model_manager is None: if config is None: from config import get_default_config config = get_default_config() _model_manager = SecureModelManager(config) return _model_manager def get_pre_task_validator(config=None): """获取预验证器单例""" global _pre_task_validator if _pre_task_validator is None: if config is None: from config import get_default_config config = get_default_config() _pre_task_validator = PreTaskModelValidator(config.get('upload', {})) return _pre_task_validator def validate_models_before_task(task_config): """验证任务的所有模型(外部调用接口)""" validator = get_pre_task_validator() return validator.validate_models_before_task(task_config) def verify_single_model_api(model_path, encryption_key): """验证单个模型(API接口)""" validator = get_pre_task_validator() return validator.verify_single_model(model_path, encryption_key) def verify_model_key_api(model_config, encryption_key): """验证模型密钥(API接口)""" manager = get_model_manager() return manager.verify_model_key(model_config, encryption_key) def load_verified_model_api(model_config, encryption_key): """加载已验证的模型(API接口)""" manager = get_model_manager() return manager.load_verified_model(model_config, encryption_key) def generate_secure_key_api(length=32): """生成安全密钥(API接口)""" return ModelEncryptionService.generate_secure_key(length) def validate_key_strength_api(key): """验证密钥强度(API接口)""" return ModelEncryptionService.validate_key_strength(key) def test_decryption_api(encrypted_path, password): """测试解密(API接口)""" validator = MandatoryModelValidator() return validator.decrypt_and_verify(encrypted_path, password) # ==================== 向后兼容的接口 ==================== class MandatoryModelEncryptor: """向后兼容的加密器类""" @staticmethod def encrypt_model(input_path, output_path, password, require_encryption=True): """加密模型 - 向后兼容接口""" return SimpleModelEncryptor.encrypt_model_file(input_path, output_path, password) @staticmethod def decrypt_model(encrypted_path, password, verify_key=False): """解密模型 - 向后兼容接口""" validator = MandatoryModelValidator() return validator.decrypt_model(encrypted_path, password, verify_key=verify_key) @staticmethod def is_properly_encrypted(encrypted_path): """检查是否被正确加密 - 向后兼容接口""" return MandatoryModelValidator.is_properly_encrypted(encrypted_path) @staticmethod def generate_secure_key(): """生成安全密钥 - 向后兼容接口""" return ModelEncryptionService.generate_secure_key() @staticmethod def verify_model_key(encrypted_path, password): """验证模型密钥 - 向后兼容接口""" validator = MandatoryModelValidator() return validator.decrypt_and_verify(encrypted_path, password) # ==================== 测试代码 ==================== if __name__ == "__main__": import time print("=== 强制加密模型验证模块测试 ===") # 配置示例 config = { 'models_dir': './models', 'encrypted_models_dir': './encrypted_models', 'cache_verification': True } # 模型配置示例 model_config = { 'path': 'yolov8n.enc', # 加密模型文件路径 'device': 'cpu', 'half': False, 'encrypted': True, 'encryption_key': 'test-password-123!@#' # 测试密钥 } # 测试密钥生成 print("\n1. 测试密钥生成:") key_info = ModelEncryptionService.generate_secure_key() print(f" 生成的密钥: {key_info['key']}") print(f" 密钥哈希: {key_info['key_hash']}") print(f" 短哈希: {key_info['short_hash']}") # 测试密钥强度验证 print("\n2. 测试密钥强度验证:") test_key = "WeakKey123" valid, msg = ModelEncryptionService.validate_key_strength(test_key) print(f" 密钥 '{test_key}': {valid} - {msg}") test_key = "Strong@Password#2024!Complex" valid, msg = ModelEncryptionService.validate_key_strength(test_key) print(f" 密钥 '{test_key}': {valid} - {msg}") # 测试模型验证器 print("\n3. 测试模型验证器:") validator = MandatoryModelValidator() # 测试不存在的文件 result = validator.verify_model_encryption("nonexistent.enc") print(f" 不存在的文件验证: {result['valid']} - {result.get('error', '')}") # 测试预验证器 print("\n4. 测试预验证器:") pre_validator = PreTaskModelValidator(config) task_config = { 'models': [ { 'path': 'test_model.enc', 'encryption_key': 'Test@Password#2024!Complex' } ] } result = pre_validator.validate_models_before_task(task_config) print(f" 任务验证结果: 成功={result['success']}, 错误={result.get('error', '无')}") # 测试安全模型管理器 print("\n5. 测试安全模型管理器:") manager = SecureModelManager(config) # 测试模型信息获取 model_info = manager.get_model_info(model_config) if model_info: print(f" 模型信息: 哈希={model_info['model_hash']}, 大小={model_info['original_size']}") else: print(f" 模型信息: 不可用") print("\n=== 测试完成 ===")