serialization module
Full Documentation for hippynn.experiment.serialization
module.
Click here for a summary page.
Checkpoint and state generation.
As a user, in most cases you will only need the load functions here.
- check_mapping_devices(map_location, model_device)[source]
Check options for restarting across devices.
- Parameters:
map_location – device mapping argument for torch.load.
model_device – automatically handle device mapping.
- Raises:
TypeError – if both map_location and model_device are specified
- Returns:
processed map_location and model_device
- create_state(model: GraphModule, controller: Controller, metric_tracker: MetricTracker) dict [source]
Create an experiment state dictionary.
- Parameters:
model – current model
controller – controller
metric_tracker – current metrics
- Returns:
dictionary containing experiment state.
- Return type:
dict
- create_structure_file(training_modules: TrainingModules, database: Database, controller: Controller, fname='experiment_structure.pt') None [source]
Save an experiment structure. (i.e. full model, not just state_dict).
- Parameters:
training_modules – contains model, controller, and loss
database – database for training
controller – controller
fname – filename to save the checkpoint
- Returns:
None
- load_checkpoint(structure_fname: str, state_fname: str, restart_db=False, map_location=None, model_device=None, **kwargs) dict [source]
Load checkpoint file from given filename.
For details more information on to use this function, see Restarting training.
- Parameters:
structure_fname – name of the structure file
state_fname – name of the state file
restart_db – restore database or not, defaults to False
map_location – device mapping argument for
torch.load
, defaults to Nonemodel_device – automatically handle device mapping. Defaults to None, defaults to None
- Returns:
experiment structure
- load_checkpoint_from_cwd(map_location=None, model_device=None, **kwargs) dict [source]
Same as
load_checkpoint
, but using default filenames.- Parameters:
map_location (Union[str, dict, torch.device, Callable], optional) – device mapping argument for
torch.load
, defaults to Nonemodel_device (Union[int, str, torch.device], optional) – automatically handle device mapping. Defaults to None, defaults to None
- Returns:
experiment structure
- Return type:
dict
- load_model_from_cwd(map_location=None, model_device=None, **kwargs) GraphModule [source]
Only load model from current working directory.
- Parameters:
map_location (Union[str, dict, torch.device, Callable], optional) – device mapping argument for
torch.load
, defaults to Nonemodel_device (Union[int, str, torch.device], optional) – automatically handle device mapping. Defaults to None, defaults to None
- Returns:
model with reloaded parameters
- load_saved_tensors(structure_fname: str, state_fname: str, **kwargs) Tuple[dict, dict] [source]
Load torch tensors from file.
- Parameters:
structure_fname – name of the structure file
state_fname – name of the state file
- Returns:
loaded dictionaries of checkpoint and model parameters
- restore_checkpoint(structure: dict, state: dict, restart_db=False) dict [source]
This function loads the parameters from the state dictionary into the modules, optionally tries to restart the database, and sets the RNG state.
- Parameters:
structure – experiment structure object
state – experiment state object
restart_db – Attempt to restore database (true/false)
- Returns:
experiment structure