Source code for swyft.networks.classifier

from abc import ABC, abstractmethod
from typing import Dict, Hashable, Tuple

import torch
import torch.nn as nn

import swyft
import swyft.utils
from swyft.networks.channelized import ResidualNetWithChannel
from swyft.networks.standardization import (
    OnlineDictStandardizingLayer,
    OnlineStandardizingLayer,
)
from swyft.types import Array, MarginalIndex, ObsShapeType


class HeadTailClassifier(ABC):
    """Abstract class which ensures that child classifier networks will function with swyft"""

    @abstractmethod
    def head(self, observation: Dict[Hashable, torch.Tensor]) -> torch.Tensor:
        """convert the observation into a tensor of features

        Args:
            observation: observation type

        Returns:
            a tensor of features which can be utilized by tail
        """
        pass

    @abstractmethod
    def tail(self, features: torch.Tensor, parameters: torch.Tensor) -> torch.Tensor:
        """finish the forward pass using features computed by head

        Args:
            features: output of head
            parameters: the parameters normally given to forward pass

        Returns:
            the same output as `forward(observation, parameters)`
        """
        pass


[docs] class ObservationTransform(nn.Module): def __init__( self, observation_key: Hashable, observation_shapes: ObsShapeType, online_z_score: bool, ) -> None: super().__init__() self.observation_key = observation_key self.observation_shapes = observation_shapes self.flatten = nn.Flatten() if online_z_score: self.online_z_score = OnlineDictStandardizingLayer(self.observation_shapes) else: self.online_z_score = nn.Identity() def forward(self, observation: Dict[Hashable, torch.Tensor]) -> torch.Tensor: z_scored_observation = self.online_z_score(observation) return self.flatten(z_scored_observation[self.observation_key]) # B, O @property def n_features(self) -> int: with torch.no_grad(): fabricated_observation = { key: torch.rand(2, *shape) for key, shape in self.observation_shapes.items() } _, n_features = self.forward(fabricated_observation).shape return n_features
[docs] class ParameterTransform(nn.Module): def __init__( self, n_parameters: int, marginal_indices: MarginalIndex, online_z_score: bool ) -> None: super().__init__() self.register_buffer( "marginal_indices", torch.tensor(swyft.utils.tupleize_marginal_indices(marginal_indices)), ) self.n_parameters = torch.Size([n_parameters]) if online_z_score: self.online_z_score = OnlineStandardizingLayer(self.n_parameters) else: self.online_z_score = nn.Identity() def forward(self, parameters: torch.Tensor) -> torch.Tensor: parameters = self.online_z_score(parameters) return self.get_marginal_block(parameters, self.marginal_indices) # B, M, P @property def marginal_block_shape(self) -> Tuple[int, int]: return self.get_marginal_block_shape(self.marginal_indices) @staticmethod def is_marginal_block_possible(marginal_indices: MarginalIndex) -> bool: marginal_indices = swyft.utils.tupleize_marginal_indices(marginal_indices) return [len(marginal_indices[0]) == len(mi) for mi in marginal_indices] @classmethod def get_marginal_block_shape( cls, marginal_indices: MarginalIndex ) -> Tuple[int, int]: marginal_indices = swyft.utils.tupleize_marginal_indices(marginal_indices) assert cls.is_marginal_block_possible( marginal_indices ), f"Each tuple in {marginal_indices} must have the same length." return len(marginal_indices), len(marginal_indices[0]) @classmethod def get_marginal_block( cls, parameters: Array, marginal_indices: MarginalIndex ) -> torch.Tensor: depth = swyft.utils.depth(marginal_indices) tuple_marginal_indices = swyft.utils.tupleize_marginal_indices(marginal_indices) assert cls.is_marginal_block_possible( tuple_marginal_indices ), f"Each tuple in {tuple_marginal_indices} must have the same length." if depth in [0, 1, 2]: return torch.stack( [parameters[..., mi] for mi in tuple_marginal_indices], dim=1 ) else: raise ValueError( f"{marginal_indices} must be of the form (a) 2, (b) [2, 3], (c) [2, [1, 3]], or (d) [[0, 1], [1, 2]]." )
def spectral_embedding(z, Lmax=8): device = z.device DB = z.shape[-1] f = 2 ** torch.arange(Lmax, device=device) ZF = z.repeat_interleave(Lmax, dim=-1) * f.repeat(DB) # Embedding multiplies last dimension size by 2*Lmax+1 return torch.cat([z, torch.sin(ZF), torch.cos(ZF)], dim=-1)
[docs] class MarginalClassifier(nn.Module): def __init__( self, n_marginals: int, n_combined_features: int, hidden_features: int, num_blocks: int, dropout_probability: float = 0.0, use_batch_norm: bool = True, Lmax: int = 0, ) -> None: super().__init__() self.n_marginals = n_marginals self.n_combined_features = n_combined_features self.net = ResidualNetWithChannel( channels=self.n_marginals, in_features=self.n_combined_features * (1 + 2 * Lmax), out_features=1, hidden_features=hidden_features, num_blocks=num_blocks, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) self.Lmax = Lmax def forward( self, features: torch.Tensor, marginal_block: torch.Tensor ) -> torch.Tensor: if len(features.shape) == 2: # Input shape is B, O fb = features.unsqueeze(1).expand(-1, self.n_marginals, -1) # B, M, O else: fb = features # Input shape is alreadby B, M, O combined = torch.cat([fb, marginal_block], dim=2) # B, M, O + P if self.Lmax > 0: combined = spectral_embedding(combined, Lmax=self.Lmax) return self.net(combined).squeeze(-1) # B, M
[docs] class Network(nn.Module, HeadTailClassifier): def __init__( self, observation_transform: nn.Module, parameter_transform: nn.Module, marginal_classifier: nn.Module, ) -> None: super().__init__() self.observation_transform = observation_transform self.parameter_transform = parameter_transform self.marginal_classifier = marginal_classifier def forward( self, observation: Dict[Hashable, torch.Tensor], parameters: torch.Tensor ) -> torch.Tensor: features = self.observation_transform(observation) # B, O marginal_block = self.parameter_transform(parameters) # B, M, P return self.marginal_classifier(features, marginal_block) # B, M
[docs] def head(self, observation: Dict[Hashable, torch.Tensor]) -> torch.Tensor: return self.observation_transform(observation) # B, O
[docs] def tail(self, features: torch.Tensor, parameters: torch.Tensor) -> torch.Tensor: marginal_block = self.parameter_transform(parameters) # B, M, P return self.marginal_classifier(features, marginal_block) # B, M
def get_marginal_classifier( observation_key: Hashable, marginal_indices: MarginalIndex, observation_shapes: ObsShapeType, n_parameters: int, hidden_features: int, num_blocks: int, observation_online_z_score: bool = True, parameter_online_z_score: bool = True, ) -> nn.Module: observation_transform = ObservationTransform( observation_key, observation_shapes, online_z_score=observation_online_z_score ) n_observation_features = observation_transform.n_features parameter_transform = ParameterTransform( n_parameters, marginal_indices, online_z_score=parameter_online_z_score ) n_marginals, n_block_parameters = parameter_transform.marginal_block_shape marginal_classifier = MarginalClassifier( n_marginals, n_observation_features + n_block_parameters, hidden_features=hidden_features, num_blocks=num_blocks, ) return Network(observation_transform, parameter_transform, marginal_classifier,) if __name__ == "__main__": pass