Skip to content

Plotting

Parameter Plots

Multi-panel plot of fitted model parameters.

Panels shown (only for non-None arrays): - α_x (ax) vs ages - β_x^(i) vs ages for each period term i - κ_t^(i) vs years for each period term i - β_x^(0) (b0x) vs ages (if present) - γ_c (gc) vs cohorts (if present)

Parameters:

Name Type Description Default
fit FitStMoMo

A fitted FitStMoMo object.

required
fig Figure | None

Optional existing Figure to plot into. If None, a new Figure is created.

None

Returns:

Type Description
Figure
Source code in src/pystmomo/plot/parameters.py
def plot_parameters(fit: FitStMoMo, *, fig: Figure | None = None) -> Figure:
    """Multi-panel plot of fitted model parameters.

    Panels shown (only for non-None arrays):
    - α_x (ax) vs ages
    - β_x^(i) vs ages for each period term i
    - κ_t^(i) vs years for each period term i
    - β_x^(0) (b0x) vs ages (if present)
    - γ_c (gc) vs cohorts (if present)

    Parameters
    ----------
    fit:
        A fitted FitStMoMo object.
    fig:
        Optional existing Figure to plot into.  If None, a new Figure is created.

    Returns
    -------
    matplotlib.figure.Figure
    """
    # Each panel is (title, x-axis label, xvals, yvals)
    panels: list[tuple[str, str, np.ndarray, np.ndarray]] = []

    if fit.ax is not None:
        panels.append(("$\\alpha_x$", "Age", fit.ages, fit.ax))

    if fit.bx is not None:
        bx = np.atleast_2d(fit.bx)
        # bx shape: (n_ages, N) where N = number of period terms
        if bx.ndim == 1:
            bx = bx[:, np.newaxis]
        for i in range(bx.shape[1]):
            label = f"$\\beta_x^{{({i+1})}}$" if bx.shape[1] > 1 else "$\\beta_x$"
            panels.append((label, "Age", fit.ages, bx[:, i]))

    if fit.kt is not None:
        kt = np.atleast_2d(fit.kt)
        # kt shape: (N, n_years)
        if kt.ndim == 1:
            kt = kt[np.newaxis, :]
        for i in range(kt.shape[0]):
            label = f"$\\kappa_t^{{({i+1})}}$" if kt.shape[0] > 1 else "$\\kappa_t$"
            panels.append((label, "Year", fit.years, kt[i, :]))

    if fit.b0x is not None:
        panels.append(("$\\beta_x^{{(0)}}$", "Age", fit.ages, fit.b0x))

    if fit.gc is not None:
        panels.append(("$\\gamma_c$", "Cohort", fit.cohorts, fit.gc))

    n_panels = len(panels)
    if n_panels == 0:
        if fig is None:
            fig = Figure(figsize=(6, 4))
        return fig

    # Layout: up to 3 columns
    ncols = min(n_panels, 3)
    nrows = (n_panels + ncols - 1) // ncols

    if fig is None:
        fig = Figure(figsize=(5 * ncols, 4 * nrows))

    axes = fig.subplots(nrows, ncols, squeeze=False)

    for idx, (label, xlabel, xvals, yvals) in enumerate(panels):
        row = idx // ncols
        col = idx % ncols
        ax = axes[row][col]
        ax.plot(xvals, yvals)
        ax.set_title(label)
        ax.set_xlabel(xlabel)

    # Hide unused axes
    for idx in range(n_panels, nrows * ncols):
        row = idx // ncols
        col = idx % ncols
        axes[row][col].set_visible(False)

    fig.tight_layout()
    return fig

Forecast Fan Charts

Plot κ_t forecasts with confidence bands.

Expects fc to expose: - fc.kt_central: shape (N, h) central forecast of period indexes - fc.kt_lower: shape (N, h) lower confidence bound - fc.kt_upper: shape (N, h) upper confidence bound - fc.years: forecast year labels, length h - fc.fit: original FitStMoMo (for historical kt and years)

Falls back gracefully if confidence bands are absent.

Parameters:

Name Type Description Default
fc

A ForStMoMo forecast object.

required
ages list[int] | None

Ignored (present for API symmetry with rate-based plots).

None
fig Figure | None

Optional existing Figure; if None a new one is created.

None

Returns:

Type Description
Figure
Source code in src/pystmomo/plot/forecast_plot.py
def plot_forecast(
    fc,
    *,
    ages: list[int] | None = None,
    fig: Figure | None = None,
) -> Figure:
    """Plot κ_t forecasts with confidence bands.

    Expects ``fc`` to expose:
    - ``fc.kt_central``: shape (N, h) central forecast of period indexes
    - ``fc.kt_lower``: shape (N, h) lower confidence bound
    - ``fc.kt_upper``: shape (N, h) upper confidence bound
    - ``fc.years``: forecast year labels, length h
    - ``fc.fit``: original FitStMoMo (for historical kt and years)

    Falls back gracefully if confidence bands are absent.

    Parameters
    ----------
    fc:
        A ForStMoMo forecast object.
    ages:
        Ignored (present for API symmetry with rate-based plots).
    fig:
        Optional existing Figure; if None a new one is created.

    Returns
    -------
    matplotlib.figure.Figure
    """
    kt_central = np.atleast_2d(np.asarray(fc.kt_central))  # (N, h)
    N = kt_central.shape[0]

    has_bands = hasattr(fc, "kt_lower") and fc.kt_lower is not None
    if has_bands:
        kt_lower = np.atleast_2d(np.asarray(fc.kt_lower))
        kt_upper = np.atleast_2d(np.asarray(fc.kt_upper))

    years_fc = np.asarray(fc.years)

    # Historical kt from original fit (optional)
    has_hist = hasattr(fc, "fit") and fc.fit is not None
    if has_hist:
        kt_hist = np.atleast_2d(fc.fit.kt)   # (N, n_hist)
        years_hist = fc.fit.years
    else:
        kt_hist = None
        years_hist = None

    ncols = min(N, 3)
    nrows = (N + ncols - 1) // ncols
    if fig is None:
        fig = Figure(figsize=(5 * ncols, 4 * nrows))

    axes = fig.subplots(nrows, ncols, squeeze=False)

    for i in range(N):
        row = i // ncols
        col = i % ncols
        ax = axes[row][col]
        label = f"$\\kappa_t^{{({i+1})}}$" if N > 1 else "$\\kappa_t$"

        # Historical
        if has_hist:
            ax.plot(years_hist, kt_hist[i], label="Historical")

        # Forecast central
        ax.plot(years_fc, kt_central[i], label="Forecast")

        # Confidence bands
        if has_bands:
            ax.fill_between(
                years_fc,
                kt_lower[i],
                kt_upper[i],
                alpha=0.2,
                label="CI",
            )

        ax.set_title(label)
        ax.set_xlabel("Year")
        ax.legend()

    # Hide unused subplots
    for i in range(N, nrows * ncols):
        row = i // ncols
        col = i % ncols
        axes[row][col].set_visible(False)

    fig.tight_layout()
    return fig

Fan chart of simulated mortality rates at a given age.

Expects sim to expose: - sim.rates: shape (n_ages, h, nsim) array of simulated rates - sim.ages: age labels, length n_ages - sim.years: forecast year labels, length h - sim.fit: original FitStMoMo (optional, for historical rates)

Parameters:

Name Type Description Default
sim

A SimStMoMo simulation object.

required
age int

The age at which to draw the fan chart.

required
levels tuple[float, ...]

Quantile coverage levels for fan bands (e.g. 0.95 → 2.5th–97.5th pct).

(0.5, 0.8, 0.95)
fig Figure | None

Optional existing Figure; if None a new one is created.

None

Returns:

Type Description
Figure
Source code in src/pystmomo/plot/forecast_plot.py
def plot_fan(
    sim,
    age: int,
    *,
    levels: tuple[float, ...] = (0.5, 0.8, 0.95),
    fig: Figure | None = None,
) -> Figure:
    """Fan chart of simulated mortality rates at a given age.

    Expects ``sim`` to expose:
    - ``sim.rates``: shape (n_ages, h, nsim) array of simulated rates
    - ``sim.ages``: age labels, length n_ages
    - ``sim.years``: forecast year labels, length h
    - ``sim.fit``: original FitStMoMo (optional, for historical rates)

    Parameters
    ----------
    sim:
        A SimStMoMo simulation object.
    age:
        The age at which to draw the fan chart.
    levels:
        Quantile coverage levels for fan bands (e.g. 0.95 → 2.5th–97.5th pct).
    fig:
        Optional existing Figure; if None a new one is created.

    Returns
    -------
    matplotlib.figure.Figure
    """
    ages_arr = np.asarray(sim.ages)
    idx = np.where(ages_arr == age)[0]
    if len(idx) == 0:
        raise ValueError(f"Age {age} not found in simulation ages.")
    age_idx = int(idx[0])

    # sim.rates shape: (n_ages, h, nsim)
    rates_sim = np.asarray(sim.rates)[age_idx, :, :]  # (h, nsim)
    years_fc = np.asarray(sim.years)

    if fig is None:
        fig = Figure(figsize=(9, 5))

    ax = fig.add_subplot(111)

    # Central (median)
    median = np.median(rates_sim, axis=1)
    ax.plot(years_fc, median, label="Median")

    # Fan bands
    # Using a single color with decreasing alpha for fan levels is often cleaner
    for level in sorted(levels):
        lo_pct = 100.0 * (1.0 - level) / 2.0
        hi_pct = 100.0 - lo_pct
        lo = np.percentile(rates_sim, lo_pct, axis=1)
        hi = np.percentile(rates_sim, hi_pct, axis=1)
        ax.fill_between(years_fc, lo, hi, alpha=0.15,
                        label=f"{int(level * 100)}% CI")

    # Historical rates (optional)
    if hasattr(sim, "fit") and sim.fit is not None:
        hist_ages = np.asarray(sim.fit.ages)
        hist_idx = np.where(hist_ages == age)[0]
        if len(hist_idx) > 0:
            hist_rates = sim.fit.fitted_rates[int(hist_idx[0]), :]
            ax.plot(sim.fit.years, hist_rates, color="#64748b", label="Historical (fitted)")

    ax.set_xlabel("Year")
    ax.set_ylabel("Mortality rate")
    ax.set_title(f"Fan chart — age {age}")
    ax.legend()

    fig.tight_layout()
    return fig

Residual Plots

Heatmap of residuals on the ages × years grid.

Uses matplotlib imshow with the RdBu_r colormap, centred at zero.

Parameters:

Name Type Description Default
fit FitStMoMo

A fitted FitStMoMo object.

required
kind str

Type of residuals: "deviance" (default), "pearson", or "response".

'deviance'

Returns:

Type Description
Figure
Source code in src/pystmomo/plot/residual_plot.py
def plot_residual_heatmap(fit: FitStMoMo, kind: str = "deviance") -> Figure:
    """Heatmap of residuals on the ages × years grid.

    Uses matplotlib ``imshow`` with the RdBu_r colormap, centred at zero.

    Parameters
    ----------
    fit:
        A fitted FitStMoMo object.
    kind:
        Type of residuals: ``"deviance"`` (default), ``"pearson"``, or
        ``"response"``.

    Returns
    -------
    matplotlib.figure.Figure
    """
    res = _get_residuals(fit, kind)

    fig = Figure(figsize=(10, 6))
    ax = fig.add_subplot(111)

    vmax = np.nanmax(np.abs(res))
    vmax = vmax if vmax > 0 else 1.0

    im = ax.imshow(
        res,
        aspect="auto",
        cmap="RdBu_r",
        vmin=-vmax,
        vmax=vmax,
        origin="upper",
        extent=[fit.years[0], fit.years[-1], fit.ages[-1], fit.ages[0]],
    )
    fig.colorbar(im, ax=ax, label=f"{kind.capitalize()} residual")
    ax.set_xlabel("Year")
    ax.set_ylabel("Age")
    ax.set_title(f"{kind.capitalize()} residuals — ages × years")

    fig.tight_layout()
    return fig

Residuals vs fitted log-rates scatter plot.

Parameters:

Name Type Description Default
fit FitStMoMo

A fitted FitStMoMo object.

required
kind str

Type of residuals: "deviance" (default), "pearson", or "response".

'deviance'

Returns:

Type Description
Figure
Source code in src/pystmomo/plot/residual_plot.py
def plot_residual_scatter(fit: FitStMoMo, kind: str = "deviance") -> Figure:
    """Residuals vs fitted log-rates scatter plot.

    Parameters
    ----------
    fit:
        A fitted FitStMoMo object.
    kind:
        Type of residuals: ``"deviance"`` (default), ``"pearson"``, or
        ``"response"``.

    Returns
    -------
    matplotlib.figure.Figure
    """
    res = _get_residuals(fit, kind)
    rates = fit.fitted_rates

    # Use only active cells
    mask = (fit.wxt > 0) & (rates > 0)
    log_rates = np.log(rates[mask])
    residuals = res[mask]

    fig = Figure(figsize=(8, 5))
    ax = fig.add_subplot(111)

    ax.scatter(log_rates, residuals, s=12, alpha=0.4)
    ax.axhline(0.0, color="#dc2626", linewidth=1.2, linestyle="--")
    ax.set_xlabel("Fitted log-rate")
    ax.set_ylabel(f"{kind.capitalize()} residual")
    ax.set_title(f"{kind.capitalize()} residuals vs fitted log-rates")

    fig.tight_layout()
    return fig