pymc.sample_posterior_predictive#
- pymc.sample_posterior_predictive(trace, model=None, *, var_names=None, sample_vars=None, freeze_vars=None, sample_dims=None, random_seed=None, progressbar=True, progressbar_theme=<rich.theme.Theme object>, return_inferencedata=True, extend_inferencedata=False, predictions=False, idata_kwargs=None, backend=None, compile_kwargs=None)[source]#
Generate forward samples for var_names, conditioned on the posterior samples of variables found in the trace.
This method can be used to perform different kinds of model predictions, including posterior predictive checks.
The matching of unobserved model variables, and posterior samples in the trace is made based on the variable names. Therefore, a different model than the one used for posterior sampling may be used for posterior predictive sampling, as long as the variables whose posterior we want to condition on have the same name, and compatible shape and coordinates.
- Parameters:
- trace
backend,list,xarray.Dataset,xarray.DataTree, orMultiTrace Trace generated from MCMC sampling, or a list of dicts (eg. points or from
find_MAP()), orxarray.Dataset(eg. DataTree.posterior or DataTree.prior)- model
Model(optionalifinwithcontext) Model to be used to generate the posterior predictive samples. It will generally be the model used to generate the trace, but it doesn’t need to be.
- sample_vars
strorlistofstr, optional Random variables or deterministics to regenerate on each draw rather than copy from the trace. Regeneration propagates volatility downstream: an RV that is in the trace and not listed here keeps its trace value, but if one of its ancestors is volatile (listed here, or a changed Data/coord) an
ImplicitFreezeWarningflags it so the user can opt in by adding it here, or silence the warning viafreeze_vars. Empty by default — RVs missing from the trace (including observed RVs) are always regenerated automatically. Cannot overlap withfreeze_vars.- freeze_vars
strorlistofstr, optional Trace variables (RVs or deterministics) to reuse from the trace. Cannot overlap with
sample_vars. Trace RVs not insample_varsare already implicitly frozen, so the practical effect of listing an RV here is to silence itsImplicitFreezeWarning. Deterministics don’t trigger that warning at all — a volatile deterministic just recomputes with the current upstream values — so listing one only matters when you want to keep the trace value instead (see example below).- var_names
strorlistofstr, optional Controls only which variables appear in the output; does not trigger resampling. Each listed name is either computed fresh or copied from the input trace, depending on whether it or any of its upstream is volatile (see the behavior section below). Defaults to
sample_varswhen that is specified; otherwise (the classic posterior-predictive default) to the observed variables plus any deterministic that depends on these.- sample_dims
listofstr, optional Dimensions over which to loop and generate posterior predictive samples. When
sample_dimsisNone(default) both “chain” and “draw” are considered sample dimensions. Only taken into account when trace is xarray.DataTree or Dataset.- random_seed
int,RandomStateorGenerator, optional Seed for the random number generator.
- progressbarbool
Whether to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion (“expected time of arrival”; ETA).
- return_inferencedatabool, default
True Whether to return an
xarray.DataTree(True) object or a dictionary (False).- extend_inferencedatabool, default
False Whether to automatically use
xarray.DataTree.update()to add the posterior predictive samples to trace or not. If True, trace is modified inplace but still returned. If the DataTree already contains a group that would be added (e.g.posterior_predictive), a warning is issued and the existing group is overwritten.- predictionsbool, default
False Flag used to set the location of posterior predictive samples within the returned
xarray.DataTreeobject. If False, assumes samples are generated based on the fitting data to be used for posterior predictive checks, and samples are stored in theposterior_predictive. If True, assumes samples are generated based on out-of-sample data as predictions, and samples are stored in thepredictionsgroup.- idata_kwargs
dict, optional Keyword arguments for
pymc.to_inference_data()ifpredictions=Falseor topymc.predictions_to_inference_data()otherwise.- backend: str, optional
Which computational backend to use. Recommended to be one of “numba”, “c”, and “jax”.
- compile_kwargs: dict, optional
Keyword arguments for
pymc.pytensorf.compile().compile_kwargs["mode"]cannot be combined withbackend.
- trace
- Returns:
xarray.DataTreeorDictA
xarray.DataTreeobject containing the posterior predictive samples (default), or a dictionary with variable names as keys, and samples as numpy arrays.
Examples
Posterior predictive checks and predictions#
The most common use of sample_posterior_predictive is to perform posterior predictive checks (in-sample predictions) and new model predictions (out-of-sample predictions). Deterministics that depend on
Dataare recomputed automatically when the data changes — no extra work needed:import pymc as pm with pm.Model(coords={"trial": [0, 1, 2]}) as model: x = pm.Data("x", [-1, 0, 1], dims=["trial"]) beta = pm.Normal("beta") noise = pm.HalfNormal("noise") linpred = pm.Deterministic("linpred", x * beta, dims=["trial"]) y = pm.Normal("y", mu=linpred, sigma=noise, observed=[-2, 0, 3], dims=["trial"]) idata = pm.sample() # in-sample posterior predictive posterior_predictive = pm.sample_posterior_predictive(idata).posterior_predictive with model: pm.set_data({"x": [-2, 2]}, coords={"trial": [3, 4]}) # out-of-sample predictions. `linpred` is recomputed with the new `x` # (and the trace's `beta`); `y` is resampled from the new `linpred`. pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
Freezing deterministics#
A deterministic is normally recomputed whenever its inputs change. Occasionally, though, a deterministic captures something that should stay anchored to the training data — e.g. an HSGP standardization computed from
pm.Datathat must not be rederived from the prediction data. Pass the deterministic infreeze_varsto keep its trace value:import pymc as pm with pm.Model() as model: x = pm.Data("x", [1.0, 2.0, 3.0]) x_mean = pm.Deterministic("x_mean", x.mean()) centered = pm.Deterministic("centered", x - x_mean) mu = pm.Normal("mu") obs = pm.Normal("obs", mu + centered, 1, observed=[0, 0, 0]) idata = pm.sample() # New x values. Without freezing, `x_mean` would be recomputed as the new mean. with model: pm.set_data({"x": [100.0, 200.0, 300.0]}) pm.sample_posterior_predictive(idata, freeze_vars=["x_mean"])
Forcing a deterministic to recompute#
If
do()swaps a new expression into a deterministic while every RV and Data value stays unchanged,sample_posterior_predictivesees nothing volatile and reuses the deterministic from the trace. List it insample_varsto force recomputation from the current graph:with pm.Model() as model: x = pm.Normal("x") pm.Deterministic("det", x**2) pm.Normal("obs", model["det"], 1, observed=[0.0]) idata = pm.sample() with pm.do(model, {model["det"]: model["x"] ** 3}) as intervened_model: # Force recomputation using the new `x**3` graph. pm.sample_posterior_predictive(idata, sample_vars=["det", "obs"])
Using different models#
It’s common to use the same model for posterior and posterior predictive sampling, but this is not required. The matching between unobserved model variables and posterior samples is based on the name alone.
For the last example we could have created a new predictions model. Since the new
yhas no observations, we request it viasample_varsargument.import pymc as pm with pm.Model(coords={"trial": [0, 1, 2]}) as train_model: x = pm.Data("x", [-1, 0, 1], dims=["trial"]) beta = pm.Normal("beta") noise = pm.HalfNormal("noise") y = pm.Normal("y", mu=x * beta, sigma=noise, observed=[-2, 0, 3], dims=["trial"]) idata = pm.sample() with pm.Model(coords={"trial": [3, 4]}) as prediction_model: x = pm.Data("x", [-2, 2], dims=["trial"]) beta = pm.Normal("beta") noise = pm.HalfNormal("noise") y = pm.Normal("y", mu=x * beta, sigma=noise, dims=["trial"]) predictions = pm.sample_posterior_predictive( idata, sample_vars=["y"], predictions=True, )
The new model may even have a different structure and unobserved variables that don’t exist in the trace. These variables will be sampled automatically because they have no trace values to fall back on. In the following example we added a new
extra_noisevariable between the inferred posteriornoiseand the new StudentT observational distributiony:with pm.Model(coords={"trial": [3, 4]}) as distinct_predictions_model: x = pm.Data("x", [-2, 2], dims=["trial"]) beta = pm.Normal("beta") noise = pm.HalfNormal("noise") extra_noise = pm.HalfNormal("extra_noise", sigma=noise) y = pm.StudentT("y", nu=4, mu=x * beta, sigma=extra_noise, dims=["trial"]) predictions = pm.sample_posterior_predictive(idata, var_names=["y"], predictions=True)
For more about out-of-model predictions, see this blog post.
The behavior of
sample_vars,freeze_vars, andvar_names#Each of these three arguments controls one aspect of the operation:
sample_vars— trace variables to treat as volatile: regenerate them (from their distribution or expression) instead of copying from the trace. Empty by default.freeze_vars— which trace variables to reuse explicitly (silences the implicit-freeze warning below).var_names— which variables appear in the output. Does not trigger resampling of variables in the trace. Defaults tosample_vars.
Volatility. Volatility originates from three sources — variables listed in
sample_vars, changed Data/coords, and RVs missing from the trace (including observed RVs, which are always regenerated since they have no trace value to reuse). It then propagates downstream through deterministics and other RVs. An RV that is in the trace and not listed insample_varskeeps its trace value — even when one of its ancestors is being resampled. This prevents a singlesample_vars=["x"]call, or aset_datacall, from silently invalidating the posterior values for every downstream variable. When an auto-frozen trace variable has a volatile ancestor, anImplicitFreezeWarningflags it so the user can opt in by adding it tosample_vars(to resample) or opt out by adding it tofreeze_vars(to silence the warning while keeping the trace value). The log lists all the RVs being resampled in any given call.The following examples use this model:
from logging import getLogger import pymc as pm # Some environments like google colab suppress # the default logging output of PyMC getLogger("pymc").setLevel("INFO") kwargs = {"progressbar": False, "random_seed": 0} with pm.Model() as model: x = pm.Normal("x") y = pm.Normal("y") z = pm.Normal("z", x + y**2) det = pm.Deterministic("det", pm.math.exp(z)) obs = pm.Normal("obs", det, 1, observed=[20]) idata = pm.sample(tune=10, draws=10, chains=2, **kwargs)
Default behavior: Generate samples of
obsconditioned on the posterior samples ofzfound in the trace. These are often referred to as posterior predictive samples in the literature:with model: pm.sample_posterior_predictive(idata, **kwargs) # Sampling: [obs]
Copy the trace values for
zanddet. Nothing is resampled without explicit sample_vars:with model: pm.sample_posterior_predictive(idata, var_names=["z", "det"], **kwargs) # Sampling: []
Generate new samples of z and det, conditioned on the posterior samples of x and y found in the trace.
with model: pm.sample_posterior_predictive(idata, var_names=["z", "det"], sample_vars=["z"], **kwargs) # Sampling: [z]
Generate samples of y, z and det, conditioned on the posterior samples of x found in the trace.
Warning
The samples of
yare equivalent to its prior, since it does not depend on any other variables.In contrast, the samples of
zanddetdepend on the new samples ofyand the posterior samples ofxfound in the trace.with model: pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], sample_vars=["y", "z"], **kwargs) # Sampling: [y, z]
Note that if
zis not placed in sample_vars it won’t be resampled even though it depends on the freshly drawny— cascade stops at RVs that are in the trace. A warning flags this behavior forz:with model: pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], sample_vars=["y"], **kwargs) # ImplicitFreezeWarning: 'z' (ancestor is resampled (y)) # Sampling: [y]
If this is the intended behavior z can be added to freeze_vars explicitly, and the warning is avoided.
with model: pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], sample_vars=["y"], freeze_vars=["z"], **kwargs) # Sampling: [y]
Passing every RV to
sample_varsmakes this equivalent tosample_prior_predictive(). Includingobsinsample_varsis redundant — it isn’t in the trace so it is always regenerated:with model: pm.sample_posterior_predictive( idata, var_names=["x", "y", "z", "det", "obs"], sample_vars=["x", "y", "z", "obs"], **kwargs, ) # Sampling: [obs, x, y, z]
Controlling the number of samples#
You can manipulate the DataTree to control the number of samples
import pymc as pm with pm.Model() as model: ... idata = pm.sample()
Generate 1 posterior predictive sample for every 5 posterior samples.
thinned_idata = idata.sel(draw=slice(None, None, 5)) with model: idata.update(pm.sample_posterior_predictive(thinned_idata))
Generate 5 posterior predictive samples for every posterior sample.
expanded_idata = idata.copy() expanded_idata.posterior = idata.posterior.expand_dims(pred_id=5) with model: pm.sample_posterior_predictive( expanded_idata, sample_dims=["chain", "draw", "pred_id"], extend_inferencedata=True, )