#!/usr/bin/env python3
"""Unified LL evaluation for TabularACE with three modes:

- evaluate_ll_independent: context-only (independent targets), exact mixture LL
- evaluate_ll_ar: teacher forcing with re-encode-as-context (true -> context)
- evaluate_ll_ar_buffer: teacher forcing with buffer AR + AR tokens (K-chunk style)

Per-function normalization is applied (fit on context). Results are reported in
normalized space as the sum of per-target LLs (K targets) per function, and the
mean ± std across functions. Optionally reuse identical batches across runs.
"""

import argparse
from pathlib import Path
import numpy as np
import random
import torch

from src.models.tabular_embedder import TabularACE
from src.models.tabular_inference import TabularInferenceSDPA
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr


def _augment_state_dict_keys(sd: dict) -> dict:
    out = dict(sd)
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> TabularACE:
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)

    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def _repeat_batch(b1: DataAttr, n: int) -> DataAttr:
    def rep(x):
        if x is None:
            return None
        return x.repeat(n, 1, 1)
    return DataAttr(
        xc=rep(b1.xc), yc=rep(b1.yc), xb=rep(getattr(b1, 'xb', None)), yb=rep(getattr(b1, 'yb', None)),
        xt=rep(b1.xt), yt=rep(b1.yt)
    )


def main():
    ap = argparse.ArgumentParser(description="Unified LL eval: independent | ar | ar_buffer")
    ap.add_argument('--mode', required=True, choices=['evaluate_ll_independent', 'evaluate_ll_ar', 'evaluate_ll_ar_buffer'])
    ap.add_argument('--checkpoint', required=True)
    ap.add_argument('--device', default='cpu')
    ap.add_argument('--num-fns', type=int, default=8)
    ap.add_argument('--nc', type=int, default=64)
    ap.add_argument('--nb', type=int, default=0)
    ap.add_argument('--nt', type=int, default=32)
    ap.add_argument('--noise', type=float, default=0.0)
    ap.add_argument('--n-mc', type=int, default=64, help='Only used for AR modes')
    ap.add_argument('--K', type=int, default=32, help='Only used for evaluate_ll_ar_buffer')
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'])
    ap.add_argument('--seed', type=int, default=None)
    ap.add_argument('--batch-path', type=str, default=None, help='Load a pre-saved batch (.pt)')
    ap.add_argument('--save-batch', type=str, default=None, help='Save generated batch (.pt)')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)
    engine = TabularInferenceSDPA.from_trained_model(model, backend=args.attn_backend)

    # Seed deterministically if requested
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # Data: load or generate
    if args.batch_path:
        loaded = torch.load(args.batch_path, map_location='cpu')
        full = DataAttr(xc=loaded['xc'], yc=loaded['yc'], xb=loaded['xb'], yb=loaded['yb'], xt=loaded['xt'], yt=loaded['yt'])
    else:
        from src.data.gp_sampler import GPSampler
        sampler = GPSampler(device='cpu', dtype=torch.float32, noise_range=(args.noise, args.noise))
        full: DataAttr = sampler.generate_batch(
            batch_size=args.num_fns,
            num_context=args.nc,
            num_buffer=args.nb,
            num_target=args.nt,
        )
        if args.save_batch:
            Path(args.save_batch).parent.mkdir(parents=True, exist_ok=True)
            torch.save({'xc': full.xc.cpu(), 'yc': full.yc.cpu(), 'xb': full.xb.cpu(), 'yb': full.yb.cpu(), 'xt': full.xt.cpu(), 'yt': full.yt.cpu()}, args.save_batch)
            print(f"Saved batch to {args.save_batch}")

    ll_sums = []
    per_func_mc_std = []  # for AR modes

    for i in range(args.num_fns):
        # Per-function normalization (2D fit/transform)
        b1 = DataAttr(
            xc=full.xc[i:i+1].cpu(), yc=full.yc[i:i+1].cpu(),
            xb=full.xb[i:i+1].cpu(), yb=full.yb[i:i+1].cpu(),
            xt=full.xt[i:i+1].cpu(), yt=full.yt[i:i+1].cpu(),
        )
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        xc2, yc2 = b1.xc.squeeze(0).numpy(), b1.yc.squeeze(0).numpy()
        xt2, yt2 = b1.xt.squeeze(0).numpy(), b1.yt.squeeze(0).numpy()
        scaler.fit(xc2, yc2)
        b1n = DataAttr(
            xc=torch.from_numpy(scaler.transform_x(xc2)).unsqueeze(0),
            yc=torch.from_numpy(scaler.transform_y(yc2)).unsqueeze(0),
            xb=torch.empty(1, 0, b1.xc.shape[-1]),
            yb=torch.empty(1, 0, b1.yc.shape[-1]),
            xt=torch.from_numpy(scaler.transform_x(xt2)).unsqueeze(0),
            yt=torch.from_numpy(scaler.transform_y(yt2)).unsqueeze(0),
        ).to(device)

        if args.mode == 'evaluate_ll_independent':
            with torch.no_grad():
                ll_per_token = engine.predict_ll_independent(b1n)  # [1,T,1]
                ll_sums.append(ll_per_token.sum().item())
        elif args.mode == 'evaluate_ll_ar':
            # Build n_mc permutations and evaluate in parallel
            T = b1n.xt.shape[1]
            perms = [torch.randperm(T) for _ in range(args.n_mc)]
            order = torch.stack(perms, dim=0).to(device)  # [n_mc, T]
            bmc = _repeat_batch(b1n, args.n_mc).to(device)
            with torch.no_grad():
                ll_per_step = engine.evaluate_ll_reencode_tf(bmc, order)  # [n_mc, T, 1]
                ll_seq = ll_per_step.sum(dim=1).squeeze(-1)  # [n_mc]
                ll_sums.append(ll_seq.mean().item())
                per_func_mc_std.append(ll_seq.std(unbiased=True).item())
        elif args.mode == 'evaluate_ll_ar_buffer':
            T = b1n.xt.shape[1]
            assert args.K <= 32 and T <= 32, "K and T must be <= max_buffer_size (32)"
            perms = [torch.randperm(T) for _ in range(args.n_mc)]
            order = torch.stack(perms, dim=0).to(device)
            bmc = _repeat_batch(b1n, args.n_mc).to(device)
            with torch.no_grad():
                ll_per_step = engine.evaluate_ll_buffer_tf_kchunk(bmc, order, ar_indexing='chunk')
                ll_seq = ll_per_step.sum(dim=1).squeeze(-1)
                ll_sums.append(ll_seq.mean().item())
                per_func_mc_std.append(ll_seq.std(unbiased=True).item())
        else:
            raise ValueError(f"Unknown mode: {args.mode}")

    ll_sums = np.array(ll_sums)
    print(f"{args.mode}: normalized LL sum over {args.nt} targets:")
    print(f"  mean = {ll_sums.mean():.4f}, std = {ll_sums.std(ddof=1):.4f}, N = {len(ll_sums)}")
    if per_func_mc_std:
        mc_std = np.array(per_func_mc_std)
        print(f"  per-function MC std (mean) = {mc_std.mean():.4f}")


if __name__ == '__main__':
    main()
