#!/usr/bin/env python3
"""Evaluate teacher-forcing LL (Mode B: re-encode as context) with permutation MC.

For each function, generate n_mc random permutations over K=32 targets, stack
along batch dimension, and evaluate exact LL per step while inserting true
targets as context rows (no AR tokens). Report per-function mean±std LL sums.
"""

import argparse
import numpy as np
import torch

from src.models.tabular_embedder import TabularACE
from src.models.tabular_inference import TabularInferenceSDPA
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr


def _augment_state_dict_keys(sd: dict) -> dict:
    out = dict(sd)
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> TabularACE:
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def _repeat_batch(b1: DataAttr, n: int) -> DataAttr:
    def rep(x):
        if x is None:
            return None
        return x.repeat(n, 1, 1)
    return DataAttr(xc=rep(b1.xc), yc=rep(b1.yc), xb=rep(getattr(b1, 'xb', None)), yb=rep(getattr(b1, 'yb', None)), xt=rep(b1.xt), yt=rep(b1.yt))


def main():
    ap = argparse.ArgumentParser(description="Mode B LL (teacher forcing; re-encode as context) with permutation MC")
    ap.add_argument('--checkpoint', required=True)
    ap.add_argument('--device', default='cpu')
    ap.add_argument('--num-fns', type=int, default=8)
    ap.add_argument('--nc', type=int, default=64)
    ap.add_argument('--nb', type=int, default=0)
    ap.add_argument('--nt', type=int, default=32)
    ap.add_argument('--noise', type=float, default=0.0)
    ap.add_argument('--n-mc', type=int, default=64)
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'])
    ap.add_argument('--seed', type=int, default=None)
    ap.add_argument('--batch-path', type=str, default=None, help='Load a pre-saved batch (.pt)')
    ap.add_argument('--save-batch', type=str, default=None, help='Path to save generated batch (.pt)')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)

    # Seed
    if args.seed is not None:
        import random, numpy as _np
        random.seed(args.seed); _np.random.seed(args.seed); torch.manual_seed(args.seed)

    # Data
    if args.batch_path is not None:
        loaded = torch.load(args.batch_path, map_location='cpu')
        full = DataAttr(xc=loaded['xc'], yc=loaded['yc'], xb=loaded['xb'], yb=loaded['yb'], xt=loaded['xt'], yt=loaded['yt'])
    else:
        from src.data.gp_sampler import GPSampler
        sampler = GPSampler(device='cpu', dtype=torch.float32, noise_range=(args.noise, args.noise))
        full: DataAttr = sampler.generate_batch(
            batch_size=args.num_fns,
            num_context=args.nc,
            num_buffer=args.nb,
            num_target=args.nt,
        )
        if args.save_batch is not None:
            from pathlib import Path as _P
            _P(args.save_batch).parent.mkdir(parents=True, exist_ok=True)
            torch.save({'xc': full.xc.cpu(), 'yc': full.yc.cpu(), 'xb': full.xb.cpu(), 'yb': full.yb.cpu(), 'xt': full.xt.cpu(), 'yt': full.yt.cpu()}, args.save_batch)
            print(f"Saved batch to {args.save_batch}")

    # Engine
    engine = TabularInferenceSDPA.from_trained_model(model, backend=args.attn_backend)
    ll_means, ll_stds = [], []

    for i in range(args.num_fns):
        # Single function, per-function normalization in 2D
        b1 = DataAttr(
            xc=full.xc[i:i+1].cpu(), yc=full.yc[i:i+1].cpu(),
            xb=full.xb[i:i+1].cpu(), yb=full.yb[i:i+1].cpu(),
            xt=full.xt[i:i+1].cpu(), yt=full.yt[i:i+1].cpu(),
        )
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        xc2, yc2 = b1.xc.squeeze(0).numpy(), b1.yc.squeeze(0).numpy()
        xt2, yt2 = b1.xt.squeeze(0).numpy(), b1.yt.squeeze(0).numpy()
        scaler.fit(xc2, yc2)
        b1n = DataAttr(
            xc=torch.from_numpy(scaler.transform_x(xc2)).unsqueeze(0),
            yc=torch.from_numpy(scaler.transform_y(yc2)).unsqueeze(0),
            xb=torch.empty(1, 0, b1.xc.shape[-1]),
            yb=torch.empty(1, 0, b1.yc.shape[-1]),
            xt=torch.from_numpy(scaler.transform_x(xt2)).unsqueeze(0),
            yt=torch.from_numpy(scaler.transform_y(yt2)).unsqueeze(0),
        ).to(device)

        # MC via batch expansion and random permutations
        T = b1n.xt.shape[1]
        perms = [torch.randperm(T) for _ in range(args.n_mc)]
        order = torch.stack(perms, dim=0).to(device)  # [n_mc, T]
        bmc = _repeat_batch(b1n, args.n_mc).to(device)

        with torch.no_grad():
            ll_per_step = engine.evaluate_ll_reencode_tf(bmc, order)  # [n_mc, T, 1]
            ll_seq = ll_per_step.sum(dim=1).squeeze(-1)  # [n_mc]
            ll_means.append(ll_seq.mean().item())
            ll_stds.append(ll_seq.std(unbiased=True).item())

    # Aggregate across functions
    ll_means = np.array(ll_means)
    ll_stds = np.array(ll_stds)
    print("Mode B (teacher-forcing, re-encode as context), normalized LL sums:")
    print(f"  mean = {ll_means.mean():.4f} ± {ll_means.std(ddof=1):.4f} (across functions)")
    print(f"  per-function MC std (mean) = {ll_stds.mean():.4f}")


if __name__ == '__main__':
    main()
