359 lines
14 KiB
Python
359 lines
14 KiB
Python
# model_upload_manager.py
|
||
import os
|
||
import hashlib
|
||
import json
|
||
import tempfile
|
||
import time
|
||
import threading
|
||
from pathlib import Path
|
||
from log import logger
|
||
from cryptography.fernet import Fernet
|
||
from cryptography.hazmat.primitives import hashes
|
||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||
import base64
|
||
import pickle
|
||
|
||
|
||
class ChunkedUploadManager:
|
||
"""分片上传管理器"""
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
self.uploads_dir = config.get('uploads_dir', 'uploads')
|
||
self.temp_dir = config.get('temp_dir', 'temp_uploads')
|
||
self.encrypted_models_dir = config.get('encrypted_models_dir', 'encrypted_models')
|
||
|
||
# 创建必要的目录
|
||
os.makedirs(self.uploads_dir, exist_ok=True)
|
||
os.makedirs(self.temp_dir, exist_ok=True)
|
||
os.makedirs(self.encrypted_models_dir, exist_ok=True)
|
||
|
||
# 存储上传状态
|
||
self.upload_sessions = {}
|
||
self.lock = threading.Lock()
|
||
|
||
# 清理过期的上传会话(每10分钟)
|
||
self._start_cleanup_thread()
|
||
|
||
def _start_cleanup_thread(self):
|
||
"""启动清理线程"""
|
||
|
||
def cleanup():
|
||
while True:
|
||
try:
|
||
self.cleanup_expired_sessions()
|
||
time.sleep(600) # 每10分钟清理一次
|
||
except Exception as e:
|
||
logger.error(f"清理上传会话失败: {str(e)}")
|
||
time.sleep(60)
|
||
|
||
thread = threading.Thread(target=cleanup, daemon=True)
|
||
thread.start()
|
||
|
||
def create_upload_session(self, filename, total_size, chunk_size, encryption_key=None):
|
||
"""创建上传会话"""
|
||
try:
|
||
# 生成唯一的session_id
|
||
session_id = hashlib.md5(f"{filename}_{time.time()}".encode()).hexdigest()
|
||
|
||
# 创建临时目录用于存储分片
|
||
session_dir = os.path.join(self.temp_dir, session_id)
|
||
os.makedirs(session_dir, exist_ok=True)
|
||
|
||
# 计算总分片数
|
||
total_chunks = (total_size + chunk_size - 1) // chunk_size
|
||
|
||
session_info = {
|
||
'session_id': session_id,
|
||
'filename': filename,
|
||
'original_filename': filename,
|
||
'total_size': total_size,
|
||
'chunk_size': chunk_size,
|
||
'total_chunks': total_chunks,
|
||
'received_chunks': set(),
|
||
'received_size': 0,
|
||
'created_at': time.time(),
|
||
'last_activity': time.time(),
|
||
'status': 'uploading',
|
||
'session_dir': session_dir,
|
||
'encryption_key': encryption_key,
|
||
'encrypted': encryption_key is not None,
|
||
'merged_file': None,
|
||
'encrypted_file': None
|
||
}
|
||
|
||
with self.lock:
|
||
self.upload_sessions[session_id] = session_info
|
||
|
||
logger.info(f"创建上传会话: {session_id}, 文件: {filename}, 总分片: {total_chunks}")
|
||
|
||
return {
|
||
'success': True,
|
||
'session_id': session_id,
|
||
'total_chunks': total_chunks
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建上传会话失败: {str(e)}")
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
def upload_chunk(self, session_id, chunk_index, chunk_data):
|
||
"""上传分片"""
|
||
try:
|
||
with self.lock:
|
||
if session_id not in self.upload_sessions:
|
||
return {'success': False, 'error': '会话不存在或已过期'}
|
||
|
||
session = self.upload_sessions[session_id]
|
||
|
||
# 检查分片索引是否有效
|
||
if chunk_index < 0 or chunk_index >= session['total_chunks']:
|
||
return {'success': False, 'error': f'无效的分片索引: {chunk_index}'}
|
||
|
||
# 检查分片是否已接收
|
||
if chunk_index in session['received_chunks']:
|
||
return {'success': False, 'error': f'分片 {chunk_index} 已接收'}
|
||
|
||
# 保存分片
|
||
chunk_filename = os.path.join(session['session_dir'], f'chunk_{chunk_index:06d}')
|
||
with open(chunk_filename, 'wb') as f:
|
||
f.write(chunk_data)
|
||
|
||
# 更新会话状态
|
||
session['received_chunks'].add(chunk_index)
|
||
session['received_size'] += len(chunk_data)
|
||
session['last_activity'] = time.time()
|
||
|
||
# 计算进度
|
||
progress = (len(session['received_chunks']) / session['total_chunks']) * 100
|
||
|
||
logger.debug(f"上传分片: {session_id} - 分片 {chunk_index}, 进度: {progress:.1f}%")
|
||
|
||
# 检查是否所有分片都已上传完成
|
||
if len(session['received_chunks']) == session['total_chunks']:
|
||
session['status'] = 'merging'
|
||
# 启动合并线程
|
||
threading.Thread(target=self._merge_and_encrypt, args=(session_id,), daemon=True).start()
|
||
|
||
return {
|
||
'success': True,
|
||
'progress': progress,
|
||
'received_chunks': len(session['received_chunks']),
|
||
'total_chunks': session['total_chunks']
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"上传分片失败: {str(e)}")
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
def _merge_and_encrypt(self, session_id):
|
||
"""合并分片并加密"""
|
||
try:
|
||
with self.lock:
|
||
session = self.upload_sessions.get(session_id)
|
||
if not session:
|
||
return
|
||
|
||
session['status'] = 'merging'
|
||
logger.info(f"开始合并分片: {session_id}, 文件: {session['filename']}")
|
||
|
||
# 合并分片
|
||
merged_path = os.path.join(session['session_dir'], 'merged_file')
|
||
with open(merged_path, 'wb') as output:
|
||
for chunk_idx in range(session['total_chunks']):
|
||
chunk_file = os.path.join(session['session_dir'], f'chunk_{chunk_idx:06d}')
|
||
with open(chunk_file, 'rb') as input_chunk:
|
||
output.write(input_chunk.read())
|
||
|
||
# 更新会话状态
|
||
with self.lock:
|
||
session['merged_file'] = merged_path
|
||
session['status'] = 'encrypting'
|
||
logger.info(f"分片合并完成: {session_id}, 开始加密")
|
||
|
||
# 加密文件
|
||
if session['encryption_key']:
|
||
encrypted_result = self._encrypt_in_memory(merged_path, session['encryption_key'])
|
||
|
||
with self.lock:
|
||
if encrypted_result['success']:
|
||
session['encrypted_file'] = encrypted_result['encrypted_path']
|
||
session['model_hash'] = encrypted_result['model_hash']
|
||
session['key_hash'] = encrypted_result['key_hash']
|
||
session['status'] = 'completed'
|
||
session['last_activity'] = time.time()
|
||
|
||
# 清理合并的原始文件(不在磁盘保存)
|
||
if os.path.exists(merged_path):
|
||
os.remove(merged_path)
|
||
|
||
# 清理分片文件
|
||
self._cleanup_chunks(session['session_dir'])
|
||
|
||
logger.info(f"文件加密完成: {session_id}, 加密文件: {session['encrypted_file']}")
|
||
else:
|
||
session['status'] = 'failed'
|
||
session['error'] = encrypted_result.get('error', '加密失败')
|
||
logger.error(f"文件加密失败: {session_id}, 错误: {session['error']}")
|
||
else:
|
||
# 如果没有提供密钥,直接保存原始文件(不推荐)
|
||
with self.lock:
|
||
session['status'] = 'failed'
|
||
session['error'] = '未提供加密密钥'
|
||
logger.warning(f"未提供加密密钥: {session_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"合并加密过程失败: {str(e)}")
|
||
with self.lock:
|
||
if session_id in self.upload_sessions:
|
||
self.upload_sessions[session_id]['status'] = 'failed'
|
||
self.upload_sessions[session_id]['error'] = str(e)
|
||
|
||
def _encrypt_in_memory(self, file_path, password):
|
||
"""在内存中加密文件"""
|
||
try:
|
||
# 读取文件到内存
|
||
with open(file_path, 'rb') as f:
|
||
model_data = f.read()
|
||
|
||
# 计算模型哈希
|
||
model_hash = hashlib.sha256(model_data).hexdigest()
|
||
|
||
# 生成盐和密钥
|
||
salt = os.urandom(16)
|
||
kdf = PBKDF2HMAC(
|
||
algorithm=hashes.SHA256(),
|
||
length=32,
|
||
salt=salt,
|
||
iterations=100000,
|
||
)
|
||
key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
|
||
key_hash = hashlib.sha256(key).hexdigest()
|
||
|
||
# 加密数据
|
||
fernet = Fernet(key)
|
||
encrypted_data = fernet.encrypt(model_data)
|
||
|
||
# 创建加密数据包
|
||
encrypted_payload = {
|
||
'salt': salt,
|
||
'data': encrypted_data,
|
||
'model_hash': model_hash,
|
||
'original_size': len(model_data),
|
||
'encrypted': True,
|
||
'version': '2.0', # 新版本标识
|
||
'created_at': time.time(),
|
||
'key_hash': key_hash[:16] # 保存密钥哈希的前16位用于验证
|
||
}
|
||
|
||
# 生成加密文件名
|
||
encrypted_filename = f"{model_hash[:16]}.enc"
|
||
encrypted_path = os.path.join(self.encrypted_models_dir, encrypted_filename)
|
||
|
||
# 保存加密文件
|
||
with open(encrypted_path, 'wb') as f:
|
||
pickle.dump(encrypted_payload, f)
|
||
|
||
return {
|
||
'success': True,
|
||
'encrypted_path': encrypted_path,
|
||
'model_hash': model_hash,
|
||
'key_hash': key_hash,
|
||
'filename': encrypted_filename
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"内存加密失败: {str(e)}")
|
||
return {'success': False, 'error': str(e)}
|
||
|
||
def _cleanup_chunks(self, session_dir):
|
||
"""清理分片文件"""
|
||
try:
|
||
for item in os.listdir(session_dir):
|
||
item_path = os.path.join(session_dir, item)
|
||
if os.path.isfile(item_path):
|
||
os.remove(item_path)
|
||
os.rmdir(session_dir)
|
||
except Exception as e:
|
||
logger.warning(f"清理分片文件失败: {str(e)}")
|
||
|
||
def get_upload_status(self, session_id):
|
||
"""获取上传状态"""
|
||
with self.lock:
|
||
if session_id not in self.upload_sessions:
|
||
return {'success': False, 'error': '会话不存在'}
|
||
|
||
session = self.upload_sessions[session_id]
|
||
|
||
# 构建返回数据
|
||
result = {
|
||
'session_id': session_id,
|
||
'filename': session['filename'],
|
||
'status': session['status'],
|
||
'progress': (len(session['received_chunks']) / session['total_chunks']) * 100,
|
||
'received_chunks': len(session['received_chunks']),
|
||
'total_chunks': session['total_chunks'],
|
||
'received_size': session['received_size'],
|
||
'total_size': session['total_size'],
|
||
'encrypted': session['encrypted'],
|
||
'created_at': session['created_at'],
|
||
'last_activity': session['last_activity']
|
||
}
|
||
|
||
if session['status'] == 'completed':
|
||
result['encrypted_file'] = session['encrypted_file']
|
||
result['model_hash'] = session.get('model_hash')
|
||
result['key_hash'] = session.get('key_hash')
|
||
result['relative_path'] = os.path.basename(session['encrypted_file'])
|
||
elif session['status'] == 'failed':
|
||
result['error'] = session.get('error', '未知错误')
|
||
|
||
return {'success': True, 'data': result}
|
||
|
||
def cleanup_expired_sessions(self, expire_hours=24):
|
||
"""清理过期的上传会话"""
|
||
try:
|
||
current_time = time.time()
|
||
expired_sessions = []
|
||
|
||
with self.lock:
|
||
for session_id, session in list(self.upload_sessions.items()):
|
||
# 清理超过24小时无活动的会话
|
||
if current_time - session['last_activity'] > expire_hours * 3600:
|
||
expired_sessions.append(session_id)
|
||
|
||
for session_id in expired_sessions:
|
||
session = self.upload_sessions.pop(session_id)
|
||
# 清理临时文件
|
||
if os.path.exists(session['session_dir']):
|
||
try:
|
||
self._cleanup_chunks(session['session_dir'])
|
||
except:
|
||
pass
|
||
logger.info(f"清理过期会话: {session_id}")
|
||
|
||
return len(expired_sessions)
|
||
|
||
except Exception as e:
|
||
logger.error(f"清理过期会话失败: {str(e)}")
|
||
return 0
|
||
|
||
|
||
# 全局上传管理器实例
|
||
_upload_manager = None
|
||
|
||
|
||
def get_upload_manager(config=None):
|
||
"""获取上传管理器单例"""
|
||
global _upload_manager
|
||
if _upload_manager is None:
|
||
if config is None:
|
||
config = {
|
||
'uploads_dir': 'uploads',
|
||
'temp_dir': 'temp_uploads',
|
||
'encrypted_models_dir': 'encrypted_models'
|
||
}
|
||
_upload_manager = ChunkedUploadManager(config)
|
||
return _upload_manager
|
||
|