pymc.sample_posterior_predictive#

pymc.sample_posterior_predictive(trace, model=None, *, var_names=None, sample_vars=None, freeze_vars=None, sample_dims=None, random_seed=None, progressbar=True, progressbar_theme=<rich.theme.Theme object>, return_inferencedata=True, extend_inferencedata=False, predictions=False, idata_kwargs=None, backend=None, compile_kwargs=None)[source]#

Generate forward samples for var_names, conditioned on the posterior samples of variables found in the trace.

This method can be used to perform different kinds of model predictions, including posterior predictive checks.

The matching of unobserved model variables, and posterior samples in the trace is made based on the variable names. Therefore, a different model than the one used for posterior sampling may be used for posterior predictive sampling, as long as the variables whose posterior we want to condition on have the same name, and compatible shape and coordinates.

Parameters:
tracebackend, list, xarray.Dataset, xarray.DataTree, or MultiTrace

Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), or xarray.Dataset (eg. DataTree.posterior or DataTree.prior)

modelModel (optional if in with context)

Model to be used to generate the posterior predictive samples. It will generally be the model used to generate the trace, but it doesn’t need to be.

sample_varsstr or list of str, optional

Random variables or deterministics to regenerate on each draw rather than copy from the trace. Regeneration propagates volatility downstream: an RV that is in the trace and not listed here keeps its trace value, but if one of its ancestors is volatile (listed here, or a changed Data/coord) an ImplicitFreezeWarning flags it so the user can opt in by adding it here, or silence the warning via freeze_vars. Empty by default — RVs missing from the trace (including observed RVs) are always regenerated automatically. Cannot overlap with freeze_vars.

freeze_varsstr or list of str, optional

Trace variables (RVs or deterministics) to reuse from the trace. Cannot overlap with sample_vars. Trace RVs not in sample_vars are already implicitly frozen, so the practical effect of listing an RV here is to silence its ImplicitFreezeWarning. Deterministics don’t trigger that warning at all — a volatile deterministic just recomputes with the current upstream values — so listing one only matters when you want to keep the trace value instead (see example below).

var_namesstr or list of str, optional

Controls only which variables appear in the output; does not trigger resampling. Each listed name is either computed fresh or copied from the input trace, depending on whether it or any of its upstream is volatile (see the behavior section below). Defaults to sample_vars when that is specified; otherwise (the classic posterior-predictive default) to the observed variables plus any deterministic that depends on these.

sample_dimslist of str, optional

Dimensions over which to loop and generate posterior predictive samples. When sample_dims is None (default) both “chain” and “draw” are considered sample dimensions. Only taken into account when trace is xarray.DataTree or Dataset.

random_seedint, RandomState or Generator, optional

Seed for the random number generator.

progressbarbool

Whether to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion (“expected time of arrival”; ETA).

return_inferencedatabool, default True

Whether to return an xarray.DataTree (True) object or a dictionary (False).

extend_inferencedatabool, default False

Whether to automatically use xarray.DataTree.update() to add the posterior predictive samples to trace or not. If True, trace is modified inplace but still returned. If the DataTree already contains a group that would be added (e.g. posterior_predictive), a warning is issued and the existing group is overwritten.

predictionsbool, default False

Flag used to set the location of posterior predictive samples within the returned xarray.DataTree object. If False, assumes samples are generated based on the fitting data to be used for posterior predictive checks, and samples are stored in the posterior_predictive. If True, assumes samples are generated based on out-of-sample data as predictions, and samples are stored in the predictions group.

idata_kwargsdict, optional

Keyword arguments for pymc.to_inference_data() if predictions=False or to pymc.predictions_to_inference_data() otherwise.

backend: str, optional

Which computational backend to use. Recommended to be one of “numba”, “c”, and “jax”.

compile_kwargs: dict, optional

Keyword arguments for pymc.pytensorf.compile(). compile_kwargs["mode"] cannot be combined with backend.

Returns:
xarray.DataTree or Dict

A xarray.DataTree object containing the posterior predictive samples (default), or a dictionary with variable names as keys, and samples as numpy arrays.

Examples

Posterior predictive checks and predictions#

The most common use of sample_posterior_predictive is to perform posterior predictive checks (in-sample predictions) and new model predictions (out-of-sample predictions). Deterministics that depend on Data are recomputed automatically when the data changes — no extra work needed:

import pymc as pm

