Restarting training
You may want to restart training from a previous state, e.g., due to job length constraints on HPC systems.
By default, hippynn saves checkpoint information for the best model found so far as measured by validation performance.
The checkpoint contains the training modules, the experiment controller, and
the metrics of the experiment so far. This can be seen by breaking down
setup_and_train()
into its component steps,
setup_training()
, and train_model()
:
from hippynn.experiment import setup_training, train_model
training_modules, controller, metrics = setup_training(
training_modules=training_modules,
setup_params=experiment_params,
)
train_model(
training_modules,
database,
controller,
metrics,
callbacks=None,
batch_callbacks=None,
)
Simple restart
To restart training later, you can use the following:
from hippynn.experiment.serialization import load_checkpoint
check = load_checkpoint("./experiment_structure.pt", "./best_checkpoint.pt")
train_model(**check, callbacks=None, batch_callbacks=None)
or to use the default filenames and load from the current directory:
from hippynn.experiment.serialization import load_checkpoint_from_cwd
check = load_checkpoint_from_cwd()
train_model(**check, callbacks=None, batch_callbacks=None)
Note
In release 0.0.4, the restore_db
argument has been renamed to
restart_db
for internal consistence. restore_db
in all scripts using
hippynn > 0.0.3 should be replaced with restart_db
. The affected
functions are load_checkpoint
, load_checkpoint_from_cwd
, and
restore_checkpoint
. If hippynn <= 0.0.3 is used, please keep the
original restore_db
keyword.
If all you want to do is use a previously trained model, here is how to load the model only:
from hippynn.experiment.serialization import load_model_from_cwd
model = load_model_from_cwd()
The returned model
object will have the original model with the best
parameters loaded. This can then be used with, for example, the Predictor.
Cross-device restart
If a model was trained on a device that is no longer accessible (due to change
of configuration or loading on a different computer) you may want to load a checkpoint
to a different device. The standard pytorch argument map_device
is a bit tricky to
handle in this case, as not all tensors in the checkpoint still belong on the device.
If this keyword is specified, hippynn
will attempt to automatically move the correct
tensors to the correct device. To perform cross-device loading, use the model_device
argument to load_checkpoint_from_cwd()
or load_checkpoint()
:
from hippynn.experiment.serialization import load_checkpoint_from_cwd
check = load_checkpoint_from_cwd(model_device='cuda:0')
train_model(**check, callbacks=None, batch_callbacks=None)
The string ‘auto’ can be provided to transfer to the default device.
Note
To avoid cross-device restarts as much as you can, use the
environment variable CUDA_VISIBLE_DEVICES
before importing hippynn
.
In this case, if, for example, only 1 GPU is used, it will always be
labeled as 0, no matter physically which device is used.
Advanced Details
The checkpoint file contains the torch RNG state, and restoring a checkpoint resets the torch RNG.
If your database is not
Restartable
, you will have to explicitly reload it and pass it totrain_model
, as well. If your database is restartable, any pre-processing of the database is not recorded in the checkpoint file. Thus any pre-processing steps such as moving the database to the GPU need to be performed before activatingtrain_model
. The dictionary containing the database information is stored astraining_modules.evaluator.db_info
, so you can use this dictionary to easily reload your database.hippynn does not include support for serializing and restarting callback objects; to restart a training that involves callbacks, the callbacks will have to be retrieved using user code.
It is not a good idea to wholesale transfer tensors in a checkpoint off of the CPU using a keyword such as
map_location=torch.device(0)
. This will map all tensors to GPU 0, and breaks the RNG which only supports a CPU tensor. Doing so, you will see errors likeTypeError: RNG state must be a torch.ByteTensor
. Moving everything to CPU withmap_location="cpu"
always works. Ifmap_location
is used, and the value is anything other thanNone
or"cpu"
, you are likely to get an error during loading or training. The argument will directly be passed totorch.load
.For more details of this option, check torch load docs.
Here are a list of objects and their final device after loading.
Objects
Destinations
training_modules.model
model_device
training_modules.loss
model_device
training_modules.evaluator.loss
CPU
controller.optimizer
Some on
model_device
and some on CPU, depending on details of the implementation.database
CPU
Not mentioned
CPU