zamba.models.model_manager¶
Classes¶
ModelManager
¶
Mediates loading, configuration, and logic of model calls.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
ModelConfig |
Instantiated ModelConfig. |
required |
__init__(self, config: ModelConfig)
special
¶
Source code in zamba/models/model_manager.py
def __init__(self, config: ModelConfig):
self.config = config
from_yaml(config)
classmethod
¶
Source code in zamba/models/model_manager.py
@classmethod
def from_yaml(cls, config):
if not isinstance(config, ModelConfig):
config = ModelConfig.parse_file(config)
return cls(config)
predict(self)
¶
Source code in zamba/models/model_manager.py
def predict(self):
predict_model(
predict_config=self.config.predict_config,
video_loader_config=self.config.video_loader_config,
)
train(self)
¶
Source code in zamba/models/model_manager.py
def train(self):
train_model(
train_config=self.config.train_config,
video_loader_config=self.config.video_loader_config,
)
Functions¶
instantiate_model(checkpoint: Union[os.PathLike, str], weight_download_region: RegionEnum, scheduler_config: Optional[zamba.models.config.SchedulerConfig], model_cache_dir: Optional[os.PathLike], labels: Optional[pandas.core.frame.DataFrame], from_scratch: bool = False, model_name: Optional[zamba.models.config.ModelEnum] = None, predict_all_zamba_species: bool = True) -> ZambaVideoClassificationLightningModule
¶
Instantiates the model from a checkpoint and detects whether the model head should be replaced. The model head is replaced if labels contain species that are not on the model or predict_all_zamba_species=False.
Supports model instantiation for the following cases: - train from scratch (from_scratch=True) - finetune with new species (from_scratch=False, labels contains different species than model) - finetune with a subset of zamba species and output only the species in the labels file (predict_all_zamba_species=False) - finetune with a subset of zamba species but output all zamba species (predict_all_zamba_species=True) - predict using pretrained model (labels=None)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint |
path or str |
Either the path to a checkpoint on disk or the name of a
checkpoint file in the S3 bucket, i.e., one that is discoverable by |
required |
weight_download_region |
RegionEnum |
Server region for downloading weights. |
required |
scheduler_config |
SchedulerConfig |
SchedulerConfig to use for training or finetuning. Only used if labels is not None. |
required |
model_cache_dir |
path |
Directory in which to store pretrained model weights. |
required |
labels |
pd.DataFrame |
Dataframe where filepath is the index and columns are one hot encoded species. |
required |
from_scratch |
bool |
Whether to instantiate the model with base weights. This means starting from the imagenet weights for image based models and the Kinetics weights for video models. Defaults to False. Only used if labels is not None. |
False |
model_name |
ModelEnum |
Model name used to look up default hparams used for that model. Only relevant if training from scratch. |
None |
predict_all_zamba_species(bool) |
Whether the species outputted by the model should be all zamba species. If you want the model classes to only be the species in your labels file, set to False. Defaults to True. Only used if labels is not None. |
required |
Returns:
Type | Description |
---|---|
ZambaVideoClassificationLightningModule |
Instantiated model |
Source code in zamba/models/model_manager.py
def instantiate_model(
checkpoint: Union[os.PathLike, str],
weight_download_region: RegionEnum,
scheduler_config: Optional[SchedulerConfig],
model_cache_dir: Optional[os.PathLike],
labels: Optional[pd.DataFrame],
from_scratch: bool = False,
model_name: Optional[ModelEnum] = None,
predict_all_zamba_species: bool = True,
) -> ZambaVideoClassificationLightningModule:
"""Instantiates the model from a checkpoint and detects whether the model head should be replaced.
The model head is replaced if labels contain species that are not on the model or predict_all_zamba_species=False.
Supports model instantiation for the following cases:
- train from scratch (from_scratch=True)
- finetune with new species (from_scratch=False, labels contains different species than model)
- finetune with a subset of zamba species and output only the species in the labels file (predict_all_zamba_species=False)
- finetune with a subset of zamba species but output all zamba species (predict_all_zamba_species=True)
- predict using pretrained model (labels=None)
Args:
checkpoint (path or str): Either the path to a checkpoint on disk or the name of a
checkpoint file in the S3 bucket, i.e., one that is discoverable by `download_weights`.
weight_download_region (RegionEnum): Server region for downloading weights.
scheduler_config (SchedulerConfig, optional): SchedulerConfig to use for training or finetuning.
Only used if labels is not None.
model_cache_dir (path, optional): Directory in which to store pretrained model weights.
labels (pd.DataFrame, optional): Dataframe where filepath is the index and columns are one hot encoded species.
from_scratch (bool): Whether to instantiate the model with base weights. This means starting
from the imagenet weights for image based models and the Kinetics weights for video models.
Defaults to False. Only used if labels is not None.
model_name (ModelEnum, optional): Model name used to look up default hparams used for that model.
Only relevant if training from scratch.
predict_all_zamba_species(bool): Whether the species outputted by the model should be all zamba species.
If you want the model classes to only be the species in your labels file, set to False.
Defaults to True. Only used if labels is not None.
Returns:
ZambaVideoClassificationLightningModule: Instantiated model
"""
if from_scratch:
# get hparams from official model
with (MODELS_DIRECTORY / f"{model_name}/hparams.yaml").open() as f:
hparams = yaml.safe_load(f)
else:
# download if neither local checkpoint nor cached checkpoint exist
if not checkpoint.exists() and not (model_cache_dir / checkpoint).exists():
logger.info("Downloading weights for model.")
checkpoint = download_weights(
filename=str(checkpoint),
weight_region=weight_download_region,
destination_dir=model_cache_dir,
)
hparams = torch.load(checkpoint, map_location=torch.device("cpu"))["hyper_parameters"]
model_class = available_models[hparams["model_class"]]
logger.info(f"Instantiating model: {model_class.__name__}")
if labels is None:
# predict; load from checkpoint uses associated hparams
logger.info("Loading from checkpoint.")
return model_class.load_from_checkpoint(checkpoint_path=checkpoint)
# get species from labels file
species = labels.filter(regex=r"^species_").columns.tolist()
species = [s.split("species_", 1)[1] for s in species]
# check if species in label file are a subset of pretrained model species
is_subset = set(species).issubset(set(hparams["species"]))
# train from scratch
if from_scratch:
logger.info("Training from scratch.")
# default would use scheduler used for pretrained model
if scheduler_config != "default":
hparams.update(scheduler_config.dict())
hparams.update({"species": species})
model = model_class(**hparams)
# replace the head
elif not predict_all_zamba_species or not is_subset:
if not predict_all_zamba_species:
logger.info(
"Limiting only to species in labels file. Replacing model head and finetuning."
)
else:
logger.info(
"Provided species do not fully overlap with Zamba species. Replacing model head and finetuning."
)
# update in case we want to finetune with different scheduler
if scheduler_config != "default":
hparams.update(scheduler_config.dict())
hparams.update({"species": species})
model = model_class(finetune_from=checkpoint, **hparams)
# resume training; add additional species columns to labels file if needed
elif is_subset:
logger.info(
"Provided species fully overlap with Zamba species. Resuming training from latest checkpoint."
)
# update in case we want to resume with different scheduler
if scheduler_config != "default":
hparams.update(scheduler_config.dict())
model = model_class.load_from_checkpoint(checkpoint_path=checkpoint, **hparams)
# add in remaining columns for species that are not present
for c in set(hparams["species"]).difference(set(species)):
# labels are still OHE at this point
labels[f"species_{c}"] = 0
# sort columns so columns on dataloader are the same as columns on model
labels.sort_index(axis=1, inplace=True)
logger.info(f"Using learning rate scheduler: {model.hparams['scheduler']}")
logger.info(f"Using scheduler params: {model.hparams['scheduler_params']}")
return model
predict_model(predict_config: PredictConfig, video_loader_config: VideoLoaderConfig = None)
¶
Predicts from a model and writes out predictions to a csv.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
predict_config |
PredictConfig |
Pydantic config for performing inference. |
required |
video_loader_config |
VideoLoaderConfig |
Pydantic config for preprocessing videos. If None, will use default for model specified in PredictConfig. |
None |
Source code in zamba/models/model_manager.py
def predict_model(
predict_config: PredictConfig,
video_loader_config: VideoLoaderConfig = None,
):
"""Predicts from a model and writes out predictions to a csv.
Args:
predict_config (PredictConfig): Pydantic config for performing inference.
video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
If None, will use default for model specified in PredictConfig.
"""
# get default VLC for model if not specified
if video_loader_config is None:
video_loader_config = ModelConfig(
predict_config=predict_config, video_loader_config=video_loader_config
).video_loader_config
# set up model
model = instantiate_model(
checkpoint=predict_config.checkpoint,
weight_download_region=predict_config.weight_download_region,
model_cache_dir=predict_config.model_cache_dir,
scheduler_config=None,
labels=None,
)
data_module = ZambaDataModule(
video_loader_config=video_loader_config,
transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
predict_metadata=predict_config.filepaths,
batch_size=predict_config.batch_size,
num_workers=predict_config.num_workers,
)
validate_species(model, data_module)
if video_loader_config.cache_dir is None:
logger.info("No cache dir is specified. Videos will not be cached.")
else:
logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")
trainer = pl.Trainer(
gpus=predict_config.gpus, logger=False, fast_dev_run=predict_config.dry_run
)
configuration = {
"model_class": model.model_class,
"species": model.species,
"predict_config": json.loads(predict_config.json(exclude={"filepaths"})),
"inference_start_time": datetime.utcnow().isoformat(),
"video_loader_config": json.loads(video_loader_config.json()),
}
if predict_config.save is not False:
config_path = predict_config.save_dir / "predict_configuration.yaml"
logger.info(f"Writing out full configuration to {config_path}.")
with config_path.open("w") as fp:
yaml.dump(configuration, fp)
dataloader = data_module.predict_dataloader()
logger.info("Starting prediction...")
probas = trainer.predict(model=model, dataloaders=dataloader)
df = pd.DataFrame(
np.vstack(probas), columns=model.species, index=dataloader.dataset.original_indices
)
# change output format if specified
if predict_config.proba_threshold is not None:
df = (df > predict_config.proba_threshold).astype(int)
elif predict_config.output_class_names:
df = df.idxmax(axis=1)
else: # round to a useful number of places
df = df.round(5)
if predict_config.save is not False:
preds_path = predict_config.save_dir / "zamba_predictions.csv"
logger.info(f"Saving out predictions to {preds_path}.")
with preds_path.open("w") as fp:
df.to_csv(fp, index=True)
return df
train_model(train_config: TrainConfig, video_loader_config: Optional[zamba.data.video.VideoLoaderConfig] = None)
¶
Trains a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_config |
TrainConfig |
Pydantic config for training. |
required |
video_loader_config |
VideoLoaderConfig |
Pydantic config for preprocessing videos. If None, will use default for model specified in TrainConfig. |
None |
Source code in zamba/models/model_manager.py
def train_model(
train_config: TrainConfig,
video_loader_config: Optional[VideoLoaderConfig] = None,
):
"""Trains a model.
Args:
train_config (TrainConfig): Pydantic config for training.
video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
If None, will use default for model specified in TrainConfig.
"""
# get default VLC for model if not specified
if video_loader_config is None:
video_loader_config = ModelConfig(
train_config=train_config, video_loader_config=video_loader_config
).video_loader_config
# set up model
model = instantiate_model(
checkpoint=train_config.checkpoint,
scheduler_config=train_config.scheduler_config,
weight_download_region=train_config.weight_download_region,
model_cache_dir=train_config.model_cache_dir,
labels=train_config.labels,
from_scratch=train_config.from_scratch,
model_name=train_config.model_name,
predict_all_zamba_species=train_config.predict_all_zamba_species,
)
data_module = ZambaDataModule(
video_loader_config=video_loader_config,
transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
train_metadata=train_config.labels,
batch_size=train_config.batch_size,
num_workers=train_config.num_workers,
)
validate_species(model, data_module)
train_config.save_dir.mkdir(parents=True, exist_ok=True)
# add folder version_n that auto increments if we are not overwriting
tensorboard_version = train_config.save_dir.name if train_config.overwrite else None
tensorboard_save_dir = (
train_config.save_dir.parent if train_config.overwrite else train_config.save_dir
)
tensorboard_logger = TensorBoardLogger(
save_dir=tensorboard_save_dir,
name=None,
version=tensorboard_version,
default_hp_metric=False,
)
logging_and_save_dir = (
tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir
)
model_checkpoint = ModelCheckpoint(
dirpath=logging_and_save_dir,
filename=train_config.model_name,
monitor=train_config.early_stopping_config.monitor,
mode=train_config.early_stopping_config.mode,
)
callbacks = [model_checkpoint]
if train_config.early_stopping_config is not None:
callbacks.append(EarlyStopping(**train_config.early_stopping_config.dict()))
if train_config.backbone_finetune_config is not None:
callbacks.append(BackboneFinetuning(**train_config.backbone_finetune_config.dict()))
trainer = pl.Trainer(
gpus=train_config.gpus,
max_epochs=train_config.max_epochs,
auto_lr_find=train_config.auto_lr_find,
logger=tensorboard_logger,
callbacks=callbacks,
fast_dev_run=train_config.dry_run,
accelerator="ddp" if data_module.multiprocessing_context is not None else None,
plugins=DDPPlugin(find_unused_parameters=False)
if data_module.multiprocessing_context is not None
else None,
)
if video_loader_config.cache_dir is None:
logger.info("No cache dir is specified. Videos will not be cached.")
else:
logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")
if train_config.auto_lr_find:
logger.info("Finding best learning rate.")
trainer.tune(model, data_module)
try:
git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
except git.exc.InvalidGitRepositoryError:
git_hash = None
configuration = {
"git_hash": git_hash,
"model_class": model.model_class,
"species": model.species,
"starting_learning_rate": model.lr,
"train_config": json.loads(train_config.json(exclude={"labels"})),
"training_start_time": datetime.utcnow().isoformat(),
"video_loader_config": json.loads(video_loader_config.json()),
}
if not train_config.dry_run:
config_path = Path(logging_and_save_dir) / "train_configuration.yaml"
config_path.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"Writing out full configuration to {config_path}.")
with config_path.open("w") as fp:
yaml.dump(configuration, fp)
logger.info("Starting training...")
trainer.fit(model, data_module)
if not train_config.dry_run:
if trainer.datamodule.test_dataloader() is not None:
logger.info("Calculating metrics on holdout set.")
test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0]
with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp:
json.dump(test_metrics, fp, indent=2)
if trainer.datamodule.val_dataloader() is not None:
logger.info("Calculating metrics on validation set.")
val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0]
with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp:
json.dump(val_metrics, fp, indent=2)
return trainer
validate_species(model: ZambaVideoClassificationLightningModule, data_module: ZambaDataModule)
¶
Source code in zamba/models/model_manager.py
def validate_species(model: ZambaVideoClassificationLightningModule, data_module: ZambaDataModule):
conflicts = []
for dataloader_name, dataloader in zip(
("Train", "Val", "Test"),
(
data_module.train_dataloader(),
data_module.val_dataloader(),
data_module.test_dataloader(),
),
):
if (dataloader is not None) and (dataloader.dataset.species != model.species):
conflicts.append(
f"""{dataloader_name} dataset includes:\n{", ".join(dataloader.dataset.species)}\n"""
)
if len(conflicts) > 0:
conflicts.append(f"""Model predicts:\n{", ".join(model.species)}""")
conflict_msg = "\n\n".join(conflicts)
raise ValueError(
f"""Dataloader species and model species do not match.\n\n{conflict_msg}"""
)