import argparse
import json
import logging
import math
import os
from typing import Callable
import random
import sys
import time
from dataclasses import dataclass
from typing import Dict, List
import matplotlib.ticker as ticker

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel
import matplotlib.pyplot as plt
import matplotlib as mpl
from cycler import cycler

# mpl.rcParams['font.family'] = 'serif'        # or 'sans-serif', 'monospace', etc.
# mpl.rcParams['font.serif'] = ['Times New Roman']  # preferred font list
mpl.rcParams['font.size'] = 9              # default font size
mpl.rcParams['axes.titlesize'] = 11         # axes title font size
mpl.rcParams['axes.labelsize'] = 9         # x and y labels
mpl.rcParams['xtick.labelsize'] = 9
mpl.rcParams['ytick.labelsize'] = 9
mpl.rcParams['legend.fontsize'] = 7

# Paul Tol’s color-blind safe palette
colors = ["#0072B2", "#D55E00", "#009E73", 
          "#F0E442", "#56B4E9", "#E69F00", "#CC79A7"]

plt.rcParams['axes.prop_cycle'] = cycler(color=colors)


# --------------------------- Logging --------------------------- #

def setup_logging(out_dir: str, filename: str = "eval_synth_tr.log"):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, filename)
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(path, mode="w"),
            logging.StreamHandler(sys.stdout)
        ]
    )
    logging.info(f"Logging to {path}")
    return path

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ---------------------- GP Prior (Synthetic) ------------------- #

def rbf_kernel(x1: torch.Tensor, x2: torch.Tensor, lengthscale: float, variance: float) -> torch.Tensor:
    diff = x1.unsqueeze(1) - x2.unsqueeze(0)  # [N, M, D]
    dist_sq = (diff ** 2).sum(-1)             # [N, M]
    return variance * torch.exp(-0.5 * dist_sq / (lengthscale ** 2))

def get_gp_prior(num_datasets: int, num_points: int, num_features: int, hp: dict, device: torch.device):
    xs = torch.rand(num_datasets, num_points, num_features, device=device)
    ys = []
    for i in range(num_datasets):
        x = xs[i]
        K = rbf_kernel(x, x, hp.get('lengthscale', 0.6), hp.get('kernel_variance', 0.01))
        K += (hp.get('output_noise', 1e-2) ** 2) * torch.eye(num_points, device=device)
        mean = torch.full((num_points,), hp.get('mean_shift', 1.0), device=device)
        y = torch.distributions.MultivariateNormal(mean, K).sample()
        ys.append(y)
    ys = torch.stack(ys, dim=0)  # [B, N]
    return xs, ys

# ----------------- Bucket helpers (as in training) ------------- #

@dataclass
class BucketInfo:
    limits: torch.Tensor   # [K+1]
    mids: torch.Tensor     # [K]
    widths: torch.Tensor   # [K]

    @staticmethod
    def from_limits(limits: torch.Tensor) -> "BucketInfo":
        widths = limits[1:] - limits[:-1]
        mids = (limits[1:] + limits[:-1]) / 2.0
        return BucketInfo(limits=limits, mids=mids, widths=widths)

def compute_bucket_nll(logits: torch.Tensor, y: torch.Tensor, bl: torch.Tensor) -> torch.Tensor:
    # y_to_bucket_idx
    target_idx = torch.searchsorted(bl, y) - 1
    target_idx[y <= bl[0]] = 0
    target_idx[y >= bl[-1]] = len(bl) - 2

    widths = bl[1:] - bl[:-1]
    log_probs = torch.log_softmax(logits, dim=-1) - torch.log(widths)

    gathered = log_probs.gather(-1, target_idx[..., None]).squeeze(-1)

    # edge corrections (half-normal)
    def halfnorm(range_max, p=0.5):
        s = range_max / torch.distributions.HalfNormal(torch.tensor(1., device=y.device)).icdf(torch.tensor(p, device=y.device))
        return torch.distributions.HalfNormal(s)

    hn_left, hn_right = halfnorm(widths[0]), halfnorm(widths[-1])
    first = target_idx == 0
    if first.any():
        gathered[first] += hn_left.log_prob((bl[1] - y[first])).clamp(min=1e-9) + torch.log(widths[0])
    last = target_idx == len(widths) - 1
    if last.any():
        gathered[last] += hn_right.log_prob((y[last] - bl[-2])).clamp(min=1e-9) + torch.log(widths[-1])

    return -gathered.mean()

