experiment_builder¶
After defining a model (as described in base_models) you can train your model (or generate from a checkpoint) using an experiment builder, this provides a Command line interface to train/validation/test loops.
If you want to define custom train/validation/test loops you can create a subclass of ExperimentBuilder
and
override the methods you wish to change.
ExperimentBuilder¶
-
class
morgana.experiment_builder.
ExperimentBuilder
(model_class, experiment_name, **kwargs)[source]¶ Bases:
object
Interface for running training, validation, and generation. Works as glue for machine learning support code.
- Parameters
model_class (morgana.base_models.BaseModel) – Model to be initialised by the experiment builder. Must contain implementations of all abstract methods.
experiment_name (str) – Name of the experiment, this is used as the directory to save all output (under
experiments_base
directory). This is the only command line argument that has no default value and is required.kwargs (dict[str, *]) – Command line arguments. See
add_args()
for all options.
-
model
¶ Model instance.
-
ema
¶ Helper for updating a second model instance with an exponential moving average of the parameters.
- Type
morgana.utils.ExponentialMovingAverage (if
ema_decay
is not 0.)
-
device
¶ Name of the device to place the mode and parameters on.
- Type
str or
torch.device
-
_lr_schedule
¶ Partially initialised learning rate schedule. Depending on the schedule, this will be used per epoch or batch.
- Type
torch.optim.lr_scheduler._LRScheduler
-
logger
¶ Python Logger. Copies
stdout
,stderr
, and all tqdm output to separate files.- Type
_logging.Logger
-
train_loader
¶ Attribute is only present if
self.train
isTrue
. Produces batches from the training data (on a given device) when used as an iterator.- Type
torch.utils.data.DataLoader
(in adata._DataLoaderWrapper
container).
-
valid_loader
¶ Attribute is only present if
self.valid
isTrue
. Produces batches from the validation data (on a given device) when used as an iterator.- Type
torch.utils.data.DataLoader
(in adata._DataLoaderWrapper
container).
-
test_loader
¶ Attribute is only present if
self.test
isTrue
. Produces batches from the testing data (on a given device) when used as an iterator.- Type
torch.utils.data.DataLoader
(in adata._DataLoaderWrapper
container).
Notes
All arguments provided as command line arguments (including
experiment_name
) are saved as instance attributes.-
classmethod
get_experiment_args
()[source]¶ Creates a command line argument parser and returns the dictionary of arguments.
-
log_initial_setup
(self, **kwargs)[source]¶ Copies model definition if the experiment is new. Logs the model summary and config options.
-
resolve_setting_conflicts
(self)[source]¶ Checks settings and modify any that are inconsistent. Errors for incorrect setting should be raised here.
If a checkpoint file is given, and
torch.optim.lr_scheduler.ReduceLROnPlateau
(early stopping) learning rate schedule is being used validation will be forced to be turned on.If training is off, the epoch number will be extracted from the checkpoint file and set as the current epoch.
- Raises
ValueError – If no procedures are specified (e.g. training, validation, or testing).
ValueError – If a checkpoint file is given and
self.start_epoch
is less than the epoch of this file. This may be overly restrictive, but it will catch some off by one errors. Can be avoided by renaming the checkpoint file.ValueError – If no checkpoint file is given and training is turned off.
-
build_model
(self, model_class, model_kwargs, checkpoint_path=None)[source]¶ Creates model instance. Loads parameters from a checkpoint file, if given. Moves the model to the device.
-
load_data
(self, data_sources, data_dir, id_list, normalisers=None, name='', shuffle=True)[source]¶ Creates a dataset using the data sources and bathces this as a data loader.
- Parameters
data_sources (dict[str, _DataSource]) – Specification of the different data to be loaded.
data_dir (str) – The directory containing all data for this dataset split.
id_list (str) – The name of the file id-list containing base names to load, contained withing
self.data_root
.normalisers (None or dict[str, _FeatureNormaliser]) – Normalisers to be passed to the
morgana.data._DataSource
instances.name (str) – An identifier used for logging.
shuffle (bool) – Whether to shuffle the data every epoch.
- Returns
An instance with the
__iter__
method, allowing for iteration over batches of the dataset.- Return type
torch.utils.data.DataLoader (in a
morgana.data._DataLoaderWrapper
container)
-
train_epoch
(self, data_loader, optimizer, lr_schedule=None, gen_output=False, out_dir=None)[source]¶ Trains the model once on all batches in the data loader given.
Gradient updates, and EMA gradient updates.
Batch level learning rate schedule updates.
Logging metrics to tqdm and to a
metrics.json
file.
- Parameters
data_loader (torch.utils.data.DataLoader (in a
data._DataLoaderWrapper
container)) – An instance with the__iter__
method, allowing for iteration over batches of the dataset.optimizer (torch.optim.Optimizer) –
lr_schedule (torch.optim.lr_scheduler._LRScheduler) – Learning rate schedule, only used if it is a member of
morgana.lr_schedules.BATCH_LR_SCHEDULES
.gen_output (bool) – Whether to generate output for this training epoch. Output is defined by
morgana.base_models.BaseModel.analysis_for_train_batch()
.out_dir (str) – Directory used to save output (changes for each epoch).
- Returns
loss – Average loss for entire batch.
- Return type
-
run_train
(self)[source]¶ Runs training from
start_epoch
toend_epoch
.Parameter checkpointing, and EMA parameter checkpointing
Validation and generation.
Epoch level learning rate schedule updates
-
valid_epoch
(self, data_loader, model=None, gen_output=False, out_dir=None)[source]¶ Evaluates model once on all batches in data loader given. Performs analysis of model output if requested.
- Parameters
data_loader (torch.utils.data.DataLoader (in a
data._DataLoaderWrapper
container)) – An instance with the__iter__
method, allowing for iteration over batches of the dataset.model (morgana.base_models.BaseModel) – Model instance. If
self.ema_decay
is non-zero this will be the ema model.gen_output (bool) – Whether to generate output for this validation epoch. Output is defined by
morgana.base_models.BaseModel.analysis_for_valid_batch()
.out_dir (str) – Directory used to save output (changes for each epoch).
- Returns
loss – Average loss for entire batch.
- Return type
-
test_epoch
(self, data_loader, model=None, out_dir=None)[source]¶ Evaluates the model once on all batches in the data loader given. Performs analysis of model predictions.
- Parameters
data_loader (
torch.utils.data.DataLoader
(in amorgana.data._DataLoaderWrapper
container)) – An instance with the__iter__
method, allowing for iteration over batches of the dataset.model (morgana.base_models.BaseModel) – Model instance. If
self.ema_decay
is non-zero this will be the ema model.out_dir (str) – Directory used to save output (changes for each epoch).