pyrovision.datasets#
All datasets are subclasses of torchvision.datasets.vision.VisionDataset
i.e, they have __getitem__
and __len__
methods implemented.
Hence, they can all be passed to a torch.utils.data.DataLoader
which can load multiple samples parallelly using torch.multiprocessing
workers.
OpenFire#
An image classification dataset for wildfire in natural environments, built using Google Images referenced data.
- class pyrovision.datasets.OpenFire(root: str, train: bool = True, download: bool = False, validate_images: bool = True, num_samples: Optional[int] = None, num_threads: Optional[int] = None, prefetch_fn: Optional[Callable[[Tuple[Union[Path, str], Union[Path, str]]], None]] = None, **kwargs: Any)[source]#
Implements an image classification dataset for wildfire detection, collected from web searches.
>>> from pyrovision.datasets import OpenFire >>> train_set = OpenFire("path/to/your/folder", train=True, download=True) >>> img, target = train_set[0]
- Parameters:
root – Root directory where ‘OpenFire’ is located.
train – If True, returns training subset, else validation set.
download – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
num_samples – Number of samples to download (all by default)
num_threads – If download is set to True, use this amount of threads for downloading the dataset.
prefetch_fn – optional function that will be applied to all images before data loading
**kwargs – optional arguments of torchvision.datasets.VisionDataset