Source code for pyroengine.engine

# Copyright (C) 2022-2026, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import io
import logging
import shutil
import signal
import time
from collections import deque
from pathlib import Path
from typing import Any, Dict, Never, Optional, Tuple

import numpy as np
import requests
from PIL import Image
from pyro_predictor import Predictor
from pyroclient import client
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.models import Response

__all__ = ["Engine"]

logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)
logger = logging.getLogger(__name__)


def handler(_signum: int, _frame: object) -> Never:
    raise TimeoutError("Heartbeat check timed out")


def heartbeat_with_timeout(api_instance: Any, cam_id: str, timeout: int = 1) -> None:  # noqa: ANN401
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(timeout)
    try:
        api_instance.heartbeat(cam_id)
    except TimeoutError:
        logger.warning(f"Heartbeat check timed out for {cam_id}")
    except RequestsConnectionError:
        logger.warning(f"Unable to reach the pyro-api with {cam_id}")
    finally:
        signal.alarm(0)


[docs] class Engine(Predictor): """Manages predictions and API interactions for wildfire alerts. Extends Predictor with pyroclient API integration: heartbeats, image uploads, alert staging and caching. Args: hub_repo: repository on HF Hub to load the ONNX model from conf_thresh: confidence threshold to send an alert api_url: url of the pyronear API cam_creds: api credentials for each camera, the dictionary should be as the one in the example alert_relaxation: number of consecutive positive detections required to send the first alert, and also the number of consecutive negative detections before stopping the alert frame_size: Resize frame to frame_size before sending it to the api in order to save bandwidth (H, W) cache_backup_period: number of minutes between each cache backup to disk frame_saving_period: Send one frame over N to the api for our dataset cache_size: maximum number of alerts to save in cache day_time_strategy: strategy to define if it's daytime save_captured_frames: save all captured frames for debugging save_detections_frames: Save all locally detection frames locally kwargs: keyword args of Classifier Examples: >>> from pyroengine import Engine >>> cam_creds ={ >>> "cam_id_1": {'login':'log1', 'password':'pwd1'}, >>> "cam_id_2": {'login':'log2', 'password':'pwd2'}, >>> } >>> pyroEngine = Engine(None, 0.25, 'https://api.pyronear.org', cam_creds, 48.88, 2.38) """ def __init__( self, model_path: Optional[str] = None, conf_thresh: float = 0.35, model_conf_thresh: float = 0.05, max_bbox_size: float = 0.4, api_url: Optional[str] = None, cam_creds: Optional[Dict[str, Dict[str, str]]] = None, nb_consecutive_frames: int = 7, frame_size: Optional[Tuple[int, int]] = None, cache_backup_period: int = 60, frame_saving_period: Optional[int] = None, cache_size: int = 100, cache_folder: str = "data/", backup_size: int = 30, jpeg_quality: int = 80, day_time_strategy: Optional[str] = None, save_captured_frames: Optional[bool] = False, save_detections_frames: Optional[bool] = False, send_last_image_period: int = 3600, # 1H last_bbox_mask_fetch_period: int = 3600, # 1H **kwargs: Any, # noqa: ANN401 ) -> None: cam_ids = list(cam_creds.keys()) if isinstance(cam_creds, dict) else None super().__init__( model_path=model_path, conf_thresh=conf_thresh, model_conf_thresh=model_conf_thresh, max_bbox_size=max_bbox_size, nb_consecutive_frames=nb_consecutive_frames, frame_size=frame_size, cam_ids=cam_ids, **kwargs, ) # API Setup self.api_client: dict[str, Any] = {} if isinstance(api_url, str) and isinstance(cam_creds, dict): # Instantiate clients for each camera for id_, (camera_token, _, _) in cam_creds.items(): ip = id_.split("_")[0] if ip not in self.api_client: self.api_client[ip] = client.Client(camera_token, api_url) # Cache & relaxation self.frame_saving_period = frame_saving_period self.jpeg_quality = jpeg_quality self.cache_backup_period = cache_backup_period self.day_time_strategy = day_time_strategy self.save_captured_frames = save_captured_frames self.save_detections_frames = save_detections_frames self.cam_creds = cam_creds self.send_last_image_period = send_last_image_period self.last_bbox_mask_fetch_period = last_bbox_mask_fetch_period # Local backup self._backup_size = backup_size # Augment states with API-specific fields for state in self._states.values(): state["last_image_sent"] = None state["last_bbox_mask_fetch"] = None # Occlusion masks self.occlusion_masks: Dict[str, Tuple[Optional[str], Dict[Any, Any], str]] = {"-1": (None, {}, "0")} if isinstance(cam_creds, dict): for cam_id, (_, pose_id, bbox_mask_url) in cam_creds.items(): self.occlusion_masks[cam_id] = (bbox_mask_url, {}, pose_id) # Restore pending alerts cache self._alerts: deque = deque(maxlen=cache_size) self._cache = Path(cache_folder) # with Docker, the path has to be a bind volume if not self._cache.is_dir(): raise ValueError(f"Cache folder does not exist: {self._cache}") def _new_state(self) -> Dict[str, Any]: state = super()._new_state() state["last_image_sent"] = None state["last_bbox_mask_fetch"] = None return state
[docs] def heartbeat(self, cam_id: str) -> Response: """Updates last ping of device""" ip = cam_id.split("_")[0] return self.api_client[ip].heartbeat()
[docs] def predict( self, frame: Image.Image, cam_id: Optional[str] = None, occlusion_bboxes: Optional[Dict[Any, Any]] = None, # noqa: ARG002 fake_pred: Optional[np.ndarray] = None, ) -> float: """Computes the confidence that the image contains wildfire cues Args: frame: a PIL image cam_id: the name of the camera that sent this image occlusion_bboxes: ignored — Engine manages occlusion masks internally via URL fetch fake_pred: replace model prediction by another one for evaluation purposes, need to be given in onnx format: fake_pred = [[x1, x2] [y1, y2] [w1, w2] [h1, h2] [conf1, conf2]] Returns: the predicted confidence """ cam_key = cam_id or "-1" if cam_key not in self._states: self._states[cam_key] = self._new_state() # Reduce image size to save bandwidth if isinstance(self.frame_size, tuple): frame = frame.resize(self.frame_size[::-1], Image.BILINEAR) # type: ignore[attr-defined] # Heartbeat if len(self.api_client) > 0 and isinstance(cam_id, str): heartbeat_with_timeout(self, cam_id, timeout=1) if ( self._states[cam_key]["last_image_sent"] is None or time.time() - self._states[cam_key]["last_image_sent"] > self.send_last_image_period ): # send image periodically logger.info(f"Uploading periodical image for cam {cam_id}") self._states[cam_key]["last_image_sent"] = time.time() ip = cam_id.split("_")[0] if ip in self.api_client: stream = io.BytesIO() frame.save(stream, format="JPEG", quality=self.jpeg_quality) response = self.api_client[ip].update_last_image(stream.getvalue()) logger.info(response.text) # Update occlusion masks if ( self._states[cam_key]["last_bbox_mask_fetch"] is None or time.time() - self._states[cam_key]["last_bbox_mask_fetch"] > self.last_bbox_mask_fetch_period ): logger.info(f"Update occlusion masks for cam {cam_key}") self._states[cam_key]["last_bbox_mask_fetch"] = time.time() bbox_mask_url, bbox_mask_dict, pose_id = self.occlusion_masks[cam_key] if bbox_mask_url is not None: full_url = f"{bbox_mask_url}_{pose_id}.json" try: response = requests.get(full_url, timeout=5) response.raise_for_status() bbox_mask_dict = response.json() self.occlusion_masks[cam_key] = (bbox_mask_url, bbox_mask_dict, pose_id) logger.info(f"Downloaded occlusion masks for cam {cam_key} at {bbox_mask_url} :{bbox_mask_dict}") except requests.exceptions.RequestException: logger.info(f"No occlusion masks available for: {cam_key}") # Inference with ONNX if fake_pred is None: _, bbox_mask_dict, _ = self.occlusion_masks[cam_key] preds = self.model(frame.convert("RGB"), bbox_mask_dict) else: if fake_pred.size == 0: preds = np.empty((0, 5)) else: # Apply classifier post_process method for confidence filter and nms preds = self.model.post_process(fake_pred, pad=(0, 0)) # Filter predictions larger than max_bbox_size preds = preds[(preds[:, 2] - preds[:, 0]) < self.max_bbox_size, :] preds = np.reshape(preds, (-1, 5)) logger.info(f"pred for {cam_key} : {preds}") conf = self._update_states(frame, preds, cam_key) if self.save_captured_frames: self._local_backup(frame, cam_id, is_alert=False) # Log analysis result device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else "" pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire" logger.info(f"{device_str}{pred_str} (confidence: {conf:.2%})") # Alert if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str): # Save the alert in cache to avoid connection issues for idx, (frame, preds, bboxes, ts, is_staged) in enumerate(self._states[cam_key]["last_predictions"]): if not is_staged: self._stage_alert(frame, cam_id, ts, bboxes) self._states[cam_key]["last_predictions"][idx] = ( frame, preds, bboxes, ts, True, ) return float(conf)
def _stage_alert(self, frame: Image.Image, cam_id: str, ts: int, bboxes: list) -> None: # Store information in the queue self._alerts.append({ "frame": frame, "cam_id": cam_id, "ts": ts, "media_id": None, "alert_id": None, "bboxes": bboxes, }) def fill_empty_bboxes(self) -> None: cam_id_to_indices: Dict[str, list[int]] = {} for i, alert in enumerate(self._alerts): cam_id_to_indices.setdefault(alert["cam_id"], []).append(i) for indices in cam_id_to_indices.values(): non_empty_indices = [i for i in indices if self._alerts[i]["bboxes"]] if not non_empty_indices: continue for i in indices: if not self._alerts[i]["bboxes"]: closest_index = min(non_empty_indices, key=lambda x: abs(x - i)) src = np.array(self._alerts[closest_index]["bboxes"], dtype=float) if src.size == 0: continue filled = src.copy() filled[:, -1] = 0.0 # force confidence to 0 for duplicated boxes self._alerts[i]["bboxes"] = [tuple(row) for row in filled.tolist()] def _process_alerts(self) -> None: if self.cam_creds is not None: self.fill_empty_bboxes() for _ in range(len(self._alerts)): # try to upload the oldest element frame_info = self._alerts[0] cam_id = frame_info["cam_id"] logger.info(f"Camera '{cam_id}' - Sending alert from {frame_info['ts']}...") # Save alert on device if self.save_detections_frames: self._local_backup(frame_info["frame"], cam_id) try: # Detection creation bboxes = self._alerts[0]["bboxes"] if not bboxes: logger.warning(f"Camera '{cam_id}' - skipping alert with empty bboxes") self._alerts.popleft() continue stream = io.BytesIO() frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality) bboxes = [tuple(bboxe) for bboxe in bboxes] _, pose_id, _ = self.cam_creds[cam_id] ip = cam_id.split("_")[0] response = self.api_client[ip].create_detection(stream.getvalue(), bboxes, pose_id) try: response.json()["id"] except ValueError: logger.error(f"Camera '{cam_id}' - non-JSON response body: {response.text}") raise # Clear self._alerts.popleft() logger.info(f"Camera '{cam_id}' - alert sent") stream.seek(0) # "Rewind" the stream to the beginning so we can read its content except (KeyError, RequestsConnectionError, ValueError) as e: logger.error(f"Camera '{cam_id}' - unable to upload cache") logger.error(e) break def _local_backup(self, img: Image.Image, cam_id: Optional[str], is_alert: bool = True) -> None: """Save image on device Args: img (Image.Image): Image to save cam_id (str): camera id (ip address) is_alert (bool): is the frame an alert ? """ folder = "alerts" if is_alert else "save" backup_cache = self._cache.joinpath(f"backup/{folder}/") self._clean_local_backup(backup_cache) # Dump old cache backup_cache = backup_cache.joinpath(f"{time.strftime('%Y%m%d')}/{cam_id}") backup_cache.mkdir(parents=True, exist_ok=True) file = backup_cache.joinpath(f"{time.strftime('%Y%m%d-%H%M%S')}.jpg") img.save(file) def _clean_local_backup(self, backup_cache: Path) -> None: """Clean local backup when it's bigger than _backup_size MB Args: backup_cache (Path): backup to clean """ backup_by_days = list(backup_cache.glob("*")) backup_by_days.sort() for folder in backup_by_days: s = sum(f.stat().st_size for f in backup_cache.rglob("*") if f.is_file()) // 1024**2 if s > self._backup_size: shutil.rmtree(folder) else: break