serialization module

Full Documentation for hippynn.experiment.serialization module. Click here for a summary page.

Checkpoint and state generation.

As a user, in most cases you will only need the load functions here.

check_mapping_devices(map_location, model_device)[source]

Check options for restarting across devices.

Parameters:
  • map_location – device mapping argument for torch.load.

  • model_device – automatically handle device mapping.

Raises:

TypeError – if both map_location and model_device are specified

Returns:

processed map_location and model_device

create_state(model: GraphModule, controller: Controller, metric_tracker: MetricTracker) dict[source]

Create an experiment state dictionary.

Parameters:
  • model – current model

  • controller – controller

  • metric_tracker – current metrics

Returns:

dictionary containing experiment state.

Return type:

dict

create_structure_file(training_modules: TrainingModules, database: Database, controller: Controller, fname='experiment_structure.pt') None[source]

Save an experiment structure. (i.e. full model, not just state_dict).

Parameters:
  • training_modules – contains model, controller, and loss

  • database – database for training

  • controller – controller

  • fname – filename to save the checkpoint

Returns:

None

load_checkpoint(structure_fname: str, state_fname: str, restart_db=False, map_location=None, model_device=None, **kwargs) dict[source]

Load checkpoint file from given filename.

For details more information on to use this function, see Restarting training.

Parameters:
  • structure_fname – name of the structure file

  • state_fname – name of the state file

  • restart_db – restore database or not, defaults to False

  • map_location – device mapping argument for torch.load, defaults to None

  • model_device – automatically handle device mapping. Defaults to None, defaults to None

Returns:

experiment structure

load_checkpoint_from_cwd(map_location=None, model_device=None, **kwargs) dict[source]

Same as load_checkpoint, but using default filenames.

Parameters:
  • map_location (Union[str, dict, torch.device, Callable], optional) – device mapping argument for torch.load, defaults to None

  • model_device (Union[int, str, torch.device], optional) – automatically handle device mapping. Defaults to None, defaults to None

Returns:

experiment structure

Return type:

dict

load_model_from_cwd(map_location=None, model_device=None, **kwargs) GraphModule[source]

Only load model from current working directory.

Parameters:
  • map_location (Union[str, dict, torch.device, Callable], optional) – device mapping argument for torch.load, defaults to None

  • model_device (Union[int, str, torch.device], optional) – automatically handle device mapping. Defaults to None, defaults to None

Returns:

model with reloaded parameters

load_saved_tensors(structure_fname: str, state_fname: str, **kwargs) Tuple[dict, dict][source]

Load torch tensors from file.

Parameters:
  • structure_fname – name of the structure file

  • state_fname – name of the state file

Returns:

loaded dictionaries of checkpoint and model parameters

restore_checkpoint(structure: dict, state: dict, restart_db=False) dict[source]

This function loads the parameters from the state dictionary into the modules, optionally tries to restart the database, and sets the RNG state.

Parameters:
  • structure – experiment structure object

  • state – experiment state object

  • restart_db – Attempt to restore database (true/false)

Returns:

experiment structure