"""
A script for testing and visualizing the performance of the Amortized Conditioning Engine (ACE) model.

This script provides two main functionalities:
1. Testing the model on a pre-generated offline dataset of Gaussian Process samples.
2. Testing the model on a simple, analytically defined 1D function (a combination of sin and cos).

It loads a pre-trained model checkpoint, runs autoregressive sampling to generate predictions,
and visualizes the results by plotting the predicted mean and standard deviation against the
true values and context points.
"""

import torch
import math
import numpy as np
from omegaconf import OmegaConf
from src.models.ace import AmortizedConditioningEngine, InferenceEngine2
from src.models.modules import Embedder, MixtureGaussian, Transformer
from torch.utils.data import DataLoader
from src.data.utils import OfflineBatchLoader
import matplotlib.pyplot as plt
from src.utils import DataAttr
from typing import Tuple, Any, Dict, List


def build_ace_model(config) -> AmortizedConditioningEngine:
    """Build model from config."""
    cfg = config.model

    embedder = Embedder(
        dim_x=cfg.dim_x,
        dim_y=cfg.dim_y,
        hidden_dim=cfg.embedder.hidden_dim,
        out_dim=cfg.dim_model,  # Use dim_model from config
        depth=cfg.embedder.depth,
    )

    backbone = Transformer(
        num_layers=cfg.backbone.num_layers,
        dim_model=cfg.dim_model,
        num_head=cfg.backbone.num_heads,
        dim_feedforward=cfg.backbone.dim_feedforward,
        dropout=cfg.backbone.dropout,
    )

    head = MixtureGaussian(
        dim_y=cfg.dim_y,
        dim_model=cfg.dim_model,
        dim_feedforward=cfg.head.dim_feedforward,
        num_components=cfg.head.num_components,
    )

    model = AmortizedConditioningEngine(
        embedder=embedder,
        backbone=backbone,
        head=head,
        max_buffer_size=cfg.max_buffer_size,
        targets_block_size_for_buffer_attend=cfg.targets_block_size_for_buffer_attend,
    )

    return model


