swyft.networks

class swyft.networks.BatchNorm1dWithChannel(num_channels, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]

BatchNorm1d over the batch, N. Requires shape (N, C, L).

Otherwise, same as torch.nn.BatchNorm1d with extra num_channel. Cannot do the temporal batch norm case.

Parameters:
  • num_channels (int) –

  • num_features (int) –

  • eps (float) –

  • momentum (float) –

  • affine (bool) –

  • track_running_stats (bool) –

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

input (Tensor) –

Return type:

Tensor

class swyft.networks.LinearWithChannel(channels, in_features, out_features)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • channels (int) –

  • in_features (int) –

  • out_features (int) –

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor) –

Return type:

Tensor

class swyft.networks.MarginalClassifier(n_marginals, n_combined_features, hidden_features, num_blocks, dropout_probability=0.0, use_batch_norm=True)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • n_marginals (int) –

  • n_combined_features (int) –

  • hidden_features (int) –

  • num_blocks (int) –

  • dropout_probability (float) –

  • use_batch_norm (bool) –

forward(features, marginal_block)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • features (Tensor) –

  • marginal_block (Tensor) –

Return type:

Tensor

class swyft.networks.Network(observation_transform, parameter_transform, marginal_classifier)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • observation_transform (Module) –

  • parameter_transform (Module) –

  • marginal_classifier (Module) –

forward(observation, parameters)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • observation (Dict[Hashable, Tensor]) –

  • parameters (Tensor) –

Return type:

Tensor

head(observation)[source]

convert the observation into a tensor of features

Parameters:

observation (Dict[Hashable, Tensor]) – observation type

Returns:

a tensor of features which can be utilized by tail

Return type:

Tensor

tail(features, parameters)[source]

finish the forward pass using features computed by head

Parameters:
  • features (Tensor) – output of head

  • parameters (Tensor) – the parameters normally given to forward pass

Returns:

the same output as forward(observation, parameters)

Return type:

Tensor

class swyft.networks.ObservationTransform(observation_key, observation_shapes, online_z_score)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • observation_key (Hashable) –

  • observation_shapes (Mapping[Hashable, Union[Size, Tuple[int, ...]]]) –

  • online_z_score (bool) –

forward(observation)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

observation (Dict[Hashable, Tensor]) –

Return type:

Tensor

class swyft.networks.OnlineDictStandardizingLayer(shapes, stable=False, epsilon=1e-10, use_average_std=False)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • shapes (Dict[Hashable, Tuple[int, ...]]) –

  • stable (bool) –

  • epsilon (float) –

  • use_average_std (bool) –

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Dict[Hashable, Tensor]) –

Return type:

Tensor

class swyft.networks.OnlineStandardizingLayer(shape, stable=False, epsilon=1e-10, use_average_std=False)[source]

Accumulate mean and variance online using the “parallel algorithm” algorithm from [1].

Parameters:
  • shape (Tuple[int, ...]) – shape of mean, variance, and std array. do not include batch dimension!

  • stable (bool) – (optional) compute using the stable version of the algorithm [1]

  • epsilon (float) – (optional) added to the computation of the standard deviation for numerical stability.

  • use_average_std (bool) – (optional) True to normalize using std averaged over the whole observation, False to normalize using std of each component of the observation.

References

[1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor) –

Return type:

Tensor

class swyft.networks.ParameterTransform(n_parameters, marginal_indices, online_z_score)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • n_parameters (int) –

  • marginal_indices (Union[int, Sequence[int], Sequence[Sequence[int]]]) –

  • online_z_score (bool) –

forward(parameters)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

parameters (Tensor) –

Return type:

Tensor

class swyft.networks.ResidualNetWithChannel(channels, in_features, out_features, hidden_features, num_blocks=2, activation=<function relu>, dropout_probability=0.0, use_batch_norm=False)[source]

A general-purpose residual network. Works only with channelized 1-dim inputs.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • channels (int) –

  • in_features (int) –

  • out_features (int) –

  • hidden_features (int) –

  • num_blocks (int) –

  • activation (Callable) –

  • dropout_probability (float) –

  • use_batch_norm (bool) –

forward(inputs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

inputs (Tensor) –

Return type:

Tensor