The Line Fitting Case

Normal Distribution
t Distribution
Mixture Model
Outlier

This notebook is part of a set of examples for teaching Bayesian inference methods and probabilistic programming, using the numpyro library.

Imports
import io

import arviz as az
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer
import pandas as pd

from arviz.labels import MapLabeller
from IPython.display import display
Constants
FIGSIZE = (6.4, 4.8)
TEXTSIZE = 10

LABELLER = MapLabeller({"alpha": r"$\alpha$", "beta": r"$\beta$"})
RNG Setup
np_rng = np.random.default_rng(1)
rng_key = jax.random.PRNGKey(1)

In this example, we will fit a line to some data points using three separate models that handle outliers in different ways, and compare them.

Data

Data
csv = """
x,y,y_err
-1.991,-1.265,0.2
-1.942,-1.718,0.2
-1.866,-0.539,0.2
-1.451,-1.468,0.2
-1.383,-0.04,0.2
-0.947,-1.069,0.2
-0.865,-0.935,0.2
0.135,-0.006,0.2
0.424,0.219,0.2
0.96,1.366,0.2
1.411,-0.392,0.2
1.603,1.777,0.2
1.675,1.857,0.2
1.777,1.986,0.2
1.828,1.59,0.2
"""

data = pd.read_csv(io.StringIO(csv))
Figure
def plot_data(scatter_kwargs=None, ax=None):
    if ax is None:
        _, ax = plt.subplots(figsize=FIGSIZE)

    scatter_kwargs = {"color": "black", **(scatter_kwargs or {})}
    data.plot.scatter(x="x", y="y", yerr="y_err", **scatter_kwargs, ax=ax)

    return ax


plot_data()
plt.show()

Figure 1: A set of 15 data points that have some linear correlation.
x y y_err
0 -1.991 -1.265 0.2
1 -1.942 -1.718 0.2
2 -1.866 -0.539 0.2
3 -1.451 -1.468 0.2
4 -1.383 -0.040 0.2
5 -0.947 -1.069 0.2
6 -0.865 -0.935 0.2
7 0.135 -0.006 0.2
8 0.424 0.219 0.2
9 0.960 1.366 0.2
10 1.411 -0.392 0.2
11 1.603 1.777 0.2
12 1.675 1.857 0.2
13 1.777 1.986 0.2
14 1.828 1.590 0.2

Models

Model A

Let’s start with a minimal linear model parametrized by a slope \(\beta\) and intercept \(\alpha\), which does not take outliers into account:

\[y(x_i)=\alpha+\beta\,x_i\]

To be able to choose convenient minimally informative priors, let’s reparametrize this model and use the angle \(\theta\) between the line and the \(x\)-axis instead of the slope, as well as the perpendicular intercept \(\alpha_p\) instead of the \(y\)-axis intercept:

\[\theta=\arctan(\beta)\] \[\alpha_p=\alpha\cos(\theta)\]

def model_A(df):
    x_meas = numpyro.param("x_meas", df["x"].values)
    y_meas = numpyro.param("y_meas", df["y"].values)
    y_err = numpyro.param("y_err", df["y_err"].values)

    alpha_p = numpyro.sample("alpha_p", dist.Uniform(-0.5, 0.5))
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))

    alpha = numpyro.deterministic("alpha", alpha_p / jnp.cos(theta))
    beta = numpyro.deterministic("beta", jnp.tan(theta))

    with numpyro.plate("data", len(df)):
        y = numpyro.deterministic("y", alpha + beta * x_meas)
        numpyro.sample("y_obs", dist.Normal(y, y_err), obs=y_meas)
Figure
numpyro.render_model(
    model=model_A,
    model_args=(data,),
    render_params=True,
    render_distributions=True,
)

Figure 2: The graph of model A.

Model B

To make the model more robust against outliers, let’s try to use the \(t\) distribution as the sampling distribution instead of the normal distribution. In comparison, the \(t\) distribution has heavier tails, the strength of which can be controlled by an additional free parameter \(\nu\).

def model_B(df):
    x_meas = numpyro.param("x_meas", df["x"].values)
    y_meas = numpyro.param("y_meas", df["y"].values)
    y_err = numpyro.param("y_err", df["y_err"].values)

    alpha_p = numpyro.sample("alpha_p", dist.Uniform(-0.5, 0.5))
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    nu = numpyro.sample("nu", dist.InverseGamma(1, 1))

    alpha = numpyro.deterministic("alpha", alpha_p / jnp.cos(theta))
    beta = numpyro.deterministic("beta", jnp.tan(theta))

    with numpyro.plate("data", len(df)):
        y = numpyro.deterministic("y", alpha + beta * x_meas)
        numpyro.sample("y_obs", dist.StudentT(nu, y, y_err), obs=y_meas)
