swyft.networks¶
- class swyft.networks.BatchNorm1dWithChannel(num_channels, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]¶
BatchNorm1d over the batch, N. Requires shape (N, C, L).
Otherwise, same as torch.nn.BatchNorm1d with extra num_channel. Cannot do the temporal batch norm case.
- Parameters:
num_channels (int) –
num_features (int) –
eps (float) –
momentum (float) –
affine (bool) –
track_running_stats (bool) –
- forward(input)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
input (Tensor) –
- Return type:
Tensor
- class swyft.networks.LinearWithChannel(channels, in_features, out_features)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
channels (int) –
in_features (int) –
out_features (int) –
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
x (Tensor) –
- Return type:
Tensor
- class swyft.networks.MarginalClassifier(n_marginals, n_combined_features, hidden_features, num_blocks, dropout_probability=0.0, use_batch_norm=True)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
n_marginals (int) –
n_combined_features (int) –
hidden_features (int) –
num_blocks (int) –
dropout_probability (float) –
use_batch_norm (bool) –
- forward(features, marginal_block)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
features (Tensor) –
marginal_block (Tensor) –
- Return type:
Tensor
- class swyft.networks.Network(observation_transform, parameter_transform, marginal_classifier)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
observation_transform (Module) –
parameter_transform (Module) –
marginal_classifier (Module) –
- forward(observation, parameters)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
observation (Dict[Hashable, Tensor]) –
parameters (Tensor) –
- Return type:
Tensor
- class swyft.networks.ObservationTransform(observation_key, observation_shapes, online_z_score)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
observation_key (Hashable) –
observation_shapes (Mapping[Hashable, Union[Size, Tuple[int, ...]]]) –
online_z_score (bool) –
- forward(observation)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
observation (Dict[Hashable, Tensor]) –
- Return type:
Tensor
- class swyft.networks.OnlineDictStandardizingLayer(shapes, stable=False, epsilon=1e-10, use_average_std=False)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
shapes (Dict[Hashable, Tuple[int, ...]]) –
stable (bool) –
epsilon (float) –
use_average_std (bool) –
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
x (Dict[Hashable, Tensor]) –
- Return type:
Tensor
- class swyft.networks.OnlineStandardizingLayer(shape, stable=False, epsilon=1e-10, use_average_std=False)[source]¶
Accumulate mean and variance online using the “parallel algorithm” algorithm from [1].
- Parameters:
shape (Tuple[int, ...]) – shape of mean, variance, and std array. do not include batch dimension!
stable (bool) – (optional) compute using the stable version of the algorithm [1]
epsilon (float) – (optional) added to the computation of the standard deviation for numerical stability.
use_average_std (bool) – (optional)
True
to normalize using std averaged over the whole observation,False
to normalize using std of each component of the observation.
References
[1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
x (Tensor) –
- Return type:
Tensor
- class swyft.networks.ParameterTransform(n_parameters, marginal_indices, online_z_score)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
n_parameters (int) –
marginal_indices (Union[int, Sequence[int], Sequence[Sequence[int]]]) –
online_z_score (bool) –
- forward(parameters)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
parameters (Tensor) –
- Return type:
Tensor
- class swyft.networks.ResidualNetWithChannel(channels, in_features, out_features, hidden_features, num_blocks=2, activation=<function relu>, dropout_probability=0.0, use_batch_norm=False)[source]¶
A general-purpose residual network. Works only with channelized 1-dim inputs.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
channels (int) –
in_features (int) –
out_features (int) –
hidden_features (int) –
num_blocks (int) –
activation (Callable) –
dropout_probability (float) –
use_batch_norm (bool) –
- forward(inputs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
inputs (Tensor) –
- Return type:
Tensor