routines module

Routines for setting up and performing training

class SetupParams(device: ~torch.device | str | None = None, controller: ~hippynn.experiment.controllers.Controller | ~typing.Callable | None = None, stopping_key: str | None = None, optimizer: ~torch.optim.optimizer.Optimizer | ~typing.Callable = <class 'torch.optim.adam.Adam'>, learning_rate: float | None = None, scheduler: ~torch.optim.lr_scheduler._LRScheduler | ~typing.Callable = None, batch_size: int | None = None, eval_batch_size: int | None = None, max_epochs: int | None = None, fraction_train_eval: float | None = 0.1)[source]

  • stopping_key – name of validation metric for stopping

  • controller – Optional – Controller object for LR scheduling and ending experiment. If not provided, will be constructed from parameters below.

  • Device – Where to train the model. Falls back to CUDA if available. Specify a tuple of device numbers to use DataParallel.

  • max_epochs – Optional – maximum number of epochs to train. Mandatory if controller not provided.

  • batch_size – Only used if the controller itself is not specified. Mandatory if controller not provided.

  • eval_batch_size – Only used if the controller itself is not specified.

  • scheduler – scheduler passed to the controller

  • optimizer – Pytorch optimizer or optimizer class. Defaults to Adam.

  • learning_rate – If an optimizer class is provided, the learning rate is used to construct the optimizer.

  • fraction_train_eval – What fraction of the training dataset to evaluate in the evaluation phase

All params after stopping_key, controller, and device are optional and can be built into a controller.


Multiple GPUs is an experimental feature that is currently under debugging.


setup_and_train(training_modules: TrainingModules, database: Database, setup_params: SetupParams, store_all_better=False, store_best=True, store_every=0)[source]

training_modules: see setup_training()


database: see train_model()


setup_params: see setup_training()


store_all_better: Save the state dict for each model doing better than a previous one


store_best: Save a checkpoint for the best model


store_every: Save a checkpoint for every certain epochs


See train_model()

Shortcut for setup_training followed by train_model.


The training loop will capture KeyboardInterrupt exceptions to abort the experiment early. If you would like to gracefully kill training programmatically, see train_model() with callbacks argument.


Saves files in the current running directory; recommend you switch to a fresh directory with a descriptive name for your experiment.

setup_training(training_modules: TrainingModules, setup_params: SetupParams)[source]

Prepares training_modules for training with experiment_params.


training_modules: Tuple of model, training loss, and evaluation losses (Can be built from graph using graphs.assemble_training_modules)


setup_params: parameters controlling how training is performed (See SetupParams)


  • sets devices for training modules

  • if no controller given:

    • instantiates and links optimizer to the learnable params on the model

    • instantiates and links scheduler to optimizer

    • builds a default controller with setup params

  • creates a MetricTracker for storing the training metrics



test_model(database, evaluator, batch_size, when, metric_tracker=None)[source]

Tests the model on the database according to the model_evaluator metrics. If a plot_maker is attached to the model evaluator, it will make plots. The plots will go in a sub-folder specified by when the testing is taking place. The results are then printed.

  • database – The database test the model on.

  • evaluator – The evaluator containing model and evaluation losses to measure.

  • when – A string to specify what plots are currently to be used.

  • metric_tracker – (Optional) metric tracker to save metrics on. If not provided, a blank one will be constructed.


metric tracker

train_model(training_modules, database, controller, metric_tracker, callbacks, batch_callbacks, store_all_better=False, store_best=True, store_every=0, store_structure_file=True, store_metrics=True, quiet=False)[source]

Performs training loop, allows keyboard interrupt. When done, reinstate the best model, make plots and metrics over time, and test the model.

  • training_modules – tuple-like of model, loss, and evaluator

  • database – Database

  • controller – Controller

  • metric_tracker – MetricTracker for storing model performance

  • callbacks – callbacks to perform after every epoch.

  • batch_callbacks – callbacks to perform after every batch

  • store_best – Save a checkpoint for the best model

  • store_all_better – Save the state dict for each model doing better than a previous one

  • store_every – Save a checkpoint for every certain epochs

  • store_structure_file – Save the structure file for this experiment

  • store_metrics – Save the metric tracker for this experiment.

  • quiet – If True, disable printing during training (still prints testing results).




callbacks take the form of an iterable of callables and will be called with cb(epoch,new_best)

  • epoch indicates the epoch number

  • new_best indicates if the model is a new best model


batch_callbacks take the form of an iterable of callables and will each be called with cb(batch_inputs, batch_model_outputs, batch_targets)


You may want to make your callbacks store other state, if so, an easy way is to make them a callable object.


callback state is not managed by hippynn. If your wish to save or load callback state, you will have to manage that manually (possibly with a callback itself).

training_loop(training_modules: TrainingModules, database, controller: Controller, metric_tracker: MetricTracker, callbacks, batch_callbacks, store_all_better, store_best, store_every, quiet)[source]

Performs a high-level training loop.

  • training_modules – training modules from assemble_modules

  • database – database to train to

  • controller – controller for early stopping and/or learning rate decay

  • metric_tracker – the metric tracker

  • callbacks – list of callbacks for each epoch

  • batch_callbacks – list of callbacks for each batch

  • store_best – Save a checkpoint for the best model

  • store_all_better – Save the state dict for each model doing better than a previous one

  • store_every – Save a checkpoint for every certain epochs

  • quiet – whether to print information. Setting quiet to true won’t prevent progress bars.


metrics – the state of the experiment after training


Saves files in the current running directory; recommend switching to a fresh directory with a descriptive name for each experiment.

Rough structure.

Loop over Epochs, performing:

  • Loop over batches

    • Make predictions

    • Calculate Loss, perform backwards

    • Optimizer Step

    • Batch Callbacks

  • Perform validation and print results

  • Controller/Scheduler Step

  • Epoch callbacks

  • If new best, save the model state_dict and a checkpoint