Source code for pyroengine.engine

# Copyright (C) 2022-2024, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import glob
import io
import json
import logging
import os
import shutil
import signal
import time
from collections import deque
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import numpy as np
from PIL import Image
from pyroclient import client
from requests.exceptions import ConnectionError
from requests.models import Response

from pyroengine.utils import box_iou, nms

from .vision import Classifier

__all__ = ["Engine"]

logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)


def handler(signum, frame):
    raise TimeoutError("Heartbeat check timed out")


def heartbeat_with_timeout(api_instance, cam_id, timeout=1):
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(timeout)
    try:
        api_instance.heartbeat(cam_id)
    except TimeoutError:
        logging.warning(f"Heartbeat check timed out for {cam_id}")
    except ConnectionError:
        logging.warning(f"Unable to reach the pyro-api with {cam_id}")
    finally:
        signal.alarm(0)


[docs]class Engine: """This implements an object to manage predictions and API interactions for wildfire alerts. Args: hub_repo: repository on HF Hub to load the ONNX model from conf_thresh: confidence threshold to send an alert api_url: url of the pyronear API cam_creds: api credectials for each camera, the dictionary should be as the one in the example latitude: device latitude longitude: device longitude alert_relaxation: number of consecutive positive detections required to send the first alert, and also the number of consecutive negative detections before stopping the alert frame_size: Resize frame to frame_size before sending it to the api in order to save bandwidth (H, W) cache_backup_period: number of minutes between each cache backup to disk frame_saving_period: Send one frame over N to the api for our dataset cache_size: maximum number of alerts to save in cache day_time_strategy: strategy to define if it's daytime save_captured_frames: save all captured frames for debugging kwargs: keyword args of Classifier Examples: >>> from pyroengine import Engine >>> cam_creds ={ >>> "cam_id_1": {'login':'log1', 'password':'pwd1'}, >>> "cam_id_2": {'login':'log2', 'password':'pwd2'}, >>> } >>> pyroEngine = Engine(None, 0.25, 'https://api.pyronear.org', cam_creds, 48.88, 2.38) """ def __init__( self, model_path: Optional[str] = None, conf_thresh: float = 0.15, api_url: Optional[str] = None, cam_creds: Optional[Dict[str, Dict[str, str]]] = None, latitude: Optional[float] = None, longitude: Optional[float] = None, nb_consecutive_frames: int = 4, frame_size: Optional[Tuple[int, int]] = None, cache_backup_period: int = 60, frame_saving_period: Optional[int] = None, cache_size: int = 100, cache_folder: str = "data/", backup_size: int = 30, jpeg_quality: int = 80, day_time_strategy: Optional[str] = None, save_captured_frames: Optional[bool] = False, **kwargs: Any, ) -> None: """Init engine""" # Engine Setup self.model = Classifier(model_path=model_path, conf=0.05) self.conf_thresh = conf_thresh # API Setup if isinstance(api_url, str): assert isinstance(latitude, float) and isinstance(longitude, float) and isinstance(cam_creds, dict) self.latitude = latitude self.longitude = longitude self.api_client = {} if isinstance(api_url, str) and isinstance(cam_creds, dict): # Instantiate clients for each camera for _id, vals in cam_creds.items(): self.api_client[_id] = client.Client(api_url, vals["login"], vals["password"]) # Cache & relaxation self.frame_saving_period = frame_saving_period self.nb_consecutive_frames = nb_consecutive_frames self.frame_size = frame_size self.jpeg_quality = jpeg_quality self.cache_backup_period = cache_backup_period self.day_time_strategy = day_time_strategy self.save_captured_frames = save_captured_frames # Local backup self._backup_size = backup_size # Var initialization self._states: Dict[str, Dict[str, Any]] = { "-1": {"last_predictions": deque([], self.nb_consecutive_frames), "ongoing": False}, } if isinstance(cam_creds, dict): for cam_id in cam_creds: self._states[cam_id] = { "last_predictions": deque([], self.nb_consecutive_frames), "ongoing": False, } self.occlusion_masks: Dict[str, Optional[np.ndarray]] = {"-1": None} if isinstance(cam_creds, dict): for cam_id in cam_creds: mask_file = cache_folder + "/occlusion_masks/" + cam_id + ".jpg" if os.path.isfile(mask_file): self.occlusion_masks[cam_id] = np.array(Image.open(mask_file).convert(("L"))) else: self.occlusion_masks[cam_id] = None # Restore pending alerts cache self._alerts: deque = deque([], cache_size) self._cache = Path(cache_folder) # with Docker, the path has to be a bind volume assert self._cache.is_dir() self._load_cache() self.last_cache_dump = datetime.now(timezone.utc)
[docs] def clear_cache(self) -> None: """Clear local cache""" for file in self._cache.rglob("pending*"): file.unlink()
def _dump_cache(self) -> None: # Remove previous dump json_path = self._cache.joinpath("pending_alerts.json") if json_path.is_file(): with open(json_path, "rb") as f: data = json.load(f) for entry in data: os.remove(entry["frame_path"]) os.remove(json_path) data = [] for idx, info in enumerate(self._alerts): # Save frame to disk info["frame"].save(self._cache.joinpath(f"pending_frame{idx}.jpg")) # Save path in JSON data.append( { "frame_path": str(self._cache.joinpath(f"pending_frame{idx}.jpg")), "cam_id": info["cam_id"], "ts": info["ts"], "localization": info["localization"], } ) # JSON dump if len(data) > 0: with open(json_path, "w") as f: json.dump(data, f) def _load_cache(self) -> None: # Read json json_path = self._cache.joinpath("pending_alerts.json") if json_path.is_file(): with open(json_path, "rb") as f: data = json.load(f) for entry in data: # Open image frame = Image.open(entry["frame_path"], mode="r") self._alerts.append({"frame": frame, "cam_id": entry["cam_id"], "ts": entry["ts"]})
[docs] def heartbeat(self, cam_id: str) -> Response: """Updates last ping of device""" return self.api_client[cam_id].heartbeat()
def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> int: """Updates the detection states""" conf_th = self.conf_thresh * self.nb_consecutive_frames # Reduce threshold once we are in alert mode to collect more data if self._states[cam_key]["ongoing"]: conf_th *= 0.8 # Get last predictions boxes = np.zeros((0, 5)) boxes = np.concatenate([boxes, preds]) for _, box, _, _, _ in self._states[cam_key]["last_predictions"]: if box.shape[0] > 0: boxes = np.concatenate([boxes, box]) conf = 0 output_predictions = np.zeros((0, 5)) # Get the best ones if boxes.shape[0]: best_boxes = nms(boxes) # We keep only detections with at least two boxes above conf_th detections = boxes[boxes[:, -1] > self.conf_thresh, :] ious_detections = box_iou(best_boxes[:, :4], detections[:, :4]) strong_detection = np.sum(ious_detections > 0, 0) > 1 best_boxes = best_boxes[strong_detection, :] if best_boxes.shape[0]: ious = box_iou(best_boxes[:, :4], boxes[:, :4]) best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T]) combine_predictions = best_boxes[best_boxes_scores > conf_th, :] conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds if len(combine_predictions): # send only preds boxes that match combine_predictions ious = box_iou(combine_predictions[:, :4], preds[:, :4]) iou_match = [np.max(iou) > 0 for iou in ious] output_predictions = preds[iou_match, :] # Limit bbox size for api output_predictions = np.round(output_predictions, 3) # max 3 digit output_predictions = output_predictions[:5, :] # max 5 bbox self._states[cam_key]["last_predictions"].append( (frame, preds, output_predictions.tolist(), datetime.now(timezone.utc).isoformat(), False) ) # update state if conf > self.conf_thresh: self._states[cam_key]["ongoing"] = True else: self._states[cam_key]["ongoing"] = False return conf
[docs] def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: """Computes the confidence that the image contains wildfire cues Args: frame: a PIL image cam_id: the name of the camera that sent this image Returns: the predicted confidence """ # Heartbeat if len(self.api_client) > 0 and isinstance(cam_id, str): heartbeat_with_timeout(self, cam_id, timeout=1) cam_key = cam_id or "-1" # Reduce image size to save bandwidth if isinstance(self.frame_size, tuple): frame = frame.resize(self.frame_size[::-1], getattr(Image, "BILINEAR")) # Inference with ONNX preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) print(preds) conf = self._update_states(frame, preds, cam_key) if self.save_captured_frames: self._local_backup(frame, cam_id, is_alert=False) # Log analysis result device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else "" pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire" logging.info(f"{device_str}{pred_str} (confidence: {conf:.2%})") # Alert if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str): # Save the alert in cache to avoid connection issues for idx, (frame, preds, localization, ts, is_staged) in enumerate( self._states[cam_key]["last_predictions"] ): if not is_staged: self._stage_alert(frame, cam_id, ts, localization) self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True # Check if it's time to backup pending alerts ts = datetime.now(timezone.utc) if ts > self.last_cache_dump + timedelta(minutes=self.cache_backup_period): self._dump_cache() self.last_cache_dump = ts return float(conf)
def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, localization: list) -> None: # Store information in the queue self._alerts.append( { "frame": frame, "cam_id": cam_id, "ts": ts, "media_id": None, "alert_id": None, "localization": localization, } ) def _process_alerts(self) -> None: for _ in range(len(self._alerts)): # try to upload the oldest element frame_info = self._alerts[0] cam_id = frame_info["cam_id"] logging.info(f"Camera '{cam_id}' - Sending alert from {frame_info['ts']}...") # Save alert on device self._local_backup(frame_info["frame"], cam_id) try: # Media creation if not isinstance(self._alerts[0]["media_id"], int): self._alerts[0]["media_id"] = self.api_client[cam_id].create_media_from_device().json()["id"] # Alert creation if not isinstance(self._alerts[0]["alert_id"], int): self._alerts[0]["alert_id"] = ( self.api_client[cam_id] .send_alert_from_device( lat=self.latitude, lon=self.longitude, media_id=self._alerts[0]["media_id"], localization=self._alerts[0]["localization"], ) .json()["id"] ) # Media upload stream = io.BytesIO() frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality) response = self.api_client[cam_id].upload_media( self._alerts[0]["media_id"], media_data=stream.getvalue(), ) # Force a KeyError if the request failed response.json()["id"] # Clear self._alerts.popleft() logging.info(f"Camera '{cam_id}' - alert sent") stream.seek(0) # "Rewind" the stream to the beginning so we can read its content except (KeyError, ConnectionError) as e: logging.warning(f"Camera '{cam_id}' - unable to upload cache") logging.warning(e) break def _local_backup(self, img: Image.Image, cam_id: Optional[str], is_alert: bool = True) -> None: """Save image on device Args: img (Image.Image): Image to save cam_id (str): camera id (ip address) is_alert (bool): is the frame an alert ? """ folder = "alerts" if is_alert else "save" backup_cache = self._cache.joinpath(f"backup/{folder}/") self._clean_local_backup(backup_cache) # Dump old cache backup_cache = backup_cache.joinpath(f"{time.strftime('%Y%m%d')}/{cam_id}") backup_cache.mkdir(parents=True, exist_ok=True) file = backup_cache.joinpath(f"{time.strftime('%Y%m%d-%H%M%S')}.jpg") img.save(file) def _clean_local_backup(self, backup_cache) -> None: """Clean local backup when it's bigger than _backup_size MB Args: backup_cache (Path): backup to clean """ backup_by_days = list(backup_cache.glob("*")) backup_by_days.sort() for folder in backup_by_days: s = ( sum( os.path.getsize(f) for f in glob.glob(str(backup_cache) + "/**/*", recursive=True) if os.path.isfile(f) ) // 1024**2 ) if s > self._backup_size: shutil.rmtree(folder) else: break