def bucket_expectation(logits: torch.Tensor, bi: BucketInfo) -> torch.Tensor:
    probs = torch.softmax(logits, dim=-1)
    return torch.sum(probs * bi.mids, dim=-1)

def bucket_variance(logits: torch.Tensor, bi: BucketInfo) -> torch.Tensor:
    probs = torch.softmax(logits, dim=-1)
    mean = torch.sum(probs * bi.mids, dim=-1)
    spread = (bi.mids - mean.unsqueeze(-1)) ** 2
    within = (bi.widths ** 2) / 9.0
    return torch.sum(probs * (spread + within), dim=-1)

# --------------------------- Models --------------------------- #
# Paste the model to be evaluated here (You can copy the codes from model files)

def uniform_normalize(x: torch.Tensor) -> torch.Tensor:
    return (x - 0.5) / math.sqrt(1/9)

class Encoder(nn.Module):
    """Typical self attention module in transformer"""
    def __init__(self, d_model, n_heads, n_hidden, dropout=0.0):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.out = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, d_model),
        )
        
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, src, trainset_size):
        src_left, src_right = src[:, :trainset_size], src[:, trainset_size:]
        x_left = self.self_attn(src_left, src_left, src_left)[0] # all train points to each other
        x_right = self.self_attn(src_right, src_left, src_left)[0] # test points attend to train points
        x = torch.cat([x_left, x_right], dim=1)
        x = self.norm1(src + self.dropout(x))
        return self.norm2(self.dropout(self.out(x)) + x)


class Transformer(nn.Module):
    def __init__(self, num_features, n_out, n_layers=2, d_model=512, n_heads=4, n_hidden=1024, dropout=0.0, normalize=lambda x:x):
        super().__init__()
        
        self.x_encoder = nn.Linear(num_features, d_model)
        self.y_encoder = nn.Linear(1, d_model)
        
        self.model = nn.ModuleList(
            [Encoder(d_model, n_heads, n_hidden, dropout) for _ in range(n_layers)]
        )
        
        self.out = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.GELU(),
            nn.Linear(n_hidden, n_out)
        )
                
        self.normalize = normalize
        self.init_weights()
        
    def forward(self, x, y, trainset_size):
        """
        Args:
            x: num_datasets x number_of_points x num_features
            y: num_datasets x number_of_points
            trainset_size: int specifying the number of points to use as training dataset size
        
        Returns:
            outputs for each x
        """
        x_src = self.x_encoder(self.normalize(x))
        y_src = self.y_encoder(y)
        
        src = torch.cat([x_src[:, :trainset_size] + y_src[:, :trainset_size], x_src[:, trainset_size:]], dim=1)
        for encoder in self.model:
            src = encoder(src, trainset_size)

        return self.out(src)
    
    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

# --------------------------- Metrics --------------------------- #

def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    eps = 1e-9
    mse = float(np.mean((y_true - y_pred) ** 2))
    mae = float(np.mean(np.abs(y_true - y_pred)))
    mape = float(np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), eps, None))) * 100.0)
    max_err = float(np.max(np.abs(y_true - y_pred)))
    return {"mse": mse, "mae": mae, "mape_pct": mape, "max_error": max_err}

# --------------------- Single-dataset eval --------------------- #

