experiment
Full Documentation for hippynn.experiment
package.
Click here for a summary page.
Functions for training.
- class SetupParams(device: ~torch.device | str | None = None, controller: ~hippynn.experiment.controllers.Controller | ~typing.Callable | None = None, stopping_key: str | None = None, optimizer: ~torch.optim.optimizer.Optimizer | ~typing.Callable = <class 'torch.optim.adam.Adam'>, learning_rate: float | None = None, scheduler: ~torch.optim.lr_scheduler._LRScheduler | ~typing.Callable = None, batch_size: int | None = None, eval_batch_size: int | None = None, max_epochs: int | None = None, fraction_train_eval: float | None = 0.1)[source]
Bases:
object
- Parameters:
stopping_key – name of validation metric for stopping
controller – Optional – Controller object for LR scheduling and ending experiment. If not provided, will be constructed from parameters below.
Device – Where to train the model. Falls back to CUDA if available. Specify a tuple of device numbers to use DataParallel.
max_epochs – Optional – maximum number of epochs to train. Mandatory if controller not provided.
batch_size – Only used if the controller itself is not specified. Mandatory if controller not provided.
eval_batch_size – Only used if the controller itself is not specified.
scheduler – scheduler passed to the controller
optimizer – Pytorch optimizer or optimizer class. Defaults to Adam.
learning_rate – If an optimizer class is provided, the learning rate is used to construct the optimizer.
fraction_train_eval – What fraction of the training dataset to evaluate in the evaluation phase
All params after stopping_key, controller, and device are optional and can be built into a controller.
Note
Multiple GPUs is an experimental feature that is currently under debugging.
- optimizer
alias of
Adam
- batch_size: int | None = None
- controller: Controller | Callable | None = None
- device: device | str | None = None
- eval_batch_size: int | None = None
- fraction_train_eval: float | None = 0.1
- learning_rate: float | None = None
- max_epochs: int | None = None
- scheduler: _LRScheduler | Callable = None
- stopping_key: str | None = None
- assemble_for_training(train_loss, validation_losses, validation_names=None, plot_maker=None)[source]
- Parameters:
train_loss – LossNode
validation_losses – dict of (name:loss_node) or list of LossNodes. -if a list of loss nodes, the name of the node will be used for printing the loss, -this can be overwritten with a list of validation_names
(optional) (validation_names) – list of names for loss nodes, only if validation_losses is a list.
plot_maker – optional PlotMaker for model evaluation
- Returns:
training_modules, db_info -db_info: dict of inputs (input to model) and targets (input to loss) in terms of the db_name.
assemble_for_training
computes:what inputs are needed to the model
what outputs of the model are needed for the loss
what targets are needed from the database for the loss
It then uses this info to create
GraphModule
s for the model and loss, and anEvaluator
based on validation loss (& names), early stopping, plot maker.Note
Model and training loss are always evaluated on the active device. But the validation losses reside by default on the CPU. This helps compute statistics over large datasets. To accomplish this, the modules associated with the loss are copied in the validation loss. Thus, after assembling the modules for training, changes to the loss nodes will not affect the model evaluator. In all likelihood you aren’t planning to do something too fancy like change the loss nodes during training. But if you do plan to do something like that with callbacks, know that you would probably need to construct a new evaluator.
- load_checkpoint(structure_fname: str, state_fname: str, restart_db=False, map_location=None, model_device=None, **kwargs) dict [source]
Load checkpoint file from given filename.
For details more information on to use this function, see Restarting training.
- Parameters:
structure_fname – name of the structure file
state_fname – name of the state file
restart_db – restore database or not, defaults to False
map_location – device mapping argument for
torch.load
, defaults to Nonemodel_device – automatically handle device mapping. Defaults to None, defaults to None
- Returns:
experiment structure
- load_checkpoint_from_cwd(map_location=None, model_device=None, **kwargs) dict [source]
Same as
load_checkpoint
, but using default filenames.- Parameters:
map_location (Union[str, dict, torch.device, Callable], optional) – device mapping argument for
torch.load
, defaults to Nonemodel_device (Union[int, str, torch.device], optional) – automatically handle device mapping. Defaults to None, defaults to None
- Returns:
experiment structure
- Return type:
dict
- load_model_from_cwd(map_location=None, model_device=None, **kwargs) GraphModule [source]
Only load model from current working directory.
- Parameters:
map_location (Union[str, dict, torch.device, Callable], optional) – device mapping argument for
torch.load
, defaults to Nonemodel_device (Union[int, str, torch.device], optional) – automatically handle device mapping. Defaults to None, defaults to None
- Returns:
model with reloaded parameters
- setup_and_train(training_modules: TrainingModules, database: Database, setup_params: SetupParams, store_all_better=False, store_best=True, store_every=0)[source]
- Param:
training_modules: see
setup_training()
- Param:
database: see
train_model()
- Param:
setup_params: see
setup_training()
- Param:
store_all_better: Save the state dict for each model doing better than a previous one
- Param:
store_best: Save a checkpoint for the best model
- Param:
store_every: Save a checkpoint for every certain epochs
- Returns:
See
train_model()
Shortcut for setup_training followed by train_model.
Note
The training loop will capture KeyboardInterrupt exceptions to abort the experiment early. If you would like to gracefully kill training programmatically, see
train_model()
with callbacks argument.Note
Saves files in the current running directory; recommend you switch to a fresh directory with a descriptive name for your experiment.
- setup_training(training_modules: TrainingModules, setup_params: SetupParams)[source]
Prepares training_modules for training with experiment_params.
- Param:
training_modules: Tuple of model, training loss, and evaluation losses (Can be built from graph using graphs.assemble_training_modules)
- Param:
setup_params: parameters controlling how training is performed (See
SetupParams
)
Roughly:
sets devices for training modules
if no controller given:
instantiates and links optimizer to the learnable params on the model
instantiates and links scheduler to optimizer
builds a default controller with setup params
creates a MetricTracker for storing the training metrics
- Returns:
(optimizer,evaluator,controller,metrics,callbacks)
- test_model(database, evaluator, batch_size, when, metric_tracker=None)[source]
Tests the model on the database according to the model_evaluator metrics. If a plot_maker is attached to the model evaluator, it will make plots. The plots will go in a sub-folder specified by when the testing is taking place. The results are then printed.
- Parameters:
database – The database test the model on.
evaluator – The evaluator containing model and evaluation losses to measure.
when – A string to specify what plots are currently to be used.
metric_tracker – (Optional) metric tracker to save metrics on. If not provided, a blank one will be constructed.
- Returns:
metric tracker
- train_model(training_modules, database, controller, metric_tracker, callbacks, batch_callbacks, store_all_better=False, store_best=True, store_every=0, store_structure_file=True, store_metrics=True, quiet=False)[source]
Performs training loop, allows keyboard interrupt. When done, reinstate the best model, make plots and metrics over time, and test the model.
- Parameters:
training_modules – tuple-like of model, loss, and evaluator
database – Database
controller – Controller
metric_tracker – MetricTracker for storing model performance
callbacks – callbacks to perform after every epoch.
batch_callbacks – callbacks to perform after every batch
store_best – Save a checkpoint for the best model
store_all_better – Save the state dict for each model doing better than a previous one
store_every – Save a checkpoint for every certain epochs
store_structure_file – Save the structure file for this experiment
store_metrics – Save the metric tracker for this experiment.
quiet – If True, disable printing during training (still prints testing results).
- Returns:
metric_tracker
Note
callbacks take the form of an iterable of callables and will be called with cb(epoch,new_best)
epoch indicates the epoch number
new_best indicates if the model is a new best model
Note
batch_callbacks take the form of an iterable of callables and will each be called with cb(batch_inputs, batch_model_outputs, batch_targets)
Note
You may want to make your callbacks store other state, if so, an easy way is to make them a callable object.
Note
callback state is not managed by
hippynn
. If your wish to save or load callback state, you will have to manage that manually (possibly with a callback itself).
Submodules