"""Paper 28194
# ICML Paper 28194: Function-Valued Causal Influence in Nonlinear Time Series

## Code for all tables and figures in the paper

This code allows to run NAVAR on multiple synthetic datasets (linear / threshold / saturating / sign-changing),
repeat R times per system, and collect the scalar causal score for X->Y. It also contains code for "Varieties of Democracy" Dataset.

How to use:
1) Load the  folder must contain train_NAVAR.py (and its dependencies: NAVAR.py, dataloader.py, evaluate.py, edge_ablation.py, requirements.txt.)

2) Run this script in the same environment where NAVAR runs.

Outputs:
- synthetic_navar_scores.csv (all runs), NAVAR Causal Matrix
- printed summary table of X->Y scores per system (Table 1)
- Figure 2
- Figure 3
"""

#Install dependencies
#!pip -q install numpy pandas scikit-learn torch - required libraries or
!pip -q install -r requirements.txt # from requirements file

#Set global seed for reproducibility
def set_global_seed(seed: int):
    import os
    import random
    import numpy as np
    import torch

    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Make CuDNN deterministic (important)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

"""##Code for Table 1"""

# 1) Import necessary libraries and train_NAVAR
import os
import sys
import math
import numpy as np
import pandas as pd
from train_NAVAR import train_NAVAR

# 2) Synthetic data generator
import numpy as np
import pandas as pd

def _standardize_cols(arr: np.ndarray) -> np.ndarray:
    mu = arr.mean(axis=0)
    sd = arr.std(axis=0)
    sd = np.where(sd < 1e-12, 1.0, sd)
    return (arr - mu) / sd

# Variance-normalize any g(x) under X~N(0,1) ---
def make_scaled(g_raw, seed=0, n_mc=300_000):
    rng = np.random.default_rng(seed)
    xs = rng.normal(size=n_mc)
    vals = np.array([g_raw(float(x)) for x in xs], dtype=float)
    scale = np.sqrt(np.var(vals))
    if not np.isfinite(scale) or scale < 1e-12:
        raise ValueError("Bad scale in make_scaled()")
    return lambda x: g_raw(x) / scale

# 1) Linear (anchor)
def g_linear(x):
    return np.clip(x, -1.2, 1.2)

# 2) Threshold (piecewise constant sign)
def g_threshold(x, c=0.6, a=1.6):
    return a * (np.sign(x) if abs(x) > c else 0.0)

# 3) Saturating (hard saturation)
def g_saturating(x):
    return np.clip(x, -1.0, 1.0)

# 4) Sign-changing (clean reversal)
def g_sign_changing(x, c=0.6):
    return -x if abs(x) < c else x

SCALE = 1.0
g = lambda f: lambda x: SCALE * f(x)


# Mapping
SYSTEMS = {
    "linear": g_linear,
    "threshold": g_threshold,
    "saturating": g_saturating,
    "sign_changing": g_sign_changing,
}

def generate_dataset(
    g_func,
    T: int = 2000,
    noise_std: float = 0.3,
    seed: int = 0,
) -> pd.DataFrame:
    rng = np.random.default_rng(seed)
    X = np.zeros(T)
    Y = np.zeros(T)
    Z = np.zeros(T)

    for t in range(1, T):
        eps = rng.normal(0.0, noise_std)
        jump = rng.choice([-1.5, 1.5]) if rng.random() < 0.15 else 0.0
        X[t] = 0.6 * X[t-1] + eps + jump
        Y[t] = 0.3 * Y[t-1] + g_func(X[t-1]) + rng.normal(0.0, noise_std)
        Z[t] = 0.6 * Z[t-1] + rng.normal(0.0, noise_std)

    data = np.column_stack([X, Y, Z])
    data = _standardize_cols(data)
    return pd.DataFrame(data, columns=["X", "Y", "Z"])

# 3) NAVAR training config
NAVAR_CONFIG = dict(
    maxlags=1,
    hidden_nodes=16,
    hidden_layers=1,
    dropout=0.10,
    epochs=2000,
    batch_size=128,
    learning_rate=3e-4,
    lambda1=0.15,          # sparsity penalty on contributions
    weight_decay=0.001,
    val_proportion=0.10,   # hold-out tail as validation (per NAVAR's split method)
    check_every=2000,
    normalize=False,       # we already z-score standardized
    lstm=False,
    split_timeseries=None,
    use_cuda=False,        # set True if your environment supports CUDA in NAVAR
)

