# Copyright (C) 2022-2026, Pyronear.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import io
import logging
import shutil
import signal
import time
from collections import deque
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Never, Optional, Tuple
import numpy as np
from PIL import Image
from pyro_predictor import Predictor
from pyroclient import client
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import RequestException
from requests.models import Response
__all__ = ["Engine"]
logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)
logger = logging.getLogger(__name__)
def handler(_signum: int, _frame: object) -> Never:
raise TimeoutError("Heartbeat check timed out")
def heartbeat_with_timeout(api_instance: Any, cam_id: str, timeout: int = 1) -> None: # noqa: ANN401
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
api_instance.heartbeat(cam_id)
except TimeoutError:
logger.warning(f"Heartbeat check timed out for {cam_id}")
except RequestsConnectionError:
logger.warning(f"Unable to reach the pyro-api with {cam_id}")
finally:
signal.alarm(0)
[docs]
class Engine(Predictor):
"""Manages predictions and API interactions for wildfire alerts.
Extends Predictor with pyroclient API integration: heartbeats, image uploads, alert staging and caching.
Args:
hub_repo: repository on HF Hub to load the ONNX model from
conf_thresh: confidence threshold to send an alert
api_url: url of the pyronear API
cam_creds: api credentials for each camera, the dictionary should be as the one in the example
alert_relaxation: number of consecutive positive detections required to send the first alert, and also
the number of consecutive negative detections before stopping the alert
frame_size: Resize frame to frame_size before sending it to the api in order to save bandwidth (H, W)
cache_backup_period: number of minutes between each cache backup to disk
frame_saving_period: Send one frame over N to the api for our dataset
cache_size: maximum number of alerts to save in cache
day_time_strategy: strategy to define if it's daytime
save_captured_frames: save all captured frames for debugging
save_detections_frames: Save all locally detection frames locally
kwargs: keyword args of Classifier
Examples:
>>> from pyroengine import Engine
>>> cam_creds ={
>>> "cam_id_1": {'login':'log1', 'password':'pwd1'},
>>> "cam_id_2": {'login':'log2', 'password':'pwd2'},
>>> }
>>> pyroEngine = Engine(None, 0.25, 'https://api.pyronear.org', cam_creds, 48.88, 2.38)
"""
def __init__(
self,
model_path: Optional[str] = None,
conf_thresh: float = 0.35,
model_conf_thresh: float = 0.05,
max_bbox_size: float = 0.4,
api_url: Optional[str] = None,
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
nb_consecutive_frames: int = 5,
frame_size: Optional[Tuple[int, int]] = None,
cache_backup_period: int = 60,
frame_saving_period: Optional[int] = None,
cache_size: int = 100,
cache_folder: str = "data/",
backup_size: int = 30,
jpeg_quality: int = 80,
day_time_strategy: Optional[str] = None,
save_captured_frames: Optional[bool] = False,
save_detections_frames: Optional[bool] = False,
send_last_image_period: int = 3600, # 1H
last_bbox_mask_fetch_period: int = 3600, # 1H
**kwargs: Any, # noqa: ANN401
) -> None:
cam_ids = list(cam_creds.keys()) if isinstance(cam_creds, dict) else None
super().__init__(
model_path=model_path,
conf_thresh=conf_thresh,
model_conf_thresh=model_conf_thresh,
max_bbox_size=max_bbox_size,
nb_consecutive_frames=nb_consecutive_frames,
frame_size=frame_size,
cam_ids=cam_ids,
**kwargs,
)
# API Setup
self.api_client: dict[str, Any] = {}
if isinstance(api_url, str) and isinstance(cam_creds, dict):
# Instantiate clients for each camera
for id_, (camera_token, _) in cam_creds.items():
ip = id_.split("_")[0]
if ip not in self.api_client:
self.api_client[ip] = client.Client(camera_token, api_url)
# Cache & relaxation
self.frame_saving_period = frame_saving_period
self.jpeg_quality = jpeg_quality
self.cache_backup_period = cache_backup_period
self.day_time_strategy = day_time_strategy
self.save_captured_frames = save_captured_frames
self.save_detections_frames = save_detections_frames
self.cam_creds = cam_creds
self.send_last_image_period = send_last_image_period
self.last_bbox_mask_fetch_period = last_bbox_mask_fetch_period
# Local backup
self._backup_size = backup_size
# Augment states with API-specific fields
for state in self._states.values():
state["last_image_sent"] = None
state["last_bbox_mask_fetch"] = None
state["last_pose_image_sent"] = None
# Occlusion masks: cam_id -> dict of bboxes (keyed by mask id)
self.occlusion_masks: Dict[str, Dict[Any, Any]] = {}
# Restore pending alerts cache
self._alerts: deque = deque(maxlen=cache_size)
self._cache = Path(cache_folder) # with Docker, the path has to be a bind volume
if not self._cache.is_dir():
raise ValueError(f"Cache folder does not exist: {self._cache}")
def _new_state(self) -> Dict[str, Any]:
state = super()._new_state()
state["last_image_sent"] = None
state["last_bbox_mask_fetch"] = None
state["last_pose_image_sent"] = None
return state
[docs]
def heartbeat(self, cam_id: str) -> Response:
"""Updates last ping of device"""
ip = cam_id.split("_")[0]
return self.api_client[ip].heartbeat()
[docs]
def predict(
self,
frame: Image.Image,
cam_id: Optional[str] = None,
occlusion_bboxes: Optional[Dict[Any, Any]] = None, # noqa: ARG002
fake_pred: Optional[np.ndarray] = None,
) -> float:
"""Computes the confidence that the image contains wildfire cues
Args:
frame: a PIL image
cam_id: the name of the camera that sent this image
occlusion_bboxes: ignored — Engine manages occlusion masks internally via URL fetch
fake_pred: replace model prediction by another one for evaluation purposes, need to be given in onnx format:
fake_pred = [[x1, x2]
[y1, y2]
[w1, w2]
[h1, h2]
[conf1, conf2]]
Returns:
the predicted confidence
"""
cam_key = cam_id or "-1"
if cam_key not in self._states:
self._states[cam_key] = self._new_state()
# Reduce image size to save bandwidth
if isinstance(self.frame_size, tuple):
target = (self.frame_size[1], self.frame_size[0]) # PIL expects (W, H)
if frame.size != target:
frame = frame.resize(target, Image.BILINEAR) # type: ignore[attr-defined]
# Canonical bytes: encode once, decode for inference, reuse for every upload/backup.
# Guarantees model and stored media see byte-identical input.
buf = io.BytesIO()
frame.save(buf, format="JPEG", quality=self.jpeg_quality)
encoded_bytes = buf.getvalue()
buf.seek(0)
frame = Image.open(buf).convert("RGB")
# Heartbeat
if len(self.api_client) > 0 and isinstance(cam_id, str):
heartbeat_with_timeout(self, cam_id, timeout=1)
if (
self._states[cam_key]["last_image_sent"] is None
or time.time() - self._states[cam_key]["last_image_sent"] > self.send_last_image_period
):
# send image periodically
logger.info(f"Uploading periodical image for cam {cam_id}")
self._states[cam_key]["last_image_sent"] = time.time()
ip = cam_id.split("_")[0]
if ip in self.api_client:
response = self.api_client[ip].update_last_image(encoded_bytes)
logger.info(response.text)
# Send one pose image per day at 12:00
if isinstance(self.cam_creds, dict) and cam_id in self.cam_creds:
now = datetime.now()
today_noon = now.replace(hour=12, minute=0, second=0, microsecond=0)
last_pose_sent = self._states[cam_key]["last_pose_image_sent"]
if now >= today_noon and (last_pose_sent is None or last_pose_sent < today_noon):
_, pose_id = self.cam_creds[cam_id]
ip = cam_id.split("_")[0]
if ip in self.api_client:
logger.info(f"Uploading daily pose image for cam {cam_id} (pose {pose_id})")
self._states[cam_key]["last_pose_image_sent"] = now
response = self.api_client[ip].update_pose_image(pose_id, encoded_bytes)
logger.info(response.text)
# Update occlusion masks from API
if (
self._states[cam_key]["last_bbox_mask_fetch"] is None
or time.time() - self._states[cam_key]["last_bbox_mask_fetch"] > self.last_bbox_mask_fetch_period
):
logger.info(f"Update occlusion masks for cam {cam_key}")
self._states[cam_key]["last_bbox_mask_fetch"] = time.time()
if isinstance(cam_id, str) and isinstance(self.cam_creds, dict) and cam_id in self.cam_creds:
_, pose_id = self.cam_creds[cam_id]
ip = cam_id.split("_")[0]
if ip in self.api_client:
try:
response = self.api_client[ip].list_pose_masks(pose_id)
response.raise_for_status()
masks_data = response.json()
bbox_mask_dict: Dict[Any, Any] = {}
for mask_entry in masks_data:
mask_str = mask_entry["mask"].strip("()")
coords = tuple(float(c) for c in mask_str.split(","))
bbox_mask_dict[str(mask_entry["id"])] = coords
self.occlusion_masks[cam_key] = bbox_mask_dict
logger.info(f"Downloaded occlusion masks for cam {cam_key}: {bbox_mask_dict}")
except RequestException as e:
logger.warning(f"Failed to fetch occlusion masks for cam {cam_key} (pose {pose_id}): {e}")
# Inference with ONNX
if fake_pred is None:
bbox_mask_dict = self.occlusion_masks.get(cam_key, {})
preds = self.model(frame.convert("RGB"), bbox_mask_dict)
else:
if fake_pred.size == 0:
preds = np.empty((0, 5))
else:
# Apply classifier post_process method for confidence filter and nms
preds = self.model.post_process(fake_pred, pad=(0, 0))
# Filter predictions larger than max_bbox_size
preds = preds[(preds[:, 2] - preds[:, 0]) < self.max_bbox_size, :]
preds = np.reshape(preds, (-1, 5))
logger.info(f"pred for {cam_key} : {preds}")
conf = self._update_states(frame, preds, cam_key, encoded_bytes=encoded_bytes)
if self.save_captured_frames:
self._local_backup(frame, cam_id, is_alert=False, encoded_bytes=encoded_bytes)
# Log analysis result
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire"
logger.info(f"{device_str}{pred_str} (confidence: {conf:.2%})")
# Alert
if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
for idx, (frame_, preds_, bboxes, ts, is_staged, jpeg_bytes) in enumerate(
self._states[cam_key]["last_predictions"]
):
if not is_staged:
self._stage_alert(frame_, cam_id, ts, bboxes, jpeg_bytes)
self._states[cam_key]["last_predictions"][idx] = (
frame_,
preds_,
bboxes,
ts,
True,
jpeg_bytes,
)
return float(conf)
def _stage_alert(
self,
frame: Image.Image,
cam_id: str,
ts: int,
bboxes: list,
jpeg_bytes: Optional[bytes] = None,
) -> None:
# Store information in the queue
self._alerts.append({
"frame": frame,
"cam_id": cam_id,
"ts": ts,
"media_id": None,
"alert_id": None,
"bboxes": bboxes,
"jpeg_bytes": jpeg_bytes,
})
def fill_empty_bboxes(self) -> None:
cam_id_to_indices: Dict[str, list[int]] = {}
for i, alert in enumerate(self._alerts):
cam_id_to_indices.setdefault(alert["cam_id"], []).append(i)
for indices in cam_id_to_indices.values():
non_empty_indices = [i for i in indices if self._alerts[i]["bboxes"]]
if not non_empty_indices:
continue
for i in indices:
if not self._alerts[i]["bboxes"]:
closest_index = min(non_empty_indices, key=lambda x: abs(x - i))
src = np.array(self._alerts[closest_index]["bboxes"], dtype=float)
if src.size == 0:
continue
filled = src.copy()
filled[:, -1] = 0.0 # force confidence to 0 for duplicated boxes
self._alerts[i]["bboxes"] = [tuple(row) for row in filled.tolist()]
def _process_alerts(self) -> None:
if self.cam_creds is not None:
self.fill_empty_bboxes()
for _ in range(len(self._alerts)):
# try to upload the oldest element
frame_info = self._alerts[0]
cam_id = frame_info["cam_id"]
logger.info(f"Camera '{cam_id}' - Sending alert from {frame_info['ts']}...")
# Save alert on device
if self.save_detections_frames:
self._local_backup(
frame_info["frame"],
cam_id,
encoded_bytes=frame_info.get("jpeg_bytes"),
)
try:
# Detection creation
bboxes = self._alerts[0]["bboxes"]
if not bboxes:
logger.warning(f"Camera '{cam_id}' - skipping alert with empty bboxes")
self._alerts.popleft()
continue
jpeg_bytes = frame_info.get("jpeg_bytes")
if jpeg_bytes is None:
# Fallback for cached alerts staged before this version
stream = io.BytesIO()
frame_info["frame"].save(stream, format="JPEG", quality=self.jpeg_quality)
jpeg_bytes = stream.getvalue()
bboxes = [tuple(bboxe) for bboxe in bboxes]
_, pose_id = self.cam_creds[cam_id]
ip = cam_id.split("_")[0]
response = self.api_client[ip].create_detection(jpeg_bytes, bboxes, pose_id)
try:
response.json()["id"]
except ValueError:
logger.error(f"Camera '{cam_id}' - non-JSON response body: {response.text}")
raise
# Clear
self._alerts.popleft()
logger.info(f"Camera '{cam_id}' - alert sent")
except (KeyError, RequestsConnectionError, ValueError) as e:
logger.error(f"Camera '{cam_id}' - unable to upload cache")
logger.error(e)
break
def _local_backup(
self,
img: Image.Image,
cam_id: Optional[str],
is_alert: bool = True,
encoded_bytes: Optional[bytes] = None,
) -> None:
"""Save image on device
Args:
img (Image.Image): Image to save
cam_id (str): camera id (ip address)
is_alert (bool): is the frame an alert ?
encoded_bytes: pre-encoded JPEG bytes — written verbatim when provided so the
on-disk file is byte-identical to what was scored / uploaded.
"""
folder = "alerts" if is_alert else "save"
backup_cache = self._cache.joinpath(f"backup/{folder}/")
self._clean_local_backup(backup_cache) # Dump old cache
backup_cache = backup_cache.joinpath(f"{time.strftime('%Y%m%d')}/{cam_id}")
backup_cache.mkdir(parents=True, exist_ok=True)
file = backup_cache.joinpath(f"{time.strftime('%Y%m%d-%H%M%S')}.jpg")
if encoded_bytes is not None:
file.write_bytes(encoded_bytes)
else:
img.save(file)
def _clean_local_backup(self, backup_cache: Path) -> None:
"""Clean local backup when it's bigger than _backup_size MB
Args:
backup_cache (Path): backup to clean
"""
backup_by_days = list(backup_cache.glob("*"))
backup_by_days.sort()
for folder in backup_by_days:
s = sum(f.stat().st_size for f in backup_cache.rglob("*") if f.is_file()) // 1024**2
if s > self._backup_size:
shutil.rmtree(folder)
else:
break