swyft.networks#

class swyft.networks.BatchNorm1dWithChannel(*args, **kwargs)[source]#

BatchNorm1d over the batch, N. Requires shape (N, C, L).

Otherwise, same as torch.nn.BatchNorm1d with extra num_channel. Cannot do the temporal batch norm case.

Parameters:
  • num_channels (int) –

  • num_features (int) –

  • eps (float) –

  • momentum (float) –

  • affine (bool) –

  • track_running_stats (bool) –

class swyft.networks.LinearWithChannel(*args, **kwargs)[source]#
Parameters:
  • channels (int) –

  • in_features (int) –

  • out_features (int) –

class swyft.networks.MarginalClassifier(*args, **kwargs)[source]#
Parameters:
  • n_marginals (int) –

  • n_combined_features (int) –

  • hidden_features (int) –

  • num_blocks (int) –

  • dropout_probability (float) –

  • use_batch_norm (bool) –

  • Lmax (int) –

class swyft.networks.Network(*args, **kwargs)[source]#
Parameters:
  • observation_transform (torch.nn.Module) –

  • parameter_transform (torch.nn.Module) –

  • marginal_classifier (torch.nn.Module) –

head(observation)[source]#

convert the observation into a tensor of features

Parameters:

observation (Dict[Hashable, torch.Tensor]) – observation type

Returns:

a tensor of features which can be utilized by tail

Return type:

torch.Tensor

tail(features, parameters)[source]#

finish the forward pass using features computed by head

Parameters:
  • features (torch.Tensor) – output of head

  • parameters (torch.Tensor) – the parameters normally given to forward pass

Returns:

the same output as forward(observation, parameters)

Return type:

torch.Tensor

class swyft.networks.ObservationTransform(*args, **kwargs)[source]#
Parameters:
  • observation_key (Hashable) –

  • observation_shapes (Mapping[Hashable, torch.Size | Tuple[int, ...]]) –

  • online_z_score (bool) –

class swyft.networks.OnlineDictStandardizingLayer(*args, **kwargs)[source]#
Parameters:
  • shapes (Dict[Hashable, Tuple[int, ...]]) –

  • stable (bool) –

  • epsilon (float) –

  • use_average_std (bool) –

class swyft.networks.OnlineStandardizingLayer(*args, **kwargs)[source]#

Accumulate mean and variance online using the “parallel algorithm” algorithm from [1].

Parameters:
  • shape (Tuple[int, ...]) – shape of mean, variance, and std array. do not include batch dimension!

  • stable (bool) – (optional) compute using the stable version of the algorithm [1]

  • epsilon (float) – (optional) added to the computation of the standard deviation for numerical stability.

  • use_average_std (bool) – (optional) True to normalize using std averaged over the whole observation, False to normalize using std of each component of the observation.

References

[1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

class swyft.networks.ParameterTransform(*args, **kwargs)[source]#
Parameters:
  • n_parameters (int) –

  • marginal_indices (int | Sequence[int] | Sequence[Sequence[int]]) –

  • online_z_score (bool) –

class swyft.networks.ResidualNetWithChannel(*args, **kwargs)[source]#

A general-purpose residual network. Works only with channelized 1-dim inputs.

Parameters:
  • channels (int) –

  • in_features (int) –

  • out_features (int) –

  • hidden_features (int) –

  • num_blocks (int) –

  • activation (Callable) –

  • dropout_probability (float) –

  • use_batch_norm (bool) –