#!/usr/bin/env python3
"""Evaluate a trained TabularACE model on 1D GP functions using the SDPA inference engine.

This mirrors scripts/eval_tabular_gp.py in data generation, per-function normalization,
and plotting, but uses the SDPA-based TabularInferenceSDPA (Mode A: all targets at once)
instead of the training-style forward with a block mask.

Outputs a figure at --out (default: outputs/tabular_eval_gp_engine.png).
"""

import argparse
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt

from src.models.ace import AmortizedConditioningEngine
from src.models.tabular_embedder import TabularACE
from src.models.tabular_inference import TabularInferenceSDPA
from src.utils import DataAttr
from src.data.preprocess import TabICLScaler


def _augment_state_dict_keys(sd: dict) -> dict:
    """Mirror embedder keys under both prefixes (embedder. and tabular_embedder.)."""
    out = dict(sd)
    # embedder -> tabular_embedder
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    # tabular_embedder -> embedder
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> AmortizedConditioningEngine:
    """Build the exact TabularACE architecture used in training and load checkpoint strictly.

    Mirrors scripts/eval_tabular_grid_exact.py to avoid adaptive resizing that can drift weights.
    """
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,  # embedder/row encoder FF; backbone uses ff_factor
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)

    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def plot_gp_batch(xc, yc, xt, yt, ypred, out_path: Path) -> None:
    """Plot 1D GP batch: context points, true targets, predicted mean."""
    B = xc.shape[0]
    fig, axes = plt.subplots(B, 1, figsize=(8, 3.0 * B), squeeze=False)
    axes = axes[:, 0]
    for i in range(B):
        ax = axes[i]
        # Sort by xt for clean lines
        order = np.argsort(xt[i, :, 0])
        xs = xt[i, order, 0]
        ys_true = yt[i, order, 0]
        ys_pred = ypred[i, order, 0]
        ax.plot(xs, ys_true, color='black', linewidth=1.5, alpha=0.7, label='True f(x)')
        ax.plot(xs, ys_pred, color='tab:blue', linewidth=1.5, alpha=0.9, label='Pred mean')
        ax.scatter(xc[i, :, 0], yc[i, :, 0], color='tab:red', s=14, zorder=5, label='Context')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.legend(loc='best', fontsize=9)
        ax.set_title(f'Function {i+1}')
        ax.grid(True, alpha=0.15)
    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.close(fig)


def main():
    ap = argparse.ArgumentParser(description="Eval trained TabularACE on 1D GP (SDPA inference)")
    ap.add_argument('--checkpoint', required=True, help='Path to trained checkpoint .pt')
    ap.add_argument('--device', default='cpu', help='cpu or cuda')
    ap.add_argument('--batch-size', type=int, default=6, help='Number of GP functions to sample')
    ap.add_argument('--nc', type=int, default=64, help='Number of context points')
    ap.add_argument('--nb', type=int, default=0, help='Number of buffer points (ignored by engine)')
    ap.add_argument('--nt', type=int, default=256, help='Number of target points')
    ap.add_argument('--noise', type=float, default=0.0, help='Observation noise std for GP sampler (0 for smooth)')
    ap.add_argument('--mc-samples', type=int, default=64, help='Number of MC samples per target for prediction mean (Mode A)')
    ap.add_argument('--decode-mode', default='all', choices=['all', 'ar_kbatch'], help='Inference mode: all targets at once or AR K-batch')
    ap.add_argument('--k', type=int, default=32, help='K-batch size for AR mode')
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'], help='SDPA backend')
    ap.add_argument('--out', default='outputs/tabular_eval_gp_engine.png', help='Output figure path')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)

    # Generate 1D GP batch
    from src.data.gp_sampler import GPSampler
    sampler = GPSampler(device=str(device), dtype=torch.float32, noise_range=(args.noise, args.noise))
    batch: DataAttr = sampler.generate_batch(
        batch_size=args.batch_size,
        num_context=args.nc,
        num_buffer=args.nb,
        num_target=args.nt,
    )

    # Normalize per function (fit on context, transform all splits)
    B = batch.xc.shape[0]
    xc_list, yc_list, xb_list, yb_list, xt_list, yt_list = [], [], [], [], [], []
    for i in range(B):
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        scaler.fit(batch.xc[i].cpu().numpy(), batch.yc[i].cpu().numpy())
        xci, yci = scaler.transform_xy(batch.xc[i].cpu().numpy(), batch.yc[i].cpu().numpy())
        xbi, ybi = scaler.transform_xy(batch.xb[i].cpu().numpy(), batch.yb[i].cpu().numpy()) if batch.xb is not None else (np.empty((0, batch.xc.size(-1))), np.empty((0, batch.yc.size(-1))))
        xti, yti = scaler.transform_xy(batch.xt[i].cpu().numpy(), batch.yt[i].cpu().numpy())
        xc_list.append(torch.from_numpy(xci))
        yc_list.append(torch.from_numpy(yci))
        xb_list.append(torch.from_numpy(xbi))
        yb_list.append(torch.from_numpy(ybi))
        xt_list.append(torch.from_numpy(xti))
        yt_list.append(torch.from_numpy(yti))
    batch = DataAttr(
        xc=torch.stack(xc_list, dim=0).to(device=device, dtype=torch.float32),
        yc=torch.stack(yc_list, dim=0).to(device=device, dtype=torch.float32),
        xb=torch.stack(xb_list, dim=0).to(device=device, dtype=torch.float32),
        yb=torch.stack(yb_list, dim=0).to(device=device, dtype=torch.float32),
        xt=torch.stack(xt_list, dim=0).to(device=device, dtype=torch.float32),
        yt=torch.stack(yt_list, dim=0).to(device=device, dtype=torch.float32),
    )

    # SDPA-based inference (Mode A or B)
    engine = TabularInferenceSDPA.from_trained_model(model, backend=args.attn_backend)
    with torch.no_grad():
        engine.precompute_rows(batch)
        engine.build_context_kv()
        if args.decode_mode == 'all':
            preds = engine.predict_all_targets(return_params=False, num_samples=args.mc_samples)
        else:
            preds = engine.decode_ar_kbatch(batch, K=args.k)
        ypred = preds.yc.cpu().numpy()  # [B, T, 1]

    # Plot normalized space (like gp eval)
    xc = batch.xc.cpu().numpy()
    yc = batch.yc.cpu().numpy()
    xt = batch.xt.cpu().numpy()
    yt = batch.yt.cpu().numpy()
    plot_gp_batch(xc, yc, xt, yt, ypred, Path(args.out))
    print(f"Saved 1D GP engine evaluation to {args.out}")


if __name__ == '__main__':
    main()
