swyft.inference¶
The primary marginal objects are contained here.
- class swyft.inference.MarginalRatioEstimator(marginal_indices, network, device)[source]¶
Handles the training and evaluation of a ratio estimator. Which ratios are defined by the marginal_indices attribute. The network must take observation dictionaries and parameter arrays and produce estimated an log_ratio for every marginal of interest.
Define the marginals of interest with marginal_indices and the estimator architechture with network.
- Parameters:
marginal_indices (Union[int, Sequence[int], Sequence[Sequence[int]]]) – marginals of interest defined by the parameter index
network (Module) – a neural network which accepts observation and parameters and returns len(marginal_indices) ratios.
device (Union[device, str]) –
- classmethod from_state_dict(network, optimizer, scheduler, device, state_dict)[source]¶
Instantiate a MarginalRatioEstimator from a state_dict, along with a few necessary python objects.
- Parameters:
network (Module) – initialized network
optimizer (Optional[Optimizer]) – same optimizer as used by saved model
scheduler (Optional[Union[_LRScheduler, ReduceLROnPlateau]]) – same scheduler as used by saved model
device (Union[device, str]) –
state_dict (dict) –
- Returns:
loaded model
- Return type:
MarginalRatioEstimatorType
- log_ratio(observation, v, batch_size=None)[source]¶
Evaluate the ratio estimator on a single observation with many parameters. The parameters correspond to v, i.e. the “physical” parameterization. (As opposed to u which is mapped to the hypercube.)
- Parameters:
observation (Dict[Hashable, Union[ndarray, Tensor]]) – a single observation to estimate ratios on (Cannot have a batch dimension!)
v (Union[ndarray, Tensor]) – parameters
batch_size (Optional[int]) – divides the evaluation into batches of this size
- Returns:
the ratios of each marginal in marginal_indices. Each marginal index is a key.
- Return type:
MarginalToArray
- train(dataset, batch_size=50, learning_rate=0.0005, validation_percentage=0.1, optimizer=<class 'torch.optim.adam.Adam'>, scheduler=<class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs={'factor': 0.1, 'patience': 5}, early_stopping_patience=25, max_epochs=2147483647, nworkers=0, non_blocking=True, pin_memory=True)[source]¶
Train the ratio estimator based off of a dataset containing observation and parameter pairs.
Note: if the network has already been trained, training will resume where it left off. This effectively ignores optimizer, learning_rate, scheduler, and scheduler_args.
- Parameters:
dataset (Dataset) – torch dataset which returns a tuple of (observation, parameters)
batch_size (int) –
learning_rate (float) –
validation_percentage (float) – Approximates the percentage of dataset used in the validation set
optimizer (Callable) – from torch.optim optimizer. It can only accept two arguments: parameters and lr. Need more arguments? Use functools.partial.
scheduler (Optional[Callable]) – from torch.optim.lr_scheduler
scheduler_kwargs (dict) – The arguments which get passed to scheduler
early_stopping_patience (Optional[int]) – after this many fuitless epochs, training stops
max_epochs (int) – maximum number of epochs to train
nworkers (int) – number of workers to divide dataloader duties between. 0 implies one thread for training and dataloading.
non_blocking (bool) – consult torch documentation, generally use True
pin_memory (bool) – consult torch documentation, generally use True
- Return type:
None