#!/usr/bin/env python3
"""Evaluate a trained TabularACE model using the exact training config.

Builds the model with the provided, exact architecture (no inference), then
loads the checkpoint with strict=True. To avoid key-name drift between code
versions, it mirrors embedder parameters under both 'embedder.' and
'tabular_embedder.' prefixes when needed so there are no missing keys.

Generates the same side-by-side plots as eval_tabular_grid.py for either
MLP-SCM prior or synthetic 2D functions.
"""

import argparse
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.interpolate import griddata

from src.models.ace import AmortizedConditioningEngine
from src.models.tabular_embedder import TabularACE
from src.models.masks import create_training_block_mask
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr

MLPSCM = None  # lazy import


def _get_mlpscm():
    global MLPSCM
    if MLPSCM is None:
        try:
            from tabicl.prior.mlp_scm import MLPSCM as _MLP
        except Exception:
            import sys as _sys
            from pathlib import Path as _P
            _sys.path.append(str(_P(__file__).resolve().parent.parent / "tabicl" / "src"))
            from tabicl.prior.mlp_scm import MLPSCM as _MLP
        MLPSCM = _MLP
    return MLPSCM


def build_model_exact(device: torch.device) -> AmortizedConditioningEngine:
    """Construct TabularACE using the exact training config provided by user.

    Key architecture choices:
    - embed_dim=128, concat_cls=True, num_cls_tokens=2 → eff_dim_model=256
    - backbone: 12 layers, nhead=4, ff_factor=2 → dim_ff=512
    - embedder: 3 ISAB blocks (col) with 128 inducing pts; row encoder 3 blocks
    - max_buffer_size=32, num_target_points=512
    - row_rope_base=100000
    - num_features set to 10 (as trained); evaluation with 2D inputs still works
    """
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,  # embedder/row encoder FF; backbone uses ff_factor below
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)
    return model


def _augment_state_dict_keys(sd: dict) -> dict:
    """Mirror embedder keys under both prefixes to avoid missing keys.

    Some checkpoints/models reference the same module as either 'embedder.xxx'
    or 'tabular_embedder.xxx'. We ensure both forms exist and point to the
    same tensors for strict loading.
    """
    out = dict(sd)
    # embedder → tabular_embedder
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    # tabular_embedder → embedder
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def load_checkpoint_strict(model: AmortizedConditioningEngine, ckpt_path: str) -> None:
    ckpt = torch.load(ckpt_path, map_location=model.ar_token.device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)


