#!/usr/bin/env python3
"""Evaluate a trained TabularACE model on 1D GP functions and visualize.

Generates a batch of 1D GP functions via the repo's GPSampler, runs training-style
forward with a proper attention BlockMask, and plots predictions vs ground truth.
Normalization matches eval_tabular_grid: fit per function on context, transform
all splits.

Saves a figure at outputs/tabular_eval_gp.png by default.
"""

import argparse
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt

from src.models.ace import AmortizedConditioningEngine
from src.models.tabular_embedder import TabularACE
from src.utils import DataAttr
from src.data.preprocess import TabICLScaler
from src.models.masks import create_training_block_mask
from torch.nn.attention.flex_attention import create_block_mask, or_masks


def _infer_arch_from_state_dict(sd: dict) -> dict:
    """Infer TabularACE architecture hyperparameters from a state_dict.

    Copied from eval_tabular_grid to adaptively load checkpoints that may have
    drifted names or slightly different shapes.
    """
    def get_first_key(keys):
        for k in keys:
            if k in sd:
                return k
        return None

    q_w_key = get_first_key([
        "backbone.layers.0.attn.q_proj.weight",
    ])
    if q_w_key is None:
        raise RuntimeError("Could not locate backbone attention weights to infer dims")
    dim_model_eff = sd[q_w_key].shape[0]

    ff1_w_key = get_first_key([
        "backbone.layers.0.ff1.weight",
    ])
    dim_ff_eff = sd[ff1_w_key].shape[0] if ff1_w_key is not None else max(256, 2 * dim_model_eff)

    num_layers = 0
    for k in sd.keys():
        if k.startswith("backbone.layers.") and k.endswith(".attn.q_proj.weight"):
            try:
                idx = int(k.split(".")[2])
                num_layers = max(num_layers, idx + 1)
            except Exception:
                pass

    cls_key = get_first_key([
        "tabular_embedder.cls_tokens",
        "embedder.cls_tokens",
    ])
    if cls_key is None:
        ar_key = get_first_key([
            "tabular_embedder.ar_tokens",
            "ar_token",
            "embedder.ar_tokens",
        ])
        if ar_key is None:
            raise RuntimeError("Could not locate embedder tokens to infer embed dim")
        embed_dim = sd[ar_key].shape[1]
        num_cls_tokens = 1
    else:
        num_cls_tokens, embed_dim = sd[cls_key].shape

    concat_cls = (dim_model_eff == embed_dim * num_cls_tokens)

    ar_key = get_first_key([
        "tabular_embedder.ar_tokens",
        "ar_token",
        "embedder.ar_tokens",
    ])
    max_buffer_size = sd[ar_key].shape[0] if ar_key is not None else 32

    w1_key = get_first_key([
        "head.head.w1",
        "head.w1",
    ])
    if w1_key is not None and hasattr(sd[w1_key], "shape") and len(sd[w1_key].shape) == 3:
        num_components = sd[w1_key].shape[0]
        dim_ff_eff = sd[w1_key].shape[1]
    else:
        num_components = 20

    # ISAB/row depth hints (optional)
    num_isab_blocks = 1
    num_inducing_points = 64
    row_num_blocks = 1
    for k, v in sd.items():
        if k.startswith("tabular_embedder.column_processor.blocks.") and k.endswith(".inducing_points"):
            try:
                idx = int(k.split(".")[3])
                num_isab_blocks = max(num_isab_blocks, idx + 1)
                num_inducing_points = v.shape[0]
            except Exception:
                pass
    for k in sd.keys():
        if k.startswith("tabular_embedder.row_encoder.layers.") and k.endswith(".self_attn.in_proj_weight"):
            try:
                idx = int(k.split(".")[3])
                row_num_blocks = max(row_num_blocks, idx + 1)
            except Exception:
                pass

    return {
        "embed_dim": embed_dim,
        "num_cls_tokens": num_cls_tokens,
        "concat_cls": concat_cls,
        "transformer_layers": num_layers,
        "dim_model_eff": dim_model_eff,
        "dim_feedforward": dim_ff_eff,
        "num_components": num_components,
        "max_buffer_size": max_buffer_size,
        "num_isab_blocks": num_isab_blocks,
        "num_inducing_points": num_inducing_points,
        "row_num_blocks": row_num_blocks,
    }


