133 lines
4.6 KiB
Python
133 lines
4.6 KiB
Python
import torch
|
|
from flask import Flask, jsonify, request, render_template
|
|
from flask_socketio import SocketIO
|
|
from config import get_default_config
|
|
from detectionThread import DetectionThread
|
|
import global_data as gd
|
|
from log import logger
|
|
from mapping_cn import class_mapping_cn
|
|
|
|
# Flask初始化
|
|
app = Flask(__name__, static_url_path='/static')
|
|
socketio = SocketIO(app,
|
|
cors_allowed_origins="*",
|
|
async_mode='threading',
|
|
allow_unsafe_werkzeug=True,
|
|
max_http_buffer_size=5 * 1024 * 1024) # 增加WebSocket缓冲区
|
|
|
|
|
|
# ======================= Flask路由 =======================
|
|
@app.route('/', methods=['GET'])
|
|
def main():
|
|
return render_template("flv2.html")
|
|
|
|
|
|
@app.route('/start_detection', methods=['POST'])
|
|
def start_detection():
|
|
detection_active = gd.get_value('detection_active')
|
|
if detection_active:
|
|
return jsonify({"status": "error", "message": "检测已在运行"}), 400
|
|
|
|
config = get_default_config()
|
|
config['socketIO'] = socketio
|
|
|
|
# 配置更新逻辑
|
|
if request.json:
|
|
# 更新RTMP地址
|
|
if 'rtmp_url' in request.json:
|
|
config['rtmp']['url'] = request.json['rtmp_url']
|
|
# 更新推流地址
|
|
if 'push_url' in request.json and request.json['push_url'] is not None:
|
|
config['push']['url'] = request.json['push_url']
|
|
# minio文件夹名称
|
|
if 'taskname' in request.json:
|
|
config['task']['taskname'] = request.json['taskname']
|
|
# 标签
|
|
|
|
if 'tag' in request.json and request.json['tag'] is not {}:
|
|
config['task']['tag'] = request.json['tag']
|
|
else:
|
|
config['task']['tag'] = class_mapping_cn
|
|
|
|
if 'taskid' in request.json:
|
|
config['task']['taskid'] = request.json['taskid']
|
|
# 性能参数调整
|
|
if 'imgsz' in request.json:
|
|
config['predict']['imgsz'] = max(128, min(1920, request.json['imgsz']))
|
|
if 'frame_skip' in request.json:
|
|
config['predict']['frame_skip'] = request.json['frame_skip']
|
|
if 'model_name' in request.json:
|
|
config['model']['path'] = request.json['model_name']
|
|
if 'aiid' in request.json:
|
|
config['task']['aiid'] = request.json['AlgoId']
|
|
if 'device' in request.json:
|
|
if request.json['device'] == "cuda:0" or "cpu":
|
|
config['predict']['device'] = request.json['device']
|
|
else:
|
|
config['predict']['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
# 创建并启动线程
|
|
detection_thread = DetectionThread(config)
|
|
gd.set_value('detection_thread', detection_thread)
|
|
detection_thread.start()
|
|
gd.set_value('detection_active', True)
|
|
return jsonify({
|
|
"status": "success",
|
|
"message": "目标检测已启动"
|
|
})
|
|
|
|
|
|
@app.route('/stop_detection', methods=['POST'])
|
|
def stop_detection():
|
|
detection_active = gd.get_value('detection_active')
|
|
detection_thread = gd.get_value('detection_thread')
|
|
if not detection_active or not detection_thread:
|
|
return jsonify({"status": "error", "message": "检测未运行"}), 400
|
|
|
|
# 停止线程
|
|
detection_thread.stop()
|
|
|
|
# 等待不超过3秒
|
|
detection_thread.join(3.0)
|
|
|
|
if detection_thread.is_alive():
|
|
logger.warning("检测线程未在规定时间停止")
|
|
else:
|
|
logger.info("检测线程已停止")
|
|
gd.set_value('detection_active', False)
|
|
gd.set_value('detection_thread', None)
|
|
return jsonify({
|
|
"status": "success",
|
|
"message": "目标检测已停止"
|
|
})
|
|
|
|
|
|
@app.route('/status', methods=['GET'])
|
|
def get_status():
|
|
detection_active = gd.get_value('detection_active')
|
|
detection_thread = gd.get_value('detection_thread')
|
|
if detection_active and detection_thread:
|
|
status = {
|
|
"active": True,
|
|
"fps": round(detection_thread.last_fps, 1), # 使用稳定FPS值
|
|
"frame_count": detection_thread.frame_count,
|
|
"detections_count": detection_thread.detections_count,
|
|
"rtmp_url": detection_thread.rtmp_url,
|
|
"reconnect_attempts": detection_thread.reconnect_attempts
|
|
}
|
|
if torch.cuda.is_available():
|
|
status['gpu_memory'] = torch.cuda.memory_allocated() // (1024 * 1024)
|
|
return jsonify(status)
|
|
else:
|
|
return jsonify({"active": False})
|
|
|
|
|
|
# WebSocket事件
|
|
@socketio.on('connect')
|
|
def handle_connect():
|
|
logger.info(f"Socket客户端已连接: {request.sid}")
|
|
|
|
|
|
@socketio.on('disconnect')
|
|
def handle_disconnect():
|
|
logger.info(f"Socket客户端断开: {request.sid}")
|