with pm.Model(coords={"trial": [0, 1, 2]}) as model:
    x = pm.Data("x", [-1, 0, 1], dims=["trial"])
    beta = pm.Normal("beta")
    noise = pm.HalfNormal("noise")
    linpred = pm.Deterministic("linpred", x * beta, dims=["trial"])
    y = pm.Normal("y", mu=linpred, sigma=noise, observed=[-2, 0, 3], dims=["trial"])

    idata = pm.sample()

    # in-sample posterior predictive
    posterior_predictive = pm.sample_posterior_predictive(idata).posterior_predictive

with model:
    pm.set_data({"x": [-2, 2]}, coords={"trial": [3, 4]})
    # out-of-sample predictions. `linpred` is recomputed with the new `x`
    # (and the trace's `beta`); `y` is resampled from the new `linpred`.
    pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)

Freezing deterministics#

A deterministic is normally recomputed whenever its inputs change. Occasionally, though, a deterministic captures something that should stay anchored to the training data — e.g. an HSGP standardization computed from pm.Data that must not be rederived from the prediction data. Pass the deterministic in freeze_vars to keep its trace value:

import pymc as pm

with pm.Model() as model:
    x = pm.Data("x", [1.0, 2.0, 3.0])
    x_mean = pm.Deterministic("x_mean", x.mean())
    centered = pm.Deterministic("centered", x - x_mean)
    mu = pm.Normal("mu")
    obs = pm.Normal("obs", mu + centered, 1, observed=[0, 0, 0])

    idata = pm.sample()

# New x values. Without freezing, `x_mean` would be recomputed as the new mean.
with model:
    pm.set_data({"x": [100.0, 200.0, 300.0]})
    pm.sample_posterior_predictive(idata, freeze_vars=["x_mean"])

Forcing a deterministic to recompute#

If do() swaps a new expression into a deterministic while every RV and Data value stays unchanged, sample_posterior_predictive sees nothing volatile and reuses the deterministic from the trace. List it in sample_vars to force recomputation from the current graph:

with pm.Model() as model:
    x = pm.Normal("x")
    pm.Deterministic("det", x**2)
    pm.Normal("obs", model["det"], 1, observed=[0.0])
    idata = pm.sample()

with pm.do(model, {model["det"]: model["x"] ** 3}) as intervened_model:
    # Force recomputation using the new `x**3` graph.
    pm.sample_posterior_predictive(idata, sample_vars=["det", "obs"])

Using different models#

It’s common to use the same model for posterior and posterior predictive sampling, but this is not required. The matching between unobserved model variables and posterior samples is based on the name alone.

For the last example we could have created a new predictions model. Since the new y has no observations, we request it via sample_vars argument.

import pymc as pm

with pm.Model(coords={"trial": [0, 1, 2]}) as train_model:
    x = pm.Data("x", [-1, 0, 1], dims=["trial"])
    beta = pm.Normal("beta")
    noise = pm.HalfNormal("noise")
    y = pm.Normal("y", mu=x * beta, sigma=noise, observed=[-2, 0, 3], dims=["trial"])

    idata = pm.sample()

with pm.Model(coords={"trial": [3, 4]}) as prediction_model:
    x = pm.Data("x", [-2, 2], dims=["trial"])
    beta = pm.Normal("beta")
    noise = pm.HalfNormal("noise")
    y = pm.Normal("y", mu=x * beta, sigma=noise, dims=["trial"])

    predictions = pm.sample_posterior_predictive(
        idata,
        sample_vars=["y"],
        predictions=True,
    )

The new model may even have a different structure and unobserved variables that don’t exist in the trace. These variables will be sampled automatically because they have no trace values to fall back on. In the following example we added a new extra_noise variable between the inferred posterior noise and the new StudentT observational distribution y:

with pm.Model(coords={"trial": [3, 4]}) as distinct_predictions_model:
    x = pm.Data("x", [-2, 2], dims=["trial"])
    beta = pm.Normal("beta")
    noise = pm.HalfNormal("noise")
    extra_noise = pm.HalfNormal("extra_noise", sigma=noise)
    y = pm.StudentT("y", nu=4, mu=x * beta, sigma=extra_noise, dims=["trial"])

    predictions = pm.sample_posterior_predictive(idata, var_names=["y"], predictions=True)

For more about out-of-model predictions, see this blog post.

The behavior of sample_vars, freeze_vars, and var_names#

