ISubGVQA.sampling.methods.target

Attributes

logger

Classes

BaseTargetDistribution

Helper class that provides a standard way to create an ABC using

TargetDistribution

Creates a generator of target distributions parameterized by alpha and beta.

Module Contents

ISubGVQA.sampling.methods.target.logger
class ISubGVQA.sampling.methods.target.BaseTargetDistribution

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

abstract params(theta: torch.Tensor, dy: torch.Tensor) torch.Tensor
class ISubGVQA.sampling.methods.target.TargetDistribution(alpha: float = 1.0, beta: float = 1.0)

Bases: BaseTargetDistribution

Creates a generator of target distributions parameterized by alpha and beta.

Example:

>>> import torch
>>> target_distribution = TargetDistribution(alpha=1.0, beta=1.0)
>>> target_distribution.params(theta=torch.tensor([1.0]), dy=torch.tensor([1.0]))
tensor([2.])
Args:

alpha (float): weight of the initial distribution parameters theta beta (float): weight of the downstream gradient dy

alpha = 1.0
beta = 1.0
params(theta: torch.Tensor, dy: torch.Tensor) torch.Tensor