Source code for pyrovision.models.utils
# Copyright (C) 2022, Pyronear.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
# Borrowed from https://github.com/frgfm/Holocron/blob/main/holocron/models/utils.py
import json
from typing import Any
import torch
import torch.nn as nn
from huggingface_hub.file_download import hf_hub_download
from pyrovision import models
__all__ = ["model_from_hf_hub"]
[docs]def model_from_hf_hub(repo_id: str, **kwargs: Any) -> nn.Module:
"""Instantiate & load a pretrained model from HF hub.
Args:
repo_id: HuggingFace model hub repo
kwargs: kwargs of `hf_hub_download`
Returns:
Model loaded with the checkpoint
"""
# Get the config
with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
cfg = json.load(f)
model = models.__dict__[cfg["arch"]](num_classes=len(cfg["classes"]), pretrained=False)
# Patch the config
model.default_cfg.update(cfg)
# Load the checkpoint
state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
model.load_state_dict(state_dict)
return model