This class provides alternate vectorization semantics for
tfd.JointDistributionNamed, which in many cases eliminate the need to
explicitly account for batch shapes in the model specification.
Instead of simply summing the log_probs of component distributions
(which may have different shapes), it first reduces the component log_probs
to ensure that jd.log_prob(jd.sample()) always returns a scalar, unless
otherwise specified.

The essential changes are:

An event of JointDistributionNamedAutoBatched is the dictionary of
tensors produced by .sample(); thus, the event_shape is the
dictionary containing the shapes of sampled tensors. These combine both
the event and batch dimensions of the component distributions. By contrast,
the event shape of a base JointDistributions does not include batch
dimensions of component distributions.

The batch_shape is a global property of the entire model, rather
than a per-component property as in base JointDistributions.
The global batch shape must be a prefix of the batch shapes of
each component; the length of this prefix is specified by an optional
argument batch_ndims. If batch_ndims is not specified, the model has
batch shape [].

Notice the 1:1 correspondence between "math" and "code". In a standard
JointDistributionNamed, we would have wrapped the first variable as
e = tfd.Independent(tfd.Exponential(rate=[100, 120]),
reinterpreted_batch_ndims=1) to specify that log_prob of the Exponential
should be a scalar, summing over both dimensions. This behavior is implicit
in JointDistributionNamedAutoBatched.

Args

model

Python dict, collections.OrderedDict, or namedtuple of
distribution-making functions each with required args corresponding
only to other keys.

validate_args

Python bool. Whether to validate input with asserts.
If validate_args is False, and the inputs are invalid,
correct behavior is not guaranteed.
Default value: False.

name

The name for ops managed by the distribution.
Default value: None (i.e., "JointDistributionNamed").

Attributes

allow_nan_stats

Python bool describing behavior when a stat is undefined.

Stats return +/- infinity when it makes sense. E.g., the variance of a
Cauchy distribution is infinity. However, sometimes the statistic is
undefined, e.g., if a distribution's pdf does not achieve a maximum within
the support of the distribution, the mode is undefined. If the mean is
undefined, then by definition the variance is undefined. E.g. the mean for
Student's T for df = 1 is undefined (no clear way to say it is either + or -
infinity), so the variance = E[(X - mean)**2] is also undefined.

batch_ndims

batch_shape

dtype

The DType of Tensors handled by this Distribution.

event_shape

model

name

Name prepended to all ops created by this Distribution.

parameters

Dictionary of parameters used to instantiate this Distribution.

reparameterization_type

Describes how samples from the distribution are reparameterized.

Currently this is one of the static instances
tfd.FULLY_REPARAMETERIZED or tfd.NOT_REPARAMETERIZED.

cross_entropy

Denote this distribution (self) by P and the other distribution by
Q. Assuming P, Q are absolutely continuous with respect to
one another and permit densities p(x) dr(x) and q(x) dr(x), (Shannon)
cross entropy is defined as:

log_prob

The measure methods of `JointDistribution` (`log_prob`, `prob`, etc.)
can be called either by passing a single structure of tensors or by using
named args for each part of the joint distribution state. For example,

`JointDistribution` component distributions names are resolved via
`jd._flat_resolve_names()`, which is implemented by each `JointDistribution`
subclass (see subclass documentation for details). Generally, for components
where a name was provided---
either explicitly as the `name` argument to a distribution or as a key in a
dict-valued JointDistribution, or implicitly, e.g., by the argument name of
a `JointDistributionSequential` distribution-making function---the provided
name will be used. Otherwise the component will receive a dummy name; these
may change without warning and should not be relied upon.
Note: not all `JointDistribution` subclasses support all calling styles;
for example, `JointDistributionNamed` does not support positional arguments
(aka "unnamed arguments") unless the provided model specifies an ordering of
variables (i.e., is an `collections.OrderedDict` or `collections.namedtuple`
rather than a plain `dict`).
Note: care is taken to resolve any potential ambiguity---this is generally
possible by inspecting the structure of the provided argument and "aligning"
it to the joint distribution output structure (defined by `jd.dtype`). For
example,

Notice that in the first call, `[4.]` is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the `Exponential` component---creating a vector-shaped batch
of `log_prob`s---we could instead write
`trivial_jd.log_prob(np.array([4]))`.
Args:
*args: Positional arguments: a `value` structure or component values
(see above).
**kwargs: Keyword arguments: a `value` structure or component values
(see above). May also include `name`, specifying a Python string name
for ops generated by this method.

Returns

log_prob

a Tensor of shape sample_shape(x) + self.batch_shape with
values of type self.dtype.

param_static_shapes

This is a class method that describes what key/value arguments are required
to instantiate the given Distribution so that a particular shape is
returned for that instance's call to sample(). Assumes that the sample's
shape is known statically.

prob

The measure methods of `JointDistribution` (`log_prob`, `prob`, etc.)
can be called either by passing a single structure of tensors or by using
named args for each part of the joint distribution state. For example,

`JointDistribution` component distributions names are resolved via
`jd._flat_resolve_names()`, which is implemented by each `JointDistribution`
subclass (see subclass documentation for details). Generally, for components
where a name was provided---
either explicitly as the `name` argument to a distribution or as a key in a
dict-valued JointDistribution, or implicitly, e.g., by the argument name of
a `JointDistributionSequential` distribution-making function---the provided
name will be used. Otherwise the component will receive a dummy name; these
may change without warning and should not be relied upon.
Note: not all `JointDistribution` subclasses support all calling styles;
for example, `JointDistributionNamed` does not support positional arguments
(aka "unnamed arguments") unless the provided model specifies an ordering of
variables (i.e., is an `collections.OrderedDict` or `collections.namedtuple`
rather than a plain `dict`).
Note: care is taken to resolve any potential ambiguity---this is generally
possible by inspecting the structure of the provided argument and "aligning"
it to the joint distribution output structure (defined by `jd.dtype`). For
example,

Notice that in the first call, `[4.]` is interpreted as a list of one
scalar while in the second call the input is a scalar. Hence both inputs
result in identical scalar outputs. If we wanted to pass an explicit
vector to the `Exponential` component---creating a vector-shaped batch
of `prob`s---we could instead write
`trivial_jd.prob(np.array([4]))`.
Args:
*args: Positional arguments: a `value` structure or component values
(see above).
**kwargs: Keyword arguments: a `value` structure or component values
(see above). May also include `name`, specifying a Python string name
for ops generated by this method.

Returns

prob

a Tensor of shape sample_shape(x) + self.batch_shape with
values of type self.dtype.