base_models¶
You can defined your model by implementing the abstract methods, BaseModel.train_data_sources()
,
BaseModel.predict()
, and BaseModel.loss()
. See Defining a model for instructions. Currently, there
are three base models that you can inherit from,
BaseModel¶
-
class
morgana.base_models.
BaseModel
[source]¶ Bases:
torch.nn.modules.module.Module
Creates an abstract model class with utility functions.
Any additional kwargs specified in
__init__
should be passed to the command line argumentmodel_kwargs
.-
normalisers
¶ Normalisers specified within the
morgana.data._DataSource
inself.train_data_sources
.- Type
-
mode
¶ Stage of training, set in
morgana.experiment_builder.ExperimentBuilder.*_epoch
, for use withself.metrics
.- Type
{‘’, ‘train’, ‘valid’, ‘test’}
-
metrics
¶ Handler for tracking metrics in an online fashion (over multiple batches).
-
step
¶ Step in training, calculated using epoch number, batch number, and number of batches per epoch. This is updated automatically by
morgana.experiment_builder.ExperimentBuilder
. Useful for logging toself.tensorboard
.- Type
-
tensorboard
¶
-
train_data_sources
(self)[source]¶ Specifies the data that will be loaded and used in training.
Only specifies what data will be loaded, but not where from.
- Returns
The data sources used by
morgana.experiment_builder.ExperimentBuilder
for the training data, can be any data structure containingmorgana.data._DataSource
instances.- Return type
features
-
valid_data_sources
(self)[source]¶ Specifies the data that will be loaded and used in validation.
Only specifies what data will be loaded, but not where from.
- Returns
The data sources used by
morgana.experiment_builder.ExperimentBuilder
for the validation data, can be any data structure containingmorgana.data._DataSource
instances.- Return type
features
-
test_data_sources
(self)[source]¶ Specifies the data that will be loaded and used in testing.
Only specifies what data will be loaded, but not where from.
- Returns
The data sources used by
morgana.experiment_builder.ExperimentBuilder
for the testing data, can be any data structure containingmorgana.data._DataSource
instances.- Return type
features
-
forward
(self, features)[source]¶ Defines the computation graph, including calculation of loss.
- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.- Returns
loss (float) – Loss of the model, as defined by
self.loss
.output_features – Predictions made by the model, can be any data structure containing
torch.Tensor
instances.
-
predict
(self, features)[source]¶ Defines the computation graph.
- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.- Returns
Predictions made by the model, can be any data structure containing
torch.Tensor
instances.- Return type
output_features
-
loss
(self, features, output_features)[source]¶ Defines which predictions should be scored against which ground truth features.
Typically this method should use
_loss()
to calculate the sequence loss for the target-prediction pairs.- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.output_features (torch.Tensor or list[torch.Tensor] or dict[str, torch.Tensor]) – Predictions output by user-defined
predict()
.
- Returns
Overall loss between (user-defined) pairs of values in
features
andoutput_features
.- Return type
-
_loss
(self, targets, predictions, seq_lens=None, loss_weights=None)[source]¶ Defines the sequence loss for multiple target-prediction pairs.
If
targets
andpredictions
are iterables they must be in the same order, i.e. when zipped corresponding elements will be used as a target-prediction pair for calculating the loss.The loss value between two frames of the target and prediction is given by
loss_fn()
. Currently this must be the same for all target-prediction pairs.- Parameters
targets (list[torch.Tensor] or torch.Tensor, shape (batch_size, seq_len, feat_dim)) – Ground truth tensor(s).
predictions (list[torch.Tensor] or torch.Tensor, shape (batch_size, seq_len, feat_dim)) – Prediction tensor(s).
seq_lens (None or list[torch.Tensor] or torch.Tensor, shape (batch_size,)) – Sequence length features. If one tensor is given it will be used for all target-prediction pairs, otherwise the length of the list given must match the length of
targets
andpredictions
.loss_weights (None or list[float], shape (num_pairs)) – The weight for each target-prediction pair’s loss. If
None
then returns the average of all pair’s losses.
- Returns
Overall (average or weight) loss.
- Return type
- Raises
ValueError – If
targets
,predictions
,seq_len
, orloss_weights
are lists with non-matching lengths.
-
loss_fn
(self, target, prediction)[source]¶ Defines the frame-wise loss calculation between ground truth and predictions.
- Parameters
target (torch.Tensor, shape (batch_size, seq_len, feat_dim)) – Ground truth feature.
prediction (torch.Tensor, shape (batch_size, seq_len, feat_dim)) – Predicted feature.
- Returns
Loss between
feature
andprediction
.- Return type
torch.Tensor, shape (batch_size, seq_len, feat_dim)
-
load_parameters
(self, checkpoint_path, strict=True, device=None)[source]¶ Loads a
state_dict
from a.pt
file.- Parameters
checkpoint_path (str) – The file path of the
.pt
file containing thestate_dict
to be loadedstrict (bool) – Whether to strictly enforce that the keys in the loaded
state_dict
match this model’s structure.device (str or
torch.device
or dict or callable) – Specifies how to remap storage locations, passed totorch.load()
.
- Returns
state_dict – Parameters and persistent buffers that define the model.
- Return type
-
analysis_for_train_batch
(self, features, output_features, out_dir, **kwargs)[source]¶ Hook used by
morgana.experiment_builder.ExperimentBuilder
after training batches for some epochs.Can be used to save output or generate visualisations.
- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.output_features (torch.Tensor or list[torch.Tensor] or dict[str, torch.Tensor]) – Predictions output by user-defined
self.predict
.out_dir (str) – The directory used to save output (changes for each epoch).
kwargs (dict) – Additional keyword arguments used for generating output.
-
analysis_for_valid_batch
(self, features, output_features, out_dir, **kwargs)[source]¶ Hook used by
morgana.experiment_builder.ExperimentBuilder
after validation batches for some epochs.Can be used to save output or generate visualisations.
- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.output_features (torch.Tensor or list[torch.Tensor] or dict[str, torch.Tensor]) – Predictions output by user-defined
self.predict
.out_dir (str) – The directory used to save output (changes for each epoch).
kwargs (dict) – Additional keyword arguments used for generating output.
-
analysis_for_test_batch
(self, features, output_features, out_dir, **kwargs)[source]¶ Hook used by
morgana.experiment_builder.ExperimentBuilder
after each testing batch.Can be used to save output or generate visualisations.
- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.output_features (torch.Tensor or list[torch.Tensor] or dict[str, torch.Tensor]) – Predictions output by user-defined
predict()
.out_dir (str) – The directory used to save output (changes for each epoch).
kwargs (dict) – Additional keyword arguments used for generating output.
-
analysis_for_train_epoch
(self, out_dir, **kwargs)[source]¶ Hook used by
morgana.experiment_builder.ExperimentBuilder
after some training epochs.Can be used to save output or generate visualisations.
-
analysis_for_valid_epoch
(self, out_dir, **kwargs)[source]¶ Hook used by
morgana.experiment_builder.ExperimentBuilder
after some validation epochs.Can be used to save output or generate visualisations.
-
analysis_for_test_epoch
(self, out_dir, **kwargs)[source]¶ Hook used by
morgana.experiment_builder.ExperimentBuilder
after each testing epoch.Can be used to save output or generate visualisations.
-
BaseSPSS¶
BaseVAE¶
-
class
morgana.base_models.
BaseVAE
(z_dim=16, kld_weight=1.0)[source]¶ Bases:
morgana.base_models.BaseSPSS
Creates an abstract VAE model, where the decoder corresponds to an SPSS model.
- Parameters
-
encode
(self, features)[source]¶ VAE encoder.
- Parameters
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.- Returns
mean (torch.Tensor, shape (batch_size, z_dim))
log_variance (torch.Tensor, shape (batch_size, z_dim))
-
sample
(self, mean, log_variance)[source]¶ Takes one sample from the approximate posterior (an isotropic Gaussian).
- Parameters
mean (torch.Tensor, shape (batch_size, z_dim)) –
log_variance (torch.Tensor, shape (batch_size, z_dim)) –
- Returns
latent_sample
- Return type
torch.Tensor, shape (batch_size, z_dim)
-
decode
(self, latent, features)[source]¶ VAE decoder.
- Parameters
latent (torch.Tensor, shape (batch_size, z_dim)) –
features (dict[str, torch.Tensor]) – The ground truth features produced by
self.*_data_sources
.
- Returns
output_features – Reconstructions from the model, can be any data structure containing
torch.Tensor
instances.- Return type
torch.Tensor or list[torch.Tensor] or dict[str, torch.Tensor]
-
forward
(self, features)[source]¶ Encodes the input features, samples from the encoding, reconstructs the input, and calculates the loss.
-
predict
(self, features)[source]¶ Runs the model in testing mode (the encoder is not used), but the latent must be provided as an input.