# model_crypto.py import os import tempfile import requests from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC import base64 import pickle from pathlib import Path from log import logger class ModelEncryptor: """模型加密/解密器""" @staticmethod def generate_key(password: str, salt: bytes = None): """生成加密密钥""" if salt is None: salt = os.urandom(16) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(password.encode())) return key, salt @staticmethod def encrypt_model(model_path: str, output_path: str, password: str): """加密模型文件""" try: # 读取模型文件 with open(model_path, 'rb') as f: model_data = f.read() # 生成密钥 key, salt = ModelEncryptor.generate_key(password) fernet = Fernet(key) # 加密数据 encrypted_data = fernet.encrypt(model_data) # 保存加密数据(包含salt) encrypted_payload = { 'salt': salt, 'data': encrypted_data, 'original_size': len(model_data) } with open(output_path, 'wb') as f: pickle.dump(encrypted_payload, f) logger.info(f"模型加密成功: {model_path} -> {output_path}") return True except Exception as e: logger.error(f"模型加密失败: {str(e)}") return False @staticmethod def decrypt_model(encrypted_path: str, password: str): """解密模型到内存""" try: # 读取加密文件 with open(encrypted_path, 'rb') as f: encrypted_payload = pickle.load(f) salt = encrypted_payload['salt'] encrypted_data = encrypted_payload['data'] # 生成密钥 key, _ = ModelEncryptor.generate_key(password, salt) fernet = Fernet(key) # 解密数据 decrypted_data = fernet.decrypt(encrypted_data) # 保存到临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp: tmp.write(decrypted_data) temp_path = tmp.name logger.info(f"模型解密成功: {encrypted_path}") return temp_path except Exception as e: logger.error(f"模型解密失败: {str(e)}") return None @staticmethod def is_encrypted(model_path: str): """检查模型是否加密""" try: with open(model_path, 'rb') as f: # 尝试读取加密格式 data = pickle.load(f) return isinstance(data, dict) and 'salt' in data and 'data' in data except: return False class ModelManager: """模型管理器,支持加密模型加载""" def __init__(self, config): self.config = config self.models_dir = "models" os.makedirs(self.models_dir, exist_ok=True) def load_model(self, model_config): """加载模型(支持加密)""" model_path = model_config['path'] encrypted = model_config.get('encrypted', False) encryption_key = model_config.get('encryption_key') local_path = os.path.join(self.models_dir, os.path.basename(model_path)) # 下载模型(如果不存在) if not os.path.exists(local_path): if not self.download_model(model_config): return None # 如果是加密模型,需要解密 if encrypted and encryption_key: if ModelEncryptor.is_encrypted(local_path): decrypted_path = ModelEncryptor.decrypt_model(local_path, encryption_key) if decrypted_path: try: from ultralytics import YOLO model = YOLO(decrypted_path).to(model_config['device']) # 清理临时文件 try: os.unlink(decrypted_path) except: pass return model except Exception as e: logger.error(f"加载解密模型失败: {str(e)}") return None else: return None else: logger.warning(f"模型未加密或密钥错误: {local_path}") return None else: # 普通模型加载 try: from ultralytics import YOLO model = YOLO(local_path).to(model_config['device']) # 应用配置 if model_config.get('half', False) and 'cuda' in model_config['device']: model = model.half() return model except Exception as e: logger.error(f"加载模型失败: {str(e)}") return None def download_model(self, model_config): """下载模型""" try: model_path = model_config['path'] download_url = model_config.get('download_url') if not download_url: logger.error(f"模型无下载地址: {model_path}") return False local_path = os.path.join(self.models_dir, os.path.basename(model_path)) logger.info(f"下载模型: {download_url} -> {local_path}") response = requests.get(download_url, stream=True, timeout=30) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(local_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): downloaded += len(chunk) f.write(chunk) if total_size > 0: progress = downloaded * 100 // total_size if progress % 10 == 0: logger.info(f"下载进度: {progress}%") logger.info(f"模型下载完成: {local_path} ({downloaded} 字节)") return True except Exception as e: logger.error(f"下载模型失败: {str(e)}") return False