swyft.networks#
- class swyft.networks.BatchNorm1dWithChannel(*args, **kwargs)[source]#
BatchNorm1d over the batch, N. Requires shape (N, C, L).
Otherwise, same as torch.nn.BatchNorm1d with extra num_channel. Cannot do the temporal batch norm case.
- Parameters:
num_channels (int) –
num_features (int) –
eps (float) –
momentum (float) –
affine (bool) –
track_running_stats (bool) –
- class swyft.networks.LinearWithChannel(*args, **kwargs)[source]#
- Parameters:
channels (int) –
in_features (int) –
out_features (int) –
- class swyft.networks.MarginalClassifier(*args, **kwargs)[source]#
- Parameters:
n_marginals (int) –
n_combined_features (int) –
hidden_features (int) –
num_blocks (int) –
dropout_probability (float) –
use_batch_norm (bool) –
Lmax (int) –
- class swyft.networks.Network(*args, **kwargs)[source]#
- Parameters:
observation_transform (torch.nn.Module) –
parameter_transform (torch.nn.Module) –
marginal_classifier (torch.nn.Module) –
- head(observation)[source]#
convert the observation into a tensor of features
- Parameters:
observation (Dict[Hashable, torch.Tensor]) – observation type
- Returns:
a tensor of features which can be utilized by tail
- Return type:
torch.Tensor
- tail(features, parameters)[source]#
finish the forward pass using features computed by head
- Parameters:
features (torch.Tensor) – output of head
parameters (torch.Tensor) – the parameters normally given to forward pass
- Returns:
the same output as forward(observation, parameters)
- Return type:
torch.Tensor
- class swyft.networks.ObservationTransform(*args, **kwargs)[source]#
- Parameters:
observation_key (Hashable) –
observation_shapes (Mapping[Hashable, torch.Size | Tuple[int, ...]]) –
online_z_score (bool) –
- class swyft.networks.OnlineDictStandardizingLayer(*args, **kwargs)[source]#
- Parameters:
shapes (Dict[Hashable, Tuple[int, ...]]) –
stable (bool) –
epsilon (float) –
use_average_std (bool) –
- class swyft.networks.OnlineStandardizingLayer(*args, **kwargs)[source]#
Accumulate mean and variance online using the “parallel algorithm” algorithm from [1].
- Parameters:
shape (Tuple[int, ...]) – shape of mean, variance, and std array. do not include batch dimension!
stable (bool) – (optional) compute using the stable version of the algorithm [1]
epsilon (float) – (optional) added to the computation of the standard deviation for numerical stability.
use_average_std (bool) – (optional)
True
to normalize using std averaged over the whole observation,False
to normalize using std of each component of the observation.
References
[1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
- class swyft.networks.ParameterTransform(*args, **kwargs)[source]#
- Parameters:
n_parameters (int) –
marginal_indices (int | Sequence[int] | Sequence[Sequence[int]]) –
online_z_score (bool) –
- class swyft.networks.ResidualNetWithChannel(*args, **kwargs)[source]#
A general-purpose residual network. Works only with channelized 1-dim inputs.
- Parameters:
channels (int) –
in_features (int) –
out_features (int) –
hidden_features (int) –
num_blocks (int) –
activation (Callable) –
dropout_probability (float) –
use_batch_norm (bool) –