ISubGVQA.sampling.methods.wrapper

Attributes

logger

Functions

imle([function, target_distribution, ...])

Turns a black-box combinatorial solver in an Exponential Family distribution via Perturb-and-MAP and I-MLE [1].

Module Contents

ISubGVQA.sampling.methods.wrapper.logger
ISubGVQA.sampling.methods.wrapper.imle(function: Callable[[torch.Tensor], torch.Tensor] = None, target_distribution: ISubGVQA.sampling.methods.target.BaseTargetDistribution | None = None, noise_distribution: ISubGVQA.sampling.methods.noise.BaseNoiseDistribution | None = None, nb_samples: int = 1, input_noise_temperature: float = 1.0, target_noise_temperature: float = 1.0)

Turns a black-box combinatorial solver in an Exponential Family distribution via Perturb-and-MAP and I-MLE [1].

The input function (solver) needs to return the solution to the problem of finding a MAP state for a constrained exponential family distribution – this is the case for most black-box combinatorial solvers [2]. If this condition is violated though, the result would not hold and there is no guarantee on the validity of the obtained gradients.

This function can be used directly or as a decorator.

[1] Mathias Niepert, Pasquale Minervini, Luca Franceschi - Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions. NeurIPS 2021 (https://arxiv.org/abs/2106.01798) [2] Marin Vlastelica, Anselm Paulus, Vít Musil, Georg Martius, Michal Rolínek - Differentiation of Blackbox Combinatorial Solvers. ICLR 2020 (https://arxiv.org/abs/1912.02175)

Example:

>>> from imle.wrapper import imle
>>> from imle.target import TargetDistribution
>>> from imle.noise import SumOfGammaNoiseDistribution
>>> target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
>>> noise_distribution = SumOfGammaNoiseDistribution(k=21, nb_iterations=100)
>>> @imle(target_distribution=target_distribution, noise_distribution=noise_distribution, nb_samples=100,
>>>       input_noise_temperature=input_noise_temperature, target_noise_temperature=5.0)
>>> def imle_solver(weights_batch: Tensor) -> Tensor:
>>>     return torch_solver(weights_batch)
Args:

function (Callable[[Tensor], Tensor]): black-box combinatorial solver target_distribution (Optional[BaseTargetDistribution]): factory for target distributions noise_distribution (Optional[BaseNoiseDistribution]): noise distribution nb_samples (int): number of noise sammples input_noise_temperature (float): noise temperature for the input distribution target_noise_temperature (float): noise temperature for the target distribution