def _parse_train_navar_output(out):
    # Case 1: dict-like return
    if isinstance(out, dict):
        causal_matrix = out.get("causal_matrix", None)
        contributions = out.get("contributions", None)
        val_loss = out.get("val_loss", out.get("validation_loss", np.nan))
        if causal_matrix is None:
            raise ValueError("train_NAVAR returned a dict but missing key 'causal_matrix'.")
        return causal_matrix, contributions, val_loss

    # Case 2: tuple/list return with >= 1 element
    if isinstance(out, (tuple, list)):
        if len(out) < 1:
            raise ValueError("train_NAVAR returned an empty tuple/list.")

        causal_matrix = out[0]
        contributions = out[1] if len(out) >= 2 else None
        val_loss = out[2] if len(out) >= 3 else np.nan

        return causal_matrix, contributions, val_loss

    # Case 3: unexpected return type
    raise ValueError(f"Unexpected train_NAVAR return type: {type(out)}")


def train_navar_safe(data_np: np.ndarray, cfg: dict):
rns: causal_matrix, contributions, val_loss

# 4) Run experiments

SYSTEMS = {
    "linear": g_linear,
    "threshold": g_threshold,
    "saturating": g_saturating,
    "sign_changing": g_sign_changing,
}

def run_experiment(
    systems: dict,
    repeats: int = 15,
    T: int = 2000,
    noise_std: float = 0.3,
    base_seed: int = 1000,
    navar_cfg: dict = None,
) -> pd.DataFrame:
    if navar_cfg is None:
        navar_cfg = NAVAR_CONFIG

    rows = []
    for system_name, g_func in systems.items():
        for r in range(repeats):
            seed = base_seed + 37 * r
            df = generate_dataset(g_func, T=T, noise_std=noise_std, seed=seed)

            data_np = df.values

            causal_matrix, contributions, val_loss = train_navar_safe(data_np, navar_cfg)

            # indices: X=0, Y=1, Z=2
            score_xy = float(causal_matrix[0, 1])

            # rank among off-diagonal edges (lower is better; 1 means top)
            cm = np.array(causal_matrix, dtype=float)
            np.fill_diagonal(cm, 0.0)
            off = cm.flatten()
            # rank as 1 + count of edges strictly greater than score_xy
            rank = int(1 + np.sum(off > score_xy))

            rows.append({
                "system": system_name,
                "run": r,
                "seed": seed,
                "T": T,
                "noise_std": noise_std,
                "score_X_to_Y": score_xy,
                "rank_offdiag": rank,
                "val_loss": float(val_loss) if val_loss is not None else np.nan,
            })

            print(f"[{system_name:12s}] run={r:02d} score_X->Y={score_xy:.6f} rank={rank}")

    return pd.DataFrame(rows)

# 5) Execute + summarize

SEED = 42 #For reproducibility
set_global_seed(SEED)
if __name__ == "__main__":
    df_res = run_experiment(SYSTEMS, repeats=15, T=2000, noise_std=0.5)

    out_csv = "synthetic_navar_scores.csv"
    df_res.to_csv(out_csv, index=False)
    print(f"\nSaved results to: {out_csv}\n")

    summary = df_res.groupby("system")["score_X_to_Y"].agg(
        ["count", "mean", "std", "min", "max"]
    ).sort_index()
    print("Summary of NAVAR scalar scores for X->Y:")
    print(summary)

    # Rank sanity check
    rank_summary = df_res.groupby("system")["rank_offdiag"].value_counts().sort_index()
    print("\nRank distribution (off-diagonal; 1 = top edge):")
    print(rank_summary)

#Summary by system
summary = df_res.groupby("system")["score_X_to_Y"].agg(
    n="count",
    mean="mean",
    std="std",
    q25=lambda s: s.quantile(0.25),
    q50="median",
    q75=lambda s: s.quantile(0.75),
    min="min",
    max="max"
).sort_index()
print(summary)

rank_summary = (
    df_res.groupby(["system", "rank_offdiag"])
    .size()
    .unstack(fill_value=0)
    .sort_index()
)
print(rank_summary)

