Skip to content

zamba.models.model_manager

Classes

ModelManager

Mediates loading, configuration, and logic of model calls.

Parameters:

Name Type Description Default
config ModelConfig

Instantiated ModelConfig.

required
__init__(self, config: ModelConfig) special
Source code in zamba/models/model_manager.py
def __init__(self, config: ModelConfig):
    self.config = config
from_yaml(config) classmethod
Source code in zamba/models/model_manager.py
@classmethod
def from_yaml(cls, config):
    if not isinstance(config, ModelConfig):
        config = ModelConfig.parse_file(config)
    return cls(config)
predict(self)
Source code in zamba/models/model_manager.py
def predict(self):
    predict_model(
        predict_config=self.config.predict_config,
        video_loader_config=self.config.video_loader_config,
    )
train(self)
Source code in zamba/models/model_manager.py
def train(self):
    train_model(
        train_config=self.config.train_config,
        video_loader_config=self.config.video_loader_config,
    )

Functions

instantiate_model(checkpoint: Union[os.PathLike, str], weight_download_region: RegionEnum, scheduler_config: Optional[zamba.models.config.SchedulerConfig], model_cache_dir: Optional[os.PathLike], labels: Optional[pandas.core.frame.DataFrame], from_scratch: bool = False, model_name: Optional[zamba.models.config.ModelEnum] = None, predict_all_zamba_species: bool = True) -> ZambaVideoClassificationLightningModule

Instantiates the model from a checkpoint and detects whether the model head should be replaced. The model head is replaced if labels contain species that are not on the model or predict_all_zamba_species=False.

Supports model instantiation for the following cases: - train from scratch (from_scratch=True) - finetune with new species (from_scratch=False, labels contains different species than model) - finetune with a subset of zamba species and output only the species in the labels file (predict_all_zamba_species=False) - finetune with a subset of zamba species but output all zamba species (predict_all_zamba_species=True) - predict using pretrained model (labels=None)

Parameters:

Name Type Description Default
checkpoint path or str

Either the path to a checkpoint on disk or the name of a checkpoint file in the S3 bucket, i.e., one that is discoverable by download_weights.

required
weight_download_region RegionEnum

Server region for downloading weights.

required
scheduler_config SchedulerConfig

SchedulerConfig to use for training or finetuning. Only used if labels is not None.

required
model_cache_dir path

Directory in which to store pretrained model weights.

required
labels pd.DataFrame

Dataframe where filepath is the index and columns are one hot encoded species.

required
from_scratch bool

Whether to instantiate the model with base weights. This means starting from the imagenet weights for image based models and the Kinetics weights for video models. Defaults to False. Only used if labels is not None.

False
model_name ModelEnum

Model name used to look up default hparams used for that model. Only relevant if training from scratch.

None
predict_all_zamba_species(bool)

Whether the species outputted by the model should be all zamba species. If you want the model classes to only be the species in your labels file, set to False. Defaults to True. Only used if labels is not None.

required

Returns:

Type Description
ZambaVideoClassificationLightningModule

Instantiated model

