Automatic Guide Generation¶
AutoGuide¶
-
class
AutoGuide
(model, prefix='auto')[source]¶ Bases:
object
Base class for automatic guides.
Derived classes must implement the
__call__()
method.Auto guides can be used individually or combined in an
AutoGuideList
object.Parameters: - model (callable) – a pyro model
- prefix (str) – a prefix that will be prefixed to all param internal sites
-
__call__
(*args, **kwargs)[source]¶ A guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
AutoGuideList¶
-
class
AutoGuideList
(model, prefix='auto')[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
Container class to combine multiple automatic guides.
Example usage:
guide = AutoGuideList(my_model) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
Parameters: - model (callable) – a Pyro model
- prefix (str) – a prefix that will be prefixed to all param internal sites
-
__call__
(*args, **kwargs)[source]¶ A composite guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
AutoCallable¶
-
class
AutoCallable
(model, guide, median=<function AutoCallable.<lambda>>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
AutoGuide
wrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.add(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.add(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide
.Parameters: - model (callable) – a Pyro model
- guide (callable) – a Pyro guide (typically over only part of the model)
- median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoDelta¶
-
class
AutoDelta
(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
By default latent variables are initialized using
init_loc_fn()
. To change this default behavior the user should callpyro.param()
before beginning inference, with"auto_"
prefixed to the targeted sample site names e.g. for sample sites named “level” and “concentration”, initialize via:pyro.param("auto_level", torch.tensor([-1., 0., 1.])) pyro.param("auto_concentration", torch.ones(k), constraint=constraints.positive)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
AutoContinuous¶
-
class
AutoContinuous
(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
Each derived class implements its own
get_posterior()
method.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Parameters: model (callable) – a Pyro model Reference:
- [1] Automatic Differentiation Variational Inference,
- Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
__call__
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a list of quantile values. Return type: dict
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal
(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the Cholesky factor is initialized to the identity. To change this default behavior the user should call
pyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_scale_tril", torch.tril(torch.rand(latent_dim)), constraint=constraints.lower_cholesky)
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal
(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity. To change this default behavior the user should call
pyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_scale", torch.ones(latent_dim), constraint=constraints.positive)
AutoLowRankMultivariateNormal¶
-
class
AutoLowRankMultivariateNormal
(model, prefix='auto', init_loc_fn=<function init_to_median>, rank=1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diag
is initialized to 1/2 and thecov_factor
is intialized randomly such thatcov_factor.matmul(cov_factor.t())
is half the identity matrix. To change this default behavior the user should callpyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_cov_factor", torch.randn(latent_dim, rank))) pyro.param("auto_cov_diag", torch.randn(latent_dim).exp()), constraint=constraints.positive)
Parameters: - model (callable) – a generative model
- rank (int) – the rank of the low-rank part of the covariance matrix
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- prefix (str) – a prefix that will be prefixed to all param internal sites
AutoIAFNormal¶
-
class
AutoIAFNormal
(model, hidden_dim=None, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aInverseAutoregressiveFlow
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- hidden_dim (int) – number of hidden dimensions in the IAF
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- prefix (str) – a prefix that will be prefixed to all param internal sites
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation
(model, prefix='auto', init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to zero. To change this default behavior the user should call
pyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim))
-
laplace_approximation
(*args, **kwargs)[source]¶ Returns a
AutoMultivariateNormal
instance whose posterior’s loc and scale_tril are given by Laplace approximation.
-
AutoDiscreteParallel¶
Initialization¶
The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.
The standard interface for initialization is a function that inputs a Pyro
trace site
dict and returns an appropriately sized value
to serve
as an initial constrained value for a guide estimate.
-
init_to_feasible
(site)[source]¶ Initialize to an arbitrary feasible point, ignoring distribution parameters.
-
init_to_median
(site, num_samples=15)[source]¶ Initialize to the prior median; fallback to a feasible point if median is undefined.
-
class
InitMessenger
(init_fn)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Initializes a site by replacing
.sample()
calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.Parameters: init_fn (callable) – An initialization function.