Each of these three arguments controls one aspect of the operation:

  • sample_vars — trace variables to treat as volatile: regenerate them (from their distribution or expression) instead of copying from the trace. Empty by default.

  • freeze_vars — which trace variables to reuse explicitly (silences the implicit-freeze warning below).

  • var_names — which variables appear in the output. Does not trigger resampling of variables in the trace. Defaults to sample_vars.

Volatility. Volatility originates from three sources — variables listed in sample_vars, changed Data/coords, and RVs missing from the trace (including observed RVs, which are always regenerated since they have no trace value to reuse). It then propagates downstream through deterministics and other RVs. An RV that is in the trace and not listed in sample_vars keeps its trace value — even when one of its ancestors is being resampled. This prevents a single sample_vars=["x"] call, or a set_data call, from silently invalidating the posterior values for every downstream variable. When an auto-frozen trace variable has a volatile ancestor, an ImplicitFreezeWarning flags it so the user can opt in by adding it to sample_vars (to resample) or opt out by adding it to freeze_vars (to silence the warning while keeping the trace value). The log lists all the RVs being resampled in any given call.

The following examples use this model:

from logging import getLogger
import pymc as pm

# Some environments like google colab suppress
# the default logging output of PyMC
getLogger("pymc").setLevel("INFO")

kwargs = {"progressbar": False, "random_seed": 0}

with pm.Model() as model:
    x = pm.Normal("x")
    y = pm.Normal("y")
    z = pm.Normal("z", x + y**2)
    det = pm.Deterministic("det", pm.math.exp(z))
    obs = pm.Normal("obs", det, 1, observed=[20])

    idata = pm.sample(tune=10, draws=10, chains=2, **kwargs)

Default behavior: Generate samples of obs conditioned on the posterior samples of z found in the trace. These are often referred to as posterior predictive samples in the literature:

with model:
    pm.sample_posterior_predictive(idata, **kwargs)
    # Sampling: [obs]

Copy the trace values for z and det. Nothing is resampled without explicit sample_vars:

with model:
    pm.sample_posterior_predictive(idata, var_names=["z", "det"], **kwargs)
    # Sampling: []

Generate new samples of z and det, conditioned on the posterior samples of x and y found in the trace.

with model:
    pm.sample_posterior_predictive(idata, var_names=["z", "det"], sample_vars=["z"], **kwargs)
    # Sampling: [z]

Generate samples of y, z and det, conditioned on the posterior samples of x found in the trace.

Warning

The samples of y are equivalent to its prior, since it does not depend on any other variables.

In contrast, the samples of z and det depend on the new samples of y and the posterior samples of x found in the trace.

with model:
    pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], sample_vars=["y", "z"], **kwargs)
    # Sampling: [y, z]

Note that if z is not placed in sample_vars it won’t be resampled even though it depends on the freshly drawn y — cascade stops at RVs that are in the trace. A warning flags this behavior for z:

with model:
    pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], sample_vars=["y"], **kwargs)
    # ImplicitFreezeWarning: 'z' (ancestor is resampled (y))
    # Sampling: [y]

If this is the intended behavior z can be added to freeze_vars explicitly, and the warning is avoided.

with model:
    pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], sample_vars=["y"], freeze_vars=["z"], **kwargs)
    # Sampling: [y]

Passing every RV to sample_vars makes this equivalent to sample_prior_predictive(). Including obs in sample_vars is redundant — it isn’t in the trace so it is always regenerated:

with model:
    pm.sample_posterior_predictive(
        idata,
        var_names=["x", "y", "z", "det", "obs"],
        sample_vars=["x", "y", "z", "obs"],
        **kwargs,
    )
    # Sampling: [obs, x, y, z]

Controlling the number of samples#

You can manipulate the DataTree to control the number of samples

import pymc as pm

with pm.Model() as model:
    ...
    idata = pm.sample()

Generate 1 posterior predictive sample for every 5 posterior samples.

thinned_idata = idata.sel(draw=slice(None, None, 5))
with model:
    idata.update(pm.sample_posterior_predictive(thinned_idata))

Generate 5 posterior predictive samples for every posterior sample.

expanded_idata = idata.copy()
expanded_idata.posterior = idata.posterior.expand_dims(pred_id=5)
with model:
    pm.sample_posterior_predictive(
        expanded_idata,
        sample_dims=["chain", "draw", "pred_id"],
        extend_inferencedata=True,
    )