Skip to content

Diagnostics

Residuals

Signed sqrt of pointwise deviance contribution, shape (n_ages, n_years).

For Poisson (link='log'): - D > 0: sign(D - D̂) * sqrt(2(Dlog(D/D̂) - (D - D̂))) * sqrt(wxt) - D = 0: -sqrt(2 * D̂) * sqrt(wxt) [sign = -1, contribution = 2*D̂]

For Binomial (link='logit'): sign(D - Eq̂) * sqrt(2(Dlog(D/(Eq̂)) + (E-D)log((E-D)/(E(1-q̂))))) * sqrt(wxt)

Cells with wxt=0 are set to 0.

Note: wxt is binary 0/1, so sqrt(wxt) == wxt, but we use sqrt(wxt) so that squared residuals sum to the weighted deviance.

Source code in src/pystmomo/diagnostics/residuals.py
def deviance_residuals(fit: FitStMoMo) -> np.ndarray:
    """Signed sqrt of pointwise deviance contribution, shape (n_ages, n_years).

    For Poisson (link='log'):
      - D > 0: sign(D - D̂) * sqrt(2*(D*log(D/D̂) - (D - D̂))) * sqrt(wxt)
      - D = 0: -sqrt(2 * D̂) * sqrt(wxt)   [sign = -1, contribution = 2*D̂]

    For Binomial (link='logit'):
      sign(D - E*q̂) * sqrt(2*(D*log(D/(E*q̂)) + (E-D)*log((E-D)/(E*(1-q̂))))) * sqrt(wxt)

    Cells with wxt=0 are set to 0.

    Note: wxt is binary 0/1, so sqrt(wxt) == wxt, but we use sqrt(wxt)
    so that squared residuals sum to the weighted deviance.
    """
    D = fit.Dxt
    Dhat = fit.fitted_deaths
    E = fit.Ext
    wxt = fit.wxt
    link = fit.model.link

    out = np.zeros_like(D, dtype=float)
    sqrt_w = np.sqrt(np.maximum(wxt, 0.0))

    if link == "log":
        # Poisson
        pos = wxt > 0

        # D > 0
        pos_d = pos & (D > 0)
        d = D[pos_d]
        dhat = np.maximum(Dhat[pos_d], _EPS)
        contrib = 2.0 * (d * np.log(d / dhat) - (d - dhat))
        out[pos_d] = np.sign(d - dhat) * np.sqrt(np.maximum(contrib, 0.0)) * sqrt_w[pos_d]

        # D == 0
        pos_zero = pos & (D <= 0)
        dhat0 = np.maximum(Dhat[pos_zero], 0.0)
        out[pos_zero] = -np.sqrt(2.0 * dhat0) * sqrt_w[pos_zero]

    else:
        # Binomial
        qhat = np.clip(fit.fitted_rates, _EPS, 1.0 - _EPS)
        Eqhat = E * qhat  # = Dhat
        pos = wxt > 0

        d = D[pos]
        e = E[pos]
        q = qhat[pos]
        eqhat = Eqhat[pos]
        e_minus_d = np.maximum(e - d, 0.0)
        e_minus_eqhat = e * (1.0 - q)

        # obs_q for log term; handle d=0 and d=e separately
        log_term1 = np.where(d > 0, d * np.log(np.maximum(d / np.maximum(eqhat, _EPS), _EPS)), 0.0)
        log_term2 = np.where(
            e_minus_d > 0,
            e_minus_d * np.log(np.maximum(e_minus_d / np.maximum(e_minus_eqhat, _EPS), _EPS)),
            0.0,
        )
        contrib = 2.0 * (log_term1 + log_term2)
        out[pos] = (
            np.sign(d - eqhat) * np.sqrt(np.maximum(contrib, 0.0)) * sqrt_w[pos]
        )

    return out

Pearson residuals, shape (n_ages, n_years).

For Poisson: (D - D̂) / sqrt(D̂) * sqrt(wxt) For Binomial: (D - Eq̂) / sqrt(Eq̂*(1-q̂)) * sqrt(wxt)

Cells with wxt=0 or zero variance are set to 0.

Source code in src/pystmomo/diagnostics/residuals.py
def pearson_residuals(fit: FitStMoMo) -> np.ndarray:
    """Pearson residuals, shape (n_ages, n_years).

    For Poisson: (D - D̂) / sqrt(D̂) * sqrt(wxt)
    For Binomial: (D - E*q̂) / sqrt(E*q̂*(1-q̂)) * sqrt(wxt)

    Cells with wxt=0 or zero variance are set to 0.
    """
    D = fit.Dxt
    Dhat = fit.fitted_deaths
    E = fit.Ext
    wxt = fit.wxt
    link = fit.model.link

    out = np.zeros_like(D, dtype=float)
    sqrt_w = np.sqrt(np.maximum(wxt, 0.0))
    pos = wxt > 0

    if link == "log":
        # Poisson: variance = D̂
        dhat = np.maximum(Dhat[pos], _EPS)
        out[pos] = (D[pos] - Dhat[pos]) / np.sqrt(dhat) * sqrt_w[pos]
    else:
        # Binomial: variance = E * q̂ * (1 - q̂)
        qhat = np.clip(fit.fitted_rates[pos], _EPS, 1.0 - _EPS)
        e = E[pos]
        var = e * qhat * (1.0 - qhat)
        out[pos] = (D[pos] - e * qhat) / np.sqrt(np.maximum(var, _EPS)) * sqrt_w[pos]

    return out