"""## Code for Figure 2
Self-containing; has some of the same code as above, but combined with additional lines for ease of reproducibility
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from train_NAVAR import train_NAVAR

SEED = 42
set_global_seed(SEED)

# 1) Data generation (same as before)

def _standardize_cols(arr: np.ndarray) -> np.ndarray:
    mu = arr.mean(axis=0)
    sd = arr.std(axis=0)
    sd = np.where(sd < 1e-12, 1.0, sd)
    return (arr - mu) / sd

def generate_dataset(g_func, T=2000, noise_std=0.3, seed=0) -> pd.DataFrame:
    rng = np.random.default_rng(seed)
    X = np.zeros(T); Y = np.zeros(T); Z = np.zeros(T)
    for t in range(1, T):
        X[t] = 0.6 * X[t-1] + rng.normal(0.0, noise_std)
        Y[t] = 0.3 * Y[t-1] + g_func(X[t-1]) + rng.normal(0.0, noise_std)
        Z[t] = 0.6 * Z[t-1] + rng.normal(0.0, noise_std)
    data = np.column_stack([X, Y, Z])
    data = _standardize_cols(data)
    return pd.DataFrame(data, columns=["X", "Y", "Z"])

# 2) Four mechanisms

def make_scaled(g_raw, seed=0, n_mc=300_000):
    rng = np.random.default_rng(seed)
    xs = rng.normal(size=n_mc)
    vals = np.array([g_raw(float(x)) for x in xs], dtype=float)
    scale = np.sqrt(np.var(vals))
    if not np.isfinite(scale) or scale < 1e-12:
        raise ValueError("Bad scale in make_scaled()")
    return lambda x: g_raw(x) / scale

# 1) Linear (anchor)
def g_linear(x):
    return np.clip(x, -1.2, 1.2)

# 2) Threshold (piecewise constant sign)
def g_threshold(x, c=0.6, a=1.6):
    return a * (np.sign(x) if abs(x) > c else 0.0)

# 3) Saturating (hard saturation)
def g_saturating(x):
    return np.clip(x, -1.0, 1.0)

# 4) Sign-changing (clean reversal)
def g_sign_changing(x, c=0.6):
    return -x if abs(x) < c else x

SCALE = 1.0
g = lambda f: lambda x: SCALE * f(x)

SYSTEMS = {
    "linear": g_linear,
    "threshold": g_threshold,
    "saturating": g_saturating,
    "sign_changing": g_sign_changing,
}


# 3) Torch-safe conversion + model forward helper

def to_numpy(x):
    if x is None:
        return None
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except Exception:
        pass
    return np.asarray(x)

def model_predict(model, Xwin_np):
    """
    Runs model forward in eval mode.
    Xwin_np expected shape: (W, N, maxlags)

    Returns: preds_np shape (W, N) (or (W, N, 1) depending on fork)
    """
    import torch
    model.eval()

    X_t = torch.tensor(Xwin_np, dtype=torch.float32)
    # Move to same device as model
    device = next(model.parameters()).device
    X_t = X_t.to(device)

    with torch.no_grad():
        out = model(X_t)
        # Many NAVAR forks return (preds, contribs) or just preds
        if isinstance(out, (tuple, list)):
            preds = out[0]
        else:
            preds = out

    preds_np = preds.detach().cpu().numpy()
    # Squeeze trailing singleton dims if needed
    if preds_np.ndim == 3 and preds_np.shape[-1] == 1:
        preds_np = preds_np[:, :, 0]
    return preds_np


# 4) Build windows: (W, N, maxlags)
def build_lag_windows(data_np, maxlags):
    """
    data_np: (T, N)
    Returns Xwin: (W, N, maxlags) where W = T - maxlags
    and target time corresponds to t in [maxlags..T-1]
    """
    T, N = data_np.shape
    W = T - maxlags
    Xwin = np.zeros((W, N, maxlags), dtype=float)
    for i, t in enumerate(range(maxlags, T)):
        # window uses times t-maxlags .. t-1
        window = data_np[t-maxlags:t, :].T  # (N, maxlags)
        Xwin[i] = window
    return Xwin



# 5) ICE / partial dependence response curve for edge X->Y

def response_curve_ice(model, Xwin, src_idx, tgt_idx, grid, lag_pos=-1, baseline="observed"):
    """
    model: trained NAVAR model
    Xwin: (W, N, maxlags)
    src_idx: source variable index (0 for X)
    tgt_idx: target variable index (1 for Y)
    grid: array of values to set X[t-1] to
    lag_pos: -1 means most recent lag in the window (t-1)
    baseline: "observed" uses original Xwin as baseline,
              or supply a numeric baseline value for X[t-1] (e.g. 0.0)
    Returns: (grid, mean_delta)
    """
    # Baseline predictions
    if baseline == "observed":
        X_base = Xwin
    else:
        X_base = Xwin.copy()
        X_base[:, src_idx, lag_pos] = float(baseline)

    pred_base = model_predict(model, X_base)[:, tgt_idx]

    deltas = []
    for xval in grid:
        X_mod = Xwin.copy()
        X_mod[:, src_idx, lag_pos] = float(xval)
        pred_mod = model_predict(model, X_mod)[:, tgt_idx]
        deltas.append(pred_mod - pred_base)

    deltas = np.stack(deltas, axis=0)  # (G, W)
    mean_delta = deltas.mean(axis=1)   # (G,)
    return grid, mean_delta



# 6) Run one paired-seed per system and plot Figure 3

def run_figure3_ice(
    systems,
    seed=1000,
    T=2000,
    noise_std=0.3,
    maxlags=1,
    grid=None
):
    if grid is None:
        grid = np.linspace(-2.5, 2.5, 81)

    src_idx, tgt_idx = 0, 1
    fig, axes = plt.subplots(2, 2, figsize=(11, 8), constrained_layout=True)
    axes = axes.ravel()

    rows = []

    for ax, (name, g_func) in zip(axes, systems.items()):
        df = generate_dataset(g_func, T=T, noise_std=noise_std, seed=seed)
        data = df.values

        # Train NAVAR (your fork signature)
        causal_matrix, contributions, val_loss, model = train_NAVAR(
            data,
            maxlags=maxlags,
            hidden_nodes=16,
            dropout=0.1,
            epochs=2000,
            learning_rate=3e-4,
            batch_size=128,
            lambda1=0.15,
            val_proportion=0.10,
            weight_decay=0.001,
            check_every=2000,
            hidden_layers=1,
            normalize=False,
            split_timeseries=None,
            lstm=False
        )

        cm = to_numpy(causal_matrix)
        score = float(cm[src_idx, tgt_idx])

        # Build windows and compute ICE-style response
        Xwin = build_lag_windows(data, maxlags=maxlags)
        gx, mean_delta = response_curve_ice(
            model=model,
            Xwin=Xwin,
            src_idx=src_idx,
            tgt_idx=tgt_idx,
            grid=grid,
            lag_pos=-1,
            baseline="observed",
        )

        ax.plot(gx, mean_delta, linewidth=2)
        ax.axhline(0.0, linewidth=1)
        ax.set_title(f"{name} | score(X→Y)={score:.3f}")
        ax.set_xlabel("Set value for X[t−1] (standardized)")
        ax.set_ylabel("Δ predicted Y (avg over windows)")
        ax.grid(True, alpha=0.3)
        ax.set_ylim(-2.5, 2.5)

        rows.append({"system": name, "score_X_to_Y": score, "val_loss": float(val_loss)})

    return fig, pd.DataFrame(rows).sort_values("system")


# ---- Execute ----
fig, score_df = run_figure3_ice(SYSTEMS, seed=1000, maxlags=1)
print(score_df)


fig.savefig("figure3_ice_response.png", dpi=200)
plt.show()

"""# Case Study: Causality in Democracy