Source code in zamba/models/model_manager.py
def instantiate_model(
    checkpoint: Union[os.PathLike, str],
    weight_download_region: RegionEnum,
    scheduler_config: Optional[SchedulerConfig],
    model_cache_dir: Optional[os.PathLike],
    labels: Optional[pd.DataFrame],
    from_scratch: bool = False,
    model_name: Optional[ModelEnum] = None,
    predict_all_zamba_species: bool = True,
) -> ZambaVideoClassificationLightningModule:
    """Instantiates the model from a checkpoint and detects whether the model head should be replaced.
    The model head is replaced if labels contain species that are not on the model or predict_all_zamba_species=False.

    Supports model instantiation for the following cases:
    - train from scratch (from_scratch=True)
    - finetune with new species (from_scratch=False, labels contains different species than model)
    - finetune with a subset of zamba species and output only the species in the labels file (predict_all_zamba_species=False)
    - finetune with a subset of zamba species but output all zamba species (predict_all_zamba_species=True)
    - predict using pretrained model (labels=None)

    Args:
        checkpoint (path or str): Either the path to a checkpoint on disk or the name of a
            checkpoint file in the S3 bucket, i.e., one that is discoverable by `download_weights`.
        weight_download_region (RegionEnum): Server region for downloading weights.
        scheduler_config (SchedulerConfig, optional): SchedulerConfig to use for training or finetuning.
            Only used if labels is not None.
        model_cache_dir (path, optional): Directory in which to store pretrained model weights.
        labels (pd.DataFrame, optional): Dataframe where filepath is the index and columns are one hot encoded species.
        from_scratch (bool): Whether to instantiate the model with base weights. This means starting
            from the imagenet weights for image based models and the Kinetics weights for video models.
            Defaults to False. Only used if labels is not None.
        model_name (ModelEnum, optional): Model name used to look up default hparams used for that model.
            Only relevant if training from scratch.
        predict_all_zamba_species(bool): Whether the species outputted by the model should be all zamba species.
            If you want the model classes to only be the species in your labels file, set to False.
            Defaults to True. Only used if labels is not None.

    Returns:
        ZambaVideoClassificationLightningModule: Instantiated model
    """
    if from_scratch:
        # get hparams from official model
        with (MODELS_DIRECTORY / f"{model_name}/hparams.yaml").open() as f:
            hparams = yaml.safe_load(f)

    else:
        # download if neither local checkpoint nor cached checkpoint exist
        if not checkpoint.exists() and not (model_cache_dir / checkpoint).exists():
            logger.info("Downloading weights for model.")
            checkpoint = download_weights(
                filename=str(checkpoint),
                weight_region=weight_download_region,
                destination_dir=model_cache_dir,
            )

        hparams = torch.load(checkpoint, map_location=torch.device("cpu"))["hyper_parameters"]

    model_class = available_models[hparams["model_class"]]

    logger.info(f"Instantiating model: {model_class.__name__}")

    if labels is None:
        # predict; load from checkpoint uses associated hparams
        logger.info("Loading from checkpoint.")
        return model_class.load_from_checkpoint(checkpoint_path=checkpoint)

    # get species from labels file
    species = labels.filter(regex=r"^species_").columns.tolist()
    species = [s.split("species_", 1)[1] for s in species]

    # check if species in label file are a subset of pretrained model species
    is_subset = set(species).issubset(set(hparams["species"]))

    # train from scratch
    if from_scratch:
        logger.info("Training from scratch.")

        # default would use scheduler used for pretrained model
        if scheduler_config != "default":
            hparams.update(scheduler_config.dict())

        hparams.update({"species": species})
        model = model_class(**hparams)

    # replace the head
    elif not predict_all_zamba_species or not is_subset:

        if not predict_all_zamba_species:
            logger.info(
                "Limiting only to species in labels file. Replacing model head and finetuning."
            )
        else:
            logger.info(
                "Provided species do not fully overlap with Zamba species. Replacing model head and finetuning."
            )

        # update in case we want to finetune with different scheduler
        if scheduler_config != "default":
            hparams.update(scheduler_config.dict())

        hparams.update({"species": species})
        model = model_class(finetune_from=checkpoint, **hparams)

    # resume training; add additional species columns to labels file if needed
    elif is_subset:
        logger.info(
            "Provided species fully overlap with Zamba species. Resuming training from latest checkpoint."
        )
        # update in case we want to resume with different scheduler
        if scheduler_config != "default":
            hparams.update(scheduler_config.dict())

        model = model_class.load_from_checkpoint(checkpoint_path=checkpoint, **hparams)

        # add in remaining columns for species that are not present
        for c in set(hparams["species"]).difference(set(species)):
            # labels are still OHE at this point
            labels[f"species_{c}"] = 0

        # sort columns so columns on dataloader are the same as columns on model
        labels.sort_index(axis=1, inplace=True)

    logger.info(f"Using learning rate scheduler: {model.hparams['scheduler']}")
    logger.info(f"Using scheduler params: {model.hparams['scheduler_params']}")

    return model

predict_model(predict_config: PredictConfig, video_loader_config: VideoLoaderConfig = None)

Predicts from a model and writes out predictions to a csv.

Parameters:

Name Type Description Default
predict_config PredictConfig

Pydantic config for performing inference.

required
video_loader_config VideoLoaderConfig