Figure
numpyro.render_model(
    model=model_B,
    model_args=(data,),
    render_params=True,
    render_distributions=True,
)

Figure 3: The graph of model B.

Model C

Let’s also try to extend model A into a two-component mixture model, where we keep model A as a “foreground” model, but add an explicit “background” model for outliers (a broad normal distribution).

def model_C(df):
    x_meas = numpyro.param("x_meas", df["x"].values)
    y_meas = numpyro.param("y_meas", df["y"].values)
    y_err = numpyro.param("y_err", df["y_err"].values)

    alpha_p = numpyro.sample("alpha_p", dist.Uniform(-0.5, 0.5))
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))

    alpha = numpyro.deterministic("alpha", alpha_p / jnp.cos(theta))
    beta = numpyro.deterministic("beta", jnp.tan(theta))

    bg_mean = numpyro.sample("bg_mean", dist.Normal(0.0, 1.0))
    bg_sigma = numpyro.sample("bg_sigma", dist.HalfNormal(3.0))

    q = numpyro.sample("q", dist.Uniform(0, 1))
    mixing_dist = dist.Categorical(probs=jnp.array([q, 1 - q]))

    with numpyro.plate("data", len(df)):
        y = numpyro.deterministic("y", alpha + beta * x_meas)

        fg_dist = dist.Normal(y, y_err)
        bg_dist = dist.Normal(bg_mean, jnp.sqrt(y_err**2 + bg_sigma**2))
        mixture = dist.MixtureGeneral(mixing_dist, [bg_dist, fg_dist])

        y_obs = numpyro.sample("y_obs", mixture, obs=y_meas)

        log_probs = mixture.component_log_probs(y_obs)
        logsumexp = jax.nn.logsumexp(log_probs, axis=-1, keepdims=True)
        numpyro.deterministic("log_p", log_probs[:, 0] - logsumexp[:, 0])
Figure
numpyro.render_model(
    model=model_C,
    model_args=(data,),
    render_params=True,
    render_distributions=True,
)

Figure 4: The graph of model C.

Sampling

def sample_model(key, model):
    sampler = numpyro.infer.MCMC(
        sampler=numpyro.infer.NUTS(model),
        num_warmup=2000,
        num_samples=10000,
        num_chains=4,
        progress_bar=False,
    )

    key, _key = jax.random.split(key)
    sampler.run(_key, data)

    return key, sampler, az.from_numpyro(sampler)

Model A

rng_key, sampler_A, inf_data_A = sample_model(rng_key, model_A)
Figure
def plot_trace_A(ax=None):
    return az.plot_trace(
        data=inf_data_A,
        figsize=(FIGSIZE[0], 2 * 2),
        var_names=["~alpha", "~beta", "~y"],
        axes=ax,
    )


plot_trace_A()
plt.tight_layout()
plt.show()