## NAVAR Causal Matrix (for Appendix)
"""

#Install dependencies
#!pip -q install numpy pandas scikit-learn torch - required libraries or
!pip -q install -r requirements.txt # from requirements file

#Extract NAVAR-ready data for NAVAR-only processing
import pandas as pd
import numpy as np

df = pd.read_csv("my_data.csv")

DROP_COLS = ["country", "country_id", "year"]


cols_to_drop = [c for c in DROP_COLS if c in df.columns]

df_navar = df.drop(columns=cols_to_drop)


variable_names = df_navar.columns.tolist()
data = df_navar.to_numpy(dtype=np.float32)

#Starting configuration
from train_NAVAR import train_NAVAR
import pandas as pd

causal_matrix, contributions, val_loss, model = train_NAVAR(
    data,
    maxlags=8,
    hidden_nodes=32,
    dropout=0.1,
    epochs=1000,
    learning_rate=3e-4,
    batch_size=128,
    lambda1=0.15,
    val_proportion=0.10,
    weight_decay=0.001,
    check_every=200,
    hidden_layers=1,
    normalize=True,
    split_timeseries=35,
    lstm=False
)

#NAVAR Causal Matrix
import pandas as pd


causal_df = pd.DataFrame(
    causal_matrix,
    index=variable_names,    # rows = source variables
    columns=variable_names   # columns = target variables
)

causal_df.to_csv("navar_causal_matrix.csv")

#Heatmaps with variable names
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

S = causal_df.copy()
np.fill_diagonal(S.values, 0)   # optional: hide self-effects

plt.figure(figsize=(8, 6))
sns.heatmap(
    S,
    cmap="RdBu_r", #Red to blue, could try "viridis",
    square=True,
    cbar_kws={"label": "NAVAR causal score"}
)
plt.title("NAVAR causal score matrix")
plt.xlabel("Target variable")
plt.ylabel("Source variable")
plt.tight_layout()
plt.show()

def set_global_seed(seed: int):
    import os
    import random
    import numpy as np
    import torch

    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Make CuDNN deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

"""## Code for Figure 3"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from train_NAVAR import train_NAVAR