def mixture_mean(means: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    if weights.shape[-1] == 1:
        w = weights
    else:
        w = weights
    return (w * means).sum(dim=2)


def sample_mixture(means: torch.Tensor, stds: torch.Tensor, weights: torch.Tensor, num_samples: int) -> torch.Tensor:
    """Monte Carlo samples from a single-output mixture.

    means/stds: [B, T, K, 1]; weights: [B, T, K, 1]
    returns samples: [B, T, num_samples, 1]
    """
    B, T, K, D = means.shape
    assert D == 1, "This sampler assumes dim_y=1"
    # Flatten for categorical sampling
    w = weights.squeeze(-1)  # [B, T, K]
    comp_idx = torch.multinomial(w.view(-1, K), num_samples, replacement=True)  # [B*T, S]
    comp_idx = comp_idx.view(B, T, num_samples)
    # Gather component params
    means_k = means.squeeze(-1)  # [B, T, K]
    stds_k = stds.squeeze(-1)    # [B, T, K]
    gather_idx = comp_idx.unsqueeze(-1)  # [B, T, S, 1]
    means_sel = torch.gather(means_k.unsqueeze(2).expand(-1, -1, num_samples, -1), 3, gather_idx).squeeze(-1)
    stds_sel = torch.gather(stds_k.unsqueeze(2).expand(-1, -1, num_samples, -1), 3, gather_idx).squeeze(-1)
    # Reparameterize
    eps = torch.randn_like(means_sel)
    samples = means_sel + stds_sel * eps  # [B, T, S]
    return samples.unsqueeze(-1)


def _choose_uniform_context(X_np: np.ndarray, nc: int) -> np.ndarray:
    assert X_np.shape[1] >= 2, "Uniform context selection expects 2D inputs"
    n = X_np.shape[0]
    gx = int(np.ceil(np.sqrt(nc)))
    mins = X_np[:, :2].min(axis=0)
    maxs = X_np[:, :2].max(axis=0)
    edges0 = np.linspace(mins[0], maxs[0], gx + 1)
    edges1 = np.linspace(mins[1], maxs[1], gx + 1)

    chosen = []
    ix0 = np.clip(np.digitize(X_np[:, 0], edges0) - 1, 0, gx - 1)
    ix1 = np.clip(np.digitize(X_np[:, 1], edges1) - 1, 0, gx - 1)
    bins = {}
    rng = np.random.default_rng(0)
    for idx in rng.permutation(n):
        key = (ix0[idx], ix1[idx])
        if key not in bins:
            bins[key] = idx
            chosen.append(idx)
            if len(chosen) >= nc:
                break
    if len(chosen) < nc:
        remaining = np.setdiff1d(np.arange(n), np.array(chosen, dtype=int), assume_unique=False)
        extra = rng.choice(remaining, size=nc - len(chosen), replace=False) if remaining.size >= (nc - len(chosen)) else remaining
        chosen = list(chosen) + list(extra)
    return np.array(chosen[:nc], dtype=int)


def _generate_mlpscm():
    prior_cls = _get_mlpscm()
    return prior_cls(
        seq_len=5000,
        num_features=2,
        num_outputs=1,
        is_causal=True,
        num_causes=np.random.randint(2, 5),
        y_is_effect=True,
        in_clique=False,
        sort_features=True,
        num_layers=np.random.randint(3, 6),
        hidden_dim=32,
        mlp_activations=torch.nn.Tanh,
        init_std=np.random.uniform(0.8, 2.0),
        block_wise_dropout=True,
        mlp_dropout_prob=np.random.uniform(0.05, 0.2),
        sampling="mixed",
        pre_sample_cause_stats=True,
        noise_std=0.01,
        pre_sample_noise_std=False,
        device="cpu",
    )


def _generate_synth_grid(grid_size: int = 64, noise_std: float = 0.01) -> Tuple[torch.Tensor, torch.Tensor]:
    g = grid_size
    x = np.linspace(0.0, 1.0, g)
    y = np.linspace(0.0, 1.0, g)
    Xg, Yg = np.meshgrid(x, y)

    cx, cy = 0.5, 0.5
    X0, Y0 = Xg - cx, Yg - cy
    theta = np.random.uniform(0, 2 * np.pi)
    s1 = np.random.uniform(0.6, 1.4)
    s2 = np.random.uniform(0.6, 1.4)
    cth, sth = np.cos(theta), np.sin(theta)
    Xw = s1 * (cth * X0 - sth * Y0)
    Yw = s2 * (sth * X0 + cth * Y0)

    kinds = ["polywave", "rbf_mix", "ripples", "checker", "saddle", "spiral", "fourier", "ridge"]
    kind = np.random.choice(kinds)

    if kind == "polywave":
        a = np.random.uniform(0.3, 0.8)
        b = np.random.uniform(0.2, 0.6)
        f = (
            a * np.sin(2 * np.pi * (np.random.uniform(1.0, 3.0) * Xw + np.random.uniform(1.0, 3.0) * Yw))
            + b * np.cos(2 * np.pi * (np.random.uniform(0.5, 2.0) * Xw - np.random.uniform(0.5, 2.0) * Yw))
            + np.random.uniform(0.2, 0.5) * (Xw) ** 2
            - np.random.uniform(0.1, 0.3) * (Yw)
        )
    elif kind == "rbf_mix":
        k = np.random.randint(2, 5)
        cx0 = np.random.uniform(-0.3, 0.3, size=k)
        cy0 = np.random.uniform(-0.3, 0.3, size=k)
        sigx = np.random.uniform(0.05, 0.2, size=k)
        sigy = np.random.uniform(0.05, 0.2, size=k)
        amp = np.random.uniform(-1.0, 1.0, size=k)
        f = np.zeros_like(Xw)
        for i in range(k):
            rot = np.random.uniform(0, 2 * np.pi)
            cr, sr = np.cos(rot), np.sin(rot)
            Xr = cr * (Xw - cx0[i]) - sr * (Yw - cy0[i])
            Yr = sr * (Xw - cx0[i]) + cr * (Yw - cy0[i])
            r2 = (Xr / sigx[i]) ** 2 + (Yr / sigy[i]) ** 2
            f += amp[i] * np.exp(-0.5 * r2)
    elif kind == "ripples":
        f = (
            np.random.uniform(0.3, 0.7) * np.sin(2 * np.pi * np.random.uniform(3, 8) * Xw)
            * np.sin(2 * np.pi * np.random.uniform(3, 8) * Yw)
            + np.random.uniform(0.2, 0.5) * np.sin(2 * np.pi * (Xw + np.random.uniform(0.2, 0.8) * Yw))
        )
    elif kind == "checker":
        kx = np.random.randint(2, 8)
        ky = np.random.randint(2, 8)
        phx = np.random.uniform(0, 2 * np.pi)
        phy = np.random.uniform(0, 2 * np.pi)
        f = np.sin(2 * np.pi * kx * Xw + phx) * np.sin(2 * np.pi * ky * Yw + phy)
    elif kind == "saddle":
        a = np.random.uniform(0.4, 1.0)
        b = np.random.uniform(-0.5, 0.5)
        f = a * (Xw ** 2 - Yw ** 2) + b * (Xw * Yw) + 0.3 * np.sin(2 * np.pi * (Xw + 0.5 * Yw))
    elif kind == "spiral":
        R = np.sqrt(Xw ** 2 + Yw ** 2)
        T = np.arctan2(Yw, Xw)
        a = np.random.uniform(6, 14)
        b = np.random.uniform(3, 8)
        f = np.sin(a * R + b * T)
    elif kind == "fourier":
        n = np.random.randint(3, 6)
        f = np.zeros_like(Xw)
        for _ in range(n):
            ux, uy = np.random.uniform(-4, 4), np.random.uniform(-4, 4)
            phase = np.random.uniform(0, 2 * np.pi)
            amp = np.random.uniform(-0.7, 0.7)
            f += amp * np.sin(2 * np.pi * (ux * Xw + uy * Yw) + phase)
    else:  # ridge
        theta_r = np.random.uniform(0, 2 * np.pi)
        cr, sr = np.cos(theta_r), np.sin(theta_r)
        proj = cr * Xw + sr * Yw
        m = np.random.uniform(4, 12)
        c0 = np.random.uniform(-0.3, 0.3)
        f = np.tanh(m * (proj - c0)) + 0.3 * np.cos(2 * np.pi * (Xw + 0.3 * Yw))

    if np.random.rand() < 0.3:
        f = np.tanh(f)
    elif np.random.rand() < 0.3:
        f = np.sign(f) * np.sqrt(np.abs(f))

    f += noise_std * np.random.randn(*f.shape)
    X_all = np.stack([Xg.ravel(), Yg.ravel()], axis=1)
    y_all = f.reshape(-1, 1)
    return torch.from_numpy(X_all).float(), torch.from_numpy(y_all).float()


def eval_one_function(
    model: AmortizedConditioningEngine,
    device: torch.device,
    nc: int,
    nb: int,
    nt: int,
    *,
    norm_method: str = "power",
    uniform_context: bool = False,
    source: str = "mlpscm",
    grid_size: int = 64,
    noise_std: float = 0.01,
    attending_chunks: int = 16,
    include_diagonal: bool = False,
    q_block_size: int = 128,
    kv_block_size: int = 128,
    mc_samples: int = 0,
):
    if source == "mlpscm":
        prior = _generate_mlpscm()
        with torch.no_grad():
            X_all, y_all = prior()
    else:
        X_all, y_all = _generate_synth_grid(grid_size=grid_size, noise_std=noise_std)
    if y_all.dim() == 1:
        y_all = y_all.unsqueeze(-1)

    N = X_all.shape[0]
    if uniform_context:
        idx_c = _choose_uniform_context(X_all.cpu().numpy(), nc)
        mask = np.ones(N, dtype=bool)
        mask[idx_c] = False
        rem = np.nonzero(mask)[0]
        rem_perm = np.random.permutation(rem)
        idx_b = rem_perm[:nb]
        idx_t = rem_perm[nb:nb + nt]
        idx_c_t = torch.from_numpy(idx_c)
        idx_b_t = torch.from_numpy(idx_b)
        idx_t_t = torch.from_numpy(idx_t)
    else:
        perm = torch.randperm(N)
        idx_c_t = perm[:nc]
        idx_b_t = perm[nc:nc + nb]
        idx_t_t = perm[nc + nb:nc + nb + nt]

    xc = X_all[idx_c_t]
    yc = y_all[idx_c_t]
    xb = X_all[idx_b_t]
    yb = y_all[idx_b_t]
    xt = X_all[idx_t_t]
    yt = y_all[idx_t_t]

    scaler = TabICLScaler(normalization_method=norm_method, outlier_threshold=4.0, random_state=0)
    scaler.fit(xc.numpy(), yc.numpy())
    xc_s, yc_s = scaler.transform_xy(xc.numpy(), yc.numpy())
    xb_s, yb_s = scaler.transform_xy(xb.numpy(), yb.numpy())
    xt_s, yt_s = scaler.transform_xy(xt.numpy(), yt.numpy())

    batch = DataAttr(
        xc=torch.from_numpy(xc_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        yc=torch.from_numpy(yc_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        xb=torch.from_numpy(xb_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        yb=torch.from_numpy(yb_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        xt=torch.from_numpy(xt_s).unsqueeze(0).to(device=device, dtype=torch.float32),
        yt=torch.from_numpy(yt_s).unsqueeze(0).to(device=device, dtype=torch.float32),
    )

    R = nc + nb + nt
    mask = create_training_block_mask(
        current_total_q_len=R,
        current_total_kv_len=R,
        current_context_section_len=nc,
        current_buffer_section_len=nb,
        attending_chunks=attending_chunks,
        include_diagonal=include_diagonal,
        q_block_size=q_block_size,
        kv_block_size=kv_block_size,
        device=device,
    )

    with torch.no_grad():
        out = model(batch, mask)
        means = out.means
        weights = out.weights
        if mc_samples and mc_samples > 0:
            sds = out.sds
            smp = sample_mixture(means, sds, weights, num_samples=mc_samples)
            pred = smp.mean(dim=2)  # [B, T, 1]
        else:
            pred = mixture_mean(means, weights)  # [B, T, 1]

    return (xc_s, yc_s, xt_s, yt_s, pred.squeeze(0).cpu().numpy())


def plot_side_by_side(ax_true, ax_pred, Xc, Yc, Xt, Yt, Ypred, title_idx: int):
    x_min = min(Xc[:, 0].min(), Xt[:, 0].min())
    x_max = max(Xc[:, 0].max(), Xt[:, 0].max())
    y_min = min(Xc[:, 1].min(), Xt[:, 1].min())
    y_max = max(Xc[:, 1].max(), Xt[:, 1].max())

    grid_size = 80
    xi = np.linspace(x_min, x_max, grid_size)
    yi = np.linspace(y_min, y_max, grid_size)
    Xi, Yi = np.meshgrid(xi, yi)

    Zi_true = griddata(Xt[:, :2], Yt[:, 0], (Xi, Yi), method='linear', fill_value=np.nan)
    Zi_pred = griddata(Xt[:, :2], Ypred[:, 0], (Xi, Yi), method='linear', fill_value=np.nan)

    vmin = np.nanmin([Zi_true, Zi_pred])
    vmax = np.nanmax([Zi_true, Zi_pred])

    ax_true.contourf(Xi, Yi, Zi_true, levels=15, cmap='RdBu_r', alpha=0.9, vmin=vmin, vmax=vmax)
    ax_true.scatter(Xc[:, 0], Xc[:, 1], c=Yc[:, 0], cmap='RdBu_r', s=30, edgecolors='black', linewidth=1.0,
                    vmin=vmin, vmax=vmax, zorder=10)
    ax_true.set_title(f'Function {title_idx} (true)', fontsize=10)
    ax_true.set_aspect('equal')
    ax_true.set_xlabel('x1'); ax_true.set_ylabel('x2')

    ax_pred.contourf(Xi, Yi, Zi_pred, levels=15, cmap='RdBu_r', alpha=0.9, vmin=vmin, vmax=vmax)
    ax_pred.scatter(Xc[:, 0], Xc[:, 1], c=Yc[:, 0], cmap='RdBu_r', s=30, edgecolors='black', linewidth=1.0,
                    vmin=vmin, vmax=vmax, zorder=10)
    ax_pred.set_title(f'Function {title_idx} (predicted mean)', fontsize=10)
    ax_pred.set_aspect('equal')
    ax_pred.set_xlabel('x1'); ax_pred.set_ylabel('x2')


def main():
    ap = argparse.ArgumentParser(description="Eval TabularACE (exact arch) on 2D functions")
    ap.add_argument('--checkpoint', required=True, help='Path to trained checkpoint .pt')
    ap.add_argument('--device', default='cpu', help='cpu or cuda')
    ap.add_argument('--nfuncs', type=int, default=6)
    ap.add_argument('--nc', type=int, default=400)
    ap.add_argument('--nb', type=int, default=32)
    ap.add_argument('--nt', type=int, default=512)
    ap.add_argument('--norm', default='power', choices=['power', 'quantile', 'quantile_rtdl', 'none'])
    ap.add_argument('--source', default='mlpscm', choices=['mlpscm', 'synth'], help='Function source')
    ap.add_argument('--grid-size', type=int, default=64, help='Grid size for synth source (GxG)')
    ap.add_argument('--noise-std', type=float, default=0.01, help='Noise std for synth source')
    ap.add_argument('--out', default='outputs/tabular_eval_grid_exact.png')
    ap.add_argument('--mc-samples', type=int, default=0, help='Monte Carlo samples per target (0 uses analytic mean)')
    # Mask hyperparameters (match training defaults)
    ap.add_argument('--attending-chunks', type=int, default=16)
    ap.add_argument('--include-diagonal', action='store_true', help='Include diagonal self-attention (default False)')
    ap.add_argument('--q-block-size', type=int, default=128)
    ap.add_argument('--kv-block-size', type=int, default=128)
    ap.add_argument('--uniform-context', action='store_true', help='Choose context points approximately uniformly over x-space')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_exact(device)
    load_checkpoint_strict(model, args.checkpoint)

    rows = args.nfuncs
    fig, axes = plt.subplots(rows, 2, figsize=(12, 3 * rows))
    if rows == 1:
        axes = np.array([axes])

    for i in range(rows):
        xc, yc, xt, yt, ypred = eval_one_function(
            model, device, nc=args.nc, nb=args.nb, nt=args.nt, norm_method=args.norm,
            uniform_context=args.uniform_context, source=args.source,
            grid_size=args.grid_size, noise_std=args.noise_std,
            attending_chunks=args.attending_chunks,
            include_diagonal=args.include_diagonal,
            q_block_size=args.q_block_size,
            kv_block_size=args.kv_block_size,
            mc_samples=args.mc_samples,
        )
        ax_true, ax_pred = axes[i]
        plot_side_by_side(ax_true, ax_pred, xc, yc, xt, yt, ypred, i + 1)

    plt.tight_layout()
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(args.out, dpi=150, bbox_inches='tight')
    print(f"Saved evaluation grid to {args.out}")


if __name__ == '__main__':
    main()
