Yolov/model_upload_manager.py

359 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# model_upload_manager.py
import os
import hashlib
import json
import tempfile
import time
import threading
from pathlib import Path
from log import logger
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
import pickle
class ChunkedUploadManager:
"""分片上传管理器"""
def __init__(self, config):
self.config = config
self.uploads_dir = config.get('uploads_dir', 'uploads')
self.temp_dir = config.get('temp_dir', 'temp_uploads')
self.encrypted_models_dir = config.get('encrypted_models_dir', 'encrypted_models')
# 创建必要的目录
os.makedirs(self.uploads_dir, exist_ok=True)
os.makedirs(self.temp_dir, exist_ok=True)
os.makedirs(self.encrypted_models_dir, exist_ok=True)
# 存储上传状态
self.upload_sessions = {}
self.lock = threading.Lock()
# 清理过期的上传会话每10分钟
self._start_cleanup_thread()
def _start_cleanup_thread(self):
"""启动清理线程"""
def cleanup():
while True:
try:
self.cleanup_expired_sessions()
time.sleep(600) # 每10分钟清理一次
except Exception as e:
logger.error(f"清理上传会话失败: {str(e)}")
time.sleep(60)
thread = threading.Thread(target=cleanup, daemon=True)
thread.start()
def create_upload_session(self, filename, total_size, chunk_size, encryption_key=None):
"""创建上传会话"""
try:
# 生成唯一的session_id
session_id = hashlib.md5(f"{filename}_{time.time()}".encode()).hexdigest()
# 创建临时目录用于存储分片
session_dir = os.path.join(self.temp_dir, session_id)
os.makedirs(session_dir, exist_ok=True)
# 计算总分片数
total_chunks = (total_size + chunk_size - 1) // chunk_size
session_info = {
'session_id': session_id,
'filename': filename,
'original_filename': filename,
'total_size': total_size,
'chunk_size': chunk_size,
'total_chunks': total_chunks,
'received_chunks': set(),
'received_size': 0,
'created_at': time.time(),
'last_activity': time.time(),
'status': 'uploading',
'session_dir': session_dir,
'encryption_key': encryption_key,
'encrypted': encryption_key is not None,
'merged_file': None,
'encrypted_file': None
}
with self.lock:
self.upload_sessions[session_id] = session_info
logger.info(f"创建上传会话: {session_id}, 文件: {filename}, 总分片: {total_chunks}")
return {
'success': True,
'session_id': session_id,
'total_chunks': total_chunks
}
except Exception as e:
logger.error(f"创建上传会话失败: {str(e)}")
return {'success': False, 'error': str(e)}
def upload_chunk(self, session_id, chunk_index, chunk_data):
"""上传分片"""
try:
with self.lock:
if session_id not in self.upload_sessions:
return {'success': False, 'error': '会话不存在或已过期'}
session = self.upload_sessions[session_id]
# 检查分片索引是否有效
if chunk_index < 0 or chunk_index >= session['total_chunks']:
return {'success': False, 'error': f'无效的分片索引: {chunk_index}'}
# 检查分片是否已接收
if chunk_index in session['received_chunks']:
return {'success': False, 'error': f'分片 {chunk_index} 已接收'}
# 保存分片
chunk_filename = os.path.join(session['session_dir'], f'chunk_{chunk_index:06d}')
with open(chunk_filename, 'wb') as f:
f.write(chunk_data)
# 更新会话状态
session['received_chunks'].add(chunk_index)
session['received_size'] += len(chunk_data)
session['last_activity'] = time.time()
# 计算进度
progress = (len(session['received_chunks']) / session['total_chunks']) * 100
logger.debug(f"上传分片: {session_id} - 分片 {chunk_index}, 进度: {progress:.1f}%")
# 检查是否所有分片都已上传完成
if len(session['received_chunks']) == session['total_chunks']:
session['status'] = 'merging'
# 启动合并线程
threading.Thread(target=self._merge_and_encrypt, args=(session_id,), daemon=True).start()
return {
'success': True,
'progress': progress,
'received_chunks': len(session['received_chunks']),
'total_chunks': session['total_chunks']
}
except Exception as e:
logger.error(f"上传分片失败: {str(e)}")
return {'success': False, 'error': str(e)}
def _merge_and_encrypt(self, session_id):
"""合并分片并加密"""
try:
with self.lock:
session = self.upload_sessions.get(session_id)
if not session:
return
session['status'] = 'merging'
logger.info(f"开始合并分片: {session_id}, 文件: {session['filename']}")
# 合并分片
merged_path = os.path.join(session['session_dir'], 'merged_file')
with open(merged_path, 'wb') as output:
for chunk_idx in range(session['total_chunks']):
chunk_file = os.path.join(session['session_dir'], f'chunk_{chunk_idx:06d}')
with open(chunk_file, 'rb') as input_chunk:
output.write(input_chunk.read())
# 更新会话状态
with self.lock:
session['merged_file'] = merged_path
session['status'] = 'encrypting'
logger.info(f"分片合并完成: {session_id}, 开始加密")
# 加密文件
if session['encryption_key']:
encrypted_result = self._encrypt_in_memory(merged_path, session['encryption_key'])
with self.lock:
if encrypted_result['success']:
session['encrypted_file'] = encrypted_result['encrypted_path']
session['model_hash'] = encrypted_result['model_hash']
session['key_hash'] = encrypted_result['key_hash']
session['status'] = 'completed'
session['last_activity'] = time.time()
# 清理合并的原始文件(不在磁盘保存)
if os.path.exists(merged_path):
os.remove(merged_path)
# 清理分片文件
self._cleanup_chunks(session['session_dir'])
logger.info(f"文件加密完成: {session_id}, 加密文件: {session['encrypted_file']}")
else:
session['status'] = 'failed'
session['error'] = encrypted_result.get('error', '加密失败')
logger.error(f"文件加密失败: {session_id}, 错误: {session['error']}")
else:
# 如果没有提供密钥,直接保存原始文件(不推荐)
with self.lock:
session['status'] = 'failed'
session['error'] = '未提供加密密钥'
logger.warning(f"未提供加密密钥: {session_id}")
except Exception as e:
logger.error(f"合并加密过程失败: {str(e)}")
with self.lock:
if session_id in self.upload_sessions:
self.upload_sessions[session_id]['status'] = 'failed'
self.upload_sessions[session_id]['error'] = str(e)
def _encrypt_in_memory(self, file_path, password):
"""在内存中加密文件"""
try:
# 读取文件到内存
with open(file_path, 'rb') as f:
model_data = f.read()
# 计算模型哈希
model_hash = hashlib.sha256(model_data).hexdigest()
# 生成盐和密钥
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
key_hash = hashlib.sha256(key).hexdigest()
# 加密数据
fernet = Fernet(key)
encrypted_data = fernet.encrypt(model_data)
# 创建加密数据包
encrypted_payload = {
'salt': salt,
'data': encrypted_data,
'model_hash': model_hash,
'original_size': len(model_data),
'encrypted': True,
'version': '2.0', # 新版本标识
'created_at': time.time(),
'key_hash': key_hash[:16] # 保存密钥哈希的前16位用于验证
}
# 生成加密文件名
encrypted_filename = f"{model_hash[:16]}.enc"
encrypted_path = os.path.join(self.encrypted_models_dir, encrypted_filename)
# 保存加密文件
with open(encrypted_path, 'wb') as f:
pickle.dump(encrypted_payload, f)
return {
'success': True,
'encrypted_path': encrypted_path,
'model_hash': model_hash,
'key_hash': key_hash,
'filename': encrypted_filename
}
except Exception as e:
logger.error(f"内存加密失败: {str(e)}")
return {'success': False, 'error': str(e)}
def _cleanup_chunks(self, session_dir):
"""清理分片文件"""
try:
for item in os.listdir(session_dir):
item_path = os.path.join(session_dir, item)
if os.path.isfile(item_path):
os.remove(item_path)
os.rmdir(session_dir)
except Exception as e:
logger.warning(f"清理分片文件失败: {str(e)}")
def get_upload_status(self, session_id):
"""获取上传状态"""
with self.lock:
if session_id not in self.upload_sessions:
return {'success': False, 'error': '会话不存在'}
session = self.upload_sessions[session_id]
# 构建返回数据
result = {
'session_id': session_id,
'filename': session['filename'],
'status': session['status'],
'progress': (len(session['received_chunks']) / session['total_chunks']) * 100,
'received_chunks': len(session['received_chunks']),
'total_chunks': session['total_chunks'],
'received_size': session['received_size'],
'total_size': session['total_size'],
'encrypted': session['encrypted'],
'created_at': session['created_at'],
'last_activity': session['last_activity']
}
if session['status'] == 'completed':
result['encrypted_file'] = session['encrypted_file']
result['model_hash'] = session.get('model_hash')
result['key_hash'] = session.get('key_hash')
result['relative_path'] = os.path.basename(session['encrypted_file'])
elif session['status'] == 'failed':
result['error'] = session.get('error', '未知错误')
return {'success': True, 'data': result}
def cleanup_expired_sessions(self, expire_hours=24):
"""清理过期的上传会话"""
try:
current_time = time.time()
expired_sessions = []
with self.lock:
for session_id, session in list(self.upload_sessions.items()):
# 清理超过24小时无活动的会话
if current_time - session['last_activity'] > expire_hours * 3600:
expired_sessions.append(session_id)
for session_id in expired_sessions:
session = self.upload_sessions.pop(session_id)
# 清理临时文件
if os.path.exists(session['session_dir']):
try:
self._cleanup_chunks(session['session_dir'])
except:
pass
logger.info(f"清理过期会话: {session_id}")
return len(expired_sessions)
except Exception as e:
logger.error(f"清理过期会话失败: {str(e)}")
return 0
# 全局上传管理器实例
_upload_manager = None
def get_upload_manager(config=None):
"""获取上传管理器单例"""
global _upload_manager
if _upload_manager is None:
if config is None:
config = {
'uploads_dir': 'uploads',
'temp_dir': 'temp_uploads',
'encrypted_models_dir': 'encrypted_models'
}
_upload_manager = ChunkedUploadManager(config)
return _upload_manager