def eval_once_models_and_gp(X: np.ndarray, Y: np.ndarray, n_train: int, device: torch.device,
                            tr: nn.Module, bi: BucketInfo) -> Dict:
    """Order points [train, test] and evaluate tr/GP. Returns metrics & predictions on test."""
    N = X.shape[0]
    assert 1 <= n_train < N
    # random order
    order = np.random.permutation(N)
    Xo, Yo = X[order], Y[order]
    X_train, Y_train = Xo[:n_train], Yo[:n_train]
    X_test,  Y_test  = Xo[n_train:],  Yo[n_train:]

    # tensors for NN
    x_t = torch.from_numpy(Xo).float().unsqueeze(0).to(device)
    y_t = torch.from_numpy(Yo).float().unsqueeze(0).unsqueeze(-1).to(device)

    # TR
    with torch.no_grad():
        t0 = time.perf_counter()
        logits_tr = tr(x_t, y_t, trainset_size=n_train)
        torch.cuda.synchronize() if device.type == "cuda" else None
        t1 = time.perf_counter()
        time_tr_ms = (t1 - t0) * 1e3
        mean_tr = bucket_expectation(logits_tr, bi)[0, n_train:].detach().cpu().numpy()
    nll_tr = float(compute_bucket_nll(logits_tr[0, n_train:], torch.from_numpy(Y_test).float().to(device), bi.limits).item())
    
    true_signal_var = 0.001
    true_noise_var = (1e-4)**2
    kernel = ConstantKernel(constant_value=true_signal_var, constant_value_bounds=(1e-5, 1e1)) \
        * RBF(length_scale=0.6, length_scale_bounds=(1e-2, 1e1)) \
        + WhiteKernel(noise_level=true_noise_var, noise_level_bounds=(1e-9, 1e-1))

    # GP baseline (sklearn)
    gp = GaussianProcessRegressor(kernel=kernel, normalize_y=False, random_state=0, n_restarts_optimizer=15)
    t0 = time.perf_counter()
    gp.fit(X_train, Y_train)
    t1 = time.perf_counter()
    gp_train_ms = (t1 - t0) * 1e3
    t0 = time.perf_counter()
    gp_mean, gp_std = gp.predict(X_test, return_std=True)
    t1 = time.perf_counter()
    gp_pred_ms = (t1 - t0) * 1e3

    # Metrics
    m_tr = compute_regression_metrics(Y_test, mean_tr)
    m_gp  = compute_regression_metrics(Y_test, gp_mean)

    m_tr["bucket_nll"] = nll_tr

    return {
        "order": order,
        "n_train": n_train,
        "Y_test": Y_test,
        "pred_tr": mean_tr,
        "pred_gp": gp_mean,
        "metrics": {"tr": m_tr, "gp": m_gp},
        "timings_ms": {"tr_infer": time_tr_ms, "gp_train": gp_train_ms, "gp_predict": gp_pred_ms, "gp_total": gp_train_ms + gp_pred_ms}
    }

# ------------------------- Predefined funcs -------------------- #

def make_nd_functions(d: int):
    """Return dict of name -> callable mapping x [N,D] to y [N], scaled to mean~1, low variance."""
    def center_scale(y, target_std=0.05, target_mean=1.0):
        y = y - y.mean()
        std = y.std() + 1e-9
        y = y / std * target_std
        return y + target_mean

    def f_linear(x):
        return (x.sum(dim=1) / math.sqrt(d))
    def f_quadratic(x):
        return ((x - 0.5) ** 2).sum(dim=1)
    def f_trigs(x):
        # average of sin & cos terms
        return (torch.sin(2*math.pi*x).sum(dim=1) + torch.cos(2*math.pi*x).sum(dim=1)) / (2*d)

    funcs = {}
    for name, f in [("linear", f_linear), ("quadratic", f_quadratic), ("trig", f_trigs)]:
        def g(f=f):
            def wrapped(x):
                base = f(x)
                return center_scale(base)
            return wrapped
        funcs[name] = g()
    return funcs

# ------------------------------ Main --------------------------- #