Pydantic config for preprocessing videos. If None, will use default for model specified in PredictConfig.

None
Source code in zamba/models/model_manager.py
def predict_model(
    predict_config: PredictConfig,
    video_loader_config: VideoLoaderConfig = None,
):
    """Predicts from a model and writes out predictions to a csv.

    Args:
        predict_config (PredictConfig): Pydantic config for performing inference.
        video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
            If None, will use default for model specified in PredictConfig.
    """
    # get default VLC for model if not specified
    if video_loader_config is None:
        video_loader_config = ModelConfig(
            predict_config=predict_config, video_loader_config=video_loader_config
        ).video_loader_config

    # set up model
    model = instantiate_model(
        checkpoint=predict_config.checkpoint,
        weight_download_region=predict_config.weight_download_region,
        model_cache_dir=predict_config.model_cache_dir,
        scheduler_config=None,
        labels=None,
    )

    data_module = ZambaDataModule(
        video_loader_config=video_loader_config,
        transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
        predict_metadata=predict_config.filepaths,
        batch_size=predict_config.batch_size,
        num_workers=predict_config.num_workers,
    )

    validate_species(model, data_module)

    if video_loader_config.cache_dir is None:
        logger.info("No cache dir is specified. Videos will not be cached.")
    else:
        logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")

    trainer = pl.Trainer(
        gpus=predict_config.gpus, logger=False, fast_dev_run=predict_config.dry_run
    )

    configuration = {
        "model_class": model.model_class,
        "species": model.species,
        "predict_config": json.loads(predict_config.json(exclude={"filepaths"})),
        "inference_start_time": datetime.utcnow().isoformat(),
        "video_loader_config": json.loads(video_loader_config.json()),
    }

    if predict_config.save is not False:

        config_path = predict_config.save_dir / "predict_configuration.yaml"
        logger.info(f"Writing out full configuration to {config_path}.")
        with config_path.open("w") as fp:
            yaml.dump(configuration, fp)

    dataloader = data_module.predict_dataloader()
    logger.info("Starting prediction...")
    probas = trainer.predict(model=model, dataloaders=dataloader)

    df = pd.DataFrame(
        np.vstack(probas), columns=model.species, index=dataloader.dataset.original_indices
    )

    # change output format if specified
    if predict_config.proba_threshold is not None:
        df = (df > predict_config.proba_threshold).astype(int)

    elif predict_config.output_class_names:
        df = df.idxmax(axis=1)

    else:  # round to a useful number of places
        df = df.round(5)

    if predict_config.save is not False:

        preds_path = predict_config.save_dir / "zamba_predictions.csv"
        logger.info(f"Saving out predictions to {preds_path}.")
        with preds_path.open("w") as fp:
            df.to_csv(fp, index=True)

    return df

train_model(train_config: TrainConfig, video_loader_config: Optional[zamba.data.video.VideoLoaderConfig] = None)

Trains a model.

Parameters:

Name Type Description Default
train_config TrainConfig

Pydantic config for training.

required
video_loader_config VideoLoaderConfig

Pydantic config for preprocessing videos. If None, will use default for model specified in TrainConfig.

