#!/usr/bin/env python3
"""Unified sequence sampling for TabularACE on 1D GP functions.

Modes (decode-mode):
- all         : Context-only (independent targets). Uses head sampling with --n-mc as num_samples.
- ar_reencode : Teacher-forcing autoregression with re-encode (true -> context). MC via batch expansion.
- ar_buffer   : Teacher-forcing autoregression with buffer tokens (true -> buffer, AR tokens). MC via batch expansion; K-chunked.

Per-function normalization (fit on context) is applied. Outputs a per-function
plot in normalized space showing true function and the sampled predictive mean.
"""

import argparse
from pathlib import Path
import numpy as np
import random
import torch
import matplotlib.pyplot as plt

from src.models.tabular_embedder import TabularACE
from src.models.tabular_inference import TabularInferenceSDPA
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr


def _augment_state_dict_keys(sd: dict) -> dict:
    out = dict(sd)
    for k, v in list(sd.items()):
        if k.startswith("embedder."):
            alt = "tabular_" + k
            if alt not in out:
                out[alt] = v
    for k, v in list(sd.items()):
        if k.startswith("tabular_embedder."):
            alt = k.replace("tabular_", "", 1)
            if alt not in out:
                out[alt] = v
    return out


def build_model_from_ckpt(ckpt_path: str, device: torch.device) -> TabularACE:
    model = TabularACE(
        num_features=10,
        embed_dim=128,
        transformer_layers=12,
        nhead=4,
        dim_feedforward=256,
        num_components=20,
        max_buffer_size=32,
        num_target_points=512,
        targets_block_size_for_buffer_attend=32,
        dropout=0.0,
        num_isab_blocks=3,
        num_inducing_points=128,
        row_rope_base=100000,
        col_nhead=4,
        row_nhead=4,
        row_num_blocks=3,
        concat_cls=True,
        num_cls_tokens=2,
        ff_factor=2.0,
    ).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    sd = _augment_state_dict_keys(sd)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model


def _repeat_batch(b: DataAttr, n: int) -> DataAttr:
    def rep(x):
        if x is None:
            return None
        return x.repeat(n, 1, 1)
    return DataAttr(xc=rep(b.xc), yc=rep(b.yc), xb=rep(getattr(b, 'xb', None)), yb=rep(getattr(b, 'yb', None)), xt=rep(b.xt), yt=rep(b.yt))


def plot_batch(xc, yc, xt, yt, ypred, out_path: Path, y_samples=None, ci_lower=None, ci_upper=None) -> None:
    B = xc.shape[0]
    fig, axes = plt.subplots(B, 1, figsize=(8, 3.0 * B), squeeze=False)
    fig.patch.set_facecolor('white')
    axes = axes[:, 0]
    for i in range(B):
        ax = axes[i]
        order = np.argsort(xt[i, :, 0])
        xs = xt[i, order, 0]
        # True function: thick black dashed line
        ax.plot(xs, yt[i, order, 0], color='black', linewidth=3.0, alpha=1.0, linestyle='--', label='True f(x)')
        # Optional: draw a few sampled trajectories as dark, thin lines (first up to 10)
        if y_samples is not None and len(y_samples) > i and y_samples[i] is not None:
            samp_i = y_samples[i]
            S = min(10, samp_i.shape[0])
            for s in range(S):
                ax.plot(xs, samp_i[s, order, 0], color='#1f2a44', linewidth=1.4, alpha=0.35)
        # Optional: 95% CI shading and dashed bounds in blue
        if ci_lower is not None and ci_upper is not None:
            lo = ci_lower[i, order, 0]
            hi = ci_upper[i, order, 0]
            ax.fill_between(xs, lo, hi, color='#5aa9ff', alpha=0.18, linewidth=0)
            ax.plot(xs, lo, color='tab:blue', linewidth=2.6, linestyle=':')
            ax.plot(xs, hi, color='tab:blue', linewidth=2.6, linestyle=':')
        # Predicted mean (thicker blue)
        ax.plot(xs, ypred[i, order, 0], color='tab:blue', linewidth=3.0, alpha=0.95, label='Pred mean')
        # Context points: large black plus markers
        ax.plot(
            xc[i, :, 0], yc[i, :, 0],
            linestyle='None', marker='+', markersize=16, markeredgewidth=3.5,
            color='black', zorder=6,
        )
        # Remove axis labels and ticks for a cleaner look
        ax.set_xlabel(""); ax.set_ylabel("")
        ax.set_xticks([]); ax.set_yticks([])
        ax.tick_params(axis='both', which='both', length=0)
        # No grid, and hide top/right spines
        ax.grid(False)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    fig.tight_layout(); out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=150, bbox_inches='tight'); plt.close(fig)


