Comments

edited

For the next major release, we can incorporate some of the lessons from numpyro and issues from our forum, to improve the HMC interface as follows:

Provide a potential_fn like in numpyro so that users do not necessarily need to pass in their model, but can generate samples from any callable whose log_density can be evaluated. This will be useful for integrating with funsors, and eventually getting rid of our TraceTreeEvaluator wrappers. pyro-ppl/funsor#123

If an initial trace is provided, we should make sure that we do not ever run the model to set up our initial state. This is important because many a time we want to sample from distributions which do not have a sample method defined. We should still be able to run HMC on such models.

Related to the points above, it will be nice if (just like in numpyro) once given a potential_fn and initial_sample we can generate subsequent samples, without having to make any assumptions about the container data structure. Right now, we assume that this container data structure is a Pyro trace object, but this is only needed to interface with the TracePosterior class. This will make it really simple to integrate NUTS/HMC into other libraries. I think a middle ground might be to move the trace wrapping/unwrapping logic into the MCMC class, but more discussion is needed.

This comment has been minimized.

I think that following numpyro approach is a good idea. It is also helpful for the future when things like lax.cond, lax.while_loop is available in pytorch (disclaim: I don't know what is the state of pytorch jit right now) so we can jit the whole trajectory as in numpyro to improve the speed.

This comment has been minimized.

Absolutely! Please feel free to create / mark off any sub-issue that you are working on above. I was planning to look at changes to the TracePosterior interface, so please feel free to take up the potential_fn refactor, which we also need to interface with funsors.