ISubGVQA.sampling.methods.aimle

Attributes

logger

Functions

aimle([function, target_distribution, ...])

Turns a black-box combinatorial solver in an Exponential Family distribution via Perturb-and-MAP and I-MLE [1].

Module Contents

ISubGVQA.sampling.methods.aimle.logger
ISubGVQA.sampling.methods.aimle.aimle(function: Callable[[torch.Tensor], torch.Tensor] | None = None, target_distribution: ISubGVQA.sampling.methods.target_aimle.BaseTargetDistribution | None = None, noise_distribution: ISubGVQA.sampling.methods.noise.BaseNoiseDistribution | None = None, nb_samples: int = 1, nb_marginal_samples: int = 1, theta_noise_temperature: float = 1.0, target_noise_temperature: float = 1.0, symmetric_perturbation: bool = False, _is_minimization: bool = False)

Turns a black-box combinatorial solver in an Exponential Family distribution via Perturb-and-MAP and I-MLE [1].

The theta function (solver) needs to return the solution to the problem of finding a MAP state for a constrained exponential family distribution – this is the case for most black-box combinatorial solvers [2]. If this condition is violated though, the result would not hold and there is no guarantee on the validity of the obtained gradients.

This function can be used directly or as a decorator.

[1] Mathias Niepert, Pasquale Minervini, Luca Franceschi - Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions. NeurIPS 2021 (https://arxiv.org/abs/2106.01798) [2] Marin Vlastelica, Anselm Paulus, Vít Musil, Georg Martius, Michal Rolínek - Differentiation of Blackbox Combinatorial Solvers. ICLR 2020 (https://arxiv.org/abs/1912.02175)

Example:

>>> from imle.aimle import aimle
>>> from imle.target import TargetDistribution
>>> from imle.noise import SumOfGammaNoiseDistribution
>>> target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
>>> noise_distribution = SumOfGammaNoiseDistribution(k=21, nb_iterations=100)
>>> @aimle(target_distribution=target_distribution, noise_distribution=noise_distribution, nb_samples=100,
>>>        theta_noise_temperature=theta_noise_temperature, target_noise_temperature=5.0)
>>> def aimle_solver(weights_batch: Tensor) -> Tensor:
>>>     return torch_solver(weights_batch)
Args:

function (Callable[[Tensor], Tensor]): black-box combinatorial solver target_distribution (Optional[BaseTargetDistribution]): factory for target distributions noise_distribution (Optional[BaseNoiseDistribution]): noise distribution nb_samples (int): number of noise samples nb_marginal_samples (int): number of noise samples used to compute the marginals theta_noise_temperature (float): noise temperature for the theta distribution target_noise_temperature (float): noise temperature for the target distribution symmetric_perturbation (bool): whether it uses the symmetric version of IMLE _is_minimization (bool): whether MAP is solving an argmin problem