def main():
    ap = argparse.ArgumentParser(description="Unified sequence sampling for TabularACE (1D GP)")
    ap.add_argument('--decode-mode', required=True, choices=['all', 'ar_reencode', 'ar_buffer'])
    ap.add_argument('--checkpoint', required=True)
    ap.add_argument('--device', default='cpu')
    ap.add_argument('--num-fns', type=int, default=6)
    ap.add_argument('--nc', type=int, default=64)
    ap.add_argument('--nb', type=int, default=0)
    ap.add_argument('--nt', type=int, default=256)
    ap.add_argument('--noise', type=float, default=0.0)
    ap.add_argument('--n-mc', type=int, default=64, help='MC samples: head.samples for all; batch expansion for AR modes')
    ap.add_argument('--K', type=int, default=32, help='Chunk size for ar_buffer; must be <= max_buffer_size')
    ap.add_argument('--refresh-mode', default='none', choices=['none', 'context_chunk'],
                    help='AR-buffer only: context_chunk re-encodes a mini-prefix [context + last K targets] each step; none keeps pure buffer.')
    ap.add_argument('--attn-backend', default='auto', choices=['auto', 'flash', 'mem', 'math'])
    ap.add_argument('--seed', type=int, default=None)
    ap.add_argument('--batch-path', type=str, default=None, help='Load a pre-saved batch (.pt)')
    ap.add_argument('--save-batch', type=str, default=None, help='Save generated batch (.pt)')
    ap.add_argument('--smooth-mean', action='store_true',
                    help='For AR modes: after sampling, re-run context-only mean on (xc,yc)+(xt, y_sample) as augmented context to produce smoothed trajectories; plot these and average their mean.')
    ap.add_argument('--out', type=str, default='outputs/tabular_sample_sequences.png')
    args = ap.parse_args()

    device = torch.device(args.device)
    model = build_model_from_ckpt(args.checkpoint, device)
    engine = TabularInferenceSDPA.from_trained_model(model, backend=args.attn_backend)

    # Seed
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # Data
    if args.batch_path:
        loaded = torch.load(args.batch_path, map_location='cpu')
        full = DataAttr(xc=loaded['xc'], yc=loaded['yc'], xb=loaded['xb'], yb=loaded['yb'], xt=loaded['xt'], yt=loaded['yt'])
    else:
        from src.data.gp_sampler import GPSampler
        sampler = GPSampler(device='cpu', dtype=torch.float32, noise_range=(args.noise, args.noise))
        full: DataAttr = sampler.generate_batch(
            batch_size=args.num_fns,
            num_context=args.nc,
            num_buffer=args.nb,
            num_target=args.nt,
        )
        if args.save_batch:
            Path(args.save_batch).parent.mkdir(parents=True, exist_ok=True)
            torch.save({'xc': full.xc.cpu(), 'yc': full.yc.cpu(), 'xb': full.xb.cpu(), 'yb': full.yb.cpu(), 'xt': full.xt.cpu(), 'yt': full.yt.cpu()}, args.save_batch)
            print(f"Saved batch to {args.save_batch}")

    xc_all, yc_all, xt_all, yt_all, ypred_all = [], [], [], [], []
    ysamples_all = []  # list of per-function sample arrays [S,T,1] or None
    ci_lower_list, ci_upper_list = [], []

    for i in range(args.num_fns):
        # Per-function normalization
        b1 = DataAttr(
            xc=full.xc[i:i+1].cpu(), yc=full.yc[i:i+1].cpu(),
            xb=full.xb[i:i+1].cpu(), yb=full.yb[i:i+1].cpu(),
            xt=full.xt[i:i+1].cpu(), yt=full.yt[i:i+1].cpu(),
        )
        scaler = TabICLScaler(normalization_method='power', outlier_threshold=4.0, random_state=0)
        xc2, yc2 = b1.xc.squeeze(0).numpy(), b1.yc.squeeze(0).numpy()
        xt2, yt2 = b1.xt.squeeze(0).numpy(), b1.yt.squeeze(0).numpy()
        scaler.fit(xc2, yc2)
        b1n = DataAttr(
            xc=torch.from_numpy(scaler.transform_x(xc2)).unsqueeze(0),
            yc=torch.from_numpy(scaler.transform_y(yc2)).unsqueeze(0),
            xb=torch.empty(1, 0, b1.xc.shape[-1]),
            yb=torch.empty(1, 0, b1.yc.shape[-1]),
            xt=torch.from_numpy(scaler.transform_x(xt2)).unsqueeze(0),
            yt=torch.from_numpy(scaler.transform_y(yt2)).unsqueeze(0),
        ).to(device)

        if args.decode_mode == 'all':
            with torch.no_grad():
                engine.precompute_rows(b1n)
                engine.build_context_kv()
                params = engine.predict_all_targets(return_params=True)
                mu = (params.weights * params.means).sum(dim=2)  # [1,T,1]
                second = (params.weights * (params.sds ** 2 + params.means ** 2)).sum(dim=2)
                var = (second - mu ** 2).clamp_min(0.0)
                ypred = mu
                ci_lo = mu - 1.96 * var.sqrt()
                ci_hi = mu + 1.96 * var.sqrt()
            ysamples_all.append(None)
            ci_lower_list.append(ci_lo.cpu().numpy())
            ci_upper_list.append(ci_hi.cpu().numpy())
        elif args.decode_mode == 'ar_reencode':
            # Batch expansion for AR MC
            bmc = _repeat_batch(b1n, args.n_mc).to(device)
            with torch.no_grad():
                engine.precompute_rows(bmc)
                engine.build_context_kv()
                preds = engine.decode_reencode(bmc)  # [n_mc,T,1]
                if args.smooth_mean:
                    # Augment context with sampled targets, run deterministic mean
                    Xc_aug = torch.cat([bmc.xc, bmc.xt], dim=1)
                    Yc_aug = torch.cat([bmc.yc, preds.yc], dim=1)
                    batch_smooth = DataAttr(
                        xc=Xc_aug, yc=Yc_aug,
                        xb=torch.empty(bmc.xc.shape[0], 0, bmc.xc.shape[-1], device=device, dtype=bmc.xc.dtype),
                        yb=torch.empty(bmc.yc.shape[0], 0, bmc.yc.shape[-1], device=device, dtype=bmc.yc.dtype),
                        xt=bmc.xt, yt=None,
                    )
                    engine.precompute_rows(batch_smooth)
                    engine.build_context_kv()
                    params = engine.predict_all_targets(return_params=True)
                    mu = (params.weights * params.means).sum(dim=2)  # [n_mc,T,1]
                    # Empirical 95% band across smoothed replicates
                    mu_np = mu.cpu().numpy()
                    q_lo = np.quantile(mu_np, 0.025, axis=0, keepdims=True)
                    q_hi = np.quantile(mu_np, 0.975, axis=0, keepdims=True)
                    ypred = mu.mean(dim=0, keepdim=True)
                    ci_lo = torch.from_numpy(q_lo)
                    ci_hi = torch.from_numpy(q_hi)
                    y_samps = mu[: min(10, args.n_mc)].cpu().numpy()
                else:
                    ypred = preds.yc.mean(dim=0, keepdim=True)  # [1,T,1]
                    # Empirical 95% CI from replicates
                    y_np = preds.yc.cpu().numpy()  # [n_mc,T,1]
                    q_lo = np.quantile(y_np, 0.025, axis=0, keepdims=True)
                    q_hi = np.quantile(y_np, 0.975, axis=0, keepdims=True)
                    ci_lo = torch.from_numpy(q_lo)
                    ci_hi = torch.from_numpy(q_hi)
                    y_samps = preds.yc[: min(10, args.n_mc)].cpu().numpy()
            ysamples_all.append(y_samps)
            ci_lower_list.append(ci_lo.cpu().numpy())
            ci_upper_list.append(ci_hi.cpu().numpy())
        elif args.decode_mode == 'ar_buffer':
            bmc = _repeat_batch(b1n, args.n_mc).to(device)
            assert args.K <= 32, "K must be <= max_buffer_size (32)"
            # Build per-replicate random permutations of target order
            T = b1n.xt.shape[1]
            order = torch.stack([torch.randperm(T) for _ in range(args.n_mc)], dim=0).to(device)
            with torch.no_grad():
                engine.precompute_rows(bmc)
                engine.build_context_kv()
                preds = engine.decode_ar_buffer_kchunk(bmc, K=args.K, order=order, refresh_mode=args.refresh_mode)  # [n_mc,T,1]
                if args.smooth_mean:
                    Xc_aug = torch.cat([bmc.xc, bmc.xt], dim=1)
                    Yc_aug = torch.cat([bmc.yc, preds.yc], dim=1)
                    batch_smooth = DataAttr(
                        xc=Xc_aug, yc=Yc_aug,
                        xb=torch.empty(bmc.xc.shape[0], 0, bmc.xc.shape[-1], device=device, dtype=bmc.xc.dtype),
                        yb=torch.empty(bmc.yc.shape[0], 0, bmc.yc.shape[-1], device=device, dtype=bmc.yc.dtype),
                        xt=bmc.xt, yt=None,
                    )
                    engine.precompute_rows(batch_smooth)
                    engine.build_context_kv()
                    params = engine.predict_all_targets(return_params=True)
                    mu = (params.weights * params.means).sum(dim=2)  # [n_mc,T,1]
                    mu_np = mu.cpu().numpy()
                    q_lo = np.quantile(mu_np, 0.025, axis=0, keepdims=True)
                    q_hi = np.quantile(mu_np, 0.975, axis=0, keepdims=True)
                    ypred = mu.mean(dim=0, keepdim=True)
                    ci_lo = torch.from_numpy(q_lo)
                    ci_hi = torch.from_numpy(q_hi)
                    y_samps = mu[: min(10, args.n_mc)].cpu().numpy()
                else:
                    ypred = preds.yc.mean(dim=0, keepdim=True)
                    y_np = preds.yc.cpu().numpy()
                    q_lo = np.quantile(y_np, 0.025, axis=0, keepdims=True)
                    q_hi = np.quantile(y_np, 0.975, axis=0, keepdims=True)
                    ci_lo = torch.from_numpy(q_lo)
                    ci_hi = torch.from_numpy(q_hi)
                    y_samps = preds.yc[: min(10, args.n_mc)].cpu().numpy()
            ysamples_all.append(y_samps)
            ci_lower_list.append(ci_lo.cpu().numpy())
            ci_upper_list.append(ci_hi.cpu().numpy())
        else:
            raise ValueError(f"Unknown decode-mode {args.decode_mode}")

        xc_all.append(b1n.xc.cpu().numpy())
        yc_all.append(b1n.yc.cpu().numpy())
        xt_all.append(b1n.xt.cpu().numpy())
        yt_all.append(b1n.yt.cpu().numpy())
        ypred_all.append(ypred.cpu().numpy())

    xc = np.concatenate(xc_all, axis=0)
    yc = np.concatenate(yc_all, axis=0)
    xt = np.concatenate(xt_all, axis=0)
    yt = np.concatenate(yt_all, axis=0)
    ypred = np.concatenate(ypred_all, axis=0)

    # Prepare samples and CI lists aligned with batch dimension
    y_samples_list = [ys if ys is not None else None for ys in ysamples_all]
    ci_lower = np.concatenate(ci_lower_list, axis=0) if ci_lower_list else None
    ci_upper = np.concatenate(ci_upper_list, axis=0) if ci_upper_list else None
    plot_batch(xc, yc, xt, yt, ypred, Path(args.out), y_samples=y_samples_list, ci_lower=ci_lower, ci_upper=ci_upper)
    print(f"Saved sequence sampling plot to {args.out}")


if __name__ == '__main__':
    main()
