172 lines
6.0 KiB
Python
172 lines
6.0 KiB
Python
import requests
|
||
import json
|
||
import time
|
||
|
||
|
||
def test_create_multi_model_task():
|
||
"""测试创建多模型任务"""
|
||
url = "http://localhost:9309/api/tasks/create"
|
||
|
||
task_data = {
|
||
"rtmp_url": "rtmp://175.27.168.120:6019/live/8UUXN5400A079H",
|
||
"push_url": "rtmp://localhost:1935/live/15",
|
||
"taskname": "多模型道路监控",
|
||
"models": [
|
||
{
|
||
"path": "yolov8n_encrypted.pt",
|
||
"encryption_key":"123456",
|
||
"tags": {
|
||
"0": {"name": "汽车", "reliability": 0.4},
|
||
"1": {"name": "行人", "reliability": 0.3},
|
||
"2": {"name": "自行车", "reliability": 0.35}
|
||
},
|
||
"conf_thres": 0.3,
|
||
"imgsz": 640,
|
||
"color": [0, 255, 0], # 绿色
|
||
"device": "cuda:0",
|
||
"line_width": 2,
|
||
"enabled": True
|
||
},
|
||
{
|
||
"path": "yolov8n_encrypted.pt",
|
||
"encryption_key": "123456",
|
||
"tags": {
|
||
"0": {"name": "轿车", "reliability": 0.5},
|
||
"1": {"name": "卡车", "reliability": 0.6},
|
||
"2": {"name": "公交车", "reliability": 0.55}
|
||
},
|
||
"conf_thres": 0.35,
|
||
"imgsz": 1280,
|
||
"color": [255, 0, 0], # 红色
|
||
"device": "cuda:0",
|
||
"line_width": 2,
|
||
"enabled": True
|
||
}
|
||
]
|
||
}
|
||
|
||
try:
|
||
print("发送多模型任务创建请求...")
|
||
response = requests.post(url, json=task_data, timeout=30)
|
||
print(f"响应状态码: {response.status_code}")
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print(f"响应内容: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
||
|
||
if result.get('status') == 'success':
|
||
task_id = result.get('task_id')
|
||
print(f"\n✅ 多模型任务创建成功!")
|
||
print(f"任务ID: {task_id}")
|
||
print(f"任务名称: {task_data['taskname']}")
|
||
print(f"加载模型数: {len(task_data['models'])}")
|
||
|
||
# 等待任务启动
|
||
time.sleep(2)
|
||
|
||
# 测试获取任务详情
|
||
status_url = f"http://localhost:9309/api/tasks/{task_id}/status"
|
||
status_response = requests.get(status_url)
|
||
if status_response.status_code == 200:
|
||
status_data = status_response.json()
|
||
print(f"\n任务状态: {status_data.get('data', {}).get('status')}")
|
||
print(
|
||
f"模型信息: {json.dumps(status_data.get('data', {}).get('models'), indent=2, ensure_ascii=False)}")
|
||
|
||
# 验证模型数量
|
||
models = status_data.get('data', {}).get('models', [])
|
||
if len(models) == 2:
|
||
print("✅ 模型数量正确")
|
||
else:
|
||
print(f"❌ 模型数量错误: 期望2,实际{len(models)}")
|
||
else:
|
||
print(f"❌ 获取任务状态失败: {status_response.text}")
|
||
else:
|
||
print(f"❌ 请求失败: {response.text}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
|
||
|
||
def test_get_all_tasks():
|
||
"""测试获取所有任务"""
|
||
url = "http://localhost:9309/api/tasks"
|
||
|
||
try:
|
||
response = requests.get(url)
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print(f"\n获取所有任务成功:")
|
||
print(f"总任务数: {result.get('data', {}).get('total', 0)}")
|
||
print(f"活动任务数: {result.get('data', {}).get('active', 0)}")
|
||
print(f"总模型数: {result.get('data', {}).get('models_count', 0)}")
|
||
|
||
tasks = result.get('data', {}).get('tasks', [])
|
||
for i, task in enumerate(tasks):
|
||
print(f"\n任务 {i + 1}:")
|
||
print(f" ID: {task.get('task_id')}")
|
||
print(f" 名称: {task.get('config', {}).get('taskname')}")
|
||
print(f" 状态: {task.get('status')}")
|
||
print(f" 模型数: {len(task.get('models', []))}")
|
||
|
||
# 显示每个模型的详细信息
|
||
for model in task.get('models', []):
|
||
print(f" - {model.get('name')}: 阈值={model.get('conf_thres')}, 颜色={model.get('color')}")
|
||
else:
|
||
print(f"❌ 获取任务列表失败: {response.text}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
|
||
|
||
def test_error_cases():
|
||
"""测试错误情况"""
|
||
url = "http://localhost:9309/api/tasks/create"
|
||
|
||
# 测试1: 缺少rtmp_url
|
||
print("\n测试1: 缺少rtmp_url")
|
||
task_data = {
|
||
"taskname": "错误测试",
|
||
"models": [{"path": "yolov8n.pt"}]
|
||
}
|
||
response = requests.post(url, json=task_data)
|
||
print(f"响应: {response.status_code} - {response.text}")
|
||
|
||
# 测试2: 缺少models
|
||
print("\n测试2: 缺少models")
|
||
task_data = {
|
||
"rtmp_url": "rtmp://localhost:1935/live/14",
|
||
"taskname": "错误测试"
|
||
}
|
||
response = requests.post(url, json=task_data)
|
||
print(f"响应: {response.status_code} - {response.text}")
|
||
|
||
# 测试3: models不是列表
|
||
print("\n测试3: models不是列表")
|
||
task_data = {
|
||
"rtmp_url": "rtmp://localhost:1935/live/14",
|
||
"taskname": "错误测试",
|
||
"models": {"path": "yolov8n.pt"} # 应该是列表
|
||
}
|
||
response = requests.post(url, json=task_data)
|
||
print(f"响应: {response.status_code} - {response.text}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("=" * 50)
|
||
print("多模型专用测试脚本")
|
||
print("=" * 50)
|
||
|
||
# 测试正常情况
|
||
test_create_multi_model_task()
|
||
|
||
# 测试获取所有任务
|
||
test_get_all_tasks()
|
||
|
||
# 测试错误情况
|
||
test_error_cases()
|
||
|
||
print("\n" + "=" * 50)
|
||
print("测试完成!")
|
||
print("=" * 50)
|