# Plotting Defaults

plt.rcParams.update({
    "font.size": 6,
    "axes.titlesize": 6,
    "axes.labelsize": 6,
    "legend.fontsize": 5,
    "xtick.labelsize": 5,
    "ytick.labelsize": 5,
    "figure.dpi": 200,
    "savefig.dpi": 300,
})

def to_numpy(x):
    if x is None:
        return None
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except Exception:
        pass
    return np.asarray(x)


def model_predict_batched(model, Xwin_np, batch_size=512):
    import torch
    model.eval()
    device = next(model.parameters()).device

    W = Xwin_np.shape[0]
    preds_out = []

    with torch.no_grad():
        for i in range(0, W, batch_size):
            xb = torch.tensor(Xwin_np[i:i+batch_size], dtype=torch.float32, device=device)
            out = model(xb)
            # Common NAVAR forks: out = (preds, contribs) or preds only
            preds = out[0] if isinstance(out, (tuple, list)) else out
            preds = preds.detach().cpu().numpy()
            if preds.ndim == 3 and preds.shape[-1] == 1:
                preds = preds[:, :, 0]
            preds_out.append(preds)

    return np.vstack(preds_out)

def build_panel_lag_windows(data_np, L, K):

    T_total, N = data_np.shape
    assert T_total % L == 0, "Total rows must be multiple of segment length L."
    U = T_total // L
    W_per = L - K
    W = U * W_per

    Xwin = np.zeros((W, N, K), dtype=float)
    t_idx = np.zeros(W, dtype=int)
    unit_idx = np.zeros(W, dtype=int)

    w = 0
    for u in range(U):
        start = u * L
        end = start + L
        for t in range(start + K, end):
            # window is times [t-K .. t-1], shape (N, K)
            Xwin[w] = data_np[t-K:t, :].T
            t_idx[w] = t
            unit_idx[w] = u
            w += 1

    return Xwin, {"t_idx": t_idx, "unit_idx": unit_idx}

# Regime bins based on target lag value
def make_tertiles(values):

    q1, q2 = np.quantile(values, [1/3, 2/3])
    bins = np.zeros_like(values, dtype=int)
    bins[values > q1] = 1
    bins[values > q2] = 2
    return bins, (q1, q2)

# ICE computation (lag-aggregated intervention)

def lag_aggregated_ice(
    model,
    Xwin,
    src_idx,
    tgt_idx,
    grid,
    regime_bins,
    batch_size=512
):

    W, N, K = Xwin.shape

    # Baseline prediction for target
    pred_base = model_predict_batched(model, Xwin, batch_size=batch_size)[:, tgt_idx]

    mean_delta = np.zeros((3, len(grid)), dtype=float)

    # Pre-compute indices per regime
    idxs = [np.where(regime_bins == b)[0] for b in range(3)]

    for gi, xval in enumerate(grid):
        Xmod = Xwin.copy()
        # lag-aggregated: set ALL lags of src to xval
        Xmod[:, src_idx, :] = float(xval)

        pred_mod = model_predict_batched(model, Xmod, batch_size=batch_size)[:, tgt_idx]
        delta = pred_mod - pred_base

        for b in range(3):
            if len(idxs[b]) == 0:
                mean_delta[b, gi] = np.nan
            else:
                mean_delta[b, gi] = float(delta[idxs[b]].mean())

    return mean_delta

# Main: run NAVAR + produce Figure 3 (Section 5.2)
SEED = 42
set_global_seed(SEED)

