Skip to content

simace.plotting

Publication-quality plots for phenotype distributions, correlations, and validation.

plot_style

simace.plotting.plot_style

Central visual style for Nature Genetics-inspired plots.

apply_nature_style

apply_nature_style()

Set matplotlib rcParams for Nature Genetics-inspired figures.

Source code in simace/plotting/plot_style.py
def apply_nature_style() -> None:
    """Set matplotlib rcParams for Nature Genetics-inspired figures."""
    mpl.rcParams.update(
        {
            # Fonts
            "font.family": "sans-serif",
            "font.sans-serif": ["Helvetica", "Arial", "DejaVu Sans"],
            # Axes
            "axes.spines.top": False,
            "axes.spines.right": False,
            "axes.linewidth": 0.7,
            "axes.grid": False,
            "axes.facecolor": "white",
            "axes.labelsize": 11,
            "axes.titlesize": 12,
            # Lines
            "lines.linewidth": 1.2,
            "lines.markersize": 5,
            # Ticks
            "xtick.direction": "out",
            "ytick.direction": "out",
            "xtick.major.width": 0.7,
            "ytick.major.width": 0.7,
            "xtick.major.size": 4,
            "ytick.major.size": 4,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            # Legend
            "legend.frameon": False,
            "legend.fontsize": 9,
            # Figure
            "figure.facecolor": "white",
            "figure.dpi": 100,
            "savefig.dpi": 150,
            "savefig.bbox": "tight",
        }
    )

enable_value_gridlines

enable_value_gridlines(ax)

Add faint horizontal gridlines for plots where absolute values matter.

Source code in simace/plotting/plot_style.py
def enable_value_gridlines(ax) -> None:
    """Add faint horizontal gridlines for plots where absolute values matter."""
    ax.yaxis.grid(True, linewidth=0.5, alpha=0.15, color="0.7", zorder=0)
    ax.set_axisbelow(True)

add_scenario_label

add_scenario_label(fig, scenario)

Add small grey italic scenario label at the bottom-right of a figure.

Source code in simace/plotting/plot_style.py
def add_scenario_label(fig, scenario: str) -> None:
    """Add small grey italic scenario label at the bottom-right of a figure."""
    if scenario:
        fig.text(
            0.99,
            0.005,
            scenario,
            fontsize=11,
            color="0.5",
            ha="right",
            va="bottom",
            fontstyle="italic",
            transform=fig.transFigure,
        )

plot_utils

simace.plotting.plot_utils

Shared plotting utilities for simace.

param_as_float

param_as_float(val, default=0.0)

Convert a scalar or per-generation dict param to a single float.

For per-gen dicts, returns the value at the lowest key (founder generation).

Source code in simace/plotting/plot_utils.py
def param_as_float(val: float | dict | None, default: float = 0.0) -> float:
    """Convert a scalar or per-generation dict param to a single float.

    For per-gen dicts, returns the value at the lowest key (founder generation).
    """
    if val is None:
        return default
    if isinstance(val, dict):
        return float(val[min(val)])
    return float(val)

save_placeholder_plot

save_placeholder_plot(output_path, message, figsize=(6, 4), dpi=150)

Save a single-panel figure with centered message text.

Source code in simace/plotting/plot_utils.py
def save_placeholder_plot(
    output_path: Any, message: str, figsize: tuple[float, float] = (6, 4), dpi: int = 150
) -> None:
    """Save a single-panel figure with centered message text."""
    import matplotlib.pyplot as plt

    _fig, ax = plt.subplots(figsize=figsize)
    ax.text(0.5, 0.5, message, ha="center", va="center", transform=ax.transAxes)
    plt.savefig(output_path, dpi=dpi)
    plt.close()

annotate_heatmap

annotate_heatmap(ax, proportions, counts, fmt_prop='.2f', prop_size=18, count_size=11)

Add two-line annotations to a heatmap: large bold proportion, smaller count.

PARAMETER DESCRIPTION
ax

Matplotlib axes containing the heatmap.

TYPE: Axes

proportions

2-D array-like of proportion values.

TYPE: ndarray

counts

2-D array-like of count values (int or float).

TYPE: ndarray

fmt_prop

Format spec for proportion values.

TYPE: str DEFAULT: '.2f'

prop_size

Font size for the proportion line.

TYPE: int DEFAULT: 18

count_size

Font size for the count line.

TYPE: int DEFAULT: 11

Source code in simace/plotting/plot_utils.py
def annotate_heatmap(
    ax: plt.Axes,
    proportions: np.ndarray,
    counts: np.ndarray,
    fmt_prop: str = ".2f",
    prop_size: int = 18,
    count_size: int = 11,
) -> None:
    """Add two-line annotations to a heatmap: large bold proportion, smaller count.

    Args:
        ax: Matplotlib axes containing the heatmap.
        proportions: 2-D array-like of proportion values.
        counts: 2-D array-like of count values (int or float).
        fmt_prop: Format spec for proportion values.
        prop_size: Font size for the proportion line.
        count_size: Font size for the count line.
    """
    proportions = np.asarray(proportions)
    counts = np.asarray(counts)
    for i in range(proportions.shape[0]):
        for j in range(proportions.shape[1]):
            p = proportions[i, j]
            c = counts[i, j]
            c_str = f"n={int(c)}" if float(c) == int(c) else f"n={c:.0f}"
            ax.text(
                j + 0.5,
                i + 0.38,
                f"{p:{fmt_prop}}",
                ha="center",
                va="center",
                fontsize=prop_size,
                fontweight="bold",
                color="white",
            )
            ax.text(j + 0.5, i + 0.62, c_str, ha="center", va="center", fontsize=count_size, color=(1, 1, 1, 0.7))

finalize_plot

finalize_plot(output_path, dpi=150, tight_rect=None, subsample_note='', scenario='')

tight_layout + savefig(bbox_inches='tight') + close current figure.

Source code in simace/plotting/plot_utils.py
def finalize_plot(
    output_path: Any,
    dpi: int = 150,
    tight_rect: list[float] | None = None,
    subsample_note: str = "",
    scenario: str = "",
) -> None:
    """tight_layout + savefig(bbox_inches='tight') + close current figure."""
    import warnings

    import matplotlib.pyplot as plt

    from simace.plotting.plot_style import add_scenario_label

    fig = plt.gcf()
    if scenario:
        add_scenario_label(fig, scenario)
    if subsample_note:
        fig.text(
            0.99,
            0.015,
            subsample_note,
            fontsize=8,
            color="0.5",
            ha="right",
            va="bottom",
            fontstyle="italic",
            transform=fig.transFigure,
        )
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*tight_layout.*", category=UserWarning)
        if tight_rect is not None:
            plt.tight_layout(rect=tight_rect)
        else:
            plt.tight_layout()
    plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
    plt.close()

draw_split_violin

draw_split_violin(ax, data_left, data_right, pos, color_left=None, color_right=None, width=0.8)

Draw a split violin at pos (left half / right half).

Replicates seaborn's violinplot(split=True, cut=0) using raw matplotlib, which is significantly faster for large arrays.

Source code in simace/plotting/plot_utils.py
def draw_split_violin(
    ax,
    data_left,
    data_right,
    pos,
    color_left=None,
    color_right=None,
    width=0.8,
):
    """Draw a split violin at *pos* (left half / right half).

    Replicates seaborn's ``violinplot(split=True, cut=0)`` using raw
    matplotlib, which is significantly faster for large arrays.
    """
    from simace.plotting.plot_style import COLOR_AFFECTED, COLOR_UNAFFECTED

    if color_left is None:
        color_left = COLOR_UNAFFECTED
    if color_right is None:
        color_right = COLOR_AFFECTED
    for data, color, side in [
        (data_left, color_left, "left"),
        (data_right, color_right, "right"),
    ]:
        if data is None or len(data) < 2:
            continue
        parts = ax.violinplot(
            [data],
            positions=[pos],
            showmeans=False,
            showmedians=False,
            showextrema=False,
            widths=width,
        )
        for body in parts["bodies"]:
            verts = body.get_paths()[0].vertices
            if side == "left":
                verts[:, 0] = np.clip(verts[:, 0], -np.inf, pos)
            else:
                verts[:, 0] = np.clip(verts[:, 0], pos, np.inf)
            body.set_facecolor(color)
            body.set_edgecolor("black")
            body.set_linewidth(0.5)
            body.set_alpha(1.0)
        # Inner box: Q1–Q3 bar + median dot (matches seaborn inner="box")
        q1, med, q3 = np.percentile(data, [25, 50, 75])
        x_inner = pos - width * 0.06 if side == "left" else pos + width * 0.06
        ax.vlines(x_inner, q1, q3, color="black", linewidth=1.0, zorder=4)
        ax.plot(x_inner, med, "o", color="white", ms=3, mew=0, zorder=5)

draw_colored_violins

draw_colored_violins(ax, datasets, positions, colors, alpha=0.7, width=0.8, zorder=3)

Draw violins at positions with per-category colors.

Replicates seaborn's violinplot(inner=None, cut=0) for categorically-coloured violin groups. Only groups with >= 2 values are drawn.

Source code in simace/plotting/plot_utils.py
def draw_colored_violins(
    ax,
    datasets,
    positions,
    colors,
    alpha=0.7,
    width=0.8,
    zorder=3,
):
    """Draw violins at *positions* with per-category *colors*.

    Replicates seaborn's ``violinplot(inner=None, cut=0)`` for
    categorically-coloured violin groups.  Only groups with >= 2 values
    are drawn.
    """
    valid = [(p, d, c) for p, d, c in zip(positions, datasets, colors, strict=True) if len(d) >= 2]
    if not valid:
        return
    v_pos, v_data, v_colors = zip(*valid, strict=True)
    parts = ax.violinplot(
        list(v_data),
        positions=list(v_pos),
        showmeans=False,
        showmedians=False,
        showextrema=False,
        widths=width,
    )
    for body, color in zip(parts["bodies"], v_colors, strict=True):
        body.set_facecolor(color)
        body.set_edgecolor("none")
        body.set_alpha(alpha)
        body.set_zorder(zorder)

setup_pair_type_panel

setup_pair_type_panel(ax, pair_types, n_pairs_per_ptype, n_reps, observed_per_rep, liability_r=None, parametric_r=None, frailty_r=None, show_violins_threshold=4, pair_colors=None, rng_seed=42)

Render one per-pair-type comparison panel except the per-rep observed dots.

For each pair type at x = i: * faint coloured violin (only when n_reps >= show_violins_threshold) * mean-of-observed wide cross * open diamond at mean liability r (if provided) * filled red star at parametric E[r] (if provided) * green filled plus at frailty r (if provided)

The per-rep observed dots are deferred so :func:finalize_pair_type_panels can decide a shared y-axis range across panels and clip outliers to the axis edges. Bold pair-type labels and parenthesised pair counts are drawn here; titles and y-labels remain caller-specific.

Returns {"ax", "ref_values", "obs_records"}.

Source code in simace/plotting/plot_utils.py
def setup_pair_type_panel(
    ax,
    pair_types: list[str],
    n_pairs_per_ptype: dict[str, int],
    n_reps: int,
    observed_per_rep: dict[str, list[float]],
    liability_r: dict[str, float] | None = None,
    parametric_r: dict[str, float] | None = None,
    frailty_r: dict[str, float] | None = None,
    show_violins_threshold: int = 4,
    pair_colors: dict[str, str] | None = None,
    rng_seed: int = 42,
) -> dict:
    """Render one per-pair-type comparison panel except the per-rep observed dots.

    For each pair type at x = i:
      * faint coloured violin (only when ``n_reps >= show_violins_threshold``)
      * mean-of-observed wide cross
      * open diamond at mean liability r (if provided)
      * filled red star at parametric E[r] (if provided)
      * green filled plus at frailty r (if provided)

    The per-rep observed dots are deferred so :func:`finalize_pair_type_panels`
    can decide a shared y-axis range across panels and clip outliers to the
    axis edges. Bold pair-type labels and parenthesised pair counts are drawn
    here; titles and y-labels remain caller-specific.

    Returns ``{"ax", "ref_values", "obs_records"}``.
    """
    from simace.plotting.plot_style import enable_value_gridlines

    if pair_colors is None:
        pair_colors = PAIR_COLORS

    ref_values: list[float] = []
    obs_records: list[tuple[Any, float, float]] = []

    if n_reps >= show_violins_threshold:
        datasets = [list(observed_per_rep.get(pt, [])) for pt in pair_types]
        if any(len(d) >= 2 for d in datasets):
            colors = [pair_colors[pt] for pt in pair_types]
            draw_colored_violins(ax, datasets, list(range(len(pair_types))), colors)

    # Light vertical separators between categories
    for i in range(len(pair_types) - 1):
        ax.axvline(i + 0.5, color="0.88", linewidth=0.6, zorder=0)

    rng = np.random.default_rng(rng_seed)
    for i, ptype in enumerate(pair_types):
        rep_vals = list(observed_per_rep.get(ptype, []))
        if not rep_vals:
            continue
        if len(rep_vals) > 1:
            jitter = rng.uniform(-0.08, 0.08, len(rep_vals))
        else:
            jitter = np.zeros(1)
        for x, v in zip(i + jitter, rep_vals, strict=False):
            obs_records.append((ax, float(x), float(v)))
        mean_v = float(np.mean(rep_vals))
        ax.scatter(i, mean_v, **_marker_obs_mean_halo())
        ax.scatter(i, mean_v, **_marker_obs_mean(color=pair_colors[ptype]))
        ref_values.append(mean_v)

    if liability_r:
        for i, ptype in enumerate(pair_types):
            v = liability_r.get(ptype)
            if v is not None:
                ax.scatter(i, float(v), **_marker_liab())
                ref_values.append(float(v))

    if frailty_r:
        for i, ptype in enumerate(pair_types):
            v = frailty_r.get(ptype)
            if v is not None:
                ax.scatter(i, float(v), **_marker_frailty())
                ref_values.append(float(v))

    if parametric_r:
        for i, ptype in enumerate(pair_types):
            v = parametric_r.get(ptype)
            if v is not None:
                ax.scatter(i, float(v), **_marker_param(color=pair_colors[ptype]))
                ref_values.append(float(v))

    ax.set_xticks(range(len(pair_types)))
    ax.set_xticklabels(pair_types, fontsize=15, fontweight="bold")
    ax.tick_params(axis="x", pad=4)
    ax.tick_params(axis="y", labelsize=11)
    for i, pt in enumerate(pair_types):
        ax.annotate(
            f"({n_pairs_per_ptype.get(pt, 0) // max(n_reps, 1):,})",
            xy=(i, 0),
            xytext=(0, -28),
            xycoords=("data", "axes fraction"),
            textcoords="offset points",
            ha="center",
            va="top",
            fontsize=9,
            color="0.35",
        )
    ax.set_xlabel("")
    ax.set_xlim(-0.6, len(pair_types) - 0.4)
    enable_value_gridlines(ax)

    return {"ax": ax, "ref_values": ref_values, "obs_records": obs_records}

finalize_pair_type_panels

finalize_pair_type_panels(panel_states, sane_band=PAIR_TYPE_SANE_BAND)

Apply a shared y-limit across all panels and draw observed dots.

The y-limit is anchored on reference markers (mean observed, liability, parametric, frailty) plus observed values inside sane_band. Observed values outside the band are rendered as small carets at the axis edge so one or two low-n outliers don't blow out the panel.

Source code in simace/plotting/plot_utils.py
def finalize_pair_type_panels(
    panel_states: list[dict],
    sane_band: tuple[float, float] = PAIR_TYPE_SANE_BAND,
) -> tuple[float, float]:
    """Apply a shared y-limit across all panels and draw observed dots.

    The y-limit is anchored on reference markers (mean observed, liability,
    parametric, frailty) plus observed values inside ``sane_band``. Observed
    values outside the band are rendered as small carets at the axis edge so
    one or two low-n outliers don't blow out the panel.
    """
    sane_lo, sane_hi = sane_band
    all_ref: list[float] = []
    all_obs_sane: list[float] = []
    for state in panel_states:
        all_ref.extend(state.get("ref_values", []))
        for _ax, _x, v in state.get("obs_records", []):
            if sane_lo <= v <= sane_hi:
                all_obs_sane.append(v)

    seed = all_ref + all_obs_sane
    if seed:
        ymax = max(seed)
        ymin = min(min(seed), 0.0)
        span = max(ymax - ymin, 0.05)
        pad = 0.10 * span
        ylim_lo = max(ymin - pad, sane_lo)
        ylim_hi = min(ymax + 1.5 * pad, sane_hi)
    else:
        ylim_lo, ylim_hi = -0.1, 1.1

    for state in panel_states:
        state["ax"].set_ylim(ylim_lo, ylim_hi)

    for state in panel_states:
        for ax_, x, v in state.get("obs_records", []):
            if v > ylim_hi:
                ax_.scatter(x, ylim_hi, marker="^", s=42, color="0.45", zorder=5)
            elif v < ylim_lo:
                ax_.scatter(x, ylim_lo, marker="v", s=42, color="0.45", zorder=5)
            else:
                ax_.scatter(x, v, **_marker_obs_per_rep())

    return ylim_lo, ylim_hi

pair_type_legend_handles

pair_type_legend_handles(has_observed_mean=True, has_liability=True, has_frailty=False, has_parametric=False)

Return Line2D proxies for fig.legend.

Markers match those used by :func:setup_pair_type_panel. Only the requested series are included.

Source code in simace/plotting/plot_utils.py
def pair_type_legend_handles(
    has_observed_mean: bool = True,
    has_liability: bool = True,
    has_frailty: bool = False,
    has_parametric: bool = False,
) -> list:
    """Return ``Line2D`` proxies for ``fig.legend``.

    Markers match those used by :func:`setup_pair_type_panel`. Only the
    requested series are included.
    """
    from matplotlib.lines import Line2D

    from simace.plotting.plot_style import COLOR_AFFECTED, COLOR_UNCENSORED

    handles = [
        Line2D([0], [0], marker="o", color="0.45", linestyle="None", markersize=7, label="Observed r (per rep)"),
    ]
    if has_observed_mean:
        handles.append(
            Line2D(
                [0],
                [0],
                marker=_WIDE_PLUS,
                color="black",
                linestyle="None",
                markersize=14,
                markeredgewidth=2.2,
                label="Observed r (mean)",
            )
        )
    if has_liability:
        handles.append(
            Line2D(
                [0],
                [0],
                marker="D",
                color="black",
                linestyle="None",
                markersize=9,
                markerfacecolor="white",
                markeredgewidth=1.5,
                label="Liability r (mean)",
            )
        )
    if has_frailty:
        handles.append(
            Line2D(
                [0],
                [0],
                marker="P",
                color=COLOR_UNCENSORED,
                linestyle="None",
                markersize=11,
                label="Frailty r (uncensored)",
            )
        )
    if has_parametric:
        handles.append(
            Line2D([0], [0], marker="*", color=COLOR_AFFECTED, linestyle="None", markersize=16, label="Parametric E[r]")
        )
    return handles

plot_correlations

simace.plotting.plot_correlations

Correlation-related phenotype plots.

Contains: plot_tetrachoric_sibling, plot_tetrachoric_by_generation, plot_cross_trait_tetrachoric, plot_parent_offspring_liability, plot_tetrachoric_by_sex. Heritability pages live in plot_heritability.

plot_tetrachoric_sibling

plot_tetrachoric_sibling(all_stats, output_path, scenario, params=None)

Plot tetrachoric correlations by relationship type using marker-based references.

For each pair type, draws shapes stacked at the same x position: gray dots per rep (observed r), a black wide cross (mean of observed), an open black diamond (mean liability r), a red star (parametric E[r]), and a green plus (frailty r on uncensored frailties, when available). Faint violins appear only when reps >= 4 so the spread is visible without dominating the panel.

Source code in simace/plotting/plot_correlations.py
def plot_tetrachoric_sibling(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str,
    params: dict[str, Any] | None = None,
) -> None:
    """Plot tetrachoric correlations by relationship type using marker-based references.

    For each pair type, draws shapes stacked at the same x position: gray dots
    per rep (observed r), a black wide cross (mean of observed), an open black
    diamond (mean liability r), a red star (parametric E[r]), and a green plus
    (frailty r on uncensored frailties, when available). Faint violins appear
    only when reps >= 4 so the spread is visible without dominating the panel.
    """
    pair_types = PAIR_TYPES
    n_reps = max(len(all_stats), 1)

    fig, axes = plt.subplots(1, 2, figsize=(13, 6.5), sharey=True)

    has_uncens_any = any(s.get("frailty_corr_uncensored") for s in all_stats)
    has_parametric_any = bool(params) and any(params.get(f"A{t}") is not None for t in (1, 2))

    panel_states: list[dict] = []

    for col_idx, trait_num in enumerate([1, 2]):
        ax = axes[col_idx]
        trait_key = f"trait{trait_num}"

        observed, n_pairs = _extract_pair_type_observed(all_stats, "tetrachoric", trait_key, pair_types)
        liability = _mean_per_pair_type(
            all_stats,
            lambda s, pt, _tk=trait_key: s.get("liability_correlations", {}).get(_tk, {}).get(pt),
            pair_types,
        )
        frailty = (
            _mean_per_pair_type(
                all_stats,
                lambda s, pt, _tk=trait_key: s.get("frailty_corr_uncensored", {}).get(_tk, {}).get(pt, {}).get("r"),
                pair_types,
            )
            if has_uncens_any
            else None
        )
        parametric = _parametric_per_pair_type(params, trait_num, pair_types)

        state = setup_pair_type_panel(
            ax,
            pair_types=pair_types,
            n_pairs_per_ptype=n_pairs,
            n_reps=n_reps,
            observed_per_rep=observed,
            liability_r=liability or None,
            parametric_r=parametric or None,
            frailty_r=frailty,
        )
        if col_idx == 0:
            ax.set_ylabel("Tetrachoric correlation", fontsize=12)
        ax.set_title(f"Trait {trait_num}", fontsize=13)
        panel_states.append(state)

    finalize_pair_type_panels(panel_states)

    fig.legend(
        handles=pair_type_legend_handles(
            has_observed_mean=True,
            has_liability=True,
            has_frailty=has_uncens_any,
            has_parametric=has_parametric_any,
        ),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),
        ncol=4,
        fontsize=10,
        frameon=False,
    )

    finalize_plot(output_path, scenario=scenario, tight_rect=[0, 0, 1, 0.94])

