Source code for pyro_risks.pipeline.load

# Copyright (C) 2021, Pyronear contributors.

# This program is licensed under the GNU Affero General Public License version 3.
# See LICENSE or go to <https://www.gnu.org/licenses/agpl-3.0.txt> for full license details.

from typing import Optional, List
from pyro_risks.datasets.utils import download
from datetime import datetime
from typing import Tuple

import pyro_risks.config as cfg
import pandas as pd
import os

__all__ = ["load_dataset"]


[docs]def load_dataset( url: Optional[str] = None, path: Optional[str] = None, usecols: Optional[List[str]] = None, pipeline_cols: Optional[List[str]] = None, destination: str = None, ) -> Tuple[pd.DataFrame, pd.Series]: """ Load Pyro Risks training datasets. Download and load Pyro Risks training datasets. Args: url: Training dataset URL. Defaults to None. path: Dataset full path. Defaults to None. usecols: Subset of the dataset columns. Defaults to None. pipeline_cols: Subset of the dataset used for training. Defaults to None. destination: folder where the dataset should be saved. Defaults to None. Returns: Tuple[pd.DataFrame, pd.Series] """ url = cfg.ERA5T_VIIRS_PIPELINE if url is None else url path = os.path.join(cfg.DATA_REGISTRY, cfg.DATASET) if path is None else path usecols = ( [cfg.DATE_VAR, cfg.ZONE_VAR, cfg.TARGET] + cfg.PIPELINE_ERA5T_VARS if usecols is None else usecols ) pipeline_cols = ( [cfg.DATE_VAR, cfg.ZONE_VAR] + cfg.PIPELINE_ERA5T_VARS if pipeline_cols is None else pipeline_cols ) destination = cfg.DATA_REGISTRY if destination is None else destination if not os.path.isfile(path): download(url=url, default_extension="csv", unzip=False, destination=destination) df = pd.read_csv(path, usecols=usecols) df["day"] = df["day"].apply( lambda x: datetime.strptime(str(x), "%Y-%m-%d") if not pd.isnull(x) else x ) X = df[pipeline_cols] y = df[cfg.TARGET] return X, y