def main():
    args = parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    log_path = setup_logging(args.out_dir)
    set_seed(args.seed)
    device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu"))
    logging.info(f"Using device: {device}")

    hp = {
        "lengthscale": args.hp_lengthscale,
        "kernel_variance": args.hp_kernel_variance,
        "output_noise": args.hp_output_noise,
        "mean_shift": args.hp_mean_shift,
    }

    # Load bucket limits
    bl = torch.load(args.bucket_path, map_location=device)
    if bl.ndim != 1:
        bl = bl.view(-1)
    bi = BucketInfo.from_limits(bl.to(device))
    K = bi.mids.numel()
    logging.info(f"Loaded bucket limits with K={K}, y-range≈[{float(bl.min()):.4f}, {float(bl.max()):.4f}]")

    # Load the model here
    # Enter the hyperparameters used during training (check from the specific model file)
    tr=Transformer(
        num_features, n_out=num_outputs,
        d_model=d_model, n_layers=n_layers,
        n_hidden=n_hidden, n_heads=n_heads,
        normalize=uniform_normalize
    ).to(device)
    tr.load_state_dict(torch.load(args.tr_path, map_location=device))
    tr.eval()

    n_train_grid: List[int] = [int(x) for x in args.n_train_grid.split(",") if x.strip()]
    n_train_grid = [n for n in n_train_grid if 1 <= n < args.N]
    logging.info(f"n_train grid: {n_train_grid} | datasets per n: {args.num_datasets} | N_total={args.N}")

    # Storage
    curves = {
        "MSE": {"tr": [], "gp": []},
        "MAE": {"tr": [], "gp": []},
        "MAX": {"tr": [], "gp": []},
        "time_ms": {"tr": [], "gp_train": [], "gp_predict": []},
    }

    # One representative plot data
    rep_plot = None
    rep_n = None

    # Loop over n_train
    for idx_n, n_train in enumerate(n_train_grid):
        mses = {"tr": [], "gp": []}
        maes = {"tr": [], "gp": []}
        maxe = {"tr": [], "gp": []}
        tms  = {"tr": [], "gp_train": [], "gp_predict": []}

        # Generate synthetic datasets in batches of 1 to keep things simple
        for j in range(args.num_datasets):
            xs, ys = get_gp_prior(1, args.N, args.n_features, hp, device=device)
            X = xs[0].detach().cpu().numpy()
            Y = ys[0].detach().cpu().numpy()

            out = eval_once_models_and_gp(X, Y, n_train, device, tr, bi)
            mses["tr"].append(out["metrics"]["tr"]["mse"])
            mses["gp"].append(out["metrics"]["gp"]["mse"])
            maes["tr"].append(out["metrics"]["tr"]["mae"])
            maes["gp"].append(out["metrics"]["gp"]["mae"])
            maxe["tr"].append(out["metrics"]["tr"]["max_error"])
            maxe["gp"].append(out["metrics"]["gp"]["max_error"])
            tms["tr"].append(out["timings_ms"]["tr_infer"])
            tms["gp_train"].append(out["timings_ms"]["gp_train"])
            tms["gp_predict"].append(out["timings_ms"]["gp_predict"])

            # save representative plot data from first dataset for first n_train
            if idx_n == 0 and j == 0:
                rep_plot = out
                rep_n = n_train

        # average & store
        for m in ["tr", "gp"]:
            curves["MSE"][m].append((n_train, float(np.mean(mses[m]))))
            curves["MAE"][m].append((n_train, float(np.mean(maes[m]))))
            curves["MAX"][m].append((n_train, float(np.mean(maxe[m]))))
        for k in ["tr", "gp_train", "gp_predict"]:
            curves["time_ms"][k].append((n_train, float(np.mean(tms[k]))))

        logging.info(f"n_train={n_train} | MSE avg: TR={np.mean(mses['tr']):.5e} GP={np.mean(mses['gp']):.5e}")
        logging.info(
            f"Timings (ms) avg @ n_train={n_train}: "
            f"TR infer={np.mean(tms['tr']):.3f}, "
            f"GP train={np.mean(tms['gp_train']):.3f}, "
            f"GP predict={np.mean(tms['gp_predict']):.3f}, "
            f"GP total={np.mean([a+b for a,b in zip(tms['gp_train'], tms['gp_predict'])]):.3f}"
        )

    # ----------------- Plots: metrics vs n_train ----------------- #

    def plot_metric_grid(metric_key: str, outfile_base: str):
        plt.figure(figsize=(3.6,2.6))
        label_map = {
            "tr": "TR",
            "gp": "GP"}
        color_map = {"tr": "red", "gp": "green"}
        for name in ["tr", "gp"]:
            xs = [a for a,_ in curves[metric_key][name]]
            ys = [b for _,b in curves[metric_key][name]]
            plt.plot(
                xs, ys, 
                color=color_map[name], 
                label=label_map[name],
                linewidth=1.2  # thicker line, optional
            )
        plt.xlabel(r"$N_{\text{context}}$")
        plt.ylabel(metric_key)
        plt.title(f"10D")
        plt.legend()
        ax = plt.gca()
        ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
        ax.ticklabel_format(axis="y", style="sci", scilimits=(0,0))  # always sci notation
        ax.yaxis.get_offset_text().set_fontsize(8)  # smaller offset text if you like

        plt.tight_layout()
        png = os.path.join(args.out_dir, f"{outfile_base}.png")
        pdf = os.path.join(args.out_dir, f"{outfile_base}.pdf")
        plt.savefig(png, dpi=200)
        plt.savefig(pdf)
        plt.close()
        return png, pdf

    mse_png, mse_pdf = plot_metric_grid("MSE", "mse_vs_n_train")
    mae_png, mae_pdf = plot_metric_grid("MAE", "mae_vs_n_train")
    max_png, max_pdf = plot_metric_grid("MAX", "maxerr_vs_n_train")

    # Optional: timings vs n_train figure (not requested but useful)
    plt.figure(figsize=(3.6,2.6))
    for key,label in [("tr","TR infer"), ("gp_train","GP train"), ("gp_predict","GP predict")]:
        xs = [a for a,_ in curves["time_ms"][key]]
        ys = [b for _,b in curves["time_ms"][key]]
        plt.plot(xs, ys, marker='o', label=label)
    plt.xlabel("Number of training points (n_train)")
    plt.ylabel("Time (ms)")
    plt.title("Timing vs n_train")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    t_png = os.path.join(args.out_dir, "timings_vs_n_train.png")
    t_pdf = os.path.join(args.out_dir, "timings_vs_n_train.pdf")
    plt.savefig(t_png, dpi=200); plt.savefig(t_pdf); plt.close()

    # --------------- Representative predictions plot ------------- #
    if rep_plot is not None:
        idxs = np.arange(len(rep_plot["Y_test"]))
        plt.figure(figsize=(3.6,2.6))
        plt.plot(idxs, rep_plot["Y_test"], label="True Y", linewidth=2)
        plt.plot(idxs, rep_plot["pred_tr"], label="TR", linestyle="--")
        plt.plot(idxs, rep_plot["pred_gp"], label="Exact GP", linestyle="-.")
        plt.xlabel("Test sample index")
        plt.ylabel("Y")
        plt.title(f"Predictions vs True (synthetic GP) | n_train={rep_n}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        pv_png = os.path.join(args.out_dir, f"predictions_vs_true_ntrain_{rep_n}.png")
        pv_pdf = os.path.join(args.out_dir, f"predictions_vs_true_ntrain_{rep_n}.pdf")
        plt.savefig(pv_png, dpi=200); plt.savefig(pv_pdf); plt.close()
    else:
        pv_png = pv_pdf = None

    # ------------------- Predefined ND functions ------------------ #
    funcs = make_nd_functions(args.n_features)
    func_artifacts = {}
    n_train_func = max(5, min(args.N-1, int(0.7 * args.N)))  # 70/30 split default

    for fname, f in funcs.items():
        # Build dataset
        X = torch.rand(args.N, args.n_features)
        Y = f(X).detach().cpu().numpy()
        X = X.detach().cpu().numpy()

        out = eval_once_models_and_gp(X, Y, n_train_func, device, tr, bi)

        # Plot
        idxs = np.arange(len(out["Y_test"]))
        plt.figure(figsize=(3.6,2.6))
        plt.plot(idxs, out["Y_test"], label="True Y", linewidth=2)
        plt.plot(idxs, out["pred_tr"], label="TR", linestyle="--")
        plt.plot(idxs, out["pred_gp"], label="Exact GP", linestyle="-.")
        plt.xlabel("Test sample index")
        plt.ylabel("Y")
        plt.title(f"{fname} function | n_train={n_train_func}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        f_png = os.path.join(args.out_dir, f"{fname}_pred_vs_true.png")
        f_pdf = os.path.join(args.out_dir, f"{fname}_pred_vs_true.pdf")
        plt.savefig(f_png, dpi=200); plt.savefig(f_pdf); plt.close()

        func_artifacts[fname] = {
            "metrics": out["metrics"],
            "plot_png": f_png, "plot_pdf": f_pdf
        }

    # ------------------------- Save summary ----------------------- #
    summary = {
        "n_features": args.n_features,
        "N_total": args.N,
        "grid": n_train_grid,
        "datasets_per_n": args.num_datasets,
        "timing_curves_ms": curves["time_ms"],
        "metric_curves": curves,
        "artifacts": {
            "mse_vs_n_train_png": mse_png, "mse_vs_n_train_pdf": mse_pdf,
            "mae_vs_n_train_png": mae_png, "mae_vs_n_train_pdf": mae_pdf,
            "maxerr_vs_n_train_png": max_png, "maxerr_vs_n_train_pdf": max_pdf,
            "timings_vs_n_train_png": t_png, "timings_vs_n_train_pdf": t_pdf,
            "rep_predictions_png": pv_png, "rep_predictions_pdf": pv_pdf,
        },
        "functions": func_artifacts,
        "log_file": log_path
    }
    with open(os.path.join(args.out_dir, "metrics_synth.json"), "w") as f:
        json.dump(summary, f, indent=2)

    logging.info("=== Completed synthetic ND evaluation ===")
    logging.info(json.dumps({
        "mse_vs_n_train": curves["MSE"],
        "mae_vs_n_train": curves["MAE"],
        "max_vs_n_train": curves["MAX"]
    }, indent=2))
    logging.info(f"Artifacts in {args.out_dir}")

def parse_args():
    p = argparse.ArgumentParser(description="Evaluate Transformer vs Exact GP on synthetic ND data and ND functions")
    p.add_argument("--model_path", type=str, default = "/model_path")
    p.add_argument("--bucket_path", type=str, default = "/buckets_path")
    p.add_argument("--out_dir", type=str, default = "10D_trans_dva_eval")
    p.add_argument("--n_features", type=int, default = 10)
    p.add_argument("--N", type=int, default=500, help="Total points per dataset")
    p.add_argument("--num_datasets", type=int, default=100, help="Datasets to average per n_train")
    p.add_argument("--n_train_grid", type=str, default="100,150,200,250,300,350,400,420,450", help="Comma-separated grid for n_train")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", type=str, default="cpu")
    # GP prior hyperparams
    p.add_argument("--hp_lengthscale", type=float, default=0.6)
    p.add_argument("--hp_kernel_variance", type=float, default=0.001)
    p.add_argument("--hp_output_noise", type=float, default=1e-4)
    p.add_argument("--hp_mean_shift", type=float, default=1.0)
    return p.parse_args()

if __name__ == "__main__":
    main()