plot_tetrachoric_by_generation

plot_tetrachoric_by_generation(all_stats, output_path, scenario='', params=None)

Plot tetrachoric correlations by relationship type, broken out by generation.

2 rows (traits) x N cols (last 3 non-founder generations). Each cell shares the same marker conventions as :func:plot_tetrachoric_sibling. Y-axis is shared across cells of the same trait row so generation-to-generation drift is directly comparable.

Source code in simace/plotting/plot_correlations.py
def plot_tetrachoric_by_generation(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    params: dict[str, Any] | None = None,
) -> None:
    """Plot tetrachoric correlations by relationship type, broken out by generation.

    2 rows (traits) x N cols (last 3 non-founder generations). Each cell shares
    the same marker conventions as :func:`plot_tetrachoric_sibling`. Y-axis is
    shared across cells of the same trait row so generation-to-generation drift
    is directly comparable.
    """
    gen_keys_sets = [set(s.get("tetrachoric_by_generation", {}).keys()) for s in all_stats]
    if not gen_keys_sets or not gen_keys_sets[0]:
        save_placeholder_plot(output_path, "No per-generation tetrachoric data")
        return

    gen_keys = sorted(set.intersection(*gen_keys_sets))
    if not gen_keys:
        save_placeholder_plot(output_path, "No per-generation tetrachoric data")
        return

    pair_types = PAIR_TYPES
    n_reps = max(len(all_stats), 1)
    n_cols = len(gen_keys)

    fig, axes = plt.subplots(2, n_cols, figsize=(6.0 * n_cols, 10), squeeze=False)

    has_parametric_any = bool(params) and any(params.get(f"A{t}") is not None for t in (1, 2))

    for row, trait_num in enumerate([1, 2]):
        trait_key = f"trait{trait_num}"
        row_states: list[dict] = []

        for col, gen_key in enumerate(gen_keys):
            ax = axes[row, col]

            observed: dict[str, list[float]] = {pt: [] for pt in pair_types}
            n_pairs: dict[str, int] = dict.fromkeys(pair_types, 0)
            for s in all_stats:
                cell = s.get("tetrachoric_by_generation", {}).get(gen_key, {}).get(trait_key, {})
                for ptype in pair_types:
                    entry = cell.get(ptype, {})
                    r = entry.get("r")
                    if r is not None:
                        observed[ptype].append(float(r))
                    n_pairs[ptype] += int(entry.get("n_pairs", 0) or 0)

            liability = _mean_per_pair_type(
                all_stats,
                lambda s, pt, _gk=gen_key, _tk=trait_key: (
                    s.get("tetrachoric_by_generation", {}).get(_gk, {}).get(_tk, {}).get(pt, {}).get("liability_r")
                ),
                pair_types,
            )
            parametric = _parametric_per_pair_type(params, trait_num, pair_types)

            state = setup_pair_type_panel(
                ax,
                pair_types=pair_types,
                n_pairs_per_ptype=n_pairs,
                n_reps=n_reps,
                observed_per_rep=observed,
                liability_r=liability or None,
                parametric_r=parametric or None,
            )

            if row == 0:
                gen_num = gen_key.replace("gen", "")
                ax.set_title(f"Gen {gen_num}", fontsize=13)
            if col == 0:
                ax.set_ylabel(f"Trait {trait_num}\nTetrachoric correlation", fontsize=12)
            row_states.append(state)

        finalize_pair_type_panels(row_states)

    fig.legend(
        handles=pair_type_legend_handles(
            has_observed_mean=True,
            has_liability=True,
            has_frailty=False,
            has_parametric=has_parametric_any,
        ),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),
        ncol=4,
        fontsize=10,
        frameon=False,
    )

    finalize_plot(output_path, scenario=scenario, tight_rect=[0, 0, 1, 0.96])

plot_cross_trait_tetrachoric

plot_cross_trait_tetrachoric(all_stats, output_path, scenario='')

Two-panel figure for cross-trait tetrachoric correlations.

Same-person cross-trait r by generation (dots per rep + mean line),

with frailty cross-trait reference lines if available.

Right: Cross-person cross-trait r by pair type (violin/dots), showing how relatedness induces cross-trait association.

Source code in simace/plotting/plot_correlations.py
def plot_cross_trait_tetrachoric(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """Two-panel figure for cross-trait tetrachoric correlations.

    Left: Same-person cross-trait r by generation (dots per rep + mean line),
          with frailty cross-trait reference lines if available.
    Right: Cross-person cross-trait r by pair type (violin/dots), showing how
           relatedness induces cross-trait association.
    """
    pair_types = PAIR_TYPES

    _fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # ---- Left panel: same-person by generation ----
    ax_left = axes[0]

    # Collect generation data across reps
    gen_data: dict[int, list[float]] = {}
    for s in all_stats:
        ct = s.get("cross_trait_tetrachoric", {})
        by_gen = ct.get("same_person_by_generation", {})
        for gk, gv in by_gen.items():
            gen_num = int(gk.replace("gen", ""))
            r_g = gv.get("r")
            if r_g is not None:
                gen_data.setdefault(gen_num, []).append(r_g)

    if gen_data:
        generations = sorted(gen_data.keys())
        for gen in generations:
            rs = gen_data[gen]
            for rep_idx, r_val in enumerate(rs):
                jitter = np.random.default_rng(42 + rep_idx).uniform(-0.08, 0.08)
                ax_left.scatter(gen + jitter, r_val, color=COLOR_OBSERVED, alpha=0.9, s=15, zorder=5)

        mean_rs = [np.mean(gen_data[g]) for g in generations]
        ax_left.plot(
            generations,
            mean_rs,
            color=COLOR_OBSERVED,
            linewidth=1.2,
            marker="o",
            markersize=5,
            zorder=6,
            label="Per-gen mean",
        )

        # Overall same-person r (averaged across reps)
        overall_rs = [s.get("cross_trait_tetrachoric", {}).get("same_person", {}).get("r") for s in all_stats]
        overall_rs = [r for r in overall_rs if r is not None]
        if overall_rs:
            mean_overall = np.mean(overall_rs)
            ax_left.axhline(
                y=mean_overall,
                color="black",
                linestyle="--",
                linewidth=1.5,
                alpha=0.7,
                label=f"Overall r = {mean_overall:.3f}",
            )

        # frailty cross-trait reference lines if available
        oracle_rs = [s.get("frailty_cross_trait_uncensored", {}).get("r") for s in all_stats]
        oracle_rs = [r for r in oracle_rs if r is not None]
        if oracle_rs:
            ax_left.axhline(
                y=np.mean(oracle_rs),
                color=COLOR_UNCENSORED,
                linestyle="-.",
                linewidth=1.0,
                alpha=0.7,
                label=f"Frailty oracle = {np.mean(oracle_rs):.3f}",
            )

        ax_left.set_xticks(generations)
        ax_left.legend(loc="best", fontsize=9)
    else:
        ax_left.text(0.5, 0.5, "No generation data", ha="center", va="center", transform=ax_left.transAxes)

    ax_left.set_xlabel("Generation")
    ax_left.set_ylabel("Cross-trait tetrachoric r")
    ax_left.set_title("Same-Person: affected1 vs affected2")

    # ---- Right panel: cross-person by pair type ----
    ax_right = axes[1]
    n_reps = max(len(all_stats), 1)

    observed: dict[str, list[float]] = {pt: [] for pt in pair_types}
    n_pairs: dict[str, int] = dict.fromkeys(pair_types, 0)
    for s in all_stats:
        cell = s.get("cross_trait_tetrachoric", {}).get("cross_person", {})
        for ptype in pair_types:
            entry = cell.get(ptype, {})
            r = entry.get("r")
            if r is not None:
                observed[ptype].append(float(r))
            n_pairs[ptype] += int(entry.get("n_pairs", 0) or 0)

    if any(observed.values()):
        right_state = setup_pair_type_panel(
            ax_right,
            pair_types=pair_types,
            n_pairs_per_ptype=n_pairs,
            n_reps=n_reps,
            observed_per_rep=observed,
        )
        finalize_pair_type_panels([right_state])
    else:
        ax_right.text(0.5, 0.5, "No cross-person data", ha="center", va="center", transform=ax_right.transAxes)

    ax_right.set_ylabel("Cross-trait tetrachoric r", fontsize=12)
    ax_right.set_title("Cross-Person: personA.affected1 vs personB.affected2", fontsize=13)

    finalize_plot(output_path, scenario=scenario)

plot_parent_offspring_liability

plot_parent_offspring_liability(df_samples, all_stats, output_path, scenario='', subsample_note='', params=None)

2 x 3 scatter grid: midparent vs offspring liability by generation.

Source code in simace/plotting/plot_correlations.py
def plot_parent_offspring_liability(
    df_samples: pd.DataFrame,
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    subsample_note: str = "",
    params: dict[str, Any] | None = None,
) -> None:
    """2 x 3 scatter grid: midparent vs offspring liability by generation."""
    from scipy.stats import t as t_dist

    from simace.plotting.plot_style import COLOR_FEMALE, COLOR_MALE

    if "generation" not in df_samples.columns:
        save_placeholder_plot(output_path, "No generation data")
        return

    # Build id -> row lookup within df_samples
    ids_arr = df_samples["id"].values
    max_id = int(ids_arr.max()) + 1
    id_to_row = np.full(max_id, -1, dtype=np.int32)
    id_to_row[ids_arr] = np.arange(len(df_samples), dtype=np.int32)

    # Select non-founder generations whose parents are present in the sample.
    # The earliest phenotyped generation's parents may be outside the phenotype
    # window (e.g. G_pheno < G_ped), so we test each candidate generation.
    _sample_ids = set(ids_arr.tolist())
    min_gen = int(df_samples["generation"].min())
    max_gen = int(df_samples["generation"].max())
    candidate_gens = list(range(max(min_gen + 1, 1), max_gen + 1))
    plot_gens = []
    for gen in candidate_gens:
        gen_mask = df_samples["generation"].values == gen
        mothers = df_samples["mother"].values[gen_mask]
        # Check if any parents are present in the sample
        if np.any(np.isin(mothers[mothers >= 0], ids_arr)):
            plot_gens.append(gen)
    # Keep at most 3 generations for a readable grid
    plot_gens = plot_gens[-3:]

    if not plot_gens:
        save_placeholder_plot(output_path, "No generations with parent data available")
        return

    n_cols = len(plot_gens)
    _fig, axes = plt.subplots(2, n_cols, figsize=(5 * n_cols, 8), squeeze=False)

    for row, trait_num in enumerate([1, 2]):
        liability = df_samples[f"liability{trait_num}"].values

        for col, gen in enumerate(plot_gens):
            ax = axes[row, col]
            gen_idx = np.where(df_samples["generation"].values == gen)[0]

            mother_ids = df_samples["mother"].values[gen_idx]
            father_ids = df_samples["father"].values[gen_idx]

            has_m = (mother_ids >= 0) & (mother_ids < max_id)
            has_f = (father_ids >= 0) & (father_ids < max_id)

            m_rows = np.full(len(gen_idx), -1, dtype=np.int32)
            f_rows = np.full(len(gen_idx), -1, dtype=np.int32)
            m_rows[has_m] = id_to_row[mother_ids[has_m]]
            f_rows[has_f] = id_to_row[father_ids[has_f]]

            valid = (m_rows >= 0) & (f_rows >= 0)

            if valid.sum() < 2:
                ax.text(0.5, 0.5, "Insufficient data", ha="center", va="center", transform=ax.transAxes)
                if row == 0:
                    ax.set_title(f"Gen {gen}")
                continue

            offspring_liab = liability[gen_idx[valid]]
            midparent_liab = (liability[m_rows[valid]] + liability[f_rows[valid]]) / 2.0

            # Sex-stratified scatter: daughters in green, sons in blue
            sex_arr = df_samples["sex"].values
            offspring_sex = sex_arr[gen_idx[valid]]
            f_mask = offspring_sex == 0
            m_mask = offspring_sex == 1
            if f_mask.any():
                ax.plot(
                    midparent_liab[f_mask],
                    offspring_liab[f_mask],
                    "o",
                    ms=2,
                    mew=0,
                    alpha=0.25,
                    color=COLOR_FEMALE,
                    rasterized=True,
                )
            if m_mask.any():
                ax.plot(
                    midparent_liab[m_mask],
                    offspring_liab[m_mask],
                    "o",
                    ms=2,
                    mew=0,
                    alpha=0.25,
                    color=COLOR_MALE,
                    rasterized=True,
                )

            # Collect pre-computed stats (averaged across reps)
            r_vals, slope_vals, intercept_vals, n_vals = [], [], [], []
            stderr_vals: list[float] = []
            for s in all_stats:
                po = s.get("parent_offspring_corr", {}).get(f"trait{trait_num}", {}).get(f"gen{gen}", {})
                if po and po.get("r") is not None:
                    r_vals.append(po["r"])
                    slope_vals.append(po["slope"])
                    intercept_vals.append(po["intercept"])
                    n_vals.append(po["n_pairs"])
                    if po.get("stderr") is not None:
                        stderr_vals.append(po["stderr"])

            if r_vals:
                mean_r = np.mean(r_vals)
                mean_slope = np.mean(slope_vals)
                mean_intercept = np.mean(intercept_vals)
                mean_n = int(np.mean(n_vals))
                mean_stderr = float(np.mean(stderr_vals)) if stderr_vals else None
            else:
                from simace.core.numerics import fast_linregress

                mean_slope, mean_intercept, mean_r, mean_stderr, _mean_pvalue = fast_linregress(
                    midparent_liab, offspring_liab
                )
                mean_n = int(valid.sum())

            # Observed regression line
            x_line = np.array([midparent_liab.min(), midparent_liab.max()])
            ax.plot(x_line, mean_slope * x_line + mean_intercept, color=COLOR_AFFECTED, linewidth=1.2)

            # 95% confidence band around regression line
            if mean_stderr is not None and mean_n > 2:
                x_smooth = np.linspace(midparent_liab.min(), midparent_liab.max(), 200)
                y_hat = mean_slope * x_smooth + mean_intercept
                x_mean = np.mean(midparent_liab)
                ss_x = np.sum((midparent_liab - x_mean) ** 2)
                if ss_x > 1e-12:
                    # Reconstruct residual SE: stderr_slope = s / sqrt(SS_x)
                    s = mean_stderr * np.sqrt(ss_x)
                    t_crit = t_dist.ppf(0.975, df=mean_n - 2)
                    se_fit = s * np.sqrt(1.0 / mean_n + (x_smooth - x_mean) ** 2 / ss_x)
                    ax.fill_between(
                        x_smooth,
                        y_hat - t_crit * se_fit,
                        y_hat + t_crit * se_fit,
                        alpha=0.15,
                        color=COLOR_AFFECTED,
                        zorder=2,
                    )

            # Expected slope from configured A (h² = A for midparent-offspring)
            if params is not None:
                expected_slope = params.get(f"A{trait_num}")
                if expected_slope is not None:
                    x_mean = np.mean(midparent_liab)
                    y_mean = np.mean(offspring_liab)
                    expected_intercept = y_mean - float(expected_slope) * x_mean
                    ax.plot(
                        x_line,
                        float(expected_slope) * x_line + expected_intercept,
                        color=COLOR_UNAFFECTED,
                        linestyle="--",
                        linewidth=1.0,
                        zorder=4,
                    )

            # Sex-stratified regression lines
            sex_slopes: dict[str, float | None] = {}
            for sex_key, sex_color in [("female", COLOR_FEMALE), ("male", COLOR_MALE)]:
                sex_slope_vals = []
                for s in all_stats:
                    po_s = (
                        s.get("parent_offspring_corr_by_sex", {})
                        .get(sex_key, {})
                        .get(f"trait{trait_num}", {})
                        .get(f"gen{gen}", {})
                    )
                    if po_s and po_s.get("slope") is not None:
                        sex_slope_vals.append(po_s["slope"])
                if sex_slope_vals:
                    s_slope = np.mean(sex_slope_vals)
                    s_intercept = np.mean(offspring_liab) - s_slope * np.mean(midparent_liab)
                    ax.plot(
                        x_line,
                        s_slope * x_line + s_intercept,
                        color=sex_color,
                        linewidth=1.5,
                        alpha=0.8,
                        zorder=3,
                    )
                    sex_slopes[sex_key] = s_slope
                else:
                    sex_slopes[sex_key] = None

            # Annotation: lead with h² (slope = heritability estimate)
            ann_lines = []
            if mean_stderr is not None:
                ann_lines.append(f"h\u00b2 = {mean_slope:.4f} \u00b1 {mean_stderr:.4f}")
            else:
                ann_lines.append(f"h\u00b2 = {mean_slope:.4f}")
            if sex_slopes.get("female") is not None:
                ann_lines.append(f"h\u00b2\u2640 = {sex_slopes['female']:.4f}")
            if sex_slopes.get("male") is not None:
                ann_lines.append(f"h\u00b2\u2642 = {sex_slopes['male']:.4f}")
            ann_lines.append(f"r = {mean_r:.4f}")
            ax.text(
                0.05,
                0.95,
                "\n".join(ann_lines),
                transform=ax.transAxes,
                va="top",
                fontsize=10,
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            )

            if row == 0:
                ax.set_title(f"Gen {gen}")
            if col == 0:
                ax.set_ylabel(f"Trait {trait_num}\nOffspring Liability")
            if row == 1:
                ax.set_xlabel("Midparent Liability")

    # Legend on the last axes
    from matplotlib.lines import Line2D

    has_any_expected = params is not None and any(params.get(f"A{t}") is not None for t in [1, 2])
    legend_handles = [
        Line2D([0], [0], color=COLOR_AFFECTED, linewidth=1.2, label="Observed h\u00b2"),
        Line2D([0], [0], color=COLOR_FEMALE, linewidth=1.2, label="Daughters"),
        Line2D([0], [0], color=COLOR_MALE, linewidth=1.2, label="Sons"),
    ]
    if has_any_expected:
        legend_handles.append(
            Line2D([0], [0], color=COLOR_UNAFFECTED, linestyle="--", linewidth=1.0, label="Expected (A)")
        )
    axes[0, -1].legend(handles=legend_handles, loc="lower right", fontsize=8)

    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_tetrachoric_by_sex

plot_tetrachoric_by_sex(all_stats, output_path, scenario='', params=None)

Tetrachoric correlations for same-sex pairs: 2 rows (traits) x 2 cols (F/M).

Same marker conventions as :func:plot_tetrachoric_sibling. Each trait row shares its y-axis across the female and male panels so cross-sex magnitude differences are directly comparable; the two trait rows are independent.

Source code in simace/plotting/plot_correlations.py
def plot_tetrachoric_by_sex(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    params: dict[str, Any] | None = None,
) -> None:
    """Tetrachoric correlations for same-sex pairs: 2 rows (traits) x 2 cols (F/M).

    Same marker conventions as :func:`plot_tetrachoric_sibling`. Each trait row
    shares its y-axis across the female and male panels so cross-sex magnitude
    differences are directly comparable; the two trait rows are independent.
    """
    pair_types = PAIR_TYPES
    sex_labels = [("female", "Female"), ("male", "Male")]
    n_reps = max(len(all_stats), 1)

    if not any(s.get("tetrachoric_by_sex") for s in all_stats):
        save_placeholder_plot(output_path, "No sex-stratified tetrachoric data")
        return

    fig, axes = plt.subplots(2, 2, figsize=(14, 10), squeeze=False)

    has_parametric_any = bool(params) and any(params.get(f"A{t}") is not None for t in (1, 2))

    for row_idx, trait_num in enumerate([1, 2]):
        trait_key = f"trait{trait_num}"
        row_states: list[dict] = []

        for col_idx, (sex_key, sex_display) in enumerate(sex_labels):
            ax = axes[row_idx, col_idx]

            observed: dict[str, list[float]] = {pt: [] for pt in pair_types}
            n_pairs: dict[str, int] = dict.fromkeys(pair_types, 0)
            for s in all_stats:
                cell = s.get("tetrachoric_by_sex", {}).get(sex_key, {}).get(trait_key, {})
                for ptype in pair_types:
                    entry = cell.get(ptype, {})
                    r = entry.get("r")
                    if r is not None:
                        observed[ptype].append(float(r))
                    n_pairs[ptype] += int(entry.get("n_pairs", 0) or 0)

            liability = _mean_per_pair_type(
                all_stats,
                lambda s, pt, _sk=sex_key, _tk=trait_key: (
                    s.get("tetrachoric_by_sex", {}).get(_sk, {}).get(_tk, {}).get(pt, {}).get("liability_r")
                ),
                pair_types,
            )
            parametric = _parametric_per_pair_type(params, trait_num, pair_types)

            state = setup_pair_type_panel(
                ax,
                pair_types=pair_types,
                n_pairs_per_ptype=n_pairs,
                n_reps=n_reps,
                observed_per_rep=observed,
                liability_r=liability or None,
                parametric_r=parametric or None,
            )

            if col_idx == 0:
                ax.set_ylabel(f"Trait {trait_num}\nTetrachoric correlation", fontsize=12)
            if row_idx == 0:
                ax.set_title(f"{sex_display}", fontsize=13)
            row_states.append(state)

        # Shared y-axis within a trait row only — cross-trait magnitudes differ.
        finalize_pair_type_panels(row_states)

    fig.legend(
        handles=pair_type_legend_handles(
            has_observed_mean=True,
            has_liability=True,
            has_frailty=False,
            has_parametric=has_parametric_any,
        ),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),
        ncol=4,
        fontsize=10,
        frameon=False,
    )

    finalize_plot(output_path, scenario=scenario, tight_rect=[0, 0, 1, 0.96])

