Easy Custom Guides¶
EasyGuide¶
-
class
EasyGuide(model)[source]¶ Bases:
objectBase class for “easy guides”.
Derived classes should define a
guide()method. Thisguide()method can combine ordinary guide statements (e.g.pyro.sampleandpyro.param) with the following special statements:group = self.group(...)selects multiplepyro.samplesites in the model. SeeGroupfor subsequent methods.with self.plate(...): ...should be used instead ofpyro.plate.self.map_estimate(...)uses aDeltaguide for a single site.
Derived classes may also override the
init()method to provide custom initialization for models sites.Parameters: model (callable) – A Pyro model. -
init(site)[source]¶ Model initialization method, may be overridden by user.
This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:
return site["fn"]()
For other possible initialization functions see http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization
-
plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]¶ A wrapper around
pyro.plateto allow EasyGuide to automatically construct plates. You should use this rather thanpyro.plateinside yourguide()implementation.
-
group(match='.*')[source]¶ Select a
Groupof model sites for joint guidance.Parameters: match (str) – A regex string matching names of model sample sites. Returns: A group of model sites. Return type: Group
-
map_estimate(name)[source]¶ Construct a maximum a posteriori (MAP) guide using Delta distributions.
Parameters: name (str) – The name of a model sample site. Returns: A sampled value. Return type: torch.Tensor
easy_guide¶
-
easy_guide(model)[source]¶ Convenience decorator to create an
EasyGuide. The following are equivalent:# Version 1. Decorate a function. @easy_guide(model) def guide(self, foo, bar): return my_guide(foo, bar) # Version 2. Create and instantiate a subclass of EasyGuide. class Guide(EasyGuide): def guide(self, foo, bar): return my_guide(foo, bar) guide = Guide(model)
Parameters: model (callable) – a Pyro model.
Group¶
-
class
Group(guide, sites)[source]¶ Bases:
objectAn autoguide helper to match a group of model sites.
Variables: - event_shape (torch.Size) – The total flattened concatenated shape of all matching sample sites in the model.
- prototype_sites (list) – A list of all matching sample sites in a prototype trace of the model.
Parameters: -
guide¶
-
sample(guide_name, fn, infer=None)[source]¶ Wrapper around
pyro.sample()to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.Parameters: Returns: A pair
(guide_z, model_zs)whereguide_zis the single concatenated blob andmodel_zsis a dict mapping site name to constrained model sample.Return type: