env_pytorch module
Full Documentation for hippynn.custom_kernels.env_pytorch
module.
Click here for a summary page.
Pure pytorch implementation of envsum operations.
Note that these implementations are usually replaced by custom cuda kernels, as explained in the Custom Kernels section of the documentation.
- envsum(sensitivities: Tensor, features: Tensor, pair_first: Tensor, pair_second: Tensor) Tensor [source]
Computes outer product of sensitivities of pairs and atom features from pair_second, whilst accumulating them onto indices pair_first.
See the Custom Kernels section of the documentation for more information.
- Parameters:
sensitivities – (n_pairs, n_sensitivities) floating point tensor
features – (n_atoms, n_features) floating point tensor
pair_first – (n_pairs,) index tensor indicating first atom of pair
pair_second – (n_pairs,) index tensor indicating second atom of pair
- Returns:
env (n_atoms, n_sensitivities, n_features) floating tensor
- featsum(env, sense, pair_first, pair_second)[source]
Compute inner product of sensitivities with environment tensor over atoms from pair_first, while accumulating them on to pair_second.
The summation order is different from envsum because this signature naturally supports the use of featsum as a backwards pass for envsum, and vise-versa.
See the Custom Kernels section of the documentation for more information.
- Parameters:
env – (n_atoms, n_sensitivities, n_features) floating tensor
sense – (n_pairs, n_sensitivities) floating point tensor
pair_first – (n_pairs,) index tensor indicating first atom of pair
pair_second – (n_pairs,) index tensor indicating second atom of pair
- Returns:
feat (n_atoms, n_features) floating point tensor
- sensesum(env, features, pair_first, pair_second)[source]
Computes product of environment at pair_first with features from pair_second, whilst summing over feature indices.
See the Custom Kernels section of the documentation for more information.
- Parameters:
env – (n_atoms, n_sensitivities, n_features) floating tensor
features – (n_atoms, n_features) floating point tensor
pair_first – (n_pairs,) index tensor indicating first atom of pair
pair_second – (n_pairs,) index tensor indicating second atom of pair
- Returns:
sense (n_pairs, n_sensitivities) floating point tensor