plot_distributions

simace.plotting.plot_distributions

Distribution-related phenotype plots.

Contains: plot_death_age_distribution, plot_trait_phenotype, plot_trait_regression, plot_cumulative_incidence, plot_cumulative_incidence_by_sex, plot_cumulative_incidence_by_sex_generation, plot_censoring_windows.

plot_death_age_distribution

plot_death_age_distribution(all_stats, censor_age, output_path, scenario='')

Plot mortality rate and cumulative mortality by decade, averaged across reps.

Source code in simace/plotting/plot_distributions.py
def plot_death_age_distribution(
    all_stats: list[dict[str, Any]], censor_age: float, output_path: str | Path, scenario: str = ""
) -> None:
    """Plot mortality rate and cumulative mortality by decade, averaged across reps."""
    _fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Average mortality rates across reps
    all_rates = np.array([s["mortality"]["rates"] for s in all_stats])
    mean_rates = all_rates.mean(axis=0)
    decade_labels = all_stats[0]["mortality"]["decade_labels"]

    # Left: mortality rate per decade
    axes[0].bar(decade_labels, mean_rates, edgecolor="black", alpha=0.7)
    axes[0].set_title("Mortality Rate by Decade")
    axes[0].set_xlabel("Age Decade")
    axes[0].set_ylabel("Mortality Rate")
    axes[0].tick_params(axis="x", rotation=45)

    # Right: cumulative mortality per decade with survival annotations
    survival = np.cumprod(1 - mean_rates)
    cumulative = 1 - survival
    bars = axes[1].bar(decade_labels, cumulative, edgecolor="black", alpha=0.7)
    for bar, s in zip(bars, survival, strict=True):
        axes[1].text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.01,
            f"S={s:.2f}",
            ha="center",
            va="bottom",
            fontsize=8,
        )
    axes[1].set_title("Cumulative Mortality by Decade")
    axes[1].set_xlabel("Age Decade")
    axes[1].set_ylabel("Cumulative Mortality")
    axes[1].tick_params(axis="x", rotation=45)

    finalize_plot(output_path, scenario=scenario)

plot_trait_phenotype

plot_trait_phenotype(df_samples, output_path, scenario='', subsample_note='')

Plot phenotype distributions for both traits in a 2x2 grid.

Source code in simace/plotting/plot_distributions.py
def plot_trait_phenotype(
    df_samples: pd.DataFrame, output_path: str | Path, scenario: str = "", subsample_note: str = ""
) -> None:
    """Plot phenotype distributions for both traits in a 2x2 grid."""
    _fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    for row, trait_num in enumerate([1, 2]):
        affected_col = f"affected{trait_num}"
        t_col = f"t_observed{trait_num}"
        death_censored_col = f"death_censored{trait_num}"

        affected = df_samples[df_samples[affected_col]]
        death_censored = df_samples[~df_samples[affected_col] & df_samples[death_censored_col]]

        axes[row, 0].hist(
            affected[t_col].dropna(),
            bins=50,
            density=True,
            edgecolor="black",
            alpha=0.7,
            color=COLOR_AFFECTED,
        )
        axes[row, 0].set_title(f"Trait {trait_num}: Age at Onset (affected)")
        axes[row, 0].set_xlabel("Age")
        axes[row, 0].set_ylabel("Density")

        axes[row, 1].hist(
            death_censored[t_col].dropna(),
            bins=50,
            density=True,
            edgecolor="black",
            alpha=0.7,
            color=COLOR_UNAFFECTED,
        )
        axes[row, 1].set_title(f"Trait {trait_num}: Age at Death (death-censored, unaffected)")
        axes[row, 1].set_xlabel("Age")
        axes[row, 1].set_ylabel("Density")

    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_trait_regression

plot_trait_regression(df_samples, all_stats, output_path, scenario='', subsample_note='')

Plot liability vs age at onset for both traits as jointplots side by side.

Source code in simace/plotting/plot_distributions.py
def plot_trait_regression(
    df_samples: pd.DataFrame,
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    subsample_note: str = "",
) -> None:
    """Plot liability vs age at onset for both traits as jointplots side by side."""
    from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

    fig = plt.figure(figsize=(16, 7))
    outer = GridSpec(1, 2, figure=fig, wspace=0.35)

    for i, trait_num in enumerate([1, 2]):
        affected_col = f"affected{trait_num}"
        t_col = f"t_observed{trait_num}"
        liability_col = f"liability{trait_num}"

        if liability_col not in df_samples.columns:
            continue

        affected = df_samples[df_samples[affected_col]].dropna(subset=[liability_col, t_col])
        x = affected[liability_col].values
        y = affected[t_col].values

        # Get regression stats from pre-computed stats (averaged across reps)
        reg_stats = [
            s["regression"][f"trait{trait_num}"]
            for s in all_stats
            if s["regression"].get(f"trait{trait_num}") is not None
        ]
        if reg_stats:
            mean_r = np.mean([r["r"] for r in reg_stats])
            mean_slope = np.mean([r["slope"] for r in reg_stats])
            mean_intercept = np.mean([r["intercept"] for r in reg_stats])
            mean_n = int(np.mean([r["n"] for r in reg_stats]))
            stderr_vals = [r["stderr"] for r in reg_stats if r.get("stderr") is not None]
            mean_stderr = float(np.mean(stderr_vals)) if stderr_vals else None
        elif len(x) >= 2:
            from simace.core.numerics import fast_linregress

            mean_slope, mean_intercept, mean_r, mean_stderr, _mean_pvalue = fast_linregress(x, y)
            mean_n = len(x)
        else:
            continue

        inner = GridSpecFromSubplotSpec(
            2,
            2,
            subplot_spec=outer[i],
            width_ratios=[4, 1],
            height_ratios=[1, 4],
            hspace=0.05,
            wspace=0.05,
        )
        ax_joint = fig.add_subplot(inner[1, 0])
        ax_marg_x = fig.add_subplot(inner[0, 0], sharex=ax_joint)
        ax_marg_y = fig.add_subplot(inner[1, 1], sharey=ax_joint)

        ax_joint.plot(x, y, "o", ms=2, mew=0, alpha=0.15, rasterized=True)
        x_line = np.array([x.min(), x.max()])
        ax_joint.plot(
            x_line,
            mean_slope * x_line + mean_intercept,
            color=COLOR_AFFECTED,
            linewidth=1.2,
        )

        # 95% confidence band
        if mean_stderr is not None and mean_n > 2:
            from scipy.stats import t as t_dist

            x_smooth = np.linspace(x.min(), x.max(), 200)
            y_hat = mean_slope * x_smooth + mean_intercept
            x_mean = np.mean(x)
            ss_x = np.sum((x - x_mean) ** 2)
            if ss_x > 1e-12:
                s = mean_stderr * np.sqrt(ss_x)
                t_crit = t_dist.ppf(0.975, df=mean_n - 2)
                se_fit = s * np.sqrt(1.0 / mean_n + (x_smooth - x_mean) ** 2 / ss_x)
                ax_joint.fill_between(
                    x_smooth,
                    y_hat - t_crit * se_fit,
                    y_hat + t_crit * se_fit,
                    alpha=0.15,
                    color=COLOR_AFFECTED,
                    zorder=2,
                )

        # Annotation: slope, r
        ann_lines = [f"slope = {mean_slope:.4f}", f"r = {mean_r:.4f}"]
        ax_joint.text(
            0.05,
            0.95,
            "\n".join(ann_lines),
            transform=ax_joint.transAxes,
            va="top",
            fontsize=11,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        )
        ax_joint.set_xlabel("Liability")
        ax_joint.set_ylabel("Age at Onset")

        ax_marg_x.hist(x, bins=50, edgecolor="none", alpha=0.7)
        ax_marg_y.hist(y, bins=50, orientation="horizontal", edgecolor="none", alpha=0.7)
        ax_marg_x.set_title(f"Trait {trait_num}", fontsize=12)
        ax_marg_x.tick_params(labelbottom=False, labelleft=False)
        ax_marg_x.set_ylabel("")
        ax_marg_y.tick_params(labelleft=False, labelbottom=False)
        ax_marg_y.set_xlabel("")

        ax_corner = fig.add_subplot(inner[0, 1])
        ax_corner.axis("off")

    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_cumulative_incidence

plot_cumulative_incidence(all_stats, censor_age, output_path, scenario='')

Plot cumulative incidence by age, mean +/- band across reps.

Source code in simace/plotting/plot_distributions.py
def plot_cumulative_incidence(
    all_stats: list[dict[str, Any]], censor_age: float, output_path: str | Path, scenario: str = ""
) -> None:
    """Plot cumulative incidence by age, mean +/- band across reps."""
    _fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

    for trait_num, ax in zip([1, 2], axes, strict=True):
        key = f"trait{trait_num}"
        ages = np.array(all_stats[0]["cumulative_incidence"][key]["ages"])

        # Support both old ("values") and new ("observed_values"/"true_values") format
        mean_true = None
        if "observed_values" in all_stats[0]["cumulative_incidence"][key]:
            all_obs = np.array([s["cumulative_incidence"][key]["observed_values"] for s in all_stats])
            all_true = np.array([s["cumulative_incidence"][key]["true_values"] for s in all_stats])
            mean_true = all_true.mean(axis=0)

            # True incidence (gray)
            ax.plot(ages, mean_true, color=COLOR_TRUE, alpha=0.7, linewidth=1.2, label="True")
            if len(all_stats) > 1:
                ax.fill_between(ages, all_true.min(axis=0), all_true.max(axis=0), alpha=0.1, color=COLOR_TRUE)
        else:
            all_obs = np.array([s["cumulative_incidence"][key]["values"] for s in all_stats])

        mean_obs = all_obs.mean(axis=0)

        # Observed incidence (colored)
        ax.plot(ages, mean_obs, color=COLOR_OBSERVED, linewidth=1.2, label="Observed")
        if len(all_stats) > 1:
            ax.fill_between(ages, all_obs.min(axis=0), all_obs.max(axis=0), alpha=0.2, color=COLOR_OBSERVED)

        # Annotate Q1, Q2 (median), Q3 on both observed and true curves
        quartile_points: dict[str, dict[str, tuple[float, float]]] = {}
        for curve, curve_color, y_offset, curve_key in [
            (mean_obs, COLOR_OBSERVED, -16, "obs"),
            (mean_true, COLOR_TRUE, 16, "true"),
        ]:
            if curve is None:
                continue
            lifetime = curve[-1]
            if lifetime <= 0:
                continue
            for frac, label, ms in [
                (0.25, "Q1", 4),
                (0.50, "Q2", 6),
                (0.75, "Q3", 4),
            ]:
                target = lifetime * frac
                idx_q = np.searchsorted(curve, target)
                age_q = ages[min(idx_q, len(ages) - 1)]

                ax.plot(age_q, target, "o", color=curve_color, markersize=ms, zorder=5)
                ax.annotate(
                    f"{label}: {age_q:.0f}",
                    xy=(age_q, target),
                    xytext=(10, y_offset),
                    textcoords="offset points",
                    fontsize=9,
                    fontweight="bold",
                    ha="left",
                    va="center",
                    color=curve_color,
                    bbox=dict(boxstyle="round,pad=0.15", facecolor="white", edgecolor="none", alpha=0.8),
                )
                quartile_points.setdefault(label, {})[curve_key] = (age_q, target)

        # Connect matching quartiles between observed and true curves
        for label in ["Q1", "Q2", "Q3"]:
            pts = quartile_points.get(label, {})
            if "obs" in pts and "true" in pts:
                ax.plot(
                    [pts["obs"][0], pts["true"][0]],
                    [pts["obs"][1], pts["true"][1]],
                    color="0.5",
                    linestyle="--",
                    linewidth=0.8,
                    zorder=4,
                )
        # Annotation box: prevalence and censoring rates
        prev = np.mean([s["prevalence"][key] for s in all_stats])
        true_prev = mean_true[-1] if mean_true is not None else mean_obs[-1]
        censored_pct = (true_prev - prev) * 100
        ax.text(
            0.03,
            0.95,
            f"Affected: {prev * 100:.1f}%\nTrue prev: {true_prev * 100:.1f}%\nCensored: {censored_pct:.1f}%",
            transform=ax.transAxes,
            ha="left",
            va="top",
            fontsize=9,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        )

        ax.set_title(f"Trait {trait_num}")
        ax.set_xlabel("Age")
        ax.legend(loc="lower right", fontsize=9)

    axes[0].set_ylabel("Cumulative Incidence")
    finalize_plot(output_path, scenario=scenario)

plot_cumulative_incidence_by_sex

plot_cumulative_incidence_by_sex(all_stats, output_path, scenario='')

Plot cumulative incidence curves split by sex, from pre-computed stats.