def main():
    # ---- Load data ----
    df = pd.read_csv("my_data.csv")
    assert "country_id" in df.columns and "year" in df.columns

    # Keep column order for NAVAR exactly as in file
    meta_cols = ["country_id", "year"]
    var_cols = [c for c in df.columns if c not in meta_cols]

    # Panel settings
    L = 35  # 35 years per country
    K = 8   # maxlags per your hyperparameters

    # Extract numeric matrix for NAVAR (T_total x N)
    data = df[var_cols].to_numpy(dtype=float)

    # ---- Train NAVAR
    causal_matrix, contributions, val_loss, model = train_NAVAR(
        data,
        maxlags=K,
        hidden_nodes=32,
        dropout=0.1,
        epochs=1000,
        learning_rate=3e-4,
        batch_size=128,
        lambda1=0.15,
        val_proportion=0.10,
        weight_decay=0.001,
        check_every=200,
        hidden_layers=1,
        normalize=True,
        split_timeseries=L,
        lstm=False
    )

    cm = to_numpy(causal_matrix)

    # Hard-code target and indices
    target_name = "Clean_elections"
    assert target_name in var_cols, f"Target {target_name} not found."

    edges = [
        ("Freedom_of_expression", target_name),
        ("Judicial_constraints", target_name),
        ("Legislative_constraints", target_name),
        ("Freedom_of_association", target_name),
    ]

    name_to_idx = {name: i for i, name in enumerate(var_cols)}
    tgt_idx = name_to_idx[target_name]

    # ---- Build panel windows for ICE ----
    data_norm = (data - data.mean(axis=0)) / (data.std(axis=0) + 1e-12)

    Xwin, meta = build_panel_lag_windows(data_norm, L=L, K=K)

    y_lag1 = Xwin[:, tgt_idx, -1]
    regime_bins, (q1, q2) = make_tertiles(y_lag1)

    # ---- ICE grid ----

    grid = np.linspace(np.quantile(y_lag1, 0.02), np.quantile(y_lag1, 0.98), 81)

    # ---- Plot settings ----
    fig, axes = plt.subplots(2, 2, figsize=(7.2, 5.6), constrained_layout=True)
    axes = axes.ravel()

    labels = ["Low (bottom tercile)", "Mid (middle tercile)", "High (top tercile)"]

    # Compute global y-limits across panels for comparability
    all_curves = []

    # ---- Compute and plot each edge ----
    for ax, (src_name, tgt_name) in zip(axes, edges):
        assert src_name in name_to_idx, f"Source {src_name} not found."
        src_idx = name_to_idx[src_name]

        # Scalar score for this edge
        score = float(cm[src_idx, tgt_idx])

        mean_delta = lag_aggregated_ice(
            model=model,
            Xwin=Xwin,
            src_idx=src_idx,
            tgt_idx=tgt_idx,
            grid=grid,
            regime_bins=regime_bins,
            batch_size=512
        )

        # Store for global y-limits
        all_curves.append(mean_delta)

        # Plot curves
        for b in range(3):
            ax.plot(grid, mean_delta[b], linewidth=1.8, label=labels[b])

        ax.axhline(0.0, linewidth=1.0)
        ax.set_title(f"{src_name} → {tgt_name}\nscalar score={score:.3f}", pad=12)
        ax.set_xlabel(f"Set {src_name}[t−1..t−{K}] to x (standardized)")
        ax.set_ylabel("Δ predicted target (mean)")

        ax.grid(True, alpha=0.25)

    # Global y-limits
    all_vals = np.concatenate([c.reshape(-1) for c in all_curves])
    all_vals = all_vals[np.isfinite(all_vals)]
    if len(all_vals) > 0:
        lim = np.quantile(np.abs(all_vals), 0.98)
        lim = max(lim, 1e-3)
        for ax in axes:
            ax.set_ylim(-lim, lim)

    # One shared legend
    handles, labels_ = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels_, loc="lower center", ncol=3, frameon=False)

    fig.savefig("Figure4_Democracy_ICE.png", bbox_inches="tight")
    fig.savefig("Figure4_Democracy_ICE.pdf", bbox_inches="tight")

    print("Saved: Figure4_Democracy_ICE.png and Figure4_Democracy_ICE.pdf")
    print(f"Validation loss: {val_loss}")
    print("Edge scalar scores:")
    for src_name, _ in edges:
        src_idx = name_to_idx[src_name]
        print(f"  {src_name} -> {target_name}: {float(cm[src_idx, tgt_idx]):.4f}")


if __name__ == "__main__":
    main()