def load_model(
    checkpoint_path: str,
    device: str = "cpu",
    model_builder: callable = build_ace_model,
    compile_model: bool = False,
) -> Tuple[AmortizedConditioningEngine, Any]:
    """
    Load a model from a checkpoint file.

    Args:
        checkpoint_path: Path to the model checkpoint.
        device: The device to load the model on.
        model_builder: A function to build the model architecture.
        compile_model: Whether to compile the model with torch.compile.

    Returns:
        A tuple containing the loaded model and its configuration.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = OmegaConf.create(checkpoint["config"])
    model = model_builder(config)
    state_dict = checkpoint["model_state_dict"]

    # Handle torch.compile prefix if present
    if any(key.startswith("_orig_mod.") for key in state_dict.keys()):
        # Remove _orig_mod. prefix from keys
        state_dict = {
            key.replace("_orig_mod.", ""): value for key, value in state_dict.items()
        }

    model.load_state_dict(state_dict)
    model = model.to(device)
    if compile_model:
        model = torch.compile(model)
        print("model compiled")
    else:
        print("running without compiled model")
    model.eval()
    return model, config


def load_offline_data(data_path: str, device: str = "cpu") -> DataLoader:
    """Load offline evaluation data."""
    dataset = OfflineBatchLoader(data_path, device=device)
    return DataLoader(
        dataset,
        batch_size=None,  # Pre-batched
        shuffle=False,
        num_workers=0,
    )


def _plot_predictions(
    ax: plt.Axes,
    batch_idx: int,
    seq_samples: list,
    batch: DataAttr,
    is_toy_example: bool,
):
    """Helper function to plot predictions for a single batch."""
    yt_samp = np.array([samp.yc[batch_idx].cpu().numpy() for samp in seq_samples])
    yt_mean = yt_samp.mean(axis=0).flatten()
    yt_std = yt_samp.std(axis=0).flatten()
    xt = seq_samples[0].xc[batch_idx].cpu().numpy().flatten()

    ctx_x, ctx_y = batch.xc[batch_idx].cpu().numpy(), batch.yc[batch_idx].cpu().numpy()
    true_x, true_y = (
        batch.xt[batch_idx].cpu().numpy(),
        batch.yt[batch_idx].cpu().numpy(),
    )

    ax.scatter(ctx_x, ctx_y, label="context", color="C0")
    if is_toy_example:
        ax.plot(true_x, true_y, color="k", label="true")
        ax.plot(xt, yt_mean, color="C1", label="pred mean")
        ax.fill_between(
            xt,
            yt_mean - yt_std,
            yt_mean + yt_std,
            alpha=0.3,
            color="C1",
            label="pred ±1 std",
        )
    else:
        ax.scatter(true_x, true_y, label="true")
        ax.errorbar(
            xt,
            yt_mean,
            yerr=yt_std,
            fmt="o",
            ecolor="gray",
            capsize=3,
            linestyle="None",
            label="pred (mean±std)",
            color="gray",
        )
    ax.set_title(f"Batch {batch_idx}")
    ax.legend(loc="best", fontsize="small")


def _generate_plots(
    seq_samples: List[DataAttr],
    batch: DataAttr,
    is_toy_example: bool,
    plot_config: Dict,
    filename: str,
    suptitle: str,
):
    """Generates and saves prediction plots."""
    n_rows, n_cols = plot_config["grid_size"]
    batch_indices = plot_config["batch_indices"]

    fig, axes = plt.subplots(
        n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), sharex=True, sharey=True
    )
    axes_flat = axes.flatten()

    for ax, batch_idx in zip(axes_flat, batch_indices):
        _plot_predictions(ax, batch_idx, seq_samples, batch, is_toy_example)

    for ax in axes_flat[len(batch_indices) :]:
        ax.axis("off")

    num_ar_samples = len(seq_samples)
    fig.suptitle(f"{suptitle} (n_ar_samp={num_ar_samples})", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(filename)
    plt.show()


def run_gp_offline_example(
    inference_engine: InferenceEngine2, device: str, k_decoding: int
):
    """Runs the offline GP data example."""
    print("Running GP offline data example...")
    data_path = "data/test_gp_data"
    data_loader = load_offline_data(data_path, device)

    batch = next(iter(data_loader))
    #inference_engine.prepare_inference_caches(batch, k_decoding)

    num_ar_samples = 100
    seq_samples = [
        inference_engine.sample_sequence(batch, K=k_decoding)
        for _ in range(num_ar_samples)
    ]

    n, m = 2, 3
    plot_config = {
        "grid_size": (n, m),
        "batch_indices": np.array(list(range(n * m))) + 6,
    }

    _generate_plots(
        seq_samples=seq_samples,
        batch=batch,
        is_toy_example=False,
        plot_config=plot_config,
        filename="gp_check_subplots.png",
        suptitle="Prediction by AR samples on GP data",
    )

def _build_mc_ar_batch(batch: DataAttr, num_ar_samples: int) -> DataAttr:
    """Builds a batch for randomized sequence autoregressive sampling."""

    assert batch.xc.shape[0] == 1 # ensure single batch for the input

    # copy xc to yc num_ar_samples times
    xc = batch.xc.repeat(num_ar_samples, 1, 1)
    yc = batch.yc.repeat(num_ar_samples, 1, 1)

    # repeat xt and yt num_ar_samples times
    xt = batch.xt.repeat(num_ar_samples, 1, 1)
    yt = batch.yt.repeat(num_ar_samples, 1, 1)

    # permuete xt and yt to randomize the order at each batch
    for b in range(num_ar_samples):
        perm = torch.randperm(xt.shape[1])
        xt[b] = xt[b, perm]
        yt[b] = yt[b, perm]

    mc_ar_batch = DataAttr(xc=xc, yc=yc, xt=xt, yt=yt)

    return mc_ar_batch


def simple_triangular_1d(x: torch.Tensor) -> torch.Tensor:
    """
    A single triangular wave at mid‐range frequency (4.0) over the domain [-2, 2],
    with no discontinuous change point.
    """
    f = 13.0
    # map x from [-2,2] to [0,1], then scale by freq
    u = (x + 2.0) * f / 4.0
    # triangular wave in [0,1]
    return 1.0 - torch.abs(2.0 * torch.frac(u) - 1.0)


def simple_sawtooth_1d(x: torch.Tensor) -> torch.Tensor:
    """
    A single sawtooth wave at mid‐range frequency (4.0) over the domain [-2, 2],
    with no discontinuous change point.
    """
    f = 52.0 # divide by 13 to check the freq [39, 52, 66]
    # map x from [-2,2] to [0,1], then scale by freq
    u = (x + 2.0) * f / 13.0
    # sawtooth wave in [0,1)
    return torch.frac(u)

def simple_sin_cos_1d(x: torch.Tensor) -> torch.Tensor:
        return torch.sin(x) + 0.5 * torch.cos(2 * x) - 0.3 * torch.sin(3 * x)


def run_toy_function_grid(
    inference_engine: InferenceEngine2,
    device: str,
    k_decoding: int,
    num_ctx_range: list = [8, 23],
    grid_dim: int = 4,       # number of rows and cols in the subplot grid
    n_mc: int = 100,         # Monte Carlo samples per subplot
    B: int = 1,              # batch size (kept at 1 for simplicity)
    N: int = 100,            # number of query points
    x_range: tuple = (-2, 2), # (start, end) for both context and query x's
    true_func = simple_sawtooth_1d,
    name = "default",
    plot_example_sample=False
):
    """Runs multiple toy-function examples in a grid of subplots with shared legend and variable context sizes."""
    print(f"Running {grid_dim}×{grid_dim} {name} toy-function examples...")
    start, end = x_range

    # Prepare the fixed query grid (shared across all subplots)
    Xq = torch.linspace(start, end, N, device=device).view(1, N, 1).repeat(B, 1, 1)
    Yq_true = true_func(Xq)

    # Set up the figure
    fig, axes = plt.subplots(
        grid_dim, grid_dim,
        figsize=(4 * grid_dim, 3 * grid_dim),
        sharex=True, sharey=True
    )
    axes = axes.flatten()

    for idx in range(grid_dim * grid_dim):
        ax = axes[idx]

        # Random number of context points between 3 and 10
        num_ctx_i = torch.randint(num_ctx_range[0], num_ctx_range[1], (1,)).item()

        # 1) Sample a new random context for this subplot
        Xctx = torch.rand(B, num_ctx_i, 1, device=device) * (end - start) + start
        Yctx = true_func(Xctx)

        # 2) Build batch and replicate for MC sampling
        batch = DataAttr(xc=Xctx, yc=Yctx, xt=Xq, yt=Yq_true)
        batch_mc = _build_mc_ar_batch(batch, n_mc)

        # 3) Run inference
        #inference_engine.prepare_inference_caches(batch_mc, k_decoding)
        seq = inference_engine.sample_sequence(batch_mc, K=k_decoding)

        # 4) Sort for plotting
        idx_sort = torch.argsort(seq.xc, dim=1)
        seq.xc = torch.gather(seq.xc, 1, idx_sort)
        seq.yc = torch.gather(seq.yc, 1, idx_sort)

        # 5) Compute mean and IQR
        mean_pred = seq.yc.mean(dim=0).cpu().numpy().flatten()
        q1 = seq.yc.quantile(0.25, dim=0).cpu().numpy().flatten()
        q3 = seq.yc.quantile(0.75, dim=0).cpu().numpy().flatten()
        x_plot = seq.xc[0, :, 0].cpu().numpy()

        if plot_example_sample:
            for i in range(1):
                ax.plot(x_plot, seq.yc[i,:,0], color="red")

        # 6) Plot true function, predictive mean, IQR, and contexts
        ax.plot(
            Xq.cpu().numpy().flatten(),
            Yq_true.cpu().numpy().flatten(),
            linestyle="--", label="true function"
        )
        ax.plot(x_plot, mean_pred, label="pred mean")
        ax.fill_between(x_plot, q1, q3, alpha=0.3, label="pred IQR")
        ax.scatter(
            Xctx.cpu().numpy().squeeze(),
            Yctx.cpu().numpy().squeeze(),
            color="black", s=30, label="context"
        )
        ax.set_title(f"Ex {idx+1} numctx {num_ctx_i}")
        if idx % grid_dim == 0:
            ax.set_ylabel("y")
        if idx // grid_dim == grid_dim - 1:
            ax.set_xlabel("x")

    # Create a single legend for all subplots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=4)

    fig.tight_layout(rect=[0, 0, 1, 0.93])  # adjust to make room for legend
    fig.suptitle(f"Prediction on ")
    fig.savefig(f"{name}_prediction_plots.png")
    plt.show()


import torch
import matplotlib.pyplot as plt

def run_toy_function_rows(
    inference_engine: InferenceEngine2,
    device: str,
    k_decoding: int,
    num_ctx_range: list = [8, 32],
    num_rows: int = 4,       # number of rows in the subplot grid
    n_mc: int = 32,         # Monte Carlo samples per subplot
    B: int = 1,              # batch size (kept at 1 for simplicity)
    N: int = 100,            # number of query points
    x_range: tuple = (-2, 2),# (start, end) for both context and query x's
    true_func = simple_sawtooth_1d,
    name = "default",
    n_AR_samples_to_plot = 3,
):
    """Runs toy-function examples in rows of 3 columns: first column shows the plot, second shows example MC samples, third is a placeholder."""
    print(f"Running {num_rows} rows × 3 columns {name} toy-function examples...")
    start, end = x_range

    # Prepare the fixed query grid (shared across all rows)
    Xq = torch.linspace(start, end, N, device=device).view(1, N, 1).repeat(B, 1, 1)
    Yq_true = true_func(Xq)

    # Set up the figure with fixed 3 columns
    fig, axes = plt.subplots(
        num_rows, 3,
        figsize=(4 * 3, 3 * num_rows),
        sharex=True, sharey=True
    )

    # Supertitle (larger font) above everything
    fig.suptitle(f"Prediction rows: {name}, K={k_decoding}", y=0.99, fontsize=16, fontweight='bold')

    # Add column titles (placeholder strings) just below the supertitle
    col_titles = ["Pred mean(IQR)", "AR MC samples", "Marginal TBD"]
    for col_idx, ct in enumerate(col_titles):
        fig.text((col_idx + 0.5) / 3, 0.90, ct, ha='center', va='bottom', fontsize=12, fontweight='bold')

    for row in range(num_rows):
        # -- First column: actual plot --
        ax = axes[row, 0]

        # Random number of context points between given range
        num_ctx_i = torch.randint(num_ctx_range[0], num_ctx_range[1], (1,)).item()

        # Sample new random context
        Xctx = torch.rand(B, num_ctx_i, 1, device=device) * (end - start) + start
        Yctx = true_func(Xctx)

        # Build batch and replicate for MC sampling
        batch = DataAttr(xc=Xctx, yc=Yctx, xt=Xq, yt=Yq_true)
        batch_mc = _build_mc_ar_batch(batch, n_mc)

        # Run inference
        #inference_engine.prepare_inference_caches(batch_mc, k_decoding)
        seq = inference_engine.sample_sequence(batch_mc, K=k_decoding)

        seq = inference_engine.evaluate_joint_loglikelihood(batch_mc, K=k_decoding)

        # Sort for plotting
        idx_sort = torch.argsort(seq.xc, dim=1)
        seq.xc = torch.gather(seq.xc, 1, idx_sort)
        seq.yc = torch.gather(seq.yc, 1, idx_sort)

        # Compute mean and IQR for context points
        mean_pred = seq.yc.mean(dim=0).cpu().numpy().flatten()
        q1 = seq.yc.quantile(0.25, dim=0).cpu().numpy().flatten()
        q3 = seq.yc.quantile(0.75, dim=0).cpu().numpy().flatten()
        x_plot = seq.xc[0, :, 0].cpu().numpy()

        # Plot true function, predictive mean, IQR, and contexts
        ax.plot(Xq.cpu().numpy().flatten(), Yq_true.cpu().numpy().flatten(), linestyle="--", label="true function", color="blue")
        ax.plot(x_plot, mean_pred, label="pred mean", color="orange")
        ax.fill_between(x_plot, q1, q3, alpha=0.3, label="pred IQR", color="orange")
        ax.scatter(Xctx.cpu().numpy().squeeze(), Yctx.cpu().numpy().squeeze(), color="black", s=30, label="context")
        ax.set_ylabel("y")
        if row == num_rows - 1:
            ax.set_xlabel("x")
        ax.set_title(f"Row {row+1}: numctx {num_ctx_i}")

        # -- Second column: plot first 3 MC sample trajectories --
        ax2 = axes[row, 1]
        ax2.plot(Xq.cpu().numpy().flatten(), Yq_true.cpu().numpy().flatten(), linestyle="--", label="true function", color="blue")
        for i in range(min(n_AR_samples_to_plot, seq.yc.shape[0])):
            ax2.plot(x_plot, seq.yc[i,:,0], alpha=0.5)

        # -- Third column: placeholder --
        ax3 = axes[row, 2]
        ax3.axis('off')  # TODO: implement this subplot in column 3

    # Create a single legend for the first column plots, placed below the rows
    handles, labels = axes[0,0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.02))

    # Adjust layout to make room for legend and titles
    fig.tight_layout(rect=[0, 0.03, 1, 0.92])
    fig.savefig(f"{name}_rows_prediction_plots.png", dpi=300)
    plt.show()



def main():
    """Main function to run the ACE model evaluation."""
    # --- Configuration ---
    DEVICE = "cpu"
    K_DECODING = 5  # decoding batch size
    RUN_GP_OFFLINE_EXAMPLE = False
    RUN_TOY_GP_FUNCTION_EXAMPLE = False
    RUN_TOY_SAWTOOTH_FUNCTION_EXAMPLE = True
    RUN_TOY_TRIANGLE_FUNCTION_EXAMPLE = False
    GP_CHECKPOINT_PATH = "checkpoints/toy_test/best_model_no_noise.pt"
    SAWTOOTH_CHECKPOINT_PATH = "checkpoints/sawtooth/best_model.pt"
    TRIANGLE_CHECKPOINT_PATH = "checkpoints/toy_test/best_model_triangle.pt"

    # --- Load Model ---
    gp_model, _ = load_model(GP_CHECKPOINT_PATH, DEVICE)
    gp_inference_engine = InferenceEngine2.from_trained_model(gp_model)
    gp_inference_engine = gp_inference_engine.to(DEVICE)

    sawtooth_model, _ = load_model(SAWTOOTH_CHECKPOINT_PATH, DEVICE)
    sawtooth_inference_engine = InferenceEngine2.from_trained_model(sawtooth_model, q_block_size=128, kv_block_size=128)
    sawtooth_inference_engine = sawtooth_inference_engine.to(DEVICE)

    triangle_model, _ = load_model(TRIANGLE_CHECKPOINT_PATH, DEVICE)
    triangle_inference_engine = InferenceEngine2.from_trained_model(triangle_model)
    triangle_inference_engine = triangle_inference_engine.to(DEVICE)

    

    # --- Run Examples ---
    if RUN_GP_OFFLINE_EXAMPLE:
        run_gp_offline_example(gp_inference_engine, DEVICE, K_DECODING, true_func=simple_sin_cos_1d)

    if RUN_TOY_GP_FUNCTION_EXAMPLE:
        run_toy_function_grid(gp_inference_engine, DEVICE, K_DECODING, true_func=simple_sin_cos_1d, name="gp")

    if RUN_TOY_SAWTOOTH_FUNCTION_EXAMPLE:
        run_toy_function_rows(sawtooth_inference_engine, DEVICE, K_DECODING, true_func=simple_sawtooth_1d, name="sawtooth", N=200, num_rows=5, num_ctx_range=[3, 8])

    if RUN_TOY_TRIANGLE_FUNCTION_EXAMPLE:
        run_toy_function_rows(triangle_inference_engine, DEVICE, K_DECODING, true_func=simple_triangular_1d, name="triangle", N=200, num_rows=5)



if __name__ == "__main__":
    main()