#!/usr/bin/env python3
"""Eval TabularACE on 1D GP with SDPA Mode B (re-encode) + MC via batch expansion.

Per request: support one function at a time, expand batch to n_mc replicas and
run decode in parallel; repeat for N functions, then plot normalized results.
"""

import argparse
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt

from src.models.tabular_embedder import TabularACE
from src.models.tabular_inference import TabularInferenceSDPA
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr


def _augment_state_dict_keys(sd: dict) -> dict:
    out = dict(sd)
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> TabularACE:
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def plot_batch(xc, yc, xt, yt, ypred, out_path: Path) -> None:
    B = xc.shape[0]
    fig, axes = plt.subplots(B, 1, figsize=(8, 3.0 * B), squeeze=False)
    axes = axes[:, 0]
    for i in range(B):
        ax = axes[i]
        order = np.argsort(xt[i, :, 0])
        xs = xt[i, order, 0]
        ax.plot(xs, yt[i, order, 0], color='black', linewidth=1.5, alpha=0.7, label='True f(x)')
        ax.plot(xs, ypred[i, order, 0], color='tab:blue', linewidth=1.5, alpha=0.9, label='Pred mean (MC)')
        ax.scatter(xc[i, :, 0], yc[i, :, 0], color='tab:red', s=14, zorder=5, label='Context')
        ax.set_xlabel('x'); ax.set_ylabel('y'); ax.legend(loc='best', fontsize=9)
        ax.set_title(f'Function {i+1}'); ax.grid(True, alpha=0.15)
    fig.tight_layout(); out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=150, bbox_inches='tight'); plt.close(fig)


def _repeat_batch(b1: DataAttr, n: int) -> DataAttr:
    def rep(x):
        if x is None:
            return None
        return x.repeat(n, 1, 1)
    return DataAttr(
        xc=rep(b1.xc), yc=rep(b1.yc), xb=rep(getattr(b1, 'xb', None)), yb=rep(getattr(b1, 'yb', None)),
        xt=rep(b1.xt), yt=rep(b1.yt)
    )


def main():
    ap = argparse.ArgumentParser(description="Eval TabularACE on 1D GP (Mode B re-encode, MC via batch expand)")
    ap.add_argument('--checkpoint', required=True)
    ap.add_argument('--device', default='cpu')
    ap.add_argument('--num-fns', type=int, default=4)
    ap.add_argument('--nc', type=int, default=64)
    ap.add_argument('--nb', type=int, default=0)
    ap.add_argument('--nt', type=int, default=256)
    ap.add_argument('--noise', type=float, default=0.0)
    ap.add_argument('--n-mc', type=int, default=64)
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'])
    ap.add_argument('--out', default='outputs/tabular_eval_gp_engine_modeb_mc.png')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)

    # Generate GP batch of size num-fns
    from src.data.gp_sampler import GPSampler
    sampler = GPSampler(device=str(device), dtype=torch.float32, noise_range=(args.noise, args.noise))
    full: DataAttr = sampler.generate_batch(
        batch_size=args.num_fns,
        num_context=args.nc,
        num_buffer=args.nb,
        num_target=args.nt,
    )

    # Prepare storage for outputs per function
    xc_all, yc_all, xt_all, yt_all, ypred_all = [], [], [], [], []

    for i in range(args.num_fns):
        # Slice one function
        b1 = DataAttr(
            xc=full.xc[i:i+1].cpu(), yc=full.yc[i:i+1].cpu(),
            xb=full.xb[i:i+1].cpu(), yb=full.yb[i:i+1].cpu(),
            xt=full.xt[i:i+1].cpu(), yt=full.yt[i:i+1].cpu(),
        )
        # Per-function normalization (use 2D arrays per function)
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        xc2 = b1.xc.squeeze(0).numpy(); yc2 = b1.yc.squeeze(0).numpy()
        xt2 = b1.xt.squeeze(0).numpy(); yt2 = b1.yt.squeeze(0).numpy()
        scaler.fit(xc2, yc2)
        xci = torch.from_numpy(scaler.transform_x(xc2)).unsqueeze(0)
        yci = torch.from_numpy(scaler.transform_y(yc2)).unsqueeze(0)
        xti = torch.from_numpy(scaler.transform_x(xt2)).unsqueeze(0)
        yti = torch.from_numpy(scaler.transform_y(yt2)).unsqueeze(0)
        if b1.xb is not None and b1.xb.shape[1] > 0:
            xb2 = b1.xb.squeeze(0).numpy(); yb2 = b1.yb.squeeze(0).numpy()
            xbi = torch.from_numpy(scaler.transform_x(xb2)).unsqueeze(0)
            ybi = torch.from_numpy(scaler.transform_y(yb2)).unsqueeze(0)
        else:
            xbi = torch.empty(1, 0, b1.xc.shape[-1])
            ybi = torch.empty(1, 0, b1.yc.shape[-1])

        b1 = DataAttr(xc=xci, yc=yci, xb=xbi, yb=ybi, xt=xti, yt=yti)

        # Expand batch by n_mc for parallel MC
        bmc = _repeat_batch(b1, args.n_mc).to(device)

        # New engine instance for clean state
        engine = TabularInferenceSDPA.from_trained_model(model, backend=args.attn_backend)
        with torch.no_grad():
            engine.precompute_rows(bmc)
            engine.build_context_kv()
            preds = engine.decode_reencode(bmc)
            # Average across MC replicates (batch dimension)
            y_mc = preds.yc  # [n_mc, T, 1]
            y_mean = y_mc.mean(dim=0, keepdim=True)  # [1, T, 1]

        # Collect for plotting (normalized)
        xc_all.append(b1.xc.numpy())
        yc_all.append(b1.yc.numpy())
        xt_all.append(b1.xt.numpy())
        yt_all.append(b1.yt.numpy())
        ypred_all.append(y_mean.cpu().numpy())

    xc = np.concatenate(xc_all, axis=0)
    yc = np.concatenate(yc_all, axis=0)
    xt = np.concatenate(xt_all, axis=0)
    yt = np.concatenate(yt_all, axis=0)
    ypred = np.concatenate(ypred_all, axis=0)

    plot_batch(xc, yc, xt, yt, ypred, Path(args.out))
    print(f"Saved Mode B (re-encode) MC evaluation to {args.out}")


if __name__ == '__main__':
    main()
