Yolov/mandatory_model_crypto.py

870 lines
31 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# mandatory_model_crypto.py - 强制加密模型验证和解密模块
import os
import tempfile
import hashlib
import pickle
import traceback
import time
import secrets
import string
from pathlib import Path
import base64
import requests
import torch
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from log import logger
class MandatoryModelValidator:
"""强制加密模型验证器 - 只负责验证和解密"""
@staticmethod
def verify_model_encryption(encrypted_path):
"""验证模型是否被正确加密"""
try:
if not os.path.exists(encrypted_path):
return {'valid': False, 'error': '加密模型文件不存在'}
with open(encrypted_path, 'rb') as f:
data = pickle.load(f)
# 检查必要的加密字段
required_fields = ['salt', 'data', 'encrypted', 'model_hash', 'version']
for field in required_fields:
if field not in data:
return {'valid': False, 'error': f'缺少加密字段: {field}'}
if data.get('encrypted', False) is not True:
return {'valid': False, 'error': '模型未加密'}
# 验证版本兼容性
if data.get('version') not in ['1.0', '2.0']:
return {'valid': False, 'error': f'不支持的加密版本: {data.get("version")}'}
return {'valid': True, 'encrypted_data': data}
except Exception as e:
return {'valid': False, 'error': f'验证失败: {str(e)}'}
@staticmethod
def is_properly_encrypted(encrypted_path):
"""检查模型是否被正确加密(简化的验证)"""
try:
with open(encrypted_path, 'rb') as f:
data = pickle.load(f)
# 基本字段检查
if not isinstance(data, dict):
return False
required_fields = ['salt', 'data', 'encrypted', 'model_hash']
for field in required_fields:
if field not in data:
return False
return data.get('encrypted', False) is True
except Exception as e:
logger.warning(f"检查加密格式失败: {str(e)}")
return False
@staticmethod
def decrypt_and_verify(encrypted_path, password, verify_key=True):
"""解密模型文件并验证完整性"""
try:
# 先验证加密格式
verify_result = MandatoryModelValidator.verify_model_encryption(encrypted_path)
if not verify_result['valid']:
return {'success': False, 'error': verify_result['error']}
encrypted_payload = verify_result['encrypted_data']
salt = encrypted_payload['salt']
encrypted_data = encrypted_payload['data']
expected_hash = encrypted_payload['model_hash']
# 生成解密密钥
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
# 解密数据
fernet = Fernet(key)
try:
decrypted_data = fernet.decrypt(encrypted_data)
except Exception as e:
if "InvalidToken" in str(e) or "Invalid signature" in str(e):
return {'success': False, 'error': '解密密钥错误'}
return {'success': False, 'error': f'解密失败: {str(e)}'}
# 验证模型哈希
actual_hash = hashlib.sha256(decrypted_data).hexdigest()
if actual_hash != expected_hash:
return {
'success': False,
'error': f'模型哈希不匹配: 期望{expected_hash[:16]}..., 实际{actual_hash[:16]}...'
}
# 验证文件大小
original_size = encrypted_payload.get('original_size', 0)
if original_size > 0 and len(decrypted_data) != original_size:
return {
'success': False,
'error': f'文件大小不匹配: 期望{original_size}, 实际{len(decrypted_data)}'
}
logger.info(f"模型解密验证成功: {encrypted_path}")
logger.debug(f"模型哈希: {actual_hash[:16]}...")
return {
'success': True,
'model_hash': actual_hash,
'original_size': len(decrypted_data),
'decrypted_data': decrypted_data, # 返回解密数据,供后续处理
'version': encrypted_payload.get('version', '1.0')
}
except Exception as e:
error_msg = str(e)
if "InvalidToken" in error_msg or "Invalid signature" in error_msg:
return {'success': False, 'error': '解密密钥错误'}
return {'success': False, 'error': f'解密失败: {error_msg}'}
@staticmethod
def decrypt_model(encrypted_path, password, verify_key=False):
"""解密模型文件 - 兼容旧接口"""
return MandatoryModelValidator.decrypt_and_verify(
encrypted_path, password, verify_key=verify_key
)
class ModelEncryptionService:
"""模型加密服务 - 处理密钥生成、验证和加密流程"""
@staticmethod
def generate_secure_key(length=32):
"""生成安全的加密密钥"""
# 使用加密安全的随机数生成器
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
key = ''.join(secrets.choice(alphabet) for _ in range(length))
# 计算密钥哈希
key_hash = hashlib.sha256(key.encode()).hexdigest()
return {
'key': key,
'key_hash': key_hash,
'short_hash': key_hash[:16]
}
@staticmethod
def validate_key_strength(key):
"""验证密钥强度"""
if len(key) < 16:
return False, "密钥长度至少16位"
# 检查字符种类
has_upper = any(c.isupper() for c in key)
has_lower = any(c.islower() for c in key)
has_digit = any(c.isdigit() for c in key)
has_special = any(c in "!@#$%^&*" for c in key)
if not (has_upper and has_lower and has_digit and has_special):
return False, "密钥应包含大小写字母、数字和特殊字符"
return True, "密钥强度足够"
@staticmethod
def create_key_pair():
"""创建密钥对(用于客户端-服务器通信)"""
# 生成主密钥
master_key = secrets.token_urlsafe(32)
# 生成验证令牌
verification_token = hashlib.sha256(
(master_key + str(time.time())).encode()
).hexdigest()[:20]
return {
'master_key': master_key,
'verification_token': verification_token,
'created_at': time.time()
}
class PreTaskModelValidator:
"""任务前模型验证器 - 在创建任务前验证所有模型"""
def __init__(self, config=None):
self.config = config or {}
self.encrypted_models_dir = self.config.get('encrypted_models_dir', 'encrypted_models')
os.makedirs(self.encrypted_models_dir, exist_ok=True)
# 验证状态缓存
self.validation_cache = {}
self.cache_lock = threading.Lock()
self.cache_expiry = 300 # 5分钟缓存有效期
def validate_models_before_task(self, task_config):
"""在创建任务前验证所有模型和密钥"""
try:
models_config = task_config.get('models', [])
if not models_config:
return {
'success': False,
'error': '任务配置中未找到模型列表'
}
validation_results = []
all_valid = True
for i, model_config in enumerate(models_config):
# 检查必要的配置项
required_fields = ['path', 'encryption_key']
for field in required_fields:
if field not in model_config:
return {
'success': False,
'error': f'模型 {i} 缺少必要字段: {field}'
}
model_path = model_config['path']
encryption_key = model_config['encryption_key']
# 构建完整的模型文件路径
full_model_path = self._get_full_model_path(model_path)
# 验证模型文件是否存在
if not os.path.exists(full_model_path):
return {
'success': False,
'error': f'模型文件不存在: {full_model_path}'
}
# 验证密钥强度
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
if not key_valid:
return {
'success': False,
'error': f'模型 {i} 密钥强度不足: {key_msg}'
}
# 验证密钥是否正确(尝试解密)
validator = MandatoryModelValidator()
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
validation_result = {
'model_index': i,
'model_path': model_path,
'full_path': full_model_path,
'file_exists': True,
'key_valid': decrypt_result['success'],
'model_hash': decrypt_result.get('model_hash', '')[:16] if decrypt_result['success'] else '',
'model_size': decrypt_result.get('original_size', 0) if decrypt_result['success'] else 0,
'validation_time': time.time()
}
if not decrypt_result['success']:
validation_result['error'] = decrypt_result.get('error', '验证失败')
all_valid = False
validation_results.append(validation_result)
return {
'success': all_valid,
'valid': all_valid,
'total_models': len(models_config),
'valid_models': sum(1 for r in validation_results if r['key_valid']),
'validation_results': validation_results,
'timestamp': time.time()
}
except Exception as e:
logger.error(f"预验证模型失败: {str(e)}")
return {
'success': False,
'error': f'预验证失败: {str(e)}'
}
def verify_single_model(self, model_path, encryption_key):
"""验证单个模型"""
try:
# 构建完整路径
full_model_path = self._get_full_model_path(model_path)
# 检查文件是否存在
if not os.path.exists(full_model_path):
return {
'success': False,
'error': f'模型文件不存在: {full_model_path}'
}
# 验证密钥强度
key_valid, key_msg = ModelEncryptionService.validate_key_strength(encryption_key)
if not key_valid:
return {
'success': False,
'error': f'密钥强度不足: {key_msg}'
}
# 解密验证
validator = MandatoryModelValidator()
decrypt_result = validator.decrypt_and_verify(full_model_path, encryption_key)
if decrypt_result['success']:
return {
'success': True,
'valid': True,
'model_hash': decrypt_result.get('model_hash', '')[:16],
'model_size': decrypt_result.get('original_size', 0),
'file_path': full_model_path,
'timestamp': time.time()
}
else:
return {
'success': False,
'valid': False,
'error': decrypt_result.get('error', '验证失败'),
'timestamp': time.time()
}
except Exception as e:
logger.error(f"验证单个模型失败: {str(e)}")
return {
'success': False,
'error': f'验证失败: {str(e)}'
}
def _get_full_model_path(self, model_path):
"""获取完整的模型文件路径"""
if os.path.isabs(model_path):
return model_path
# 如果路径以 encrypted_models/ 开头
if model_path.startswith('encrypted_models/'):
model_filename = os.path.basename(model_path)
return os.path.join(self.encrypted_models_dir, model_filename)
# 否则直接使用文件名
model_filename = os.path.basename(model_path)
return os.path.join(self.encrypted_models_dir, model_filename)
def get_model_info(self, model_path):
"""获取模型信息(不验证密钥)"""
try:
full_model_path = self._get_full_model_path(model_path)
if not os.path.exists(full_model_path):
return None
verify_result = MandatoryModelValidator().verify_model_encryption(full_model_path)
if verify_result['valid']:
data = verify_result['encrypted_data']
return {
'encrypted': True,
'model_hash': data.get('model_hash', '')[:16],
'original_size': data.get('original_size', 0),
'version': data.get('version', 'unknown'),
'file_path': full_model_path,
'file_size': os.path.getsize(full_model_path)
}
return None
except Exception as e:
logger.error(f"获取模型信息失败: {str(e)}")
return None
class SecureModelManager:
"""安全模型管理器 - 负责下载、验证、加载加密模型"""
def __init__(self, config):
self.config = config
self.models_dir = config.get('models_dir', 'models')
self.encrypted_models_dir = config.get('encrypted_models_dir', 'encrypted_models')
os.makedirs(self.models_dir, exist_ok=True)
os.makedirs(self.encrypted_models_dir, exist_ok=True)
# 验证器实例
self.validator = MandatoryModelValidator()
self.pre_validator = PreTaskModelValidator(config)
# 模型缓存
self.model_cache = {}
self.cache_lock = threading.Lock()
self.cache_expiry = 600 # 10分钟缓存
# 验证状态记录
self.verification_status = {}
self.verification_expiry = 300 # 5分钟验证缓存
def ensure_model_available(self, model_config):
"""确保模型可用:验证模型文件存在"""
try:
model_path = model_config.get('path')
if not model_path:
return {'available': False, 'error': '模型路径未配置'}
# 获取模型本地路径
local_path = self._get_local_model_path(model_path, model_config)
# 检查模型文件是否存在
if not os.path.exists(local_path):
return {'available': False, 'error': f'模型文件不存在: {local_path}'}
# 验证加密格式
verify_result = self.validator.verify_model_encryption(local_path)
if not verify_result['valid']:
return {'available': False, 'error': f'模型加密格式无效: {verify_result.get("error", "未知错误")}'}
return {
'available': True,
'local_path': local_path,
'model_hash': verify_result['encrypted_data'].get('model_hash', '')[:16]
}
except Exception as e:
return {'available': False, 'error': f'确保模型可用失败: {str(e)}'}
def verify_model_key(self, model_config, encryption_key):
"""验证模型密钥 - 这是创建任务的前置条件"""
try:
# 确保模型文件存在
availability = self.ensure_model_available(model_config)
if not availability['available']:
return {'valid': False, 'error': availability['error']}
local_path = availability['local_path']
model_hash_short = availability.get('model_hash', 'unknown')
# 检查是否已验证过(避免重复验证)
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}"
with self.cache_lock:
if cache_key in self.verification_status:
cached_result = self.verification_status[cache_key]
if time.time() - cached_result.get('verified_at', 0) < self.verification_expiry:
if cached_result['valid']:
logger.info(f"使用缓存的验证结果: {model_hash_short}")
return cached_result
else:
# 缓存过期,删除
del self.verification_status[cache_key]
# 解密验证
decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key)
if decrypt_result['success']:
result = {
'valid': True,
'model_hash': decrypt_result.get('model_hash', '')[:16],
'original_size': decrypt_result.get('original_size', 0),
'local_path': local_path,
'verified_at': time.time()
}
# 缓存验证结果
with self.cache_lock:
self.verification_status[cache_key] = result
logger.info(f"模型密钥验证成功: {model_hash_short}")
return result
else:
result = {'valid': False, 'error': decrypt_result.get('error', '验证失败')}
logger.warning(f"模型密钥验证失败: {model_hash_short} - {result['error']}")
return result
except Exception as e:
error_msg = f"验证模型密钥失败: {str(e)}"
logger.error(error_msg)
return {'valid': False, 'error': error_msg}
def load_verified_model(self, model_config, encryption_key):
"""加载已验证的模型 - 仅在验证通过后调用"""
try:
# 先验证密钥
verify_result = self.verify_model_key(model_config, encryption_key)
if not verify_result['valid']:
raise ValueError(f"模型验证未通过: {verify_result.get('error', '未知错误')}")
local_path = verify_result['local_path']
# 检查缓存
cache_key = f"{local_path}_{hashlib.md5(encryption_key.encode()).hexdigest()[:8]}"
with self.cache_lock:
if cache_key in self.model_cache:
cached_info = self.model_cache[cache_key]
if time.time() - cached_info.get('loaded_at', 0) < self.cache_expiry:
logger.info(f"使用缓存的模型: {local_path}")
return cached_info['model']
# 解密模型(使用已验证的密钥)
decrypt_result = self.validator.decrypt_and_verify(local_path, encryption_key)
if not decrypt_result['success']:
raise ValueError(f"模型解密失败: {decrypt_result.get('error', '未知错误')}")
# 保存解密数据到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp:
tmp.write(decrypt_result['decrypted_data'])
temp_path = tmp.name
# 加载YOLO模型
from ultralytics import YOLO
model = YOLO(temp_path)
# 应用设备配置
device = model_config.get('device', 'cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 应用半精度配置
if model_config.get('half', False) and 'cuda' in device:
model = model.half()
logger.info(f"启用半精度推理: {model_config.get('path', '未知模型')}")
# 清理临时文件
try:
os.unlink(temp_path)
except:
pass
# 缓存模型
with self.cache_lock:
self.model_cache[cache_key] = {
'model': model,
'device': device,
'loaded_at': time.time(),
'original_size': decrypt_result.get('original_size', 0)
}
logger.info(f"加密模型加载成功: {model_config.get('path', '未知模型')} -> {device}")
return model
except Exception as e:
logger.error(f"加载已验证模型失败: {str(e)}")
logger.error(traceback.format_exc())
return None
def _get_local_model_path(self, model_path, model_config):
"""获取模型本地路径"""
# 如果已经是本地路径,直接返回
if os.path.exists(model_path):
return model_path
# 检查是否是加密模型
encrypted = model_config.get('encrypted', False)
if encrypted:
# 从加密模型目录查找
model_filename = os.path.basename(model_path)
if not model_filename.endswith('.enc'):
model_filename += '.enc'
return os.path.join(self.encrypted_models_dir, model_filename)
else:
# 从普通模型目录查找
model_filename = os.path.basename(model_path)
return os.path.join(self.models_dir, model_filename)
def get_model_info(self, model_config):
"""获取模型信息(不验证密钥)"""
try:
availability = self.ensure_model_available(model_config)
if not availability['available']:
return None
local_path = availability['local_path']
verify_result = self.validator.verify_model_encryption(local_path)
if verify_result['valid']:
data = verify_result['encrypted_data']
return {
'encrypted': True,
'model_hash': data.get('model_hash', '')[:16],
'original_size': data.get('original_size', 0),
'version': data.get('version', 'unknown'),
'local_path': local_path,
'file_size': os.path.getsize(local_path)
}
return None
except Exception as e:
logger.error(f"获取模型信息失败: {str(e)}")
return None
def clear_cache(self):
"""清空缓存"""
with self.cache_lock:
self.model_cache.clear()
self.verification_status.clear()
logger.info("模型缓存已清空")
def get_cache_info(self):
"""获取缓存信息"""
with self.cache_lock:
return {
'model_cache_size': len(self.model_cache),
'verification_cache_size': len(self.verification_status),
'model_cache_keys': list(self.model_cache.keys()),
'verification_cache_keys': list(self.verification_status.keys())
}
class SimpleModelEncryptor:
"""简化模型加密器 - 仅用于测试或本地转换"""
@staticmethod
def create_encrypted_payload(model_data, password):
"""创建加密数据包(模拟.NET服务器的加密过程"""
# 生成盐和密钥
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
fernet = Fernet(key)
# 加密数据
encrypted_data = fernet.encrypt(model_data)
model_hash = hashlib.sha256(model_data).hexdigest()
# 创建加密数据包
encrypted_payload = {
'salt': salt,
'data': encrypted_data,
'model_hash': model_hash,
'original_size': len(model_data),
'encrypted': True,
'version': '2.0',
'created_at': time.time()
}
return encrypted_payload, key
@staticmethod
def encrypt_model_file(input_path, output_path, password):
"""加密模型文件"""
try:
# 读取原始模型文件
with open(input_path, 'rb') as f:
model_data = f.read()
# 创建加密数据包
encrypted_payload, key = SimpleModelEncryptor.create_encrypted_payload(model_data, password)
# 保存加密文件
with open(output_path, 'wb') as f:
pickle.dump(encrypted_payload, f)
logger.info(f"模型加密成功: {input_path} -> {output_path}")
return {
'success': True,
'output_path': output_path,
'model_hash': encrypted_payload['model_hash'][:16],
'original_size': len(model_data),
'encrypted_size': os.path.getsize(output_path)
}
except Exception as e:
logger.error(f"加密模型文件失败: {str(e)}")
return {'success': False, 'error': str(e)}
@staticmethod
def verify_encryption(encrypted_path, password):
"""验证加密文件"""
return MandatoryModelValidator().decrypt_and_verify(encrypted_path, password)
# ==================== 全局实例和接口函数 ====================
# 导入线程模块
import threading
# 全局实例
_model_manager = None
_pre_task_validator = None
_upload_manager = None
def get_model_manager(config=None):
"""获取模型管理器单例"""
global _model_manager
if _model_manager is None:
if config is None:
from config import get_default_config
config = get_default_config()
_model_manager = SecureModelManager(config)
return _model_manager
def get_pre_task_validator(config=None):
"""获取预验证器单例"""
global _pre_task_validator
if _pre_task_validator is None:
if config is None:
from config import get_default_config
config = get_default_config()
_pre_task_validator = PreTaskModelValidator(config.get('upload', {}))
return _pre_task_validator
def validate_models_before_task(task_config):
"""验证任务的所有模型(外部调用接口)"""
validator = get_pre_task_validator()
return validator.validate_models_before_task(task_config)
def verify_single_model_api(model_path, encryption_key):
"""验证单个模型API接口"""
validator = get_pre_task_validator()
return validator.verify_single_model(model_path, encryption_key)
def verify_model_key_api(model_config, encryption_key):
"""验证模型密钥API接口"""
manager = get_model_manager()
return manager.verify_model_key(model_config, encryption_key)
def load_verified_model_api(model_config, encryption_key):
"""加载已验证的模型API接口"""
manager = get_model_manager()
return manager.load_verified_model(model_config, encryption_key)
def generate_secure_key_api(length=32):
"""生成安全密钥API接口"""
return ModelEncryptionService.generate_secure_key(length)
def validate_key_strength_api(key):
"""验证密钥强度API接口"""
return ModelEncryptionService.validate_key_strength(key)
def test_decryption_api(encrypted_path, password):
"""测试解密API接口"""
validator = MandatoryModelValidator()
return validator.decrypt_and_verify(encrypted_path, password)
# ==================== 向后兼容的接口 ====================
class MandatoryModelEncryptor:
"""向后兼容的加密器类"""
@staticmethod
def encrypt_model(input_path, output_path, password, require_encryption=True):
"""加密模型 - 向后兼容接口"""
return SimpleModelEncryptor.encrypt_model_file(input_path, output_path, password)
@staticmethod
def decrypt_model(encrypted_path, password, verify_key=False):
"""解密模型 - 向后兼容接口"""
validator = MandatoryModelValidator()
return validator.decrypt_model(encrypted_path, password, verify_key=verify_key)
@staticmethod
def is_properly_encrypted(encrypted_path):
"""检查是否被正确加密 - 向后兼容接口"""
return MandatoryModelValidator.is_properly_encrypted(encrypted_path)
@staticmethod
def generate_secure_key():
"""生成安全密钥 - 向后兼容接口"""
return ModelEncryptionService.generate_secure_key()
@staticmethod
def verify_model_key(encrypted_path, password):
"""验证模型密钥 - 向后兼容接口"""
validator = MandatoryModelValidator()
return validator.decrypt_and_verify(encrypted_path, password)
# ==================== 测试代码 ====================
if __name__ == "__main__":
import time
print("=== 强制加密模型验证模块测试 ===")
# 配置示例
config = {
'models_dir': './models',
'encrypted_models_dir': './encrypted_models',
'cache_verification': True
}
# 模型配置示例
model_config = {
'path': 'yolov8n.enc', # 加密模型文件路径
'device': 'cpu',
'half': False,
'encrypted': True,
'encryption_key': 'test-password-123!@#' # 测试密钥
}
# 测试密钥生成
print("\n1. 测试密钥生成:")
key_info = ModelEncryptionService.generate_secure_key()
print(f" 生成的密钥: {key_info['key']}")
print(f" 密钥哈希: {key_info['key_hash']}")
print(f" 短哈希: {key_info['short_hash']}")
# 测试密钥强度验证
print("\n2. 测试密钥强度验证:")
test_key = "WeakKey123"
valid, msg = ModelEncryptionService.validate_key_strength(test_key)
print(f" 密钥 '{test_key}': {valid} - {msg}")
test_key = "Strong@Password#2024!Complex"
valid, msg = ModelEncryptionService.validate_key_strength(test_key)
print(f" 密钥 '{test_key}': {valid} - {msg}")
# 测试模型验证器
print("\n3. 测试模型验证器:")
validator = MandatoryModelValidator()
# 测试不存在的文件
result = validator.verify_model_encryption("nonexistent.enc")
print(f" 不存在的文件验证: {result['valid']} - {result.get('error', '')}")
# 测试预验证器
print("\n4. 测试预验证器:")
pre_validator = PreTaskModelValidator(config)
task_config = {
'models': [
{
'path': 'test_model.enc',
'encryption_key': 'Test@Password#2024!Complex'
}
]
}
result = pre_validator.validate_models_before_task(task_config)
print(f" 任务验证结果: 成功={result['success']}, 错误={result.get('error', '')}")
# 测试安全模型管理器
print("\n5. 测试安全模型管理器:")
manager = SecureModelManager(config)
# 测试模型信息获取
model_info = manager.get_model_info(model_config)
if model_info:
print(f" 模型信息: 哈希={model_info['model_hash']}, 大小={model_info['original_size']}")
else:
print(f" 模型信息: 不可用")
print("\n=== 测试完成 ===")