Figure 5: The MCMC trace plot for model A.
Table
az.summary(inf_data_A, kind="all", var_names=["~y"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 0.123 0.051 0.028 0.220 0.0 0.0 38063.0 27611.0 1.0
alpha_p 0.098 0.041 0.021 0.175 0.0 0.0 38019.0 26810.0 1.0
beta 0.757 0.036 0.690 0.824 0.0 0.0 38251.0 26859.0 1.0
theta 0.648 0.023 0.605 0.691 0.0 0.0 38251.0 26859.0 1.0

Model B

rng_key, sampler_B, inf_data_B = sample_model(rng_key, model_B)
Figure
def plot_trace_B(ax=None):
    return az.plot_trace(
        data=inf_data_B,
        figsize=(FIGSIZE[0], 3 * 2),
        var_names=["~alpha", "~beta", "~y"],
        axes=ax,
    )


plot_trace_B()
plt.tight_layout()
plt.show()

Figure 6: The MCMC trace plot for model B.
Table
az.summary(inf_data_B, kind="all", var_names=["~y"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 0.053 0.110 -0.148 0.255 0.001 0.001 20228.0 14070.0 1.0
alpha_p 0.038 0.081 -0.110 0.184 0.001 0.001 19871.0 13157.0 1.0
beta 0.983 0.074 0.845 1.119 0.001 0.000 20588.0 19928.0 1.0
nu 1.138 0.448 0.438 1.949 0.003 0.002 26560.0 26607.0 1.0
theta 0.775 0.038 0.704 0.844 0.000 0.000 20588.0 19928.0 1.0

Model C

rng_key, sampler_C, inf_data_C = sample_model(rng_key, model_C)
inf_data_C.posterior["p"] = np.exp(inf_data_C.posterior["log_p"])
Figure
def plot_trace_C(ax=None):
    return az.plot_trace(
        data=inf_data_C,
        figsize=(FIGSIZE[0], 5 * 2),
        var_names=["~alpha", "~beta", "~y", "~log_p", "~p"],
        axes=ax,
    )


plot_trace_C()
plt.tight_layout()
plt.show()

Figure 7: The MCMC trace plot for model C.
Table
az.summary(inf_data_C, kind="all", var_names=["~y", "~log_p", "~p"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 0.052 0.077 -0.095 0.195 0.000 0.000 28809.0 26061.0 1.0
alpha_p 0.037 0.055 -0.066 0.139 0.000 0.000 28310.0 25799.0 1.0
beta 1.007 0.057 0.902 1.112 0.000 0.000 24814.0 22155.0 1.0
bg_mean -0.403 0.415 -1.186 0.419 0.003 0.002 22564.0 17542.0 1.0
bg_sigma 0.785 0.539 0.001 1.708 0.004 0.003 18153.0 10511.0 1.0
q 0.332 0.126 0.103 0.563 0.001 0.000 33729.0 24843.0 1.0
theta 0.788 0.029 0.735 0.839 0.000 0.000 24814.0 22155.0 1.0

Results

Figure
def plot_pair(*items, ax=None):
    for i, inf_data in enumerate(items):
        ax = az.plot_pair(
            data=inf_data,
            figsize=FIGSIZE,
            var_names=["alpha", "beta"],
            labeller=LABELLER,
            kind=["scatter", "kde"],
            scatter_kwargs={"color": f"C{i}"},
            marginals=True,
            marginal_kwargs={"color": f"C{i}"},
            point_estimate="mean",
            reference_values={
                LABELLER.var_name_to_str("alpha"): 0,
                LABELLER.var_name_to_str("beta"): 1,
            },
            textsize=TEXTSIZE,
            ax=ax,
        )

    return ax


plot_pair(inf_data_A, inf_data_B, inf_data_C)
plt.show()

Figure 8: The joint posterior distribution of \(\alpha\) and \(\beta\) for models A (blue), B (orange), and C (green).
p_outlier = az.extract(inf_data_C, var_names=["p"]).mean("sample")
Figure
def plot_samples(*items, p=None, ax=None):
    marker_style = {"marker": "o", "s": 36, "zorder": 2}
    ax = plot_data(scatter_kwargs=marker_style, ax=ax)

    x_min, x_max = -2.8, 2.8
    x = np.linspace(x_min, x_max, 50)
    for i, inf_data in enumerate(items):
        samples = az.extract(
            data=inf_data, var_names=["alpha", "beta"], num_samples=100
        ).to_dataframe()
        for _, d in samples.iterrows():
            y = d["alpha"] + d["beta"] * x
            ax.plot(x, y, c=f"C{i}", alpha=0.1, zorder=1)

    if p is not None:
        marker_style = {**marker_style, "zorder": 3}
        ax.scatter(data["x"], data["y"], c=p, cmap="gray", ec="black", **marker_style)

    ax.set_xlim(x_min, x_max)
    return ax


plot_samples(inf_data_A, inf_data_B, inf_data_C, p=p_outlier)
plt.show()

Figure 9: A few realizations of models A (blue), B (orange), and C (green), shown together with the original data points. The opacity of each data point indicates its probability of being an outlier, as estimated from model C.
comp_data = az.compare(
    compare_dict={"A": inf_data_A, "B": inf_data_B, "C": inf_data_C},
    ic="loo",
    method="BB-pseudo-BMA",
    seed=np_rng,
)
rank elpd_loo p_loo elpd_diff weight se dse warning scale
C 0 -14.548210 5.000393 0.000000 8.824918e-01 32.993495 0.000000 True log
B 1 -18.289472 4.520744 3.741261 1.175078e-01 5.073979 2.878680 False log
A 2 -72.965036 19.576818 58.416826 3.312725e-07 3.815517 32.273156 True log
Figure
def plot_compare(ax=None):
    ax = az.plot_compare(
        comp_df=comp_data,
        figsize=(FIGSIZE[0], 0.5 * len(comp_data)),
        plot_ic_diff=False,
        legend=False,
        title=False,
        textsize=TEXTSIZE,
        ax=ax,
    )

    ax.set_xlabel("ELPD (LOO)")
    ax.set_ylabel("Model")
    return ax


plot_compare()
plt.show()

Figure 10: The ranking of models A, B, and C based on their ELPD.

Resources

Watermark

Python implementation: CPython
Python version       : 3.11.7
IPython version      : 8.20.0

jax       : 0.4.23
arviz     : 0.17.0
matplotlib: 3.8.2
numpyro   : 0.13.2
numpy     : 1.26.3
pandas    : 2.1.4