Raw (response) residuals, shape (n_ages, n_years).

(D - D̂) * wxt

Source code in src/pystmomo/diagnostics/residuals.py
def response_residuals(fit: FitStMoMo) -> np.ndarray:
    """Raw (response) residuals, shape (n_ages, n_years).

    (D - D̂) * wxt
    """
    return (fit.Dxt - fit.fitted_deaths) * fit.wxt

Cross-Validation

Period-based leave-last-out cross-validation.

Holds out the last n_test_years = n_years // n_folds consecutive years, refits the model on the remaining years, forecasts h = n_test_years steps, and compares predicted rates to observed rates.

Parameters:

Name Type Description Default
fit FitStMoMo

A fitted FitStMoMo object.

required
n_folds int

Number of folds. Determines the test window size as n_test_years = n_years // n_folds.

5
metric Literal['mse', 'log_mse']

Primary metric to return in 'metric_value': "mse" (mean squared error on rates) or "log_mse" (MSE on log rates).

'mse'

Returns:

Type Description
dict with keys:
  • 'mse': MSE between observed and predicted rates on the test set.
  • 'log_mse': MSE on log scale.
  • 'years_test': array of held-out year labels.
  • 'rates_obs': observed rates, shape (n_ages, n_test_years).
  • 'rates_pred': predicted rates, shape (n_ages, n_test_years).
  • 'metric_value': value of the requested metric.
Source code in src/pystmomo/diagnostics/crossval.py
def cv_stmomo(
    fit: FitStMoMo,
    n_folds: int = 5,
    metric: Literal["mse", "log_mse"] = "mse",
) -> dict:
    """Period-based leave-last-out cross-validation.

    Holds out the last ``n_test_years = n_years // n_folds`` consecutive years,
    refits the model on the remaining years, forecasts ``h = n_test_years`` steps,
    and compares predicted rates to observed rates.

    Parameters
    ----------
    fit:
        A fitted FitStMoMo object.
    n_folds:
        Number of folds.  Determines the test window size as
        ``n_test_years = n_years // n_folds``.
    metric:
        Primary metric to return in 'metric_value': ``"mse"`` (mean squared
        error on rates) or ``"log_mse"`` (MSE on log rates).

    Returns
    -------
    dict with keys:
        * ``'mse'``: MSE between observed and predicted rates on the test set.
        * ``'log_mse'``: MSE on log scale.
        * ``'years_test'``: array of held-out year labels.
        * ``'rates_obs'``: observed rates, shape (n_ages, n_test_years).
        * ``'rates_pred'``: predicted rates, shape (n_ages, n_test_years).
        * ``'metric_value'``: value of the requested metric.
    """
    n_years = len(fit.years)
    n_test = n_years // n_folds
    if n_test < 1:
        raise ValueError(
            f"n_folds={n_folds} too large for n_years={n_years}: "
            f"n_test_years = n_years // n_folds = {n_test} < 1"
        )

    n_train = n_years - n_test
    years_train = fit.years[:n_train]
    years_test = fit.years[n_train:]

    Dxt_train = fit.Dxt[:, :n_train]
    Ext_train = fit.Ext[:, :n_train]
    wxt_train = fit.wxt[:, :n_train]

    # Refit on training data
    fit_train = fit.model.fit(
        Dxt_train,
        Ext_train,
        ages=fit.ages,
        years=years_train,
        wxt=wxt_train,
    )

    # Forecast h steps ahead
    fc = fit_train.forecast(h=n_test)

    # Predicted rates: central forecast, shape (n_ages, n_test)
    # ForStMoMo stores rates_central (or similar); use duck access
    if hasattr(fc, "rates_central"):
        rates_pred = np.asarray(fc.rates_central)
    elif hasattr(fc, "rates"):
        rates_pred = np.asarray(fc.rates)
    else:
        raise AttributeError(
            "ForStMoMo object has neither 'rates_central' nor 'rates' attribute."
        )

    # Observed crude rates from held-out years
    Ext_test = fit.Ext[:, n_train:]
    Dxt_test = fit.Dxt[:, n_train:]
    _EPS = 1e-15
    rates_obs = np.where(Ext_test > 0, Dxt_test / np.maximum(Ext_test, _EPS), 0.0)

    # Compute metrics (only on cells where both are positive)
    valid = (rates_obs > 0) & (rates_pred > 0)
    diff = rates_obs - rates_pred
    mse = float(np.mean(diff[valid] ** 2)) if valid.any() else float("nan")
    log_mse = float(
        np.mean((np.log(rates_obs[valid]) - np.log(rates_pred[valid])) ** 2)
    ) if valid.any() else float("nan")

    metric_value = mse if metric == "mse" else log_mse

    return {
        "mse": mse,
        "log_mse": log_mse,
        "years_test": years_test,
        "rates_obs": rates_obs,
        "rates_pred": rates_pred,
        "metric_value": metric_value,
    }