Source code for hippynn.interfaces.schnetpack_interface
"""
Interface to use schnetpack networks with hippynn.
Note: Only open boundary conditions are supported for now.
"""
try:
import schnetpack
except ImportError as ie:
raise ImportError("Schnetpack installation is required to use the schnetpack interface.") from ie
import torch
from schnetpack import Properties
from ...graphs.nodes.base import SingleNode, AutoKw
from ...graphs.nodes.networks import Network
from ...graphs.indextypes import IdxType
[docs]
class SchNetWrapper(torch.nn.Module):
def __init__(self, schnet):
super().__init__()
self.schnet = schnet
feature_sizes = [x.dense.out_features for x in schnet.interactions]
if not schnet.return_intermediates:
feature_sizes = [feature_sizes[-1]]
self.feature_sizes = feature_sizes
[docs]
def forward(self, z_arr, r_arr, nonblank):
"""
Wrap a call into the underlying schnet, which uses
dictionaries.
:param z_arr:
:param r_arr:
:param nonblank:
:return:
"""
packed = create_schnetpack_inputs(z_arr, r_arr, nonblank)
outputs = self.schnet(packed)
if self.schnet.return_intermediate:
outputs = outputs[1]
else:
outputs = [outputs]
outputs = [x[packed["nonblank"]] for x in outputs]
return outputs
[docs]
class SchNetNode(AutoKw, Network, SingleNode):
_input_names = "species", "positions", "nonblank"
_index_state = IdxType.Atoms
_auto_module_class = SchNetWrapper
def __init__(self, name, parents, module="auto", module_kwargs=None):
if module == "auto":
self.module_kwargs = module_kwargs
module = self.auto_module()
super().__init__(name, parents, module=module)
[docs]
def create_schnetpack_inputs(z_arr, r_arr, nonblank):
dtype = r_arr.dtype
device = r_arr.device
n_atoms_per_mol = (z_arr > 0).sum(axis=1)
n_mols = n_atoms_per_mol.shape[0]
n_atoms = n_atoms_per_mol.max()
z_arr = z_arr[:, :n_atoms]
r_arr = r_arr[:, :n_atoms]
nonblank = nonblank[:, :n_atoms]
atom_range = torch.arange(n_atoms).unsqueeze(0).expand(n_mols, -1)
atom_mask = torch.zeros(n_mols, n_atoms, dtype=z_arr.dtype, device=device)
atom_mask[atom_range < n_atoms_per_mol.unsqueeze(1)] = 1
neighbor_base = torch.arange(n_atoms - 1, device=device, dtype=dtype).unsqueeze(0).expand(n_mols, n_atoms, -1)
neighbor_base = neighbor_base + torch.triu(torch.ones_like(neighbor_base), diagonal=0)
neighbor_mask = atom_mask.unsqueeze(2) * atom_mask.unsqueeze(1)[:, :, 1:]
neighbors = neighbor_base * neighbor_mask
cell = torch.zeros((n_mols, 3, 3), device=device, dtype=dtype)
cell_offset = torch.zeros((n_mols, n_atoms, n_atoms - 1, 3), device=device, dtype=dtype)
return {
Properties.atom_mask: atom_mask.to(dtype),
Properties.neighbors: neighbors,
Properties.neighbor_mask: neighbor_mask.to(dtype),
Properties.Z: z_arr,
Properties.R: r_arr,
Properties.cell: cell,
Properties.cell_offset: cell_offset,
"nonblank": nonblank,
}