Source code in simace/plotting/plot_distributions.py
def plot_cumulative_incidence_by_sex(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """Plot cumulative incidence curves split by sex, from pre-computed stats."""
    stats_with_data = [s for s in all_stats if s.get("cumulative_incidence_by_sex")]
    if not stats_with_data:
        logger.warning("Skipping cumulative_incidence_by_sex: no data in stats")
        save_placeholder_plot(output_path, "No sex-stratified incidence data")
        return

    _fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

    for trait_num, ax in zip([1, 2], axes, strict=True):
        key = f"trait{trait_num}"

        for sex_label, display, color in [
            ("female", "Female", COLOR_FEMALE),
            ("male", "Male", COLOR_MALE),
        ]:
            rep_data = [
                s["cumulative_incidence_by_sex"][key][sex_label]
                for s in stats_with_data
                if sex_label in s["cumulative_incidence_by_sex"].get(key, {})
            ]
            if not rep_data:
                continue

            ages = np.array(rep_data[0]["ages"])
            all_values = np.array([d["values"] for d in rep_data])
            mean_values = all_values.mean(axis=0)
            mean_n = np.mean([d["n"] for d in rep_data])
            mean_prev = np.mean([d["prevalence"] for d in rep_data])

            ax.plot(
                ages, mean_values, color=color, linewidth=1.2, label=f"{display} ({mean_prev:.1%}, n={int(mean_n)})"
            )

        ax.set_title(f"Trait {trait_num}")
        ax.set_xlabel("Age")
        ax.legend(loc="lower right", fontsize=9)

    axes[0].set_ylabel("Cumulative Incidence")
    finalize_plot(output_path, scenario=scenario)

plot_cumulative_incidence_by_sex_generation

plot_cumulative_incidence_by_sex_generation(all_stats, output_path, scenario='')

Plot cumulative incidence by sex and generation, from pre-computed stats.

Source code in simace/plotting/plot_distributions.py
def plot_cumulative_incidence_by_sex_generation(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """Plot cumulative incidence by sex and generation, from pre-computed stats."""
    stats_with_data = [s for s in all_stats if s.get("cumulative_incidence_by_sex_generation")]
    if not stats_with_data:
        logger.warning("Skipping cumulative_incidence_by_sex_generation: no data in stats")
        save_placeholder_plot(output_path, "No sex/generation incidence data")
        return

    # Discover generation keys from first rep's first trait
    first_trait = stats_with_data[0]["cumulative_incidence_by_sex_generation"].get("trait1", {})
    gen_keys = sorted(first_trait.keys())
    if not gen_keys:
        save_placeholder_plot(output_path, "No generations")
        return

    traits = [1, 2]

    _fig, axes = plt.subplots(
        len(traits),
        len(gen_keys),
        figsize=(5 * len(gen_keys), 4 * len(traits)),
        sharex=True,
        sharey=True,
        squeeze=False,
    )

    for col, gk in enumerate(gen_keys):
        gen_num = gk.replace("gen", "")

        for row, trait_num in enumerate(traits):
            ax = axes[row, col]
            key = f"trait{trait_num}"

            for sex_label, display, color in [
                ("female", "Female", COLOR_FEMALE),
                ("male", "Male", COLOR_MALE),
            ]:
                rep_data = [
                    s["cumulative_incidence_by_sex_generation"][key][gk][sex_label]
                    for s in stats_with_data
                    if sex_label in s["cumulative_incidence_by_sex_generation"].get(key, {}).get(gk, {})
                ]
                if not rep_data:
                    continue

                ages = np.array(rep_data[0]["ages"])
                all_values = np.array([d["values"] for d in rep_data])
                mean_values = all_values.mean(axis=0)
                mean_n = np.mean([d["n"] for d in rep_data])
                mean_prev = np.mean([d["prevalence"] for d in rep_data])

                ax.plot(
                    ages, mean_values, color=color, linewidth=1.2, label=f"{display} ({mean_prev:.1%}, n={int(mean_n)})"
                )

            if row == 0:
                ax.set_title(f"Gen {gen_num}", fontsize=12)
            if col == 0:
                ax.set_ylabel(f"Trait {trait_num}\nCumulative Incidence")
            if row == len(traits) - 1:
                ax.set_xlabel("Age")
            if col == len(gen_keys) - 1:
                ax.legend(loc="lower right", fontsize=8)

    finalize_plot(output_path, scenario=scenario)

plot_censoring_windows

plot_censoring_windows(all_stats, output_path, scenario='', gen_censoring=None)

Plot per-generation censoring windows, mean +/- band across reps.

Source code in simace/plotting/plot_distributions.py
def plot_censoring_windows(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    gen_censoring: dict[int, list[float]] | None = None,
) -> None:
    """Plot per-generation censoring windows, mean +/- band across reps."""
    # Check that all reps have censoring data
    stats_with_censoring = [s for s in all_stats if s.get("censoring") is not None]
    if not stats_with_censoring:
        logger.warning("Skipping censoring_windows plot: no censoring data in stats")
        save_placeholder_plot(output_path, "No censoring data")
        return

    ages = np.array(stats_with_censoring[0]["censoring"]["censoring_ages"])

    # Discover generation keys from the stats YAML (e.g. "gen0", "gen1", ...)
    # Only include generations that have phenotyped individuals in any replicate
    all_gen_keys = sorted(stats_with_censoring[0]["censoring"]["generations"].keys())
    gen_keys = [
        gk
        for gk in all_gen_keys
        if any(s["censoring"]["generations"].get(gk, {}).get("n", 0) > 0 for s in stats_with_censoring)
    ]
    if not gen_keys:
        save_placeholder_plot(output_path, "No phenotyped generations")
        return
    if gen_censoring is None:
        gen_censoring = {}

    gen_labels = []
    for gk in gen_keys:
        gen_num = int(gk.replace("gen", ""))
        win = gen_censoring.get(gen_num)
        if win is not None:
            gen_labels.append(f"Gen {gen_num}\n[{win[0]}, {win[1]}]")
        else:
            gen_labels.append(f"Gen {gen_num}")

    traits = [1, 2]

    _fig, axes = plt.subplots(
        len(traits),
        len(gen_keys),
        figsize=(5 * len(gen_keys), 4 * len(traits)),
        sharex=True,
        sharey=True,
        squeeze=False,
    )

    for col, (gen_key, label) in enumerate(zip(gen_keys, gen_labels, strict=True)):
        # Check if any rep has data for this generation
        gen_data = [
            s["censoring"]["generations"][gen_key]
            for s in stats_with_censoring
            if s["censoring"]["generations"].get(gen_key, {}).get("n", 0) > 0
        ]
        if not gen_data:
            logger.warning("plot_censoring_windows: generation '%s' has 0 individuals", gen_key)
            for row in range(len(traits)):
                axes[row, col].text(
                    0.5,
                    0.5,
                    "No data",
                    ha="center",
                    va="center",
                    transform=axes[row, col].transAxes,
                )
            continue

        for row, trait_num in enumerate(traits):
            ax = axes[row, col]
            key = f"trait{trait_num}"

            all_true = np.array([g[key]["true_incidence"] for g in gen_data])
            all_obs = np.array([g[key]["observed_incidence"] for g in gen_data])

            mean_true = all_true.mean(axis=0)
            mean_obs = all_obs.mean(axis=0)

            ax.plot(ages, mean_true, color=COLOR_TRUE, alpha=0.7, linewidth=1.2, label="True")
            ax.fill_between(ages, mean_true, alpha=0.15, color=COLOR_TRUE)
            ax.plot(ages, mean_obs, color=COLOR_OBSERVED, linewidth=1.2, label="Observed")
            ax.fill_between(ages, mean_obs, alpha=0.2, color=COLOR_OBSERVED)

            if len(stats_with_censoring) > 1:
                ax.fill_between(
                    ages,
                    all_true.min(axis=0),
                    all_true.max(axis=0),
                    alpha=0.08,
                    color=COLOR_TRUE,
                )
                ax.fill_between(
                    ages,
                    all_obs.min(axis=0),
                    all_obs.max(axis=0),
                    alpha=0.08,
                    color=COLOR_OBSERVED,
                )

            # Annotation stats (averaged)
            pct_affected = np.mean([g[key]["pct_affected"] for g in gen_data]) * 100
            left_cens = np.mean([g[key]["left_censored"] for g in gen_data]) * 100
            right_cens = np.mean([g[key]["right_censored"] for g in gen_data]) * 100
            death_cens = np.mean([g[key]["death_censored"] for g in gen_data]) * 100

            ax.text(
                0.03,
                0.95,
                f"Affected: {pct_affected:.1f}%\n"
                f"Left-cens: {left_cens:.1f}%\n"
                f"Right-cens: {right_cens:.1f}%\n"
                f"Death-cens: {death_cens:.1f}%",
                transform=ax.transAxes,
                ha="left",
                va="top",
                fontsize=9,
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            )

            if row == 0:
                ax.set_title(label, fontsize=12)
            if col == 0:
                ax.set_ylabel(f"Trait {trait_num}\nCumulative Incidence")
            if row == len(traits) - 1:
                ax.set_xlabel("Age")

    from matplotlib.lines import Line2D

    legend_elements = [
        Line2D([0], [0], color=COLOR_TRUE, linewidth=1.2, alpha=0.7, label="True"),
        Line2D([0], [0], color=COLOR_OBSERVED, linewidth=1.2, label="Observed"),
    ]
    axes[0, -1].legend(handles=legend_elements, loc="lower right", fontsize=9)
    finalize_plot(output_path, scenario=scenario)

plot_family_structure

plot_family_structure(all_stats, output_path, scenario='')

Plot offspring and mate count distributions, averaged across replicates.

Source code in simace/plotting/plot_distributions.py
def plot_family_structure(all_stats: list[dict], output_path: str | Path, scenario: str = "") -> None:
    """Plot offspring and mate count distributions, averaged across replicates."""
    # Collect family_size dicts from each replicate
    fs_list = [s.get("family_size", {}) for s in all_stats if "family_size" in s]
    if not fs_list:
        save_placeholder_plot(output_path, "Family Structure\nNo family_size data")
        return

    _fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # --- Panel 1: Offspring per mating ---
    ax = axes[0]
    categories = ["1", "2", "3", "4+"]
    vals = np.array([[fs.get("size_dist", {}).get(c, 0) for c in categories] for fs in fs_list])
    mean_vals = vals.mean(axis=0)
    ax.bar(categories, mean_vals, color=COLOR_OBSERVED, edgecolor="white")
    for i, v in enumerate(mean_vals):
        ax.text(i, v + 0.005, f"{v:.1%}", ha="center", va="bottom", fontsize=10)
    mean_size = np.mean([fs.get("mean", 0) for fs in fs_list])
    ax.set_title("Offspring per Couple")
    ax.set_xlabel("Number of offspring")
    ax.set_ylabel("Fraction of couples")
    ax.text(
        0.97,
        0.95,
        f"mean = {mean_size:.2f}",
        transform=ax.transAxes,
        ha="right",
        va="top",
        fontsize=11,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
    )

    # --- Panel 2: Offspring per person (by sex) ---
    ax = axes[1]
    categories2 = ["0", "1", "2", "3", "4+"]
    x2 = np.arange(len(categories2))
    w2 = 0.35
    # Try sex-stratified data first, fall back to pooled
    has_sex = any(fs.get("person_offspring_dist_by_sex") for fs in fs_list)
    if has_sex:
        vals_f = np.array(
            [
                [fs.get("person_offspring_dist_by_sex", {}).get("female", {}).get(c, 0) for c in categories2]
                for fs in fs_list
            ]
        )
        vals_m = np.array(
            [
                [fs.get("person_offspring_dist_by_sex", {}).get("male", {}).get(c, 0) for c in categories2]
                for fs in fs_list
            ]
        )
        mean_f = vals_f.mean(axis=0)
        mean_m = vals_m.mean(axis=0)
        ax.bar(x2 - w2 / 2, mean_f, w2, label="Female", color=COLOR_FEMALE, edgecolor="white")
        ax.bar(x2 + w2 / 2, mean_m, w2, label="Male", color=COLOR_MALE, edgecolor="white")
        # Annotate bars — merge label when F/M values are close
        for i in range(len(categories2)):
            fv, mv = mean_f[i], mean_m[i]
            if fv < 0.005 and mv < 0.005:
                continue
            if abs(fv - mv) < 0.01:
                # Values nearly equal — single centred label
                ax.text(x2[i], max(fv, mv) + 0.008, f"{(fv + mv) / 2:.0%}", ha="center", va="bottom", fontsize=9)
            else:
                if fv > 0.005:
                    ax.text(x2[i] - w2 / 2, fv + 0.005, f"{fv:.0%}", ha="center", va="bottom", fontsize=8)
                if mv > 0.005:
                    ax.text(x2[i] + w2 / 2, mv + 0.005, f"{mv:.0%}", ha="center", va="bottom", fontsize=8)
        ax.legend(fontsize=9)
    else:
        vals2 = np.array([[fs.get("person_offspring_dist", {}).get(c, 0) for c in categories2] for fs in fs_list])
        mean_vals2 = vals2.mean(axis=0)
        ax.bar(x2, mean_vals2, 0.6, color=COLOR_OBSERVED, edgecolor="white")
        for i, v in enumerate(mean_vals2):
            ax.text(i, v + 0.005, f"{v:.1%}", ha="center", va="bottom", fontsize=10)
    ax.set_xticks(x2)
    ax.set_xticklabels(categories2)
    ax.set_title("Offspring per Person")
    ax.set_xlabel("Number of offspring")
    ax.set_ylabel("Fraction of individuals")

    # --- Panel 3: Partners per parent ---
    ax = axes[2]
    mates_list = [fs.get("mates_by_sex", {}) for fs in fs_list]
    f1 = np.mean([m.get("female_1", 0) for m in mates_list])
    f2 = np.mean([m.get("female_2+", 0) for m in mates_list])
    m1 = np.mean([m.get("male_1", 0) for m in mates_list])
    m2 = np.mean([m.get("male_2+", 0) for m in mates_list])
    x = np.arange(2)
    w = 0.35
    bars_f = ax.bar(x - w / 2, [f1, f2], w, label="Female", color=COLOR_FEMALE, edgecolor="white")
    bars_m = ax.bar(x + w / 2, [m1, m2], w, label="Male", color=COLOR_MALE, edgecolor="white")
    for bar in list(bars_f) + list(bars_m):
        h = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            h + 0.005,
            f"{h:.1%}",
            ha="center",
            va="bottom",
            fontsize=10,
        )
    ax.set_xticks(x)
    ax.set_xticklabels(["1 partner", "2+ partners"])
    ax.set_ylabel("Fraction of parents")
    ax.set_title("Partners per Parent")
    ax.legend(fontsize=10)
    f_mean = np.mean([m.get("female_mean", 0) for m in mates_list])
    m_mean = np.mean([m.get("male_mean", 0) for m in mates_list])
    ax.text(
        0.97,
        0.95,
        f"mean F={f_mean:.2f}, M={m_mean:.2f}",
        transform=ax.transAxes,
        ha="right",
        va="top",
        fontsize=11,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
    )

    finalize_plot(output_path, scenario=scenario)

plot_liability

simace.plotting.plot_liability

Liability-related phenotype plots.

Contains: plot_liability_joint, plot_liability_joint_affected, plot_liability_violin, plot_liability_violin_by_generation, plot_joint_affection, plot_censoring_confusion, plot_censoring_cascade, plot_mate_correlation.

plot_liability_joint

plot_liability_joint(df_samples, output_path, scenario='', subsample_note='')

2x2 grid of jointplots: Liability, A, C, E (trait 1 vs trait 2).

Source code in simace/plotting/plot_liability.py
def plot_liability_joint(
    df_samples: pd.DataFrame, output_path: str | Path, scenario: str = "", subsample_note: str = ""
) -> None:
    """2x2 grid of jointplots: Liability, A, C, E (trait 1 vs trait 2)."""
    _plot_joint_grid(df_samples, output_path, scenario, color_by_affected=False, subsample_note=subsample_note)

plot_liability_joint_affected

plot_liability_joint_affected(df_samples, output_path, scenario='', subsample_note='')

2x2 grid of jointplots colored by affected status (trait 1).

Source code in simace/plotting/plot_liability.py
def plot_liability_joint_affected(
    df_samples: pd.DataFrame, output_path: str | Path, scenario: str = "", subsample_note: str = ""
) -> None:
    """2x2 grid of jointplots colored by affected status (trait 1)."""
    _plot_joint_grid(
        df_samples, output_path, scenario, color_by_affected=True, affected_trait=1, subsample_note=subsample_note
    )

plot_liability_joint_affected_t2

plot_liability_joint_affected_t2(df_samples, output_path, scenario='', subsample_note='')

2x2 grid of jointplots colored by affected status (trait 2).

Source code in simace/plotting/plot_liability.py
def plot_liability_joint_affected_t2(
    df_samples: pd.DataFrame, output_path: str | Path, scenario: str = "", subsample_note: str = ""
) -> None:
    """2x2 grid of jointplots colored by affected status (trait 2)."""
    _plot_joint_grid(
        df_samples, output_path, scenario, color_by_affected=True, affected_trait=2, subsample_note=subsample_note
    )

plot_liability_violin

plot_liability_violin(df_samples, all_stats, output_path, scenario='', subsample_note='')

Split violin plot of liability by trait, split on affected status.

Source code in simace/plotting/plot_liability.py
def plot_liability_violin(
    df_samples: pd.DataFrame,
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    subsample_note: str = "",
) -> None:
    """Split violin plot of liability by trait, split on affected status."""
    # Use pre-computed prevalence averaged across reps
    prev1 = np.mean([s["prevalence"]["trait1"] for s in all_stats])
    prev2 = np.mean([s["prevalence"]["trait2"] for s in all_stats])

    _fig, ax = plt.subplots(figsize=(8, 6))
    for i, trait_num in enumerate([1, 2]):
        liab = df_samples[f"liability{trait_num}"].values
        aff = df_samples[f"affected{trait_num}"].values.astype(bool)
        draw_split_violin(ax, liab[~aff], liab[aff], pos=i)
    ax.set_xticks([0, 1])
    ax.set_xticklabels([f"Trait 1\n{prev1:.1%}", f"Trait 2\n{prev2:.1%}"])
    ax.set_ylabel("Liability")
    from matplotlib.patches import Patch

    ax.legend(
        handles=[
            Patch(facecolor=COLOR_UNAFFECTED, edgecolor="black", linewidth=0.8, label="0"),
            Patch(facecolor=COLOR_AFFECTED, edgecolor="black", linewidth=0.8, label="1"),
        ],
        title="Affected",
    )
    ax.set_title("Liability by Affected Status")

    # Annotate mean liability for each trait x affected/unaffected group
    for i, trait_num in enumerate([1, 2]):
        liab = df_samples[f"liability{trait_num}"].values
        aff = df_samples[f"affected{trait_num}"].values.astype(bool)
        if aff.any():
            mean_aff = liab[aff].mean()
            ax.plot(i + 0.05, mean_aff, "D", color="black", markersize=5, zorder=5)
            ax.text(
                i + 0.12,
                mean_aff,
                f"\u03bc={mean_aff:.2f}",
                ha="left",
                va="center",
                fontsize=9,
                fontweight="bold",
            )
        if (~aff).any():
            mean_unaff = liab[~aff].mean()
            ax.plot(i - 0.05, mean_unaff, "D", color="black", markersize=5, zorder=5)
            ax.text(
                i - 0.12,
                mean_unaff,
                f"\u03bc={mean_unaff:.2f}",
                ha="right",
                va="center",
                fontsize=9,
                fontweight="bold",
            )
    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_liability_violin_by_generation

plot_liability_violin_by_generation(df_samples, all_stats, output_path, scenario='', subsample_note='')

Split violin of liability by affected status, one column per generation.

Source code in simace/plotting/plot_liability.py
def plot_liability_violin_by_generation(
    df_samples: pd.DataFrame,
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    subsample_note: str = "",
) -> None:
    """Split violin of liability by affected status, one column per generation."""
    if "generation" not in df_samples.columns:
        save_placeholder_plot(output_path, "No generation data")
        return

    gens = sorted(df_samples["generation"].unique())
    n_gens = len(gens)

    _fig, axes = plt.subplots(2, n_gens, figsize=(4 * n_gens, 8), squeeze=False)

    for row, trait_num in enumerate([1, 2]):
        liab_col = f"liability{trait_num}"
        aff_col = f"affected{trait_num}"

        for col, gen in enumerate(gens):
            ax = axes[row, col]
            gen_mask = df_samples["generation"] == gen
            df_gen = df_samples.loc[gen_mask]

            liab = df_gen[liab_col].values
            aff = df_gen[aff_col].values.astype(bool)

            if len(liab) > 1:
                draw_split_violin(ax, liab[~aff], liab[aff], pos=0)
                obs_prev = aff.mean()
                ax.set_xticks([0])
                ax.set_xticklabels([f"{obs_prev:.1%}"])
                if row == 0 and col == n_gens - 1:
                    from matplotlib.patches import Patch

                    ax.legend(
                        handles=[
                            Patch(facecolor=COLOR_UNAFFECTED, edgecolor="black", linewidth=0.8, label="0"),
                            Patch(facecolor=COLOR_AFFECTED, edgecolor="black", linewidth=0.8, label="1"),
                        ],
                        title="Affected",
                        fontsize=8,
                    )

                # Annotate means
                if aff.any():
                    mu = liab[aff].mean()
                    ax.plot(0.05, mu, "D", color="black", markersize=5, zorder=5)
                    ax.text(0.12, mu, f"\u03bc={mu:.2f}", ha="left", va="center", fontsize=8, fontweight="bold")
                if (~aff).any():
                    mu = liab[~aff].mean()
                    ax.plot(-0.05, mu, "D", color="black", markersize=5, zorder=5)
                    ax.text(-0.12, mu, f"\u03bc={mu:.2f}", ha="right", va="center", fontsize=8, fontweight="bold")

            if row == 0:
                label = f"Gen {gen}"
                if col == 0:
                    label += " (oldest)"
                elif col == n_gens - 1:
                    label += " (youngest)"
                ax.set_title(label, fontsize=11)
            if col == 0:
                ax.set_ylabel(f"Trait {trait_num}\nLiability", fontsize=10)
            else:
                ax.set_ylabel("")

    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_liability_violin_by_sex_generation

plot_liability_violin_by_sex_generation(df_samples, all_stats, output_path, scenario='', subsample_note='')

Split violin by affected status with side-by-side F|M panels per generation.

Layout: 2 rows (traits) x N cols (generations). Each cell has two side-by-side sub-violins at x=-0.3 (female) and x=+0.3 (male).

Source code in simace/plotting/plot_liability.py
def plot_liability_violin_by_sex_generation(
    df_samples: pd.DataFrame,
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    subsample_note: str = "",
) -> None:
    """Split violin by affected status with side-by-side F|M panels per generation.

    Layout: 2 rows (traits) x N cols (generations). Each cell has two
    side-by-side sub-violins at x=-0.3 (female) and x=+0.3 (male).
    """
    if "generation" not in df_samples.columns:
        save_placeholder_plot(output_path, "No generation data")
        return

    gens = sorted(df_samples["generation"].unique())
    n_gens = len(gens)
    sex_arr = df_samples["sex"].values

    _fig, axes = plt.subplots(2, n_gens, figsize=(5 * n_gens, 8), squeeze=False)

    for row, trait_num in enumerate([1, 2]):
        liab_col = f"liability{trait_num}"
        aff_col = f"affected{trait_num}"

        for col, gen in enumerate(gens):
            ax = axes[row, col]
            gen_mask = df_samples["generation"] == gen
            sex_prev: dict[str, str] = {}

            for sex_val, sex_label, pos in [
                (0, "F", -0.3),
                (1, "M", 0.3),
            ]:
                mask = gen_mask & (sex_arr == sex_val)
                df_sub = df_samples.loc[mask]
                liab = df_sub[liab_col].values
                aff = df_sub[aff_col].values.astype(bool)

                if len(liab) > 1:
                    draw_split_violin(ax, liab[~aff], liab[aff], pos=pos, width=0.5)

                    # Annotate means
                    if aff.any():
                        mu = liab[aff].mean()
                        ax.plot(pos + 0.03, mu, "D", color="black", markersize=4, zorder=5)
                    if (~aff).any():
                        mu = liab[~aff].mean()
                        ax.plot(pos - 0.03, mu, "D", color="black", markersize=4, zorder=5)

                    sex_prev[sex_label] = f"{aff.mean():.0%}"

            ax.set_xticks([-0.3, 0.3])
            ax.set_xticklabels(
                [f"F\n{sex_prev.get('F', '')}", f"M\n{sex_prev.get('M', '')}"],
                fontsize=8,
            )

            if row == 0:
                label = f"Gen {gen}"
                if col == 0:
                    label += " (oldest)"
                elif col == n_gens - 1:
                    label += " (youngest)"
                ax.set_title(label, fontsize=11)
            if col == 0:
                ax.set_ylabel(f"Trait {trait_num}\nLiability", fontsize=10)
            else:
                ax.set_ylabel("")

            # Legend only once
            if row == 0 and col == n_gens - 1:
                from matplotlib.patches import Patch

                ax.legend(
                    handles=[
                        Patch(facecolor=COLOR_UNAFFECTED, edgecolor="black", linewidth=0.8, label="Unaffected"),
                        Patch(facecolor=COLOR_AFFECTED, edgecolor="black", linewidth=0.8, label="Affected"),
                    ],
                    fontsize=8,
                )

    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_liability_components_by_generation

plot_liability_components_by_generation(df_samples, output_path, scenario='', subsample_note='')

Mean variance component by affected status across generations.

2x3 grid: rows = traits, columns = A, C, E. Each panel shows mean component value for affected (red), unaffected (grey), and overall (black) individuals per generation. Prevalence annotated on x-tick labels.

Source code in simace/plotting/plot_liability.py
def plot_liability_components_by_generation(
    df_samples: pd.DataFrame,
    output_path: str | Path,
    scenario: str = "",
    subsample_note: str = "",
) -> None:
    """Mean variance component by affected status across generations.

    2x3 grid: rows = traits, columns = A, C, E.  Each panel shows mean
    component value for affected (red), unaffected (grey), and overall (black)
    individuals per generation.  Prevalence annotated on x-tick labels.
    """
    gens = sorted(df_samples["generation"].unique())
    n_gen = len(gens)
    if n_gen == 0:
        save_placeholder_plot(output_path, "No generation data")
        return

    components = ["A", "C", "E"]
    _fig, axes = plt.subplots(2, 3, figsize=(max(14, n_gen * 3), 8))

    for row, trait_num in enumerate([1, 2]):
        aff_col = f"affected{trait_num}"
        if aff_col not in df_samples.columns:
            continue

        # Compute prevalence per generation (shared across columns)
        prev = []
        for gen in gens:
            g = df_samples[df_samples["generation"] == gen]
            prev.append(g[aff_col].mean() if len(g) > 0 else 0)

        for col_idx, comp in enumerate(components):
            comp_col = f"{comp}{trait_num}"
            if comp_col not in df_samples.columns:
                continue

            ax = axes[row, col_idx]
            mean_aff, mean_unaff, mean_all = [], [], []
            for gen in gens:
                g = df_samples[df_samples["generation"] == gen]
                aff = g[g[aff_col]]
                unaff = g[~g[aff_col]]
                mean_aff.append(aff[comp_col].mean() if len(aff) > 0 else float("nan"))
                mean_unaff.append(unaff[comp_col].mean() if len(unaff) > 0 else float("nan"))
                mean_all.append(g[comp_col].mean() if len(g) > 0 else float("nan"))

            ax.plot(gens, mean_aff, "o-", color=COLOR_AFFECTED, label="Affected", markersize=5)
            ax.plot(gens, mean_unaff, "s-", color=COLOR_UNAFFECTED, label="Unaffected", markersize=5)
            ax.plot(gens, mean_all, "D-", color="black", label="Overall", markersize=4, linewidth=1.0)
            ax.axhline(0, color="gray", linewidth=0.5, linestyle="--")

            ax.set_xticks(gens)
            if col_idx == 0:
                # Show generation + prevalence in the A column for both traits
                ax.set_xticklabels([f"{g}\n{prev[i]:.1%}" for i, g in enumerate(gens)])
            else:
                ax.set_xticklabels([str(g) for g in gens])
            if row == 1:
                ax.set_xlabel("Generation")

            if col_idx == 0:
                ax.set_ylabel(f"Trait {trait_num}\nMean value")
            if row == 0:
                ax.set_title(comp, fontweight="bold")
                if col_idx == 2:
                    ax.legend(fontsize=8)

    finalize_plot(output_path, subsample_note=subsample_note, scenario=scenario)

plot_censoring_confusion

plot_censoring_confusion(all_stats, output_path, scenario='')

Per-trait 2x2 confusion matrix: true affected vs. observed affected.

Uses pre-computed censoring_confusion stats from full (non-subsampled) data.

Source code in simace/plotting/plot_liability.py
def plot_censoring_confusion(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """Per-trait 2x2 confusion matrix: true affected vs. observed affected.

    Uses pre-computed censoring_confusion stats from full (non-subsampled) data.
    """
    stats_with_data = [s for s in all_stats if s.get("censoring_confusion")]
    if not stats_with_data:
        save_placeholder_plot(output_path, "No censoring confusion data")
        return

    _fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    for col, trait in enumerate([1, 2]):
        ax = axes[col]
        key = f"trait{trait}"

        # Average counts across reps
        rep_data = [s["censoring_confusion"][key] for s in stats_with_data if key in s["censoring_confusion"]]
        if not rep_data:
            ax.text(0.5, 0.5, f"No data for trait {trait}", ha="center", va="center", transform=ax.transAxes)
            ax.set_title(f"Trait {trait}")
            continue

        tp = np.mean([d["tp"] for d in rep_data])
        fn = np.mean([d["fn"] for d in rep_data])
        fp = np.mean([d["fp"] for d in rep_data])
        tn = np.mean([d["tn"] for d in rep_data])
        n = np.mean([d["n"] for d in rep_data])

        props = np.array([[tp / n, fn / n], [fp / n, tn / n]])
        counts = np.array([[tp, fn], [fp, tn]])
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
        specificity = tn / (tn + fp) if (tn + fp) > 0 else float("nan")
        ppv = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
        npv = tn / (tn + fn) if (tn + fn) > 0 else float("nan")

        is_last = col == len(axes) - 1
        sns.heatmap(
            props,
            annot=False,
            cmap=HEATMAP_CMAP,
            ax=ax,
            xticklabels=["Observed Yes", "Observed No"],
            yticklabels=["True Yes", "True No"],
            vmin=0,
            vmax=1,
            cbar=is_last,
            cbar_kws={"label": "Proportion"} if is_last else {},
        )
        annotate_heatmap(ax, props, counts)

        metrics = (
            f"Sens: {sensitivity:.3f}   Spec: {specificity:.3f}   PPV: {ppv:.3f}   NPV: {npv:.3f}   n = {int(n):,}"
        )
        ax.set_title(f"Trait {trait}\n{metrics}", fontsize=11)

    finalize_plot(output_path, scenario=scenario)

plot_censoring_cascade

plot_censoring_cascade(all_stats, output_path, scenario='')

Per-trait stacked bar chart decomposing true cases by censoring fate per generation.

Uses pre-computed censoring_cascade stats from full (non-subsampled) data.

Source code in simace/plotting/plot_liability.py
def plot_censoring_cascade(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """Per-trait stacked bar chart decomposing true cases by censoring fate per generation.

    Uses pre-computed censoring_cascade stats from full (non-subsampled) data.
    """
    stats_with_data = [s for s in all_stats if s.get("censoring_cascade")]
    if not stats_with_data:
        save_placeholder_plot(output_path, "No censoring cascade data")
        return

    from simace.plotting.plot_style import CENSORING_COLORS

    color_observed = CENSORING_COLORS["observed"]
    color_death = CENSORING_COLORS["death"]
    color_right = CENSORING_COLORS["right"]
    color_left = CENSORING_COLORS["left"]

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    for col, trait in enumerate([1, 2]):
        ax = axes[col]
        key = f"trait{trait}"

        rep_data = [s["censoring_cascade"][key] for s in stats_with_data if key in s["censoring_cascade"]]
        if not rep_data:
            ax.text(0.5, 0.5, f"No data for trait {trait}", ha="center", va="center", transform=ax.transAxes)
            ax.set_title(f"Trait {trait}")
            continue

        # Discover generation keys from first rep
        gen_keys = sorted(rep_data[0].keys())
        if not gen_keys:
            ax.text(0.5, 0.5, "No generations", ha="center", va="center", transform=ax.transAxes)
            ax.set_title(f"Trait {trait}")
            continue

        counts_observed = []
        counts_death = []
        counts_right = []
        counts_left = []
        sensitivities = []
        x_labels = []

        for gk in gen_keys:
            gen_data = [r[gk] for r in rep_data if gk in r]
            if not gen_data:
                continue

            n_obs = np.mean([d["observed"] for d in gen_data])
            n_death = np.mean([d["death_censored"] for d in gen_data])
            n_right = np.mean([d["right_censored"] for d in gen_data])
            n_left = np.mean([d["left_truncated"] for d in gen_data])
            n_true = np.mean([d["true_affected"] for d in gen_data])
            window = gen_data[0]["window"]

            counts_observed.append(n_obs)
            counts_death.append(n_death)
            counts_right.append(n_right)
            counts_left.append(n_left)
            sensitivities.append(n_obs / n_true if n_true > 0 else float("nan"))
            gen_num = gk.replace("gen", "")
            x_labels.append(f"Gen {gen_num}\n[{window[0]:.0f}, {window[1]:.0f}]")

        n_bars = len(x_labels)
        x = np.arange(n_bars)
        bar_width = 0.6

        bottom = np.zeros(n_bars)
        bars_obs = np.array(counts_observed, dtype=float)
        bars_death = np.array(counts_death, dtype=float)
        bars_right = np.array(counts_right, dtype=float)
        bars_left = np.array(counts_left, dtype=float)

        ax.bar(x, bars_obs, bar_width, bottom=bottom, color=color_observed, label="Observed (TP)")
        bottom += bars_obs
        ax.bar(x, bars_death, bar_width, bottom=bottom, color=color_death, label="Death-censored")
        bottom += bars_death
        ax.bar(x, bars_right, bar_width, bottom=bottom, color=color_right, label="Right-censored")
        bottom += bars_right
        ax.bar(x, bars_left, bar_width, bottom=bottom, color=color_left, label="Left-truncated")
        bottom += bars_left

        # Annotate segments (skip if < 3% of bar height)
        for i in range(n_bars):
            total = bottom[i]
            if total == 0:
                continue
            cum = 0.0
            for count in [bars_obs[i], bars_death[i], bars_right[i], bars_left[i]]:
                if count > 0 and count / total >= 0.03:
                    mid = cum + count / 2
                    ax.text(x[i], mid, f"{int(count)}", ha="center", va="center", fontsize=8, fontweight="bold")
                cum += count

        # Fold sensitivity into x-axis tick labels (below bars, no overlap)
        for i, sens in enumerate(sensitivities):
            if not np.isnan(sens):
                x_labels[i] += f"\nsens={sens:.2f}"

        # Overall sensitivity
        total_obs = sum(counts_observed)
        total_true = total_obs + sum(counts_death) + sum(counts_right) + sum(counts_left)
        overall_sens = total_obs / total_true if total_true > 0 else float("nan")
        ax.set_title(f"Trait {trait}  (overall sensitivity: {overall_sens:.3f})", fontsize=11)

        ax.set_xticks(x)
        ax.set_xticklabels(x_labels, fontsize=9)
        ax.set_ylabel("True affected count")

    # Shared legend above the subplots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=4, fontsize=9, bbox_to_anchor=(0.5, 0.98))

    finalize_plot(output_path, scenario=scenario)

plot_joint_affection

plot_joint_affection(all_stats, output_path, scenario='')

2x2 heatmap of joint affection status (trait1 x trait2).

Uses pre-computed joint_affection and cross_trait_tetrachoric stats.

Source code in simace/plotting/plot_liability.py
def plot_joint_affection(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """2x2 heatmap of joint affection status (trait1 x trait2).

    Uses pre-computed joint_affection and cross_trait_tetrachoric stats.
    """
    # Average proportions/counts across reps
    keys = ["both", "trait1_only", "trait2_only", "neither"]
    avg_props = {}
    for k in keys:
        avg_props[k] = np.mean([s["joint_affection"]["proportions"][k] for s in all_stats])

    matrix = np.array(
        [
            [avg_props["both"], avg_props["trait1_only"]],
            [avg_props["trait2_only"], avg_props["neither"]],
        ]
    )

    avg_counts = {}
    for k in keys:
        avg_counts[k] = np.mean([s["joint_affection"]["counts"][k] for s in all_stats])

    count_matrix = np.array(
        [
            [avg_counts["both"], avg_counts["trait1_only"]],
            [avg_counts["trait2_only"], avg_counts["neither"]],
        ]
    )

    _fig, ax = plt.subplots(figsize=(7, 6))
    sns.heatmap(
        matrix,
        annot=False,
        cmap=HEATMAP_CMAP,
        ax=ax,
        xticklabels=["Affected", "Unaffected"],
        yticklabels=["Affected", "Unaffected"],
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Proportion"},
    )
    annotate_heatmap(ax, matrix, count_matrix)

    # Build subtitle from whichever correlation stats are present
    label_parts = []

    # Cross-trait tetrachoric correlation from pre-computed stats
    r_tet_vals = [s.get("cross_trait_tetrachoric", {}).get("same_person", {}).get("r") for s in all_stats]
    r_tet_vals = [v for v in r_tet_vals if v is not None]
    if r_tet_vals:
        label_parts.append(f"r_tet = {np.mean(r_tet_vals):.3f}")

    # Cross-trait frailty correlations (averaged across reps)
    uncens_vals = [
        s.get("frailty_cross_trait_uncensored", {}).get("r")
        for s in all_stats
        if s.get("frailty_cross_trait_uncensored", {}).get("r") is not None
    ]
    strat_vals = [
        s.get("frailty_cross_trait_stratified", {}).get("r")
        for s in all_stats
        if s.get("frailty_cross_trait_stratified", {}).get("r") is not None
    ]
    naive_vals = [
        s.get("frailty_cross_trait", {}).get("r")
        for s in all_stats
        if s.get("frailty_cross_trait", {}).get("r") is not None
    ]

    if uncens_vals:
        label_parts.append(f"r_frailty = {np.mean(uncens_vals):.3f}")
    if strat_vals:
        label_parts.append(f"stratified = {np.mean(strat_vals):.3f}")
    if naive_vals:
        label_parts.append(f"naive = {np.mean(naive_vals):.3f}")

    if not uncens_vals and not strat_vals and not naive_vals:
        label_parts.append("r_frailty: not computed")

    r_label = "  |  ".join(label_parts) if label_parts else ""

    ax.set_xlabel("Trait 1")
    ax.set_ylabel("Trait 2")
    title = "Joint Affected Status"
    if r_label:
        title += f"\n{r_label}"
    ax.set_title(title, fontsize=14)
    finalize_plot(output_path, scenario=scenario)

plot_mate_correlation

plot_mate_correlation(all_stats, output_path, scenario='', params=None)

Plot 2x2 heatmap of empirical mate liability correlations with expected values.

Source code in simace/plotting/plot_liability.py
def plot_mate_correlation(
    all_stats: list[dict],
    output_path: str | Path,
    scenario: str = "",
    params: dict | None = None,
) -> None:
    """Plot 2x2 heatmap of empirical mate liability correlations with expected values."""
    from simace.simulation.mate_correlation import expected_mate_corr_matrix

    # Average observed matrices across replicates
    matrices = []
    for s in all_stats:
        mc = s.get("mate_correlation")
        if mc is not None:
            matrices.append(np.array(mc["matrix"]))
    if not matrices:
        save_placeholder_plot(output_path, "No mate correlation data")
        return

    obs = np.nanmean(np.stack(matrices), axis=0)

    # Compute expected matrix from params
    exp = np.zeros((2, 2))
    if params is not None:
        am = params.get("assort_matrix", None)
        exp = expected_mate_corr_matrix(
            assort1=float(params.get("assort1", 0)),
            assort2=float(params.get("assort2", 0)),
            rA=float(params.get("rA", 0)),
            rC=float(params.get("rC", 0)),
            A1=float(params.get("A1", 0)),
            C1=param_as_float(params.get("C1", 0)),
            A2=float(params.get("A2", 0)),
            C2=param_as_float(params.get("C2", 0)),
            assort_matrix=am,
            rE=float(params.get("rE", 0)),
            E1=param_as_float(params.get("E1", 0)),
            E2=param_as_float(params.get("E2", 0)),
        )

    xlabels = ["Male trait 1", "Male trait 2"]
    ylabels = ["Female trait 1", "Female trait 2"]

    _fig, (ax_exp, ax_obs) = plt.subplots(1, 2, figsize=(12, 5))

    # Left panel: expected (parametric)
    sns.heatmap(
        exp,
        ax=ax_exp,
        cmap="RdBu_r",
        vmin=-1,
        vmax=1,
        center=0,
        annot=True,
        fmt=".2f",
        annot_kws={"fontsize": 16, "fontweight": "bold"},
        square=True,
        cbar=False,
        xticklabels=xlabels,
        yticklabels=ylabels,
    )
    a1 = float(params.get("assort1", 0)) if params else 0
    a2 = float(params.get("assort2", 0)) if params else 0
    exp_title = "Expected"
    if a1 != 0 or a2 != 0:
        exp_title += f"\nassort1={a1}, assort2={a2}"
    ax_exp.set_title(exp_title, fontsize=13)

    # Right panel: observed (realized)
    sns.heatmap(
        obs,
        ax=ax_obs,
        cmap="RdBu_r",
        vmin=-1,
        vmax=1,
        center=0,
        annot=True,
        fmt=".2f",
        annot_kws={"fontsize": 16, "fontweight": "bold"},
        square=True,
        cbar_kws={"label": "Pearson r"},
        xticklabels=xlabels,
        yticklabels=[],
    )
    ax_obs.set_title("Observed", fontsize=13)

    finalize_plot(output_path, scenario=scenario)

plot_pedigree_counts

simace.plotting.plot_pedigree_counts

Pedigree relationship pair counts diagram.

Draws a schematic multi-generational pedigree centred on a highlighted "Proband" individual. Each of the 10 relationship types is represented by colouring the border of the related individual's node and placing a labelled annotation box nearby. Mean pair counts (averaged across replicates) are shown inside each annotation box.

Family structure (4 generations):

Gen 0 Great-grandparents (GGF + GGM) Gen 1 Grandfather + Grandmother | Great-uncle (sib of Grandfather) Gen 2 Father + Mother | Uncle (sib of Father) | GU-child Gen 3 Proband MZ-twin Full-sib Pat-HS Mat-HS Cousin 2nd-Cousin

plot_pedigree_relationship_counts

plot_pedigree_relationship_counts(all_stats, output_path, scenario='', stats_key='pair_counts', generations_label='', max_degree=2)

Draw a proband-centric pedigree diagram with relationship pair counts.

PARAMETER DESCRIPTION
all_stats

Per-replicate stats dicts.

TYPE: list[dict[str, Any]]

output_path

Where to save the figure.

TYPE: str | Path

scenario

Scenario name for the title.

TYPE: str DEFAULT: ''

stats_key

Key in stats dict to read pair counts from.

TYPE: str DEFAULT: 'pair_counts'

generations_label

Label appended to title (e.g. "G_ped = 6").

TYPE: str DEFAULT: ''

max_degree

Maximum kinship degree shown in the diagram.

TYPE: int DEFAULT: 2

Source code in simace/plotting/plot_pedigree_counts.py
def plot_pedigree_relationship_counts(
    all_stats: list[dict[str, Any]],
    output_path: str | Path,
    scenario: str = "",
    stats_key: str = "pair_counts",
    generations_label: str = "",
    max_degree: int = 2,
) -> None:
    """Draw a proband-centric pedigree diagram with relationship pair counts.

    Args:
        all_stats: Per-replicate stats dicts.
        output_path: Where to save the figure.
        scenario: Scenario name for the title.
        stats_key: Key in stats dict to read pair counts from.
        generations_label: Label appended to title (e.g. "G_ped = 6").
        max_degree: Maximum kinship degree shown in the diagram.
    """
    output_path = Path(output_path)

    # Check for pair_counts data
    has_data = any(s.get(stats_key) for s in all_stats)
    if not has_data:
        save_placeholder_plot(
            output_path,
            "No pair count data available\n(re-run stats to generate)",
        )
        return

    # Average pair counts across replicates
    counts: dict[str, float] = {}
    n_reps = 0
    for s in all_stats:
        pc = s.get(stats_key)
        if not pc:
            continue
        n_reps += 1
        for name, cnt in pc.items():
            counts[name] = counts.get(name, 0) + cnt
    if n_reps > 0:
        counts = {k: v / n_reps for k, v in counts.items()}

    # Colour palette (Nature Genetics muted style)
    rel_colors = {name: PEDIGREE_COLORS[name] for name in RELATIONSHIP_ORDER}

    # Build map: node → relationship colour (for node border colouring)
    node_rel_color: dict[str, tuple[str, ...]] = {}
    for rel_name in RELATIONSHIP_ORDER:
        node = RELATIONSHIP_NODES[rel_name]
        node_rel_color[node] = rel_colors[rel_name]

    # Create figure
    _fig, ax = plt.subplots(figsize=(14, 8))
    ax.set_xlim(-3.0, 14.5)
    ax.set_ylim(-0.5, 11.5)
    ax.set_aspect("equal")
    ax.set_axis_off()

    title = "Pedigree Relationship Pair Counts"
    if generations_label:
        title += f"  ({generations_label})"
    ax.set_title(title, fontsize=14, fontweight="bold", pad=16)

    # Generation labels
    for g, y in {0: 10.0, 1: 7.5, 2: 5.0, 3: 2.0}.items():
        ax.text(
            -2.7,
            y,
            f"Gen {g}",
            fontsize=10,
            ha="center",
            va="center",
            fontstyle="italic",
            color="grey",
        )

    # --- Draw structural pedigree elements ---
    for a, b in MARRIAGES:
        _draw_marriage(ax, a, b)
    for parent_spec, children, ls in DESCENTS:
        _draw_descent(ax, parent_spec, children, linestyle=ls)
    _draw_mz_bracket(ax, *MZ_TWIN_NODES)

    # --- Draw nodes ---
    # Related nodes filled with relationship colour; proband highlighted blue.
    for name, (x, y, sex) in NODES.items():
        if name == PROBAND_NODE:
            _draw_node(ax, x, y, sex, fill="black", linewidth=1.4)
        elif name in node_rel_color:
            color = node_rel_color[name]
            # Lighten the colour for the fill (blend with white)
            r, g, b = mcolors.to_rgb(color)
            light = (0.6 + 0.4 * r, 0.6 + 0.4 * g, 0.6 + 0.4 * b)
            _draw_node(
                ax,
                x,
                y,
                sex,
                fill=light,
                edgecolor=color,
                linewidth=1.4,
            )
        else:
            _draw_node(ax, x, y, sex)

    # Proband label
    px, py = _nc(PROBAND_NODE)
    ax.text(
        px,
        py - NODE_RADIUS - 0.25,
        "Proband",
        fontsize=12,
        ha="center",
        va="top",
        fontweight="bold",
    )

    # --- Relationship labels placed directly next to each node ---
    for rel_name in RELATIONSHIP_ORDER:
        node = RELATIONSHIP_NODES[rel_name]
        nx, ny = _nc(node)
        dx, dy, ha, va = LABEL_OFFSETS[rel_name]
        color = rel_colors[rel_name]

        display = _SHORT_LABELS.get(rel_name, rel_name)
        if max_degree < 5 and rel_name == "2C":
            label = f"{display}\nnot computed"
        else:
            mean_count = counts.get(rel_name, 0)
            label = f"{display}\n({mean_count:,.0f})"

        ax.text(
            nx + dx,
            ny + dy,
            label,
            fontsize=10,
            ha=ha,
            va=va,
            color=color,
            fontweight="bold",
            zorder=5,
        )

    # Legend
    handles = []
    for n in RELATIONSHIP_ORDER:
        if max_degree < 5 and n == "2C":
            handles.append(mpatches.Patch(color=rel_colors[n], label=f"{n} (not computed)"))
        else:
            handles.append(mpatches.Patch(color=rel_colors[n], label=f"{n} ({counts.get(n, 0):,.0f})"))
    ax.legend(
        handles=handles,
        loc="upper right",
        fontsize=10,
        title="Relationship (mean pairs)",
        title_fontsize=11,
    )

    # Population metadata annotation
    if stats_key == "pair_counts_ped":
        n_ind_key, n_gen_key = "n_individuals_ped", "n_generations_ped"
    else:
        n_ind_key, n_gen_key = "n_individuals", "n_generations"
    mean_n_ind = _mean_stat(all_stats, n_ind_key)
    mean_n_gen = _mean_stat(all_stats, n_gen_key)

    footer_parts = [f"Mean across {n_reps} replicate{'s' if n_reps != 1 else ''}"]
    if mean_n_gen is not None:
        footer_parts.append(f"{int(mean_n_gen)} generations")
    if mean_n_ind is not None:
        footer_parts.append(f"{int(mean_n_ind):,} individuals")
    ax.text(
        0.99,
        0.01,
        "  |  ".join(footer_parts),
        transform=ax.transAxes,
        fontsize=9,
        ha="right",
        va="bottom",
        color="grey",
    )

    finalize_plot(output_path, scenario=scenario)
    logger.info("Pedigree counts plot saved to %s", output_path)

cli

cli()

Command-line interface for pedigree relationship counts plot.

Source code in simace/plotting/plot_pedigree_counts.py
def cli() -> None:
    """Command-line interface for pedigree relationship counts plot."""
    from simace.core.cli_base import add_logging_args, init_logging
    from simace.core.yaml_io import load_yaml

    parser = argparse.ArgumentParser(description="Plot pedigree relationship pair counts diagram")
    add_logging_args(parser)
    parser.add_argument("--stats", nargs="+", required=True, help="Stats YAML paths")
    parser.add_argument("--output", required=True, help="Output image path")
    parser.add_argument("--scenario", default="", help="Scenario name for title")
    args = parser.parse_args()
    init_logging(args)

    all_stats = [load_yaml(p) for p in args.stats]

    plot_pedigree_relationship_counts(all_stats, args.output, args.scenario)

plot_phenotype

simace.plotting.plot_phenotype

Plot phenotype distributions from pre-computed per-rep statistics.

Reads phenotype_stats.yaml and phenotype_samples.parquet files (one per rep) produced by compute_phenotype_stats.py. No full phenotype parquet loading needed.

main

main(stats_paths, sample_paths, output_dir, censor_age, gen_censoring=None, plot_ext='png', validation_paths=None, max_degree=2)

Generate all phenotype plots from pre-computed stats.

Source code in simace/plotting/plot_phenotype.py
def main(
    stats_paths: list[str],
    sample_paths: list[str],
    output_dir: str,
    censor_age: float,
    gen_censoring: dict[int, list[float]] | None = None,
    plot_ext: str = "png",
    validation_paths: list[str] | None = None,
    max_degree: int = 2,
) -> None:
    """Generate all phenotype plots from pre-computed stats."""
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    scenario = out_dir.parent.name
    from simace.plotting.plot_style import apply_nature_style

    apply_nature_style()

    all_stats = [load_yaml(p) for p in stats_paths]

    df_samples = pd.concat([pd.read_parquet(p) for p in sample_paths], ignore_index=True)
    subsample_note = ""
    if len(df_samples) > MAX_PLOT_POINTS:
        original_n = len(df_samples)
        df_samples = df_samples.sample(n=MAX_PLOT_POINTS, random_state=42).reset_index(drop=True)
        subsample_note = f"Subsampled: {MAX_PLOT_POINTS:,} of {original_n:,} individuals shown"

    ext = plot_ext

    # Load validation data early so params are available for correlation plots
    all_validations = None
    validation_params = None
    if validation_paths:
        all_validations = [load_yaml(p) for p in validation_paths]
        validation_params = all_validations[0].get("parameters", {})

    # Pedigree relationship pair counts
    plot_pedigree_relationship_counts(
        all_stats,
        out_dir / f"pedigree_counts.ped.{ext}",
        scenario,
        stats_key="pair_counts_ped",
        generations_label="G_ped",
        max_degree=max_degree,
    )
    plot_pedigree_relationship_counts(
        all_stats,
        out_dir / f"pedigree_counts.{ext}",
        scenario,
        generations_label="G_pheno",
        max_degree=max_degree,
    )

    # Family structure (offspring and mate distributions)
    plot_family_structure(
        all_stats,
        out_dir / f"family_structure.{ext}",
        scenario,
    )

    # Mate correlation heatmap
    plot_mate_correlation(
        all_stats,
        out_dir / f"mate_correlation.{ext}",
        scenario,
        params=validation_params,
    )

    # Distribution plots
    plot_death_age_distribution(
        all_stats,
        censor_age,
        out_dir / f"mortality.{ext}",
        scenario,
    )
    plot_trait_phenotype(
        df_samples,
        out_dir / f"age_at_onset_death.{ext}",
        scenario,
        subsample_note=subsample_note,
    )
    plot_trait_regression(
        df_samples,
        all_stats,
        out_dir / f"liability_vs_aoo.{ext}",
        scenario,
        subsample_note=subsample_note,
    )

    # Liability plots
    plot_liability_joint(
        df_samples,
        out_dir / f"cross_trait.{ext}",
        scenario,
        subsample_note=subsample_note,
    )
    plot_liability_joint_affected(
        df_samples,
        out_dir / f"cross_trait.phenotype.{ext}",
        scenario,
        subsample_note=subsample_note,
    )
    plot_liability_joint_affected_t2(
        df_samples,
        out_dir / f"cross_trait.phenotype.t2.{ext}",
        scenario,
        subsample_note=subsample_note,
    )
    plot_liability_violin(
        df_samples,
        all_stats,
        out_dir / f"liability_violin.phenotype.{ext}",
        scenario,
        subsample_note=subsample_note,
    )
    plot_liability_violin_by_generation(
        df_samples,
        all_stats,
        out_dir / f"liability_violin.phenotype.by_generation.{ext}",
        scenario,
        subsample_note=subsample_note,
    )
    plot_liability_violin_by_sex_generation(
        df_samples,
        all_stats,
        out_dir / f"liability_violin.phenotype.by_sex.by_generation.{ext}",
        scenario,
        subsample_note=subsample_note,
    )

    # Genetic selection by generation
    plot_liability_components_by_generation(
        df_samples,
        out_dir / f"liability_components.by_generation.{ext}",
        scenario,
        subsample_note=subsample_note,
    )

    # Survival / incidence plots
    plot_cumulative_incidence(
        all_stats,
        censor_age,
        out_dir / f"cumulative_incidence.phenotype.{ext}",
        scenario,
    )
    plot_cumulative_incidence_by_sex(
        all_stats,
        out_dir / f"cumulative_incidence.by_sex.{ext}",
        scenario,
    )
    plot_cumulative_incidence_by_sex_generation(
        all_stats,
        out_dir / f"cumulative_incidence.by_sex.by_generation.{ext}",
        scenario,
    )
    plot_joint_affection(
        all_stats,
        out_dir / f"joint_affected.phenotype.{ext}",
        scenario,
    )

    # Censoring
    if gen_censoring is not None:
        plot_censoring_windows(
            all_stats,
            out_dir / f"censoring.{ext}",
            scenario,
            gen_censoring=gen_censoring,
        )
    else:
        save_placeholder_plot(out_dir / f"censoring.{ext}", "No censoring windows configured")

    plot_censoring_confusion(
        all_stats,
        out_dir / f"censoring_confusion.{ext}",
        scenario,
    )
    plot_censoring_cascade(
        all_stats,
        out_dir / f"censoring_cascade.{ext}",
        scenario,
    )

    # Correlation plots
    plot_tetrachoric_sibling(
        all_stats,
        out_dir / f"tetrachoric.phenotype.{ext}",
        scenario,
        params=validation_params,
    )
    plot_tetrachoric_by_sex(
        all_stats,
        out_dir / f"tetrachoric.phenotype.by_sex.{ext}",
        scenario,
        params=validation_params,
    )
    plot_tetrachoric_by_generation(
        all_stats,
        out_dir / f"tetrachoric.phenotype.by_generation.{ext}",
        scenario,
        params=validation_params,
    )
    plot_cross_trait_tetrachoric(
        all_stats,
        out_dir / f"cross_trait_tetrachoric.{ext}",
        scenario,
    )
    plot_parent_offspring_liability(
        df_samples,
        all_stats,
        out_dir / f"parent_offspring_liability.by_generation.{ext}",
        scenario,
        subsample_note=subsample_note,
        params=validation_params,
    )
    # Per-generation heritability (requires validation data)
    if all_validations:
        plot_heritability_by_generation(
            all_validations,
            out_dir / f"heritability.by_generation.{ext}",
            scenario,
        )
        plot_broad_heritability_by_generation(
            all_validations,
            out_dir / f"additive_shared.by_generation.{ext}",
            scenario,
        )
    else:
        for name in ["heritability.by_generation", "additive_shared.by_generation"]:
            save_placeholder_plot(out_dir / f"{name}.{ext}", "No validation data available")

    # PO-regression heritability by sex
    plot_heritability_by_sex_generation(
        all_stats,
        out_dir / f"heritability.by_sex.by_generation.{ext}",
        scenario,
        params=validation_params,
    )

    # Observed-scale heritability from binary affected status + Dempster-Lerner lift
    plot_observed_heritability(
        all_stats,
        out_dir / f"observed_h2.{ext}",
        scenario,
        params=validation_params,
    )

    logger.info("Phenotype plots saved to %s", out_dir)

cli

cli()

Command-line interface for generating phenotype plots.

Source code in simace/plotting/plot_phenotype.py
def cli() -> None:
    """Command-line interface for generating phenotype plots."""
    from simace.core.cli_base import add_logging_args, init_logging

    parser = argparse.ArgumentParser(description="Plot phenotype distributions")
    add_logging_args(parser)
    parser.add_argument("--stats", nargs="+", required=True, help="Stats YAML paths")
    parser.add_argument("--samples", nargs="+", required=True, help="Sample parquet paths")
    parser.add_argument("--output-dir", required=True, help="Output directory")
    parser.add_argument("--censor-age", type=float, required=True, help="Maximum follow-up age")
    parser.add_argument("--gen-censoring", type=str, default=None, help="Per-generation censoring windows as JSON dict")
    parser.add_argument(
        "--plot-format", choices=["png", "pdf"], default="png", help="Output plot format (default: png)"
    )
    parser.add_argument("--validations", nargs="*", default=None, help="Validation YAML paths")
    args = parser.parse_args()

    init_logging(args)

    gen_censoring = None
    if args.gen_censoring:
        gen_censoring = {int(k): v for k, v in json.loads(args.gen_censoring).items()}

    main(
        args.stats,
        args.samples,
        args.output_dir,
        args.censor_age,
        gen_censoring=gen_censoring,
        plot_ext=args.plot_format,
        validation_paths=args.validations,
    )

plot_validation

simace.plotting.plot_validation

Plot validation results summarized across replicates per scenario.

stripplot

stripplot(df, ax, y, expected=None, expected_func=None)

Stripplot of observed values with optional expected markers.

PARAMETER DESCRIPTION
df

Gathered metrics DataFrame with scenario column.

TYPE: DataFrame

ax

Matplotlib axes to plot on.

TYPE: Axes

y

Column name for the observed metric to plot.

TYPE: str

expected

Column name for per-scenario expected values, or a fixed number.

TYPE: str | float | None DEFAULT: None

expected_func

Callable(scenario_df) returning expected value.

TYPE: Callable[[DataFrame], float] | None DEFAULT: None

Source code in simace/plotting/plot_validation.py
def stripplot(
    df: pd.DataFrame,
    ax: Axes,
    y: str,
    expected: str | float | None = None,
    expected_func: Callable[[pd.DataFrame], float] | None = None,
) -> None:
    """Stripplot of observed values with optional expected markers.

    Args:
        df: Gathered metrics DataFrame with ``scenario`` column.
        ax: Matplotlib axes to plot on.
        y: Column name for the observed metric to plot.
        expected: Column name for per-scenario expected values, or a fixed number.
        expected_func: Callable(scenario_df) returning expected value.
    """
    scenarios = df["scenario"].unique()
    positions = {s: i for i, s in enumerate(scenarios)}

    # Guard against all-NaN y column (metric not computed)
    if df[y].isna().all():
        ax.text(0.5, 0.5, "no data", ha="center", va="center", transform=ax.transAxes, fontsize=12, color="0.5")
        ax.set_ylabel(y)
        return

    sns.stripplot(data=df, x="scenario", y=y, ax=ax, alpha=0.9, color=COLOR_OBSERVED, jitter=0.15)

    if expected_func is not None or expected is not None:
        for scenario in scenarios:
            sdf = df[df["scenario"] == scenario]
            if expected_func:
                val = expected_func(sdf)
            elif isinstance(expected, str):
                val = sdf[expected].iloc[0]
            else:
                assert expected is not None
                val = expected
            try:
                val = float(val)
            except (TypeError, ValueError):
                val = None
            if val is not None and np.isfinite(val):
                ax.scatter(
                    positions[scenario],
                    val,
                    marker="_",
                    s=200,
                    linewidths=3,
                    color=COLOR_EXPECTED,
                    zorder=10,
                )

    ax.set_xlabel("")
    _long = max((len(str(s)) for s in scenarios), default=0) > 12
    if len(scenarios) > 3 or (len(scenarios) > 1 and _long):
        ax.tick_params(axis="x", rotation=30)
        for lbl in ax.get_xticklabels():
            lbl.set_ha("right")
    if len(scenarios) == 1:
        ax.set_xlim(-0.5, 0.5)

    # Tight y-axis padding based on actual data range
    data_vals = df[y].dropna().values
    all_vals = list(data_vals)
    if expected_func is not None:
        for scenario in scenarios:
            sdf = df[df["scenario"] == scenario]
            all_vals.append(expected_func(sdf))
    elif expected is not None:
        if isinstance(expected, str):
            all_vals.extend(df[expected].dropna().values)
        else:
            all_vals.append(expected)
    # Filter out non-numeric values (e.g. per-generation dict strings)
    numeric_vals = []
    for v in all_vals:
        with contextlib.suppress(TypeError, ValueError):
            numeric_vals.append(float(v))
    all_vals = np.array(numeric_vals, dtype=float)
    all_vals = all_vals[np.isfinite(all_vals)]
    if len(all_vals) > 0:
        lo, hi = float(all_vals.min()), float(all_vals.max())
        span = hi - lo
        pad = max(span * 0.15, max(0.002, abs(lo + hi) / 2 * 0.01))
        ax.set_ylim(lo - pad, hi + pad)

save

save(fig, path)

Save figure to disk and close it.

Source code in simace/plotting/plot_validation.py
def save(fig: Figure, path: str | Path) -> None:
    """Save figure to disk and close it."""
    fig.tight_layout()
    fig.savefig(path, dpi=150, bbox_inches="tight", pad_inches=0.3)
    plt.close(fig)

plot_variance_components

plot_variance_components(df, out, ext='png')

Plot observed vs expected A, C, E variance components per trait.

Source code in simace/plotting/plot_validation.py
def plot_variance_components(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot observed vs expected A, C, E variance components per trait."""
    fig, axes = plt.subplots(2, 3, figsize=_figsize(nrows=2, ncols=3))
    for row, t in enumerate([1, 2]):
        for col, comp in enumerate(["A", "C", "E"]):
            ax = axes[row, col]
            stripplot(df, ax, f"variance_{comp}{t}", expected=f"{comp}{t}")
            ax.set_title(f"Trait {t}: {comp}{t}")
            ax.set_ylabel("Variance Proportion")
            enable_value_gridlines(ax)
    save(fig, out / f"variance_components.{ext}")

plot_twin_rate

plot_twin_rate(df, out, ext='png')

Plot observed MZ twin rate vs expected across scenarios.

Source code in simace/plotting/plot_validation.py
def plot_twin_rate(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot observed MZ twin rate vs expected across scenarios."""
    fig, ax = plt.subplots(figsize=_figsize())
    stripplot(df, ax, "observed_twin_rate", expected="p_mztwin")
    ax.set_title("MZ Twin Rate: Observed vs Expected")
    ax.set_ylabel("Twin Rate")
    enable_value_gridlines(ax)
    save(fig, out / f"twin_rate.{ext}")

plot_A_correlations

plot_A_correlations(df, out, ext='png')

Plot MZ twin and full-sib additive genetic correlations.

Source code in simace/plotting/plot_validation.py
def plot_A_correlations(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot MZ twin and full-sib additive genetic correlations."""
    panels = [
        ("mz_twin_A1_corr", 1.0, "MZ Twin A1 Correlation"),
        ("dz_sibling_A1_corr", 0.5, "DZ Sibling A1 Correlation"),
        ("half_sib_A1_corr", 0.25, "Half-Sibling A1 Correlation"),
        ("parent_offspring_A1_r2", 0.5, "Midparent-Offspring A1 R²"),
    ]
    fig, axes = plt.subplots(2, 2, figsize=_figsize(nrows=2, ncols=2))
    for ax, (col, exp, title) in zip(axes.flat, panels, strict=True):
        stripplot(df, ax, col, expected=exp)
        ax.axhline(y=exp, color=COLOR_EXPECTED, linestyle="--", alpha=0.7)
        ax.set_title(title)
        ax.set_ylabel("Correlation")
        enable_value_gridlines(ax)
    save(fig, out / f"correlations_A.{ext}")

plot_phenotype_correlations

plot_phenotype_correlations(df, out, ext='png')

Plot MZ twin and full-sib liability correlations vs expected.

Source code in simace/plotting/plot_validation.py
def plot_phenotype_correlations(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot MZ twin and full-sib liability correlations vs expected."""
    panels = [
        ("mz_twin_liability1_corr", lambda d: d["A1"].iloc[0] + d["C1"].iloc[0], "MZ Twin Liability1 Corr"),
        ("dz_sibling_liability1_corr", lambda d: 0.5 * d["A1"].iloc[0] + d["C1"].iloc[0], "DZ Sibling Liability1 Corr"),
        ("half_sib_liability1_corr", lambda d: 0.25 * d["A1"].iloc[0], "Half-Sib Liability1 Corr"),
        ("parent_offspring_liability1_slope", lambda d: d["A1"].iloc[0], "Midparent-Offspring Liability1 Slope"),
    ]
    fig, axes = plt.subplots(2, 2, figsize=_figsize(nrows=2, ncols=2))
    for ax, (col, efn, title) in zip(axes.flat, panels, strict=True):
        stripplot(df, ax, col, expected_func=efn)
        ax.set_title(title)
        ax.set_ylabel("Correlation")
        enable_value_gridlines(ax)
    save(fig, out / f"correlations_phenotype.{ext}")

plot_heritability_estimates

plot_heritability_estimates(df, out, ext='png')

Plot Falconer heritability estimates vs configured A values.

Source code in simace/plotting/plot_validation.py
def plot_heritability_estimates(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot Falconer heritability estimates vs configured A values."""
    panels = [
        ("falconer_h2_trait1", "A1", "Falconer h² Trait 1", "Heritability"),
        ("parent_offspring_liability1_slope", "A1", "Midparent-Offspring Liability1", "Slope"),
        ("falconer_h2_trait2", "A2", "Falconer h² Trait 2", "Heritability"),
        ("parent_offspring_liability2_slope", "A2", "Midparent-Offspring Liability2", "Slope"),
    ]
    fig, axes = plt.subplots(2, 2, figsize=_figsize(nrows=2, ncols=2))
    for ax, (col, exp, title, ylabel) in zip(axes.flat, panels, strict=True):
        stripplot(df, ax, col, expected=exp)
        ax.set_title(title)
        ax.set_ylabel(ylabel)
        enable_value_gridlines(ax)
    save(fig, out / f"heritability_estimates.{ext}")

plot_half_sib_proportions

plot_half_sib_proportions(df, out, ext='png')

Plot observed vs expected half-sib proportions.

Source code in simace/plotting/plot_validation.py
def plot_half_sib_proportions(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot observed vs expected half-sib proportions."""
    fig, axes = plt.subplots(1, 2, figsize=_figsize(ncols=2))
    stripplot(df, axes[0], "half_sib_prop_observed", expected="half_sib_prop_expected")
    axes[0].set_title("Half-Sibling Pair Proportion")
    axes[0].set_ylabel("Proportion")
    enable_value_gridlines(axes[0])

    stripplot(df, axes[1], "offspring_with_half_sib_observed")
    axes[1].set_title("Proportion of Offspring with Half-Siblings")
    axes[1].set_ylabel("Proportion")
    enable_value_gridlines(axes[1])
    save(fig, out / f"half_sib_proportions.{ext}")

plot_cross_trait_correlations

plot_cross_trait_correlations(df, out, ext='png')

Plot cross-trait genetic and environmental correlations vs expected.

Source code in simace/plotting/plot_validation.py
def plot_cross_trait_correlations(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot cross-trait genetic and environmental correlations vs expected."""
    panels = [
        ("observed_rA", "rA", "Cross-Trait rA"),
        ("observed_rC", "rC", "Cross-Trait rC"),
        ("observed_rE", None, "Cross-Trait rE"),
    ]
    fig, axes = plt.subplots(1, 3, figsize=_figsize(ncols=3))
    for ax, (obs, exp, title) in zip(axes, panels, strict=True):
        if exp:
            stripplot(df, ax, obs, expected=exp)
        else:
            stripplot(df, ax, obs)
            ax.axhline(y=0, color=COLOR_EXPECTED, linestyle="--", alpha=0.7)
        ax.set_title(title)
        ax.set_ylabel("Correlation")
        enable_value_gridlines(ax)
    save(fig, out / f"cross_trait_correlations.{ext}")

plot_family_size

plot_family_size(df, out, ext='png')

Plot mean family size distribution across scenarios.

Source code in simace/plotting/plot_validation.py
def plot_family_size(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot mean family size distribution across scenarios."""
    fig, ax = plt.subplots(figsize=_figsize())
    scenarios = df["scenario"].unique()
    positions = {s: i for i, s in enumerate(scenarios)}
    width = 0.3

    for scenario in scenarios:
        sdf = df[df["scenario"] == scenario]
        x = positions[scenario]
        ax.scatter(
            [x - width / 2] * len(sdf),
            sdf["mother_mean_offspring"],
            color=COLOR_OBSERVED,
            alpha=0.9,
            s=30,
            zorder=5,
        )
        ax.scatter(
            [x + width / 2] * len(sdf),
            sdf["father_mean_offspring"],
            color=COLOR_AFFECTED,
            alpha=0.9,
            s=30,
            zorder=5,
        )
        # Expected mean offspring per mother marker (~2.0 for balanced sex)
        expected = 2.0
        ax.scatter(
            x,
            expected,
            marker="_",
            s=200,
            linewidths=3,
            color=COLOR_EXPECTED,
            zorder=10,
        )

    ax.set_xticks(range(len(scenarios)))
    _long = max((len(str(s)) for s in scenarios), default=0) > 12
    if len(scenarios) > 3 or (len(scenarios) > 1 and _long):
        ax.set_xticklabels(scenarios, rotation=30, ha="right")
    else:
        ax.set_xticklabels(scenarios)
    if len(scenarios) == 1:
        ax.set_xlim(-0.5, 0.5)
    ax.set_ylabel("Mean Offspring per Parent")
    ax.set_title("Family Size: Mean Offspring per Mother and Father (parents with children only)")
    enable_value_gridlines(ax)

    from matplotlib.lines import Line2D

    legend = [
        Line2D([0], [0], marker="o", color="w", markerfacecolor=COLOR_OBSERVED, markersize=5, label="Mother"),
        Line2D([0], [0], marker="o", color="w", markerfacecolor=COLOR_AFFECTED, markersize=5, label="Father"),
        Line2D([0], [0], marker="_", color=COLOR_EXPECTED, markersize=7, linewidth=1.2, label="Expected (Poisson)"),
    ]
    ax.legend(handles=legend, loc="best", fontsize="small")
    save(fig, out / f"family_size.{ext}")

plot_summary_bias

plot_summary_bias(df, out, ext='png')

Plot bias heatmap for variance components and correlations.

Source code in simace/plotting/plot_validation.py
def plot_summary_bias(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot bias heatmap for variance components and correlations."""
    dp = df.copy()
    # Coerce to numeric; per-generation dicts become NaN (bias undefined)
    for col in ["A1", "C1", "E1"]:
        dp[col] = pd.to_numeric(dp[col], errors="coerce")
    dp["A1 Bias"] = dp["variance_A1"] - dp["A1"]
    dp["C1 Bias"] = dp["variance_C1"] - dp["C1"]
    dp["E1 Bias"] = dp["variance_E1"] - dp["E1"]
    dp["Twin Rate Bias"] = dp["observed_twin_rate"] - dp["p_mztwin"]
    dp["DZ A1 Corr Bias"] = dp["dz_sibling_A1_corr"] - 0.5
    dp["Half-sib A1 Bias"] = dp["half_sib_A1_corr"] - 0.25

    panels = [
        "A1 Bias",
        "C1 Bias",
        "E1 Bias",
        "Twin Rate Bias",
        "DZ A1 Corr Bias",
        "Half-sib A1 Bias",
    ]
    scenarios = dp["scenario"].unique()
    n = len(scenarios)
    _long = max((len(str(s)) for s in scenarios), default=0) > 12
    fig, axes = plt.subplots(2, 3, figsize=_figsize(nrows=2, ncols=3))

    for ax, col in zip(axes.flat, panels, strict=True):
        if dp[col].isna().all():
            ax.text(0.5, 0.5, "no data", ha="center", va="center", transform=ax.transAxes, fontsize=12, color="0.5")
            ax.set_title(col)
            ax.set_xlabel("")
            continue
        sns.stripplot(data=dp, x="scenario", y=col, ax=ax, alpha=0.9, color=COLOR_OBSERVED, jitter=0.15)
        ax.axhline(y=0, color="red", linestyle="--", alpha=0.5)
        ax.set_title(col)
        ax.set_xlabel("")
        enable_value_gridlines(ax)
        if n > 3 or (n > 1 and _long):
            ax.tick_params(axis="x", rotation=30)
            for lbl in ax.get_xticklabels():
                lbl.set_ha("right")
        if n == 1:
            ax.set_xlim(-0.5, 0.5)
        # Tight y-axis: include zero (the reference line) in span
        vals = dp[col].dropna().values
        all_v = np.concatenate([vals, [0.0]])
        lo, hi = float(all_v.min()), float(all_v.max())
        span = hi - lo
        pad = max(span * 0.15, 0.002)
        ax.set_ylim(lo - pad, hi + pad)
    save(fig, out / f"summary_bias.{ext}")

plot_runtime

plot_runtime(df, out, ext='png')

Plot simulation runtime per scenario.

Source code in simace/plotting/plot_validation.py
def plot_runtime(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot simulation runtime per scenario."""
    sub = df.dropna(subset=["simulate_seconds"])
    if sub.empty:
        logger.warning("No simulate_seconds data; skipping runtime plot")
        return

    unique_n = sub["N"].nunique()
    if unique_n <= 1:
        # Single N value — use stripplot instead of log-log scatter
        fig, ax = plt.subplots(figsize=_figsize())
        stripplot(sub, ax, "simulate_seconds")
        ax.set_ylabel("Simulate Time (seconds)")
        ax.set_title("Simulation Runtime")
        enable_value_gridlines(ax)
        save(fig, out / f"runtime.{ext}")
        return

    fig, ax = plt.subplots(figsize=(8, 6))
    scenarios = sub["scenario"].unique()
    palette = sns.color_palette("colorblind", len(scenarios))
    color_map = dict(zip(scenarios, palette, strict=True))

    for scenario in scenarios:
        sdf = sub[sub["scenario"] == scenario]
        ax.scatter(
            sdf["N"],
            sdf["simulate_seconds"],
            color=color_map[scenario],
            label=scenario,
            alpha=0.9,
            s=40,
        )

    ax.set_xscale("log")
    ax.set_yscale("log")
    _format_log_axes(ax)
    ax.set_xlabel("Population Size (N)")
    ax.set_ylabel("Simulate Time (seconds)")
    ax.set_title("Simulation Runtime vs Population Size")
    ax.legend()
    save(fig, out / f"runtime.{ext}")

plot_memory

plot_memory(df, out, ext='png')

Plot simulation peak memory usage per scenario.

Source code in simace/plotting/plot_validation.py
def plot_memory(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot simulation peak memory usage per scenario."""
    sub = df.dropna(subset=["simulate_max_rss_mb"])
    if sub.empty:
        logger.warning("No simulate_max_rss_mb data; skipping memory plot")
        return

    unique_n = sub["N"].nunique()
    if unique_n <= 1:
        # Single N value — use stripplot instead of log-log scatter
        fig, ax = plt.subplots(figsize=_figsize())
        stripplot(sub, ax, "simulate_max_rss_mb")
        ax.set_ylabel("Peak RSS (MB)")
        ax.set_title("Simulation Memory Usage")
        enable_value_gridlines(ax)
        save(fig, out / f"memory.{ext}")
        return

    fig, ax = plt.subplots(figsize=(8, 6))
    scenarios = sub["scenario"].unique()
    palette = sns.color_palette("colorblind", len(scenarios))
    color_map = dict(zip(scenarios, palette, strict=True))

    for scenario in scenarios:
        sdf = sub[sub["scenario"] == scenario]
        ax.scatter(
            sdf["N"],
            sdf["simulate_max_rss_mb"],
            color=color_map[scenario],
            label=scenario,
            alpha=0.9,
            s=40,
        )

    ax.set_xscale("log")
    ax.set_yscale("log")
    _format_log_axes(ax)
    ax.set_xlabel("Population Size (N)")
    ax.set_ylabel("Peak RSS (MB)")
    ax.set_title("Simulation Memory Usage vs Population Size")
    ax.legend()
    save(fig, out / f"memory.{ext}")

plot_consanguineous_matings

plot_consanguineous_matings(df, out, ext='png')

Plot consanguineous mating counts and inbreeding coefficients.

Source code in simace/plotting/plot_validation.py
def plot_consanguineous_matings(df: pd.DataFrame, out: Path, ext: str = "png") -> None:
    """Plot consanguineous mating counts and inbreeding coefficients."""
    fig, axes = plt.subplots(1, 2, figsize=_figsize(ncols=2))
    stripplot(df, axes[0], "n_half_sib_matings", expected=0)
    axes[0].set_title("Half-Sib Matings")
    axes[0].set_ylabel("Count")
    enable_value_gridlines(axes[0])

    stripplot(df, axes[1], "missing_gp_links", expected=0)
    axes[1].set_title("Missing Grandparent Links")
    axes[1].set_ylabel("Count")
    enable_value_gridlines(axes[1])
    save(fig, out / f"consanguineous_matings.{ext}")

main

main(tsv_path, output_dir, plot_ext='png')

Generate all validation plots from a gathered metrics TSV.

Source code in simace/plotting/plot_validation.py
def main(tsv_path: str, output_dir: str | Path, plot_ext: str = "png") -> None:
    """Generate all validation plots from a gathered metrics TSV."""
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    logger.info("Generating validation plots in %s", out)
    from simace.plotting.plot_style import apply_nature_style

    apply_nature_style()
    df = pd.read_csv(tsv_path, sep="\t", encoding="utf-8")

    # Sort scenarios by increasing N so x-axes read left-to-right by size
    if "N" in df.columns:
        scenario_order = df.groupby("scenario")["N"].first().sort_values().index
        df["scenario"] = pd.Categorical(df["scenario"], categories=scenario_order, ordered=True)
        df = df.sort_values("scenario").reset_index(drop=True)

    plot_variance_components(df, out, ext=plot_ext)
    plot_twin_rate(df, out, ext=plot_ext)
    plot_A_correlations(df, out, ext=plot_ext)
    plot_phenotype_correlations(df, out, ext=plot_ext)
    plot_heritability_estimates(df, out, ext=plot_ext)
    plot_half_sib_proportions(df, out, ext=plot_ext)
    plot_cross_trait_correlations(df, out, ext=plot_ext)
    plot_family_size(df, out, ext=plot_ext)
    plot_summary_bias(df, out, ext=plot_ext)
    plot_runtime(df, out, ext=plot_ext)
    plot_memory(df, out, ext=plot_ext)
    plot_consanguineous_matings(df, out, ext=plot_ext)

    # Assemble validation atlas PDF — order, captions, and (future) section
    # breaks live in the manifest.
    from simace.plotting.atlas_manifest import VALIDATION_ATLAS
    from simace.plotting.plot_atlas import assemble_atlas

    assemble_atlas(list(VALIDATION_ATLAS), out, out / "atlas.pdf", plot_ext=plot_ext)

cli

cli()

Command-line interface for generating validation plots.

Source code in simace/plotting/plot_validation.py
def cli() -> None:
    """Command-line interface for generating validation plots."""
    from simace.core.cli_base import add_logging_args, init_logging

    parser = argparse.ArgumentParser(description="Plot validation results")
    add_logging_args(parser)
    parser.add_argument("tsv", help="Validation summary TSV path")
    parser.add_argument("output_dir", help="Output directory")
    parser.add_argument(
        "--plot-format", choices=["png", "pdf"], default="png", help="Output plot format (default: png)"
    )
    args = parser.parse_args()

    init_logging(args)

    main(args.tsv, args.output_dir, plot_ext=args.plot_format)

plot_pipeline

simace.plotting.plot_pipeline

Pipeline DAG diagram for the atlas title page.

Renders a single-page figure showing the Snakemake pipeline structure with each step's relevant parameters displayed inside its box.

render_pipeline_figure

render_pipeline_figure(params, scenario='')

Build and return the pipeline DAG figure (without saving).

PARAMETER DESCRIPTION
params

Merged scenario parameters dict.

TYPE: dict

scenario

Scenario name for the title.

TYPE: str DEFAULT: ''

RETURNS DESCRIPTION
Figure

The matplotlib Figure object.

Source code in simace/plotting/plot_pipeline.py
def render_pipeline_figure(
    params: dict,
    scenario: str = "",
) -> plt.Figure:
    """Build and return the pipeline DAG figure (without saving).

    Args:
        params: Merged scenario parameters dict.
        scenario: Scenario name for the title.

    Returns:
        The matplotlib Figure object.
    """
    fig = plt.figure(figsize=(11.69, 8.27))
    ax = fig.add_axes([0.02, 0.06, 0.96, 0.88])
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_axis_off()

    # Scenario area — right column, aligned with simulate box
    if scenario:
        fig.text(
            0.71,
            0.92,
            "Scenario",
            fontsize=_FONT_TABLE,
            fontfamily="sans-serif",
            color="0.4",
            ha="center",
            va="bottom",
        )
        fig.text(
            0.71,
            0.91,
            scenario,
            fontsize=_FONT_TITLE,
            fontweight="bold",
            fontfamily="sans-serif",
            ha="center",
            va="top",
        )
    # Seed + replicates below scenario name
    meta_parts = []
    seed = params.get("seed")
    if seed is not None:
        meta_parts.append(f"seed = {seed}")
    reps = params.get("replicates")
    if reps is not None:
        meta_parts.append(f"replicates = {reps}")
    std = params.get("standardize")
    if std is not None:
        meta_parts.append(f"standardize = {str(std).lower()}")
    # Surface per-trait standardize_hazard overrides when set and distinct
    # from the global standardize value (so plot readers can tell when the
    # hazard step is decoupled from the threshold step).
    for t in (1, 2):
        pp = params.get(f"phenotype_params{t}") or {}
        haz = pp.get("standardize_hazard")
        if haz is not None and haz != std:
            meta_parts.append(f"standardize_hazard.t{t} = {str(haz).lower()}")
    if meta_parts:
        fig.text(
            0.71,
            0.82,
            "\n".join(meta_parts),
            fontsize=_FONT_META,
            fontfamily="monospace",
            color="0.4",
            ha="center",
            va="top",
            linespacing=1.5,
        )

    # Build step info lookup
    step_info = {}
    for key, display, color, pnames in _PIPELINE_STEPS:
        step_info[key] = (display, color, pnames)

    # Override phenotype title with model-specific short name
    m1 = str(params.get("phenotype_model1", "frailty"))
    m2 = str(params.get("phenotype_model2", "frailty"))
    pp1 = params.get("phenotype_params1", {})
    pp2 = params.get("phenotype_params2", {})
    if m1 == m2 and pp1.get("distribution") == pp2.get("distribution") and pp1.get("method") == pp2.get("method"):
        short_name = _model_display_name(m1, pp1)[0]
        pheno_title = f"Phenotype ({short_name.lower()})"
    else:
        s1 = _model_display_name(m1, pp1)[0]
        s2 = _model_display_name(m2, pp2)[0]
        pheno_title = f"Phenotype ({s1.lower()} / {s2.lower()})"
    old = step_info["phenotype"]
    step_info["phenotype"] = (pheno_title, old[1], old[2])

    # Build rows for each step and compute box sizes
    step_rows: dict[str, list[tuple[str, str]]] = {}
    box_sizes: dict[str, tuple[float, float]] = {}
    max_w = 0.22
    for key in _STEP_POSITIONS:
        display, _, pnames = step_info[key]
        rows = _get_param_rows(pnames, params)
        step_rows[key] = rows
        # Width: based on longest name+value pair
        max_name = max((len(n) for n, _ in rows), default=0) if rows else 0
        max_val = max((len(v) for _, v in rows), default=0) if rows else 0
        table_w = 0.046 + (max_name + 2 + max_val) * _CHAR_W
        title_w = len(display) * 0.010 + 0.04
        w = max(0.22, table_w, title_w)
        max_w = max(max_w, w)
        # Height: title area + rows
        h = 0.055 + 0.024 * max(len(rows), 1)
        box_sizes[key] = (w, h)
    # Uniform width across all boxes, capped to avoid left/right column overlap
    # Left column at x=0.27, right at x=0.73 — gap of 0.46
    max_w = min(max_w, 0.43)
    for key in box_sizes:
        _, h = box_sizes[key]
        box_sizes[key] = (max_w, h)

    # Draw arrows first (behind boxes)
    for src, dst in _PIPELINE_EDGES:
        sx, sy = _STEP_POSITIONS[src]
        dx, dy = _STEP_POSITIONS[dst]
        _, sh = box_sizes[src]
        _, dh = box_sizes[dst]
        _draw_pipeline_arrow(ax, sx, sy - sh / 2, dx, dy + dh / 2)

    # Draw boxes
    for key, (cx, cy) in _STEP_POSITIONS.items():
        display, color, _ = step_info[key]
        w, h = box_sizes[key]
        _draw_step_box(ax, cx, cy, w, h, display, step_rows[key], color)

    return fig

plot_pipeline

plot_pipeline(params, output_path, scenario='')

Render the pipeline DAG diagram and save to file.

PARAMETER DESCRIPTION
params

Merged scenario parameters dict.

TYPE: dict

output_path

Where to save the figure.

TYPE: str | Path

scenario

Scenario name for the title.

TYPE: str DEFAULT: ''

Source code in simace/plotting/plot_pipeline.py
def plot_pipeline(
    params: dict,
    output_path: str | Path,
    scenario: str = "",
) -> None:
    """Render the pipeline DAG diagram and save to file.

    Args:
        params: Merged scenario parameters dict.
        output_path: Where to save the figure.
        scenario: Scenario name for the title.
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    fig = render_pipeline_figure(params, scenario=scenario)
    fig.savefig(str(output_path), dpi=150, bbox_inches="tight")
    plt.close(fig)
    logger.info("Pipeline diagram saved to %s", output_path)

cli

cli()

Command-line interface for standalone pipeline diagram rendering.

Source code in simace/plotting/plot_pipeline.py
def cli() -> None:
    """Command-line interface for standalone pipeline diagram rendering."""
    from simace.core.cli_base import add_logging_args, init_logging

    parser = argparse.ArgumentParser(
        description="Render pipeline DAG diagram with scenario parameters.",
    )
    add_logging_args(parser)
    parser.add_argument(
        "--params",
        required=True,
        help="Path to params.yaml (merged scenario parameters).",
    )
    parser.add_argument(
        "--output",
        required=True,
        help="Output image path (e.g. /tmp/pipeline.png).",
    )
    parser.add_argument(
        "--scenario",
        default="",
        help="Scenario name for the title.",
    )
    args = parser.parse_args()
    init_logging(args)

    params = load_yaml(args.params)

    plot_pipeline(params, args.output, scenario=args.scenario)

plot_atlas

simace.plotting.plot_atlas

Assemble individual plots into a multi-page PDF atlas with figure captions.

get_model_equation

get_model_equation(params)

Return mathtext equation lines for the scenario's phenotype model(s).

Source code in simace/plotting/plot_atlas.py
def get_model_equation(params: dict) -> list[str]:
    """Return mathtext equation lines for the scenario's phenotype model(s)."""
    m1 = str(params.get("phenotype_model1", "frailty"))
    m2 = str(params.get("phenotype_model2", "frailty"))
    pp1 = params.get("phenotype_params1", {})
    pp2 = params.get("phenotype_params2", {})

    if m1 == m2 and pp1.get("distribution") == pp2.get("distribution") and pp1.get("method") == pp2.get("method"):
        return _equation_lines_for_model(m1, pp1)

    lines: list[str] = []
    lines.extend(_equation_lines_for_model(m1, pp1, label="Trait 1"))
    lines.extend(_equation_lines_for_model(m2, pp2, label="Trait 2"))
    return lines

get_model_family

get_model_family(params)

Return (display_name, description) for the scenario's phenotype model(s).

When both traits use the same model family and sub-type, return that family. When they differ, return a combined description.

Source code in simace/plotting/plot_atlas.py
def get_model_family(params: dict) -> tuple[str, str]:
    """Return (display_name, description) for the scenario's phenotype model(s).

    When both traits use the same model family and sub-type, return that family.
    When they differ, return a combined description.
    """
    m1 = str(params.get("phenotype_model1", "frailty"))
    m2 = str(params.get("phenotype_model2", "frailty"))
    pp1 = params.get("phenotype_params1", {})
    pp2 = params.get("phenotype_params2", {})

    name1, desc1 = _model_display_name(m1, pp1)
    name2, desc2 = _model_display_name(m2, pp2)

    if m1 == m2 and pp1.get("distribution") == pp2.get("distribution") and pp1.get("method") == pp2.get("method"):
        return name1, desc1

    return (
        f"{name1} / {name2}",
        f"Trait 1: {desc1}; Trait 2: {desc2}",
    )

assemble_atlas

assemble_atlas(items, plot_dir, output_path, *, plot_ext='png', scenario_params=None, stats_data=None)

Combine plots and section breaks into a multi-page PDF with captions.

Walks items linearly. PlotEntry items render as a plot+caption page; the "Figure {N}: " prefix is derived from the running plot index (1-based, counting only :class:PlotEntry items). SectionBreak items render as a section divider page.

PARAMETER DESCRIPTION
items

Ordered atlas manifest, mixing :class:~simace.plotting.atlas_manifest.PlotEntry and :class:~simace.plotting.atlas_manifest.SectionBreak.

TYPE: list[AtlasItem]

plot_dir

Directory containing the plot image files; each PlotEntry.basename resolves to plot_dir / f"{basename}.{plot_ext}".

TYPE: Path

output_path

Path for the combined PDF.

TYPE: Path

plot_ext

Image extension (default "png").

TYPE: str DEFAULT: 'png'

scenario_params

If provided, a dict with key "scenario" and parameter names. A title page with all parameters is rendered first.

TYPE: dict | None DEFAULT: None

stats_data

If provided, a list of phenotype_stats dicts (one per rep). A Table 1 page is rendered after the title page.

TYPE: list[dict] | None DEFAULT: None

Source code in simace/plotting/plot_atlas.py
def assemble_atlas(
    items: list[AtlasItem],
    plot_dir: Path,
    output_path: Path,
    *,
    plot_ext: str = "png",
    scenario_params: dict | None = None,
    stats_data: list[dict] | None = None,
) -> None:
    """Combine plots and section breaks into a multi-page PDF with captions.

    Walks ``items`` linearly. ``PlotEntry`` items render as a plot+caption
    page; the ``"Figure {N}: "`` prefix is derived from the running plot
    index (1-based, counting only :class:`PlotEntry` items). ``SectionBreak``
    items render as a section divider page.

    Args:
        items: Ordered atlas manifest, mixing
            :class:`~simace.plotting.atlas_manifest.PlotEntry` and
            :class:`~simace.plotting.atlas_manifest.SectionBreak`.
        plot_dir: Directory containing the plot image files; each
            ``PlotEntry.basename`` resolves to ``plot_dir / f"{basename}.{plot_ext}"``.
        output_path: Path for the combined PDF.
        plot_ext: Image extension (default ``"png"``).
        scenario_params: If provided, a dict with key ``"scenario"`` and
            parameter names. A title page with all parameters is rendered first.
        stats_data: If provided, a list of phenotype_stats dicts (one per rep).
            A Table 1 page is rendered after the title page.
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    plot_dir = Path(plot_dir)
    atlas_dir = output_path.parent.resolve()

    n_plots = sum(1 for item in items if isinstance(item, PlotEntry))

    with PdfPages(str(output_path)) as pdf:
        # Optional title page with scenario parameters
        if scenario_params is not None:
            scenario_name = scenario_params.get("scenario", "unknown")
            _render_params_page(pdf, scenario_name, scenario_params)

            # Table 1 page (requires both params and stats)
            if stats_data:
                _render_table1_page(pdf, stats_data, scenario_name, scenario_params)

        plot_idx = 0
        for item in items:
            if isinstance(item, SectionBreak):
                _render_section_page(
                    pdf,
                    item.title,
                    item.subtitle,
                    equations=list(item.equations) if item.equations else None,
                )
                continue

            plot_idx += 1
            path = plot_dir / f"{item.basename}.{plot_ext}"
            if not path.exists():
                logger.warning("Atlas: skipping missing plot %s", path)
                continue

            try:
                rel = path.resolve().relative_to(atlas_dir)
            except ValueError:
                rel = path.name

            title = f"Figure {plot_idx}: {item.title}"
            body = item.body

            caption_len = len(title) + 2 + len(body)
            if caption_len < 300:
                caption_frac = 0.13
            elif caption_len < 500:
                caption_frac = 0.18
            else:
                caption_frac = 0.24
            img_frac = 1.0 - caption_frac - _TOP_MARGIN

            fig = plt.figure(figsize=(_PAGE_W, _PAGE_H))

            ax = fig.add_axes([0.005, caption_frac + 0.005, 0.99, img_frac - 0.005])
            with Image.open(path) as img:
                ax.imshow(img)
            ax.axis("off")

            # Thin hairline border around the figure image
            rect = plt.Rectangle(
                (0, 0),
                1,
                1,
                transform=ax.transAxes,
                linewidth=0.3,
                edgecolor="#cccccc",
                facecolor="none",
                clip_on=False,
            )
            ax.add_patch(rect)

            # Caption text in the lower portion — inline bold title + body
            caption_y = caption_frac - 0.015
            body_with_ref = f"{body}  [{rel}]" if body else f"[{rel}]"
            _render_inline_caption(
                fig,
                0.04,
                caption_y,
                title,
                body_with_ref,
                fontsize=11,
                fontfamily="sans-serif",
            )

            pdf.savefig(fig, dpi=150)
            plt.close(fig)

    logger.info("Atlas saved to %s (%d plots)", output_path, n_plots)

plot_table1

simace.plotting.plot_table1

Render an epidemiological Table 1 summarising the simulated study population.

render_table1_figure

render_table1_figure(all_stats, scenario_params, scenario='')

Build and return the Table 1 figure (11 x 8.5 landscape).

PARAMETER DESCRIPTION
all_stats

List of phenotype_stats dicts, one per replicate.

TYPE: list[dict]

scenario_params

Merged scenario config parameters.

TYPE: dict

scenario

Scenario name for the title.

TYPE: str DEFAULT: ''

RETURNS DESCRIPTION
Figure

matplotlib Figure ready for pdf.savefig().

Source code in simace/plotting/plot_table1.py
def render_table1_figure(
    all_stats: list[dict],
    scenario_params: dict,
    scenario: str = "",
) -> plt.Figure:
    """Build and return the Table 1 figure (11 x 8.5 landscape).

    Args:
        all_stats: List of phenotype_stats dicts, one per replicate.
        scenario_params: Merged scenario config parameters.
        scenario: Scenario name for the title.

    Returns:
        matplotlib Figure ready for ``pdf.savefig()``.
    """
    fig = plt.figure(figsize=(11.69, 8.27))
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis("off")

    p = scenario_params
    n_reps = len(all_stats)

    # ── Title ──────────────────────────────────────────────────────────
    title = f"Table 1.  Study Population Characteristics — {scenario}"
    fig.text(
        0.50,
        0.96,
        title,
        fontsize=_TITLE_SIZE,
        fontweight="bold",
        fontfamily="sans-serif",
        ha="center",
        va="top",
        transform=fig.transFigure,
    )
    # Thin rule below title
    fig.add_artist(
        plt.Line2D([_LEFT, _RIGHT], [0.935, 0.935], color="black", lw=0.8, transform=fig.transFigure, clip_on=False)
    )

    y = 0.89
    shade = False

    # ── A. Population ─────────────────────────────────────────────────
    y = _draw_section_header(fig, y, "A.  Population")
    shade = False

    # Deterministic values (constant across reps) — use first rep directly
    s0 = all_stats[0]
    n_ind = s0.get("n_individuals")
    n_ped = s0.get("n_individuals_ped")
    n_gen = s0.get("n_generations")

    # --- Study size & demographics ---
    r3 = _draw_row3  # shorthand
    y = r3(fig, ax, y, "Total phenotyped individuals, n", _fmt_int(n_ind), "", True)
    y = r3(fig, ax, y, "Full pedigree individuals, n", _fmt_int(n_ped), "", False)
    y = r3(fig, ax, y, "Generations observed", str(n_gen) if n_gen else "\u2014", "", True)

    f_n = _sex_n(s0, "trait1")[0]
    m_n = _sex_n(s0, "trait1")[1]
    f_pct_str = f"({_fmt_pct(f_n / n_ind)})" if f_n is not None and n_ind else ""
    m_pct_str = f"({_fmt_pct(m_n / n_ind)})" if m_n is not None and n_ind else ""
    y = r3(fig, ax, y, "  Female, n (%)", _fmt_int(f_n), f_pct_str, False)
    y = r3(fig, ax, y, "  Male, n (%)", _fmt_int(m_n), m_pct_str, True)

    # Sampling info — always shown, grayed out when defaults
    n_sample = p.get("N_sample", 0)
    car = p.get("case_ascertainment_ratio", 1.0)
    sample_active = n_sample and n_sample > 0
    car_active = car != 1.0
    sc = "black" if sample_active else "0.55"
    cc = "black" if car_active else "0.55"
    sample_val = _fmt_int(n_sample) if sample_active else "none"
    sample_rng = "" if sample_active else "(full population)"
    car_val = f"{car:.1f}\u00d7" if car_active else "1.0\u00d7"
    car_rng = "" if car_active else "(no enrichment)"
    y = r3(fig, ax, y, "Sampled individuals, n", sample_val, sample_rng, False, color=sc)
    y = r3(fig, ax, y, "Case ascertainment ratio", car_val, car_rng, True, color=cc)

    # --- Family structure ---
    v, rng = _fmt_split_f([_safe_get(s, "family_size", "mean") for s in all_stats], 2)
    y = r3(fig, ax, y, "Offspring per mating, mean", v, rng, True)

    # Per-mating family size distribution on one line
    def _dist_line(stats_key):
        parts = []
        for k in ["1", "2", "3", "4+"]:
            vals = [_safe_get(s, "family_size", stats_key, k) for s in all_stats]
            clean = [v for v in vals if v is not None]
            pct = _fmt_pct(mean(clean)) if clean else "\u2014"
            parts.append(f"{k}: {pct}")
        return "  /  ".join(parts)

    y = _draw_row(fig, ax, y, "  Distribution (1 / 2 / 3 / 4+)", _dist_line("size_dist"), False)

    # Per-person offspring distribution (includes 0 = childless)
    def _person_dist_line():
        parts = []
        for k in ["0", "1", "2", "3", "4+"]:
            vals = [_safe_get(s, "family_size", "person_offspring_dist", k) for s in all_stats]
            clean = [v for v in vals if v is not None]
            pct = _fmt_pct(mean(clean)) if clean else "\u2014"
            parts.append(f"{k}: {pct}")
        return "  /  ".join(parts)

    y = _draw_row(fig, ax, y, "Offspring per person\u00b9 (0 / 1 / 2 / 3 / 4+)", _person_dist_line(), True)

    # Number of mates by sex
    def _mates_line(sex):
        m1 = [_safe_get(s, "family_size", "mates_by_sex", f"{sex}_1") for s in all_stats]
        m2 = [_safe_get(s, "family_size", "mates_by_sex", f"{sex}_2+") for s in all_stats]
        c1 = [v for v in m1 if v is not None]
        c2 = [v for v in m2 if v is not None]
        p1 = _fmt_pct(mean(c1)) if c1 else "\u2014"
        p2 = _fmt_pct(mean(c2)) if c2 else "\u2014"
        return f"1: {p1}  /  2+: {p2}"

    y = _draw_row(fig, ax, y, "Mates per mother (1 / 2+)", _mates_line("female"), False)
    y = _draw_row(fig, ax, y, "Mates per father (1 / 2+)", _mates_line("male"), True)

    v, rng = _fmt_split_pct([_safe_get(s, "family_size", "frac_with_full_sib") for s in all_stats])
    y = r3(fig, ax, y, "With \u2265 1 full sib phenotyped, %", v, rng, True)

    # Parent status
    ps_pheno = {str(k): [_safe_get(s, "parent_status", "phenotyped", str(k)) for s in all_stats] for k in [0, 1, 2]}
    ps_ped = {str(k): [_safe_get(s, "parent_status", "in_pedigree", str(k)) for s in all_stats] for k in [0, 1, 2]}

    def _parent_pct(counts_list):
        return [c / n_ind if c is not None and n_ind else None for c in counts_list]

    def _parent_summary(ps_dict):
        vals = _parent_pct(ps_dict["0"])
        p0 = _fmt_pct(vals[0]) if vals and vals[0] is not None else "\u2014"
        vals = _parent_pct(ps_dict["1"])
        p1 = _fmt_pct(vals[0]) if vals and vals[0] is not None else "\u2014"
        vals = _parent_pct(ps_dict["2"])
        p2 = _fmt_pct(vals[0]) if vals and vals[0] is not None else "\u2014"
        return f"0: {p0}  /  1: {p1}  /  2: {p2}"

    if any(v is not None for v in ps_pheno["0"]):
        y = r3(fig, ax, y, "Parents phenotyped (0 / 1 / 2)", _parent_summary(ps_pheno), "", False)
    if any(v is not None for v in ps_ped["0"]):
        y = r3(fig, ax, y, "Parents in pedigree (0 / 1 / 2)", _parent_summary(ps_ped), "", True)

    # --- Follow-up ---
    y = r3(fig, ax, y, "Maximum follow-up age", f"{p.get('censor_age', '—')} years", "", False)

    total_py = [_safe_get(s, "person_years", "total") for s in all_stats]
    if any(v is not None for v in total_py):
        v, rng = _fmt_split(total_py)
        y = r3(fig, ax, y, "Total person-years of follow-up", v, rng, True)
        mean_fu = [py / n_ind for py in total_py if py is not None] if n_ind else []
        if mean_fu:
            v, rng = _fmt_split_f(mean_fu, 1)
            y = r3(fig, ax, y, "Mean follow-up per person, years", v, rng, False)
    deaths = [_safe_get(s, "person_years", "deaths") for s in all_stats]
    if any(d is not None for d in deaths):
        v, rng = _fmt_split(deaths)
        y = r3(fig, ax, y, "Deaths during follow-up, n", v, rng, True)

    y -= _ROW_H * 0.4

    # ── B. Disease Characteristics ────────────────────────────────────
    y = _draw_section_header(fig, y, "B.  Disease Characteristics")
    y = _draw_col4_headers(fig, y)
    shade = False

    # Prevalence — 4-column by sex
    fprev1 = [_safe_get(s, "cumulative_incidence_by_sex", "trait1", "female", "prevalence") for s in all_stats]
    mprev1 = [_safe_get(s, "cumulative_incidence_by_sex", "trait1", "male", "prevalence") for s in all_stats]
    fprev2 = [_safe_get(s, "cumulative_incidence_by_sex", "trait2", "female", "prevalence") for s in all_stats]
    mprev2 = [_safe_get(s, "cumulative_incidence_by_sex", "trait2", "male", "prevalence") for s in all_stats]
    shade = not shade
    y = _draw_row4(
        fig,
        ax,
        y,
        "Observed prevalence",
        _fmt_range_pct(fprev1),
        _fmt_range_pct(mprev1),
        _fmt_range_pct(fprev2),
        _fmt_range_pct(mprev2),
        shade,
    )

    # Affected n — derive from sex-specific prevalence × n
    def _affected_by_sex(prev_list, n_list):
        return [
            round(p * n) if p is not None and n is not None else None for p, n in zip(prev_list, n_list, strict=True)
        ]

    fn1 = [_safe_get(s, "cumulative_incidence_by_sex", "trait1", "female", "n") for s in all_stats]
    mn1 = [_safe_get(s, "cumulative_incidence_by_sex", "trait1", "male", "n") for s in all_stats]
    fn2 = [_safe_get(s, "cumulative_incidence_by_sex", "trait2", "female", "n") for s in all_stats]
    mn2 = [_safe_get(s, "cumulative_incidence_by_sex", "trait2", "male", "n") for s in all_stats]
    shade = not shade
    y = _draw_row4(
        fig,
        ax,
        y,
        "Affected, n",
        _fmt_range(_affected_by_sex(fprev1, fn1)),
        _fmt_range(_affected_by_sex(mprev1, mn1)),
        _fmt_range(_affected_by_sex(fprev2, fn2)),
        _fmt_range(_affected_by_sex(mprev2, mn2)),
        shade,
    )

    # Incidence rate by sex (per 1,000 PY) — approximate: affected / (total_py * sex_fraction)
    total_py_list = [_safe_get(s, "person_years", "total") for s in all_stats]

    def _incidence_rate(prev_list, n_sex_list, py_total_list, n_total):
        """IR ≈ (prev × n_sex) / (py_total × n_sex/n_total) × 1000 = prev × n_total / py_total × 1000."""
        rates = []
        for prev, n_sex, py in zip(prev_list, n_sex_list, py_total_list, strict=True):
            if prev is not None and n_sex and py and py > 0 and n_total:
                affected = prev * n_sex
                py_sex = py * n_sex / n_total
                rates.append(affected / py_sex * 1000 if py_sex > 0 else None)
            else:
                rates.append(None)
        return rates

    shade = not shade
    y = _draw_row4(
        fig,
        ax,
        y,
        "Incidence rate (per 1,000 PY)",
        _fmt_range_f(_incidence_rate(fprev1, fn1, total_py_list, n_ind), 1),
        _fmt_range_f(_incidence_rate(mprev1, mn1, total_py_list, n_ind), 1),
        _fmt_range_f(_incidence_rate(fprev2, fn2, total_py_list, n_ind), 1),
        _fmt_range_f(_incidence_rate(mprev2, mn2, total_py_list, n_ind), 1),
        shade,
    )

    # Age at onset quartiles
    def _aoo_quartile(all_stats, trait, key, ci_key="cumulative_incidence"):
        vals = []
        for s in all_stats:
            ci = _safe_get(s, ci_key, trait, default={})
            q = _compute_aoo_quartiles(ci)
            if q[key] is not None:
                vals.append(q[key])
        return _fmt_range_f(vals, 1)

    def _aoo_sex_quartile(all_stats, trait, sex, key):
        vals = []
        for s in all_stats:
            ci = _safe_get(s, "cumulative_incidence_by_sex", trait, sex, default={})
            q = _compute_aoo_quartiles(ci)
            if q[key] is not None:
                vals.append(q[key])
        return _fmt_range_f(vals, 1)

    for qkey, qlabel in [("q1", "Q1"), ("median", "Median"), ("q3", "Q3")]:
        shade = not shade
        y = _draw_row4(
            fig,
            ax,
            y,
            f"Age at onset, {qlabel}",
            _aoo_sex_quartile(all_stats, "trait1", "female", qkey),
            _aoo_sex_quartile(all_stats, "trait1", "male", qkey),
            _aoo_sex_quartile(all_stats, "trait2", "female", qkey),
            _aoo_sex_quartile(all_stats, "trait2", "male", qkey),
            shade,
        )

    # Co-affected by sex
    coaff_f = [_safe_get(s, "joint_affection", "by_sex", "female") for s in all_stats]
    coaff_m = [_safe_get(s, "joint_affection", "by_sex", "male") for s in all_stats]
    shade = not shade
    y = _draw_row4(
        fig,
        ax,
        y,
        "Co-affected, %",
        _fmt_range_pct(coaff_f),
        _fmt_range_pct(coaff_m),
        "",
        "",
        shade,
    )

    y -= _ROW_H * 0.4

    # ── C. Censoring ──────────────────────────────────────────────────
    y = _draw_section_header(fig, y, "C.  Censoring")
    shade = False

    # Per-generation rows: N, window, observed prevalence
    # Use trait1 cascade as reference (windows are the same for both traits)
    cascade0 = _safe_get(all_stats[0], "censoring_cascade", "trait1", default={})
    gen_keys = sorted(cascade0.keys()) if cascade0 else []
    y = _draw_col_headers(fig, y)
    for gk in gen_keys:
        # n_gen is deterministic — use first rep
        gen_n = _safe_get(s0, "censoring_cascade", "trait1", gk, "n_gen")
        window = _safe_get(s0, "censoring_cascade", "trait1", gk, "window")
        n_str = _fmt_int(gen_n)
        win_str = f"ages {window[0]:.0f}\u2013{window[1]:.0f}" if window else ""
        # Generation-specific observed prevalence
        obs1 = [_safe_get(s, "censoring_cascade", "trait1", gk, "observed") for s in all_stats]
        obs2 = [_safe_get(s, "censoring_cascade", "trait2", gk, "observed") for s in all_stats]
        gn = [_safe_get(s, "censoring_cascade", "trait1", gk, "n_gen") for s in all_stats]
        prev_g1 = [o / n if o is not None and n else None for o, n in zip(obs1, gn, strict=True)]
        prev_g2 = [o / n if o is not None and n else None for o, n in zip(obs2, gn, strict=True)]
        shade = not shade
        y = _draw_row2(
            fig,
            ax,
            y,
            f"  {gk}:  n={n_str},  {win_str}",
            _fmt_range_pct(prev_g1),
            _fmt_range_pct(prev_g2),
            shade,
        )

    # Overall mortality rate per 1,000 person-years
    mort_per_1k = []
    for s in all_stats:
        deaths = _safe_get(s, "person_years", "deaths")
        total = _safe_get(s, "person_years", "total")
        if deaths is not None and total and total > 0:
            mort_per_1k.append(deaths / total * 1000)
    if mort_per_1k:
        shade = not shade
        y = _draw_row(
            fig,
            ax,
            y,
            "Mortality rate (per 1,000 PY)",
            _fmt_range_f(mort_per_1k, 1),
            shade,
        )

    # Bottom rule
    fig.add_artist(
        plt.Line2D(
            [_LEFT, _RIGHT], [y - 0.005, y - 0.005], color="black", lw=0.8, transform=fig.transFigure, clip_on=False
        )
    )

    # Footnotes
    footnotes = []
    if n_reps > 1:
        footnotes.append(f"Values are mean [min\u2013max] across {n_reps} replicates where applicable.")
    footnotes.append(
        "\u00b9 Includes youngest generation, whose offspring are outside the phenotyped cohort"
        " (100% childless by design)."
    )
    if footnotes:
        fig.text(
            _LEFT,
            y - 0.02,
            "  ".join(footnotes),
            fontsize=6,
            fontfamily=_FONT,
            color="0.4",
            va="top",
            wrap=True,
            transform=fig.transFigure,
        )

    return fig