def _augment_state_dict_keys(sd: dict) -> dict:
    """Mirror embedder keys under both prefixes (embedder. and tabular_embedder.)."""
    out = dict(sd)
    # embedder -> tabular_embedder
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    # tabular_embedder -> embedder
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> AmortizedConditioningEngine:
    """Build the exact TabularACE architecture used in training and load checkpoint strictly.

    Mirrors scripts/eval_tabular_grid_exact.py to avoid adaptive resizing that can drift weights.
    """
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,  # embedder/row encoder FF; backbone uses ff_factor
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)

    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def mixture_mean(means: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """Compute mixture mean from component means and weights.

    means: [B, T, K, Dy]; weights: [B, T, K, Dy] or [B, T, K, 1] -> returns [B, T, Dy]
    """
    w = weights
    return (w * means).sum(dim=2)


def sample_mixture(means: torch.Tensor, stds: torch.Tensor, weights: torch.Tensor, num_samples: int) -> torch.Tensor:
    """Monte Carlo samples from a single-output mixture.

    means/stds: [B, T, K, 1]; weights: [B, T, K, 1]
    returns samples: [B, T, num_samples, 1]
    """
    B, T, K, D = means.shape
    assert D == 1, "This sampler assumes dim_y=1"
    # Flatten for categorical sampling
    w = weights.squeeze(-1)  # [B, T, K]
    comp_idx = torch.multinomial(w.view(-1, K), num_samples, replacement=True)  # [B*T, S]
    comp_idx = comp_idx.view(B, T, num_samples)
    # Gather component params
    means_k = means.squeeze(-1)  # [B, T, K]
    stds_k = stds.squeeze(-1)    # [B, T, K]
    gather_idx = comp_idx.unsqueeze(-1)  # [B, T, S, 1]
    means_sel = torch.gather(means_k.unsqueeze(2).expand(-1, -1, num_samples, -1), 3, gather_idx).squeeze(-1)
    stds_sel = torch.gather(stds_k.unsqueeze(2).expand(-1, -1, num_samples, -1), 3, gather_idx).squeeze(-1)
    # Reparameterize
    eps = torch.randn_like(means_sel)
    samples = means_sel + stds_sel * eps  # [B, T, S]
    return samples.unsqueeze(-1)


def plot_gp_batch(xc, yc, xt, yt, ypred, out_path: Path) -> None:
    """Plot 1D GP batch: context points, true targets, predicted mean.

    Expects tensors on CPU numpy-compatible; shapes [B, N, 1].
    """
    B = xc.shape[0]
    fig, axes = plt.subplots(B, 1, figsize=(8, 3.0 * B), squeeze=False)
    axes = axes[:, 0]
    for i in range(B):
        ax = axes[i]
        # Sort by xt for clean lines
        order = np.argsort(xt[i, :, 0])
        xs = xt[i, order, 0]
        ys_true = yt[i, order, 0]
        ys_pred = ypred[i, order, 0]
        ax.plot(xs, ys_true, color='black', linewidth=1.5, alpha=0.7, label='True f(x)')
        ax.plot(xs, ys_pred, color='tab:blue', linewidth=1.5, alpha=0.9, label='Pred mean')
        ax.scatter(xc[i, :, 0], yc[i, :, 0], color='tab:red', s=14, zorder=5, label='Context')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.legend(loc='best', fontsize=9)
        ax.set_title(f'Function {i+1}')
        ax.grid(True, alpha=0.15)
    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.close(fig)


def main():
    ap = argparse.ArgumentParser(description="Eval trained TabularACE on 1D GP functions (training-style forward)")
    ap.add_argument('--checkpoint', required=True, help='Path to trained checkpoint .pt')
    ap.add_argument('--device', default='cpu', help='cpu or cuda')
    ap.add_argument('--batch-size', type=int, default=6, help='Number of GP functions to sample')
    ap.add_argument('--nc', type=int, default=64, help='Number of context points')
    ap.add_argument('--nb', type=int, default=0, help='Number of buffer points (inference ignores)')
    ap.add_argument('--nt', type=int, default=256, help='Number of target points')
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'], help='SDPA backend')
    ap.add_argument('--noise', type=float, default=0.0, help='Observation noise std for GP sampler (0 for smooth functions)')
    ap.add_argument('--out', default='outputs/tabular_eval_gp.png', help='Output figure path')
    ap.add_argument('--mc-samples', type=int, default=64, help='Number of MC samples per target for prediction mean')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)

    # Generate 1D GP batch
    from src.data.gp_sampler import GPSampler
    sampler = GPSampler(device=str(device), dtype=torch.float32, noise_range=(args.noise, args.noise))
    batch: DataAttr = sampler.generate_batch(
        batch_size=args.batch_size,
        num_context=args.nc,
        num_buffer=args.nb,
        num_target=args.nt,
    )
    # Normalize like eval_tabular_grid: per-function fit on context, transform all splits
    B = batch.xc.shape[0]
    xc_list, yc_list, xb_list, yb_list, xt_list, yt_list = [], [], [], [], [], []
    for i in range(B):
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        # Fit on this function's context
        scaler.fit(batch.xc[i].cpu().numpy(), batch.yc[i].cpu().numpy())
        # Transform all splits for this function
        xci, yci = scaler.transform_xy(batch.xc[i].cpu().numpy(), batch.yc[i].cpu().numpy())
        xbi, ybi = scaler.transform_xy(batch.xb[i].cpu().numpy(), batch.yb[i].cpu().numpy()) if batch.xb is not None else (np.empty((0, batch.xc.size(-1))), np.empty((0, batch.yc.size(-1))))
        xti, yti = scaler.transform_xy(batch.xt[i].cpu().numpy(), batch.yt[i].cpu().numpy())
        xc_list.append(torch.from_numpy(xci))
        yc_list.append(torch.from_numpy(yci))
        xb_list.append(torch.from_numpy(xbi))
        yb_list.append(torch.from_numpy(ybi))
        xt_list.append(torch.from_numpy(xti))
        yt_list.append(torch.from_numpy(yti))
    # Stack back into a batch on device
    batch = DataAttr(
        xc=torch.stack(xc_list, dim=0).to(device=device, dtype=torch.float32),
        yc=torch.stack(yc_list, dim=0).to(device=device, dtype=torch.float32),
        xb=torch.stack(xb_list, dim=0).to(device=device, dtype=torch.float32),
        yb=torch.stack(yb_list, dim=0).to(device=device, dtype=torch.float32),
        xt=torch.stack(xt_list, dim=0).to(device=device, dtype=torch.float32),
        yt=torch.stack(yt_list, dim=0).to(device=device, dtype=torch.float32),
    )

    # Training-style forward with block mask (same as eval_tabular_grid)
    R = args.nc + args.nb + args.nt
    if args.nb == 0:
        # Build a composite mask without any buffer section
        Nc = args.nc
        def prefix_mask(b, h, q_idx, kv_idx):
            return kv_idx < Nc
        def localized_causal_ctx(b, h, q_idx, kv_idx):
            return (q_idx < Nc) & (kv_idx < Nc) & (q_idx >= kv_idx)
        # No diagonal mask
        mask_mod = or_masks(prefix_mask, localized_causal_ctx)
        mask = create_block_mask(
            mask_mod,
            Q_LEN=R,
            KV_LEN=R,
            B=None,
            H=None,
            BLOCK_SIZE=(128, 128),
            device=device,
        )
    else:
        mask = create_training_block_mask(
            current_total_q_len=R,
            current_total_kv_len=R,
            current_context_section_len=args.nc,
            current_buffer_section_len=args.nb,
            attending_chunks=4,
            include_diagonal=False,
            device=device,
        )
    with torch.no_grad():
        out = model(batch, mask)
        means = out.means  # [B, T, K, 1]
        weights = out.weights
        stds = out.sds
        # MC samples, then mean across samples
        S = max(1, int(args.mc_samples))
        y_samp = sample_mixture(means, stds, weights, num_samples=S)  # [B, T, S, 1]
        ypred = y_samp.mean(dim=2).cpu().numpy()  # [B, T, 1]

    # Plot (convert tensors to numpy)
    xc = batch.xc.cpu().numpy()
    yc = batch.yc.cpu().numpy()
    xt = batch.xt.cpu().numpy()
    yt = batch.yt.cpu().numpy()
    plot_gp_batch(xc, yc, xt, yt, ypred, Path(args.out))
    print(f"Saved 1D GP evaluation to {args.out}")


if __name__ == '__main__':
    main()
