"""
Layers that wrap other layers.
"""
import torch
# RESNET WRAPPER for the layers.
[docs]
class ResNetWrapper(torch.nn.Module):
"""
Resnet Wrapper Class
"""
def __init__(self, base_layer, nf_in, nf_middle, nf_out, activation=torch.nn.functional.softplus):
"""
Constructor
:param base_layer: The ResLayer pytorch module
:param nf_in: Input dimensions
:param nf_out: Output dimensions
:param activation: a nonlinearity function
"""
super().__init__()
self.activation = activation
self.base_layer = base_layer
self.res_layer = torch.nn.Linear(nf_middle, nf_out)
if nf_in != nf_out:
self.adjust_layer = torch.nn.Linear(nf_in, nf_out, False)
self.needs_size_adjust = True
else:
self.needs_size_adjust = False
[docs]
def regularization_params(self):
params = [self.res_layer.weight]
if hasattr(self.base_layer, "regularization_params"):
params.extend(self.base_layer.regularization_params())
else:
params.append(self.base_layer.weight)
if self.needs_size_adjust:
params.append(self.adjust_layer.weight)
return params
[docs]
def forward(self, *input):
"""
:param input: list of inputs for the layer;
the first one is taken to be the features which will be computed as a residual
:return: Pytorch module output
"""
middle_activation = self.activation(self.base_layer(*input))
difference_activation = self.res_layer(middle_activation)
input_activation = input[0]
if self.needs_size_adjust:
input_activation = self.adjust_layer(input_activation)
return difference_activation + input_activation