experiment_builder

After defining a model (as described in base_models) you can train your model (or generate from a checkpoint) using an experiment builder, this provides a Command line interface to train/validation/test loops.

If you want to define custom train/validation/test loops you can create a subclass of ExperimentBuilder and override the methods you wish to change.

ExperimentBuilder

class morgana.experiment_builder.ExperimentBuilder(model_class, experiment_name, **kwargs)[source]

Bases: object

Interface for running training, validation, and generation. Works as glue for machine learning support code.

Parameters
  • model_class (morgana.base_models.BaseModel) – Model to be initialised by the experiment builder. Must contain implementations of all abstract methods.

  • experiment_name (str) – Name of the experiment, this is used as the directory to save all output (under experiments_base directory). This is the only command line argument that has no default value and is required.

  • kwargs (dict[str, *]) – Command line arguments. See add_args() for all options.

experiment_dir

Directory path to save all output to.

Type

str

model

Model instance.

Type

morgana.base_models.BaseModel

ema

Helper for updating a second model instance with an exponential moving average of the parameters.

Type

morgana.utils.ExponentialMovingAverage (if ema_decay is not 0.)

epoch

Current epoch of the model (starting from 1).

Type

int

device

Name of the device to place the mode and parameters on.

Type

str or torch.device

_lr_schedule

Partially initialised learning rate schedule. Depending on the schedule, this will be used per epoch or batch.

Type

torch.optim.lr_scheduler._LRScheduler

logger

Python Logger. Copies stdout, stderr, and all tqdm output to separate files.

Type

_logging.Logger

train_loader

Attribute is only present if self.train is True. Produces batches from the training data (on a given device) when used as an iterator.

Type

torch.utils.data.DataLoader (in a data._DataLoaderWrapper container).

valid_loader

Attribute is only present if self.valid is True. Produces batches from the validation data (on a given device) when used as an iterator.

Type

torch.utils.data.DataLoader (in a data._DataLoaderWrapper container).

test_loader

Attribute is only present if self.test is True. Produces batches from the testing data (on a given device) when used as an iterator.

Type

torch.utils.data.DataLoader (in a data._DataLoaderWrapper container).

Notes

All arguments provided as command line arguments (including experiment_name) are saved as instance attributes.

classmethod get_experiment_args()[source]

Creates a command line argument parser and returns the dictionary of arguments.

classmethod add_args(parser)[source]

Adds command line arguments to a parser, see usage.

log_initial_setup(self, **kwargs)[source]

Copies model definition if the experiment is new. Logs the model summary and config options.

resolve_setting_conflicts(self)[source]

Checks settings and modify any that are inconsistent. Errors for incorrect setting should be raised here.

If a checkpoint file is given, and torch.optim.lr_scheduler.ReduceLROnPlateau (early stopping) learning rate schedule is being used validation will be forced to be turned on.

If training is off, the epoch number will be extracted from the checkpoint file and set as the current epoch.

Raises
  • ValueError – If no procedures are specified (e.g. training, validation, or testing).

  • ValueError – If a checkpoint file is given and self.start_epoch is less than the epoch of this file. This may be overly restrictive, but it will catch some off by one errors. Can be avoided by renaming the checkpoint file.

  • ValueError – If no checkpoint file is given and training is turned off.

build_model(self, model_class, model_kwargs, checkpoint_path=None)[source]

Creates model instance. Loads parameters from a checkpoint file, if given. Moves the model to the device.

load_data(self, data_sources, data_dir, id_list, normalisers=None, name='', shuffle=True)[source]

Creates a dataset using the data sources and bathces this as a data loader.

Parameters
  • data_sources (dict[str, _DataSource]) – Specification of the different data to be loaded.

  • data_dir (str) – The directory containing all data for this dataset split.

  • id_list (str) – The name of the file id-list containing base names to load, contained withing self.data_root.

  • normalisers (None or dict[str, _FeatureNormaliser]) – Normalisers to be passed to the morgana.data._DataSource instances.

  • name (str) – An identifier used for logging.

  • shuffle (bool) – Whether to shuffle the data every epoch.

Returns

An instance with the __iter__ method, allowing for iteration over batches of the dataset.

Return type

torch.utils.data.DataLoader (in a morgana.data._DataLoaderWrapper container)

train_epoch(self, data_loader, optimizer, lr_schedule=None, gen_output=False, out_dir=None)[source]

Trains the model once on all batches in the data loader given.

  • Gradient updates, and EMA gradient updates.

  • Batch level learning rate schedule updates.

  • Logging metrics to tqdm and to a metrics.json file.

Parameters
  • data_loader (torch.utils.data.DataLoader (in a data._DataLoaderWrapper container)) – An instance with the __iter__ method, allowing for iteration over batches of the dataset.

  • optimizer (torch.optim.Optimizer) –

  • lr_schedule (torch.optim.lr_scheduler._LRScheduler) – Learning rate schedule, only used if it is a member of morgana.lr_schedules.BATCH_LR_SCHEDULES.

  • gen_output (bool) – Whether to generate output for this training epoch. Output is defined by morgana.base_models.BaseModel.analysis_for_train_batch().

  • out_dir (str) – Directory used to save output (changes for each epoch).

Returns

loss – Average loss for entire batch.

Return type

float

run_train(self)[source]

Runs training from start_epoch to end_epoch.

  • Parameter checkpointing, and EMA parameter checkpointing

  • Validation and generation.

  • Epoch level learning rate schedule updates

valid_epoch(self, data_loader, model=None, gen_output=False, out_dir=None)[source]

Evaluates model once on all batches in data loader given. Performs analysis of model output if requested.

Parameters
  • data_loader (torch.utils.data.DataLoader (in a data._DataLoaderWrapper container)) – An instance with the __iter__ method, allowing for iteration over batches of the dataset.

  • model (morgana.base_models.BaseModel) – Model instance. If self.ema_decay is non-zero this will be the ema model.

  • gen_output (bool) – Whether to generate output for this validation epoch. Output is defined by morgana.base_models.BaseModel.analysis_for_valid_batch().

  • out_dir (str) – Directory used to save output (changes for each epoch).

Returns

loss – Average loss for entire batch.

Return type

float

run_valid(self, gen_output)[source]

Runs evaluation for the current epoch.

test_epoch(self, data_loader, model=None, out_dir=None)[source]

Evaluates the model once on all batches in the data loader given. Performs analysis of model predictions.

Parameters
run_test(self)[source]

Runs generation for the current epoch.

run_experiment(self)[source]

Runs all procedures requested for the experiment.