None
Source code in zamba/models/model_manager.py
def train_model(
    train_config: TrainConfig,
    video_loader_config: Optional[VideoLoaderConfig] = None,
):
    """Trains a model.

    Args:
        train_config (TrainConfig): Pydantic config for training.
        video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
            If None, will use default for model specified in TrainConfig.
    """
    # get default VLC for model if not specified
    if video_loader_config is None:
        video_loader_config = ModelConfig(
            train_config=train_config, video_loader_config=video_loader_config
        ).video_loader_config

    # set up model
    model = instantiate_model(
        checkpoint=train_config.checkpoint,
        scheduler_config=train_config.scheduler_config,
        weight_download_region=train_config.weight_download_region,
        model_cache_dir=train_config.model_cache_dir,
        labels=train_config.labels,
        from_scratch=train_config.from_scratch,
        model_name=train_config.model_name,
        predict_all_zamba_species=train_config.predict_all_zamba_species,
    )

    data_module = ZambaDataModule(
        video_loader_config=video_loader_config,
        transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
        train_metadata=train_config.labels,
        batch_size=train_config.batch_size,
        num_workers=train_config.num_workers,
    )

    validate_species(model, data_module)

    train_config.save_dir.mkdir(parents=True, exist_ok=True)

    # add folder version_n that auto increments if we are not overwriting
    tensorboard_version = train_config.save_dir.name if train_config.overwrite else None
    tensorboard_save_dir = (
        train_config.save_dir.parent if train_config.overwrite else train_config.save_dir
    )

    tensorboard_logger = TensorBoardLogger(
        save_dir=tensorboard_save_dir,
        name=None,
        version=tensorboard_version,
        default_hp_metric=False,
    )

    logging_and_save_dir = (
        tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir
    )

    model_checkpoint = ModelCheckpoint(
        dirpath=logging_and_save_dir,
        filename=train_config.model_name,
        monitor=train_config.early_stopping_config.monitor,
        mode=train_config.early_stopping_config.mode,
    )

    callbacks = [model_checkpoint]

    if train_config.early_stopping_config is not None:
        callbacks.append(EarlyStopping(**train_config.early_stopping_config.dict()))

    if train_config.backbone_finetune_config is not None:
        callbacks.append(BackboneFinetuning(**train_config.backbone_finetune_config.dict()))

    trainer = pl.Trainer(
        gpus=train_config.gpus,
        max_epochs=train_config.max_epochs,
        auto_lr_find=train_config.auto_lr_find,
        logger=tensorboard_logger,
        callbacks=callbacks,
        fast_dev_run=train_config.dry_run,
        accelerator="ddp" if data_module.multiprocessing_context is not None else None,
        plugins=DDPPlugin(find_unused_parameters=False)
        if data_module.multiprocessing_context is not None
        else None,
    )

    if video_loader_config.cache_dir is None:
        logger.info("No cache dir is specified. Videos will not be cached.")
    else:
        logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")

    if train_config.auto_lr_find:
        logger.info("Finding best learning rate.")
        trainer.tune(model, data_module)

    try:
        git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
    except git.exc.InvalidGitRepositoryError:
        git_hash = None

    configuration = {
        "git_hash": git_hash,
        "model_class": model.model_class,
        "species": model.species,
        "starting_learning_rate": model.lr,
        "train_config": json.loads(train_config.json(exclude={"labels"})),
        "training_start_time": datetime.utcnow().isoformat(),
        "video_loader_config": json.loads(video_loader_config.json()),
    }

    if not train_config.dry_run:
        config_path = Path(logging_and_save_dir) / "train_configuration.yaml"
        config_path.parent.mkdir(exist_ok=True, parents=True)
        logger.info(f"Writing out full configuration to {config_path}.")
        with config_path.open("w") as fp:
            yaml.dump(configuration, fp)

    logger.info("Starting training...")
    trainer.fit(model, data_module)

    if not train_config.dry_run:
        if trainer.datamodule.test_dataloader() is not None:
            logger.info("Calculating metrics on holdout set.")
            test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0]
            with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp:
                json.dump(test_metrics, fp, indent=2)

        if trainer.datamodule.val_dataloader() is not None:
            logger.info("Calculating metrics on validation set.")
            val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0]
            with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp:
                json.dump(val_metrics, fp, indent=2)

    return trainer

validate_species(model: ZambaVideoClassificationLightningModule, data_module: ZambaDataModule)

Source code in zamba/models/model_manager.py
def validate_species(model: ZambaVideoClassificationLightningModule, data_module: ZambaDataModule):
    conflicts = []
    for dataloader_name, dataloader in zip(
        ("Train", "Val", "Test"),
        (
            data_module.train_dataloader(),
            data_module.val_dataloader(),
            data_module.test_dataloader(),
        ),
    ):
        if (dataloader is not None) and (dataloader.dataset.species != model.species):
            conflicts.append(
                f"""{dataloader_name} dataset includes:\n{", ".join(dataloader.dataset.species)}\n"""
            )

    if len(conflicts) > 0:
        conflicts.append(f"""Model predicts:\n{", ".join(model.species)}""")

        conflict_msg = "\n\n".join(conflicts)
        raise ValueError(
            f"""Dataloader species and model species do not match.\n\n{conflict_msg}"""
        )