Source code for hippynn.layers.transform

"""
Layers that wrap other layers.
"""

import torch


# RESNET WRAPPER for the layers.
[docs] class ResNetWrapper(torch.nn.Module): """ Resnet Wrapper Class """ def __init__(self, base_layer, nf_in, nf_middle, nf_out, activation=torch.nn.functional.softplus): """ Constructor :param base_layer: The ResLayer pytorch module :param nf_in: Input dimensions :param nf_out: Output dimensions :param activation: a nonlinearity function """ super().__init__() self.activation = activation self.base_layer = base_layer self.res_layer = torch.nn.Linear(nf_middle, nf_out) if nf_in != nf_out: self.adjust_layer = torch.nn.Linear(nf_in, nf_out, False) self.needs_size_adjust = True else: self.needs_size_adjust = False
[docs] def regularization_params(self): params = [self.res_layer.weight] if hasattr(self.base_layer, "regularization_params"): params.extend(self.base_layer.regularization_params()) else: params.append(self.base_layer.weight) if self.needs_size_adjust: params.append(self.adjust_layer.weight) return params
[docs] def forward(self, *input): """ :param input: list of inputs for the layer; the first one is taken to be the features which will be computed as a residual :return: Pytorch module output """ middle_activation = self.activation(self.base_layer(*input)) difference_activation = self.res_layer(middle_activation) input_activation = input[0] if self.needs_size_adjust: input_activation = self.adjust_layer(input_activation) return difference_activation + input_activation