149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
# key_manager.py
|
|
import hashlib
|
|
import json
|
|
import threading
|
|
import time
|
|
from datetime import datetime
|
|
from log import logger
|
|
from global_data import gd
|
|
|
|
|
|
class EncryptionKeyManager:
|
|
"""加密密钥管理器"""
|
|
|
|
def __init__(self):
|
|
self.keys_store = {} # task_id -> model_keys
|
|
self.key_history = {} # 密钥使用历史
|
|
self.lock = threading.Lock()
|
|
|
|
def register_task_keys(self, task_id, model_configs):
|
|
"""注册任务的加密密钥"""
|
|
with self.lock:
|
|
if task_id not in self.keys_store:
|
|
self.keys_store[task_id] = {}
|
|
|
|
key_info = {
|
|
'task_id': task_id,
|
|
'models': [],
|
|
'registered_at': datetime.now().isoformat(),
|
|
'key_count': 0
|
|
}
|
|
|
|
for i, model_cfg in enumerate(model_configs):
|
|
encryption_key = model_cfg.get('encryption_key')
|
|
if encryption_key:
|
|
# 计算密钥哈希(不存储原始密钥)
|
|
key_hash = hashlib.sha256(encryption_key.encode()).hexdigest()
|
|
|
|
model_key_info = {
|
|
'model_index': i,
|
|
'model_path': model_cfg.get('path', 'unknown'),
|
|
'key_hash': key_hash,
|
|
'short_hash': key_hash[:16],
|
|
'key_provided': True
|
|
}
|
|
|
|
self.keys_store[task_id][f'model_{i}'] = model_key_info
|
|
key_info['models'].append(model_key_info)
|
|
key_info['key_count'] += 1
|
|
|
|
# 记录密钥使用历史
|
|
history_key = f"{task_id}_{key_hash[:8]}"
|
|
self.key_history[history_key] = {
|
|
'task_id': task_id,
|
|
'model_index': i,
|
|
'key_hash': key_hash,
|
|
'used_at': datetime.now().isoformat(),
|
|
'model_path': model_cfg.get('path', 'unknown')
|
|
}
|
|
|
|
logger.info(f"注册任务 {task_id} 的 {key_info['key_count']} 个加密密钥")
|
|
return key_info
|
|
|
|
def validate_task_keys(self, task_id):
|
|
"""验证任务的加密密钥"""
|
|
with self.lock:
|
|
if task_id not in self.keys_store:
|
|
return {
|
|
'valid': False,
|
|
'error': '任务未注册密钥',
|
|
'key_count': 0
|
|
}
|
|
|
|
key_info = self.keys_store[task_id]
|
|
valid_keys = len(key_info)
|
|
|
|
return {
|
|
'valid': True,
|
|
'key_count': valid_keys,
|
|
'models': list(key_info.keys()),
|
|
'last_updated': self.get_last_key_update(task_id)
|
|
}
|
|
|
|
def get_last_key_update(self, task_id):
|
|
"""获取密钥最后更新时间"""
|
|
if task_id not in self.keys_store:
|
|
return None
|
|
|
|
# 从历史记录中查找
|
|
for history in self.key_history.values():
|
|
if history['task_id'] == task_id:
|
|
return history['used_at']
|
|
|
|
return None
|
|
|
|
def cleanup_task_keys(self, task_id):
|
|
"""清理任务的加密密钥"""
|
|
with self.lock:
|
|
if task_id in self.keys_store:
|
|
key_count = len(self.keys_store[task_id])
|
|
del self.keys_store[task_id]
|
|
logger.info(f"清理任务 {task_id} 的 {key_count} 个加密密钥")
|
|
return True
|
|
return False
|
|
|
|
def get_key_statistics(self):
|
|
"""获取密钥统计信息"""
|
|
with self.lock:
|
|
total_tasks = len(self.keys_store)
|
|
total_keys = sum(len(keys) for keys in self.keys_store.values())
|
|
total_history = len(self.key_history)
|
|
|
|
return {
|
|
'total_tasks': total_tasks,
|
|
'total_keys': total_keys,
|
|
'total_history': total_history,
|
|
'active_tasks': list(self.keys_store.keys())
|
|
}
|
|
|
|
def verify_key_for_model(self, task_id, model_index, provided_key):
|
|
"""验证特定模型的密钥"""
|
|
with self.lock:
|
|
if task_id not in self.keys_store:
|
|
return {'valid': False, 'error': '任务未注册'}
|
|
|
|
model_key = f'model_{model_index}'
|
|
if model_key not in self.keys_store[task_id]:
|
|
return {'valid': False, 'error': '模型未注册密钥'}
|
|
|
|
# 计算提供的密钥哈希
|
|
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
|
stored_hash = self.keys_store[task_id][model_key]['key_hash']
|
|
|
|
if provided_hash == stored_hash:
|
|
return {
|
|
'valid': True,
|
|
'key_hash': stored_hash,
|
|
'short_hash': stored_hash[:16]
|
|
}
|
|
else:
|
|
return {
|
|
'valid': False,
|
|
'error': '密钥不匹配',
|
|
'provided_hash': provided_hash[:16],
|
|
'stored_hash': stored_hash[:16]
|
|
}
|
|
|
|
|
|
# 全局密钥管理器实例
|
|
key_manager = EncryptionKeyManager() |