#!/usr/bin/env python3
"""Evaluate exact log-likelihood (Mode A) for 1D GP functions in normalized space.

Per-function normalization (fit on context), context-only conditioning (independent
targets), exact mixture LL via the head, and reporting mean±std across functions
for the sum over K=32 target LLs.
"""

import argparse
from pathlib import Path
import numpy as np
import torch
import random

from src.models.tabular_embedder import TabularACE
from src.models.tabular_inference import TabularInferenceSDPA
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr


def _augment_state_dict_keys(sd: dict) -> dict:
    out = dict(sd)
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> TabularACE:
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def main():
    ap = argparse.ArgumentParser(description="Exact LL (Mode A) on 1D GP (normalized)")
    ap.add_argument('--checkpoint', required=True)
    ap.add_argument('--device', default='cpu')
    ap.add_argument('--num-fns', type=int, default=16)
    ap.add_argument('--nc', type=int, default=64)
    ap.add_argument('--nb', type=int, default=0)
    ap.add_argument('--nt', type=int, default=32)  # K=32 targets
    ap.add_argument('--noise', type=float, default=0.0)
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'])
    ap.add_argument('--seed', type=int, default=None)
    ap.add_argument('--batch-path', type=str, default=None, help='Load a pre-saved batch (.pt) to ensure same functions across scripts')
    ap.add_argument('--save-batch', type=str, default=None, help='Path to save the generated batch (.pt) for reuse')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)

    # Seed (optional)
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # Generate or load GP batch
    if args.batch_path is not None:
        loaded = torch.load(args.batch_path, map_location='cpu')
        full = DataAttr(
            xc=loaded['xc'], yc=loaded['yc'], xb=loaded['xb'], yb=loaded['yb'], xt=loaded['xt'], yt=loaded['yt']
        )
    else:
        from src.data.gp_sampler import GPSampler
        sampler = GPSampler(device='cpu', dtype=torch.float32, noise_range=(args.noise, args.noise))
        full: DataAttr = sampler.generate_batch(
            batch_size=args.num_fns,
            num_context=args.nc,
            num_buffer=args.nb,
            num_target=args.nt,
        )
        if args.save_batch is not None:
            Path(args.save_batch).parent.mkdir(parents=True, exist_ok=True)
            torch.save({
                'xc': full.xc.cpu(), 'yc': full.yc.cpu(), 'xb': full.xb.cpu(), 'yb': full.yb.cpu(), 'xt': full.xt.cpu(), 'yt': full.yt.cpu()
            }, args.save_batch)
            print(f"Saved batch to {args.save_batch}")

    # For each function: normalize per function, run exact LL, sum, collect
    ll_sums = []
    engine = TabularInferenceSDPA.from_trained_model(model, backend=args.attn_backend)

    for i in range(args.num_fns):
        b1 = DataAttr(
            xc=full.xc[i:i+1].cpu(), yc=full.yc[i:i+1].cpu(),
            xb=full.xb[i:i+1].cpu(), yb=full.yb[i:i+1].cpu(),
            xt=full.xt[i:i+1].cpu(), yt=full.yt[i:i+1].cpu(),
        )
        # Per-function normalization (2D fit/transform)
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        xc2, yc2 = b1.xc.squeeze(0).numpy(), b1.yc.squeeze(0).numpy()
        xt2, yt2 = b1.xt.squeeze(0).numpy(), b1.yt.squeeze(0).numpy()
        scaler.fit(xc2, yc2)
        b1n = DataAttr(
            xc=torch.from_numpy(scaler.transform_x(xc2)).unsqueeze(0).to(device),
            yc=torch.from_numpy(scaler.transform_y(yc2)).unsqueeze(0).to(device),
            xb=torch.empty(1, 0, b1.xc.shape[-1], device=device),
            yb=torch.empty(1, 0, b1.yc.shape[-1], device=device),
            xt=torch.from_numpy(scaler.transform_x(xt2)).unsqueeze(0).to(device),
            yt=torch.from_numpy(scaler.transform_y(yt2)).unsqueeze(0).to(device),
        )

        with torch.no_grad():
            ll_per_token = engine.predict_ll_independent(b1n)  # [1,T,1]
            ll_sum = ll_per_token.sum().item()
            ll_sums.append(ll_sum)

    ll_sums = np.array(ll_sums)
    print(f"Mode A (exact), normalized LL sum over {args.nt} targets:")
    print(f"  mean = {ll_sums.mean():.4f}, std = {ll_sums.std(ddof=1):.4f}, N = {len(ll_sums)}")


if __name__ == '__main__':
    main()
