Automatic Guide Generation¶
AutoGuide¶
-
class
AutoGuide(model, prefix='auto')[source]¶ Bases:
objectBase class for automatic guides.
Derived classes must implement the
__call__()method.Auto guides can be used individually or combined in an
AutoGuideListobject.Parameters: - model (callable) – a pyro model
- prefix (str) – a prefix that will be prefixed to all param internal sites
-
__call__(*args, **kwargs)[source]¶ A guide with the same
*args, **kwargsas the basemodel.Returns: A dict mapping sample site name to sampled value. Return type: dict
AutoGuideList¶
-
class
AutoGuideList(model, prefix='auto')[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideContainer class to combine multiple automatic guides.
Example usage:
guide = AutoGuideList(my_model) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
Parameters: - model (callable) – a Pyro model
- prefix (str) – a prefix that will be prefixed to all param internal sites
-
__call__(*args, **kwargs)[source]¶ A composite guide with the same
*args, **kwargsas the basemodel.Returns: A dict mapping sample site name to sampled value. Return type: dict
AutoCallable¶
-
class
AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideAutoGuidewrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.add(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.add(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide.Parameters: - model (callable) – a Pyro model
- guide (callable) – a Pyro guide (typically over only part of the model)
- median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoDelta¶
-
class
AutoDelta(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideThis implementation of
AutoGuideuses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
By default latent variables are initialized using
init_loc_fn(). To change this default behavior the user should callpyro.param()before beginning inference, with"auto_"prefixed to the targeted sample site names e.g. for sample sites named “level” and “concentration”, initialize via:pyro.param("auto_level", torch.tensor([-1., 0., 1.])) pyro.param("auto_concentration", torch.ones(k), constraint=constraints.positive)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
AutoContinuous¶
-
class
AutoContinuous(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideBase class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
Each derived class implements its own
get_posterior()method.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Parameters: model (callable) – a Pyro model Reference:
- [1] Automatic Differentiation Variational Inference,
- Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
__call__(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargsas the basemodel.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
median(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a list of quantile values. Return type: dict
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the Cholesky factor is initialized to the identity. To change this default behavior the user should call
pyro.param()before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_scale_tril", torch.tril(torch.rand(latent_dim)), constraint=constraints.lower_cholesky)
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity. To change this default behavior the user should call
pyro.param()before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_scale", torch.ones(latent_dim), constraint=constraints.positive)
AutoLowRankMultivariateNormal¶
-
class
AutoLowRankMultivariateNormal(model, prefix='auto', init_loc_fn=<function init_to_median>, rank=1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diagis initialized to 1/2 and thecov_factoris intialized randomly such thatcov_factor.matmul(cov_factor.t())is half the identity matrix. To change this default behavior the user should callpyro.param()before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_cov_factor", torch.randn(latent_dim, rank))) pyro.param("auto_cov_diag", torch.randn(latent_dim).exp()), constraint=constraints.positive)
Parameters: - model (callable) – a generative model
- rank (int) – the rank of the low-rank part of the covariance matrix
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- prefix (str) – a prefix that will be prefixed to all param internal sites
AutoIAFNormal¶
-
class
AutoIAFNormal(model, hidden_dim=None, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a Diagonal Normal distribution transformed via aInverseAutoregressiveFlowto construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- hidden_dim (int) – number of hidden dimensions in the IAF
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- prefix (str) – a prefix that will be prefixed to all param internal sites
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousLaplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to zero. To change this default behavior the user should call
pyro.param()before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim))
-
laplace_approximation(*args, **kwargs)[source]¶ Returns a
AutoMultivariateNormalinstance whose posterior’s loc and scale_tril are given by Laplace approximation.
-
AutoDiscreteParallel¶
Initialization¶
The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.
The standard interface for initialization is a function that inputs a Pyro
trace site dict and returns an appropriately sized value to serve
as an initial constrained value for a guide estimate.
-
init_to_feasible(site)[source]¶ Initialize to an arbitrary feasible point, ignoring distribution parameters.
-
init_to_median(site, num_samples=15)[source]¶ Initialize to the prior median; fallback to a feasible point if median is undefined.
-
class
InitMessenger(init_fn)[source]¶ Bases:
pyro.poutine.messenger.MessengerInitializes a site by replacing
.sample()calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.Parameters: init_fn (callable) – An initialization function.