import os
import yaml
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import sys
sys.path.insert(0, '/raid/home/q615005/xfac/build/python')
import xfacpy

from stylesheet import *
from hamiltonian_xy_transverse import build_xy_oracles, batch_eval_jax
from config_loader import cfg
from gpr import run_passive_gpr, run_active_gpr

# ----------------------------
# 1) TCI Wrapper (Active Learning)
# ----------------------------
def run_tci(f_noisy, qgrid, budget: int, nBit: int, dim: int, bondDim: int, reltol: float):
    """Run TCI approximation with given budget and parameters."""
    print(f"[TCI] Starting TCI with budget={budget}, bondDim={bondDim}")

    args = xfacpy.TensorCI2Param()
    args.bondDim = int(bondDim)
    args.reltol = float(reltol)
    args.ncheckhistory = 500

    # xfacpy uses quantics bits; pivot1 length is nBit*dim
    R_bits = nBit * dim
    args.pivot1 = [1] * R_bits

    def fcaching(arg):
        a = np.asarray(arg, dtype=float)
        key = tuple(np.round(a, 14))  # stable hashing
        y = f_noisy(a)
        fcaching.sampled[key] = y
        return y

    fcaching.sampled = {}

    print(f"[TCI] Initializing (nBit={nBit}, dim={dim}, bondDim={args.bondDim}, budget={budget}) ...")
    ci = xfacpy.QTensorCI(f=fcaching, qgrid=qgrid, args=args)

    while len(fcaching.sampled) < budget:
        if ci.isDone():
            break
        ci.iterate()
        if len(fcaching.sampled) % 100 == 0:
            print(f"[TCI] Progress: {len(fcaching.sampled)}/{budget} samples collected")

    print(f"[TCI] Collected {len(fcaching.sampled)} samples.")
    qtt = ci.get_qtt()
    predictor = lambda Xq: np.array([qtt.eval(x) for x in Xq], dtype=float)

    sampled_pts = np.array(list(fcaching.sampled.keys()), dtype=float) if fcaching.sampled else np.zeros((0, dim))
    return predictor, sampled_pts


# ----------------------------
# 2) Evaluation grid utilities
# ----------------------------
def make_eval_grid(nBit: int):
    """
    Evaluate on a uniform grid in [0,1) x [0,1).
    Using endpoint=False avoids boundary effects and aligns well with quantics grids.
    """
    print(f"[Grid] Creating evaluation grid with {2**nBit}x{2**nBit} points")

    N = 2 ** int(nBit)
    axis = np.linspace(0.0, 1.0, N, endpoint=False)
    H, G = np.meshgrid(axis, axis, indexing="xy")
    pts = np.stack([H.ravel(), G.ravel()], axis=1)
    return axis, H, G, pts


# ----------------------------
# 3) Publishable plotting
# ----------------------------
def plot_publishable_figure_compact(
    axis, Z_true, Z_tci, Z_gp, xi,
    tci_samples, gp_samples,
    res_dir: str,
    gpr_mode: str = "passive",
):
    """Create Figure 2 - TCI vs GPR benchmark comparison."""
    print("[Figure 2] Creating TCI vs GPR benchmark plot")

    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D

    # Load parameter ranges from config
    config = cfg()
    h_min, h_max = config['H_MIN'], config['H_MAX']
    gamma_min, gamma_max = config['GAMMA_MIN'], config['GAMMA_MAX']

    # -------------------------
    # Reshape
    # -------------------------
    N = len(axis)
    Z_true = Z_true.reshape(N, N)
    Z_tci  = Z_tci.reshape(N, N)
    Z_gp   = Z_gp.reshape(N, N)
    xi_arr = xi.reshape(N, N)

    err_tci = np.abs(Z_true - Z_tci)
    err_gp  = np.abs(Z_true - Z_gp)
    err_diff = err_gp - err_tci   # >0 => TCI better

    # -------------------------
    # Physical axes
    # -------------------------
    extent = [h_min, h_max, gamma_min, gamma_max]

    gap_vmin = max(0.0, float(np.min(Z_true)))
    gap_vmax = float(np.percentile(Z_true, 99.5))

    diff_vmax = float(np.percentile(np.abs(err_diff), 99.0))

    # -------------------------
    # Layout
    # -------------------------
    fig = plt.figure(figsize=(12.8, 7))
    gs = fig.add_gridspec(
        2, 3,
        width_ratios=[1.0, 1.0, 0.05],
        wspace=0.32,
        hspace=0.35
    )

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    cax_gap = fig.add_subplot(gs[0, 2])

    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    cax_err = fig.add_subplot(gs[1, 2])

    # -------------------------
    # (a) True gap
    # -------------------------
    im1 = ax1.imshow(
        Z_true, origin="lower", extent=extent,
        vmin=gap_vmin, vmax=gap_vmax, aspect="auto"
    )
    ax1.set_title("(a) True XY gap")
    ax1.set_xlabel("Field $h$")
    ax1.set_ylabel("Anisotropy $\\gamma$")
    ax1.axvline(1.0, ls="--", lw=line_width/2, color="white")

    # Add parabola h = 1 - gamma^2
    gamma_parab = np.linspace(extent[2], extent[3], 200)
    h_parab = 1 - gamma_parab**2
    # Only plot points within the extent
    mask_parab = (h_parab >= extent[0]) & (h_parab <= extent[1])
    ax1.plot(h_parab[mask_parab], gamma_parab[mask_parab],
             color='black', linewidth=1.5, linestyle=':', label=r'$h=1-\gamma^2$')

    cb1 = fig.colorbar(im1, cax=cax_gap)
    cb1.set_label("Gap $\\Delta(h,\\gamma)$")

    # -------------------------
    # (b) TCI prediction + samples
    # -------------------------
    im2 = ax2.imshow(
        Z_tci, origin="lower", extent=extent,
        vmin=gap_vmin, vmax=gap_vmax, aspect="auto"
    )
    ax2.set_title("(b) TCI prediction")
    ax2.set_xlabel("Field $h$")
    ax2.set_ylabel("Anisotropy $\\gamma$")
    ax2.axvline(1.0, ls="--", lw=line_width/2, color="white")

    # Add parabola h = 1 - gamma^2
    gamma_parab = np.linspace(extent[2], extent[3], 200)
    h_parab = 1 - gamma_parab**2
    # Only plot points within the extent
    mask_parab = (h_parab >= extent[0]) & (h_parab <= extent[1])
    ax2.plot(h_parab[mask_parab], gamma_parab[mask_parab],
             color='black', linewidth=1.5, linestyle=':')

    if len(tci_samples) > 0:
        # Transform samples from [0,1]x[0,1] to physical parameter space
        h_samples = h_min + (h_max - h_min) * tci_samples[:, 0]
        gamma_samples = gamma_min + (gamma_max - gamma_min) * tci_samples[:, 1]
        ax2.scatter(
            h_samples,
            gamma_samples,
            s=12, color=color_palette['red'],
            edgecolors="white",
            linewidths=0.4,
            zorder=3
        )

    # -------------------------
    # (c) Error vs xi (correlation length)
    # -------------------------
    xi_flat = xi_arr.ravel()
    e_tci = err_tci.ravel()
    e_gp  = err_gp.ravel()

    xi_clipped = np.minimum(xi_flat, np.percentile(xi_flat, 99.5))
    bins = np.logspace(
        np.log10(np.percentile(xi_clipped, 1)),
        np.log10(np.percentile(xi_clipped, 99)),
        6,
    )
    centers = np.sqrt(bins[:-1] * bins[1:])

    def binned_stats(x, y, bins_):
        means = np.full(len(bins_) - 1, np.nan)
        stds = np.full(len(bins_) - 1, np.nan)
        counts = np.zeros(len(bins_) - 1, dtype=int)
        for i in range(len(bins_) - 1):
            m = (x >= bins_[i]) & (x < bins_[i + 1])
            counts[i] = np.sum(m)
            if counts[i] > 0:
                means[i] = float(np.mean(y[m]))
                stds[i] = float(np.std(y[m]))
        return means, stds, counts

    bt_mean, bt_std, cnt_t = binned_stats(xi_clipped, e_tci, bins)
    bg_mean, bg_std, cnt_g = binned_stats(xi_clipped, e_gp,  bins)

    ax3.plot(centers, bt_mean, marker=marker["circle"], linestyle='None',
             color=color_palette['blue'], markersize=marker_size,
             markeredgewidth=markeredgewidth, label="TCI")
    ax3.plot(centers, bg_mean, marker=marker["triangle"], linestyle='None',
             color=color_palette['red'], markersize=marker_size,
             markeredgewidth=markeredgewidth, label=f"GPR ({gpr_mode})")

    ax3.set_xscale("log")
    ax3.set_xlabel(r"Correlation length $\xi$")
    ax3.set_ylabel("Mean absolute error")
    ax3.set_title("(c) Error vs Correlation length")
    ax3.grid(True, which="both", alpha=0.3)
    ax3.legend(handlelength=handlelength)

    # -------------------------
    # (d) Error difference map
    # -------------------------
    im4 = ax4.imshow(
        err_diff,
        origin="lower",
        extent=extent,
        vmin=-diff_vmax,
        vmax=+diff_vmax,
        cmap="coolwarm",
        aspect="auto",
    )

    ax4.set_title(r"(d) Error difference")
    ax4.set_xlabel("Field $h$")
    ax4.set_ylabel("Anisotropy $\\gamma$")
    ax4.axvline(1.0, ls="--", lw=line_width/2, color="white")

    # Add parabola h = 1 - gamma^2
    gamma_parab = np.linspace(extent[2], extent[3], 200)
    h_parab = 1 - gamma_parab**2
    # Only plot points within the extent
    mask_parab = (h_parab >= extent[0]) & (h_parab <= extent[1])
    ax4.plot(h_parab[mask_parab], gamma_parab[mask_parab],
             color='black', linewidth=1.5, linestyle=':')
    cb2 = fig.colorbar(im4, cax=cax_err)
    cb2.set_label(r"$|\Delta_{\rm GPR}-\Delta|-|\Delta_{\rm TCI}-\Delta|$")

    # -------------------------
    # Global legend
    # -------------------------
    legend_handles = [
        Line2D([0], [0], marker=marker["circle"], color=color_palette['red'], ls="None",
               markersize=marker_size, markeredgewidth=markeredgewidth,
               markerfacecolor=color_palette['red'], markeredgecolor="white",
               label="TCI samples"),
        Line2D([0], [0], color="black", ls=":", lw=1.5, label=r"$h=1-\gamma^2$"),
        Line2D([0], [0], color="white", ls="--", lw=line_width, label="$h=1$"),
    ]
    ax1.legend(
        handles=legend_handles,
        loc="upper left",
        handlelength=handlelength,
    )

    fig.subplots_adjust(top=0.90)

    out_pdf = os.path.join(res_dir, "tci_vs_gpr.pdf")
    print(f"[Figure 2] Saving benchmark plot as {out_pdf}")
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)


def plot_error_comparison(
    axis, Z_true, Z_tci, Z_gp,
    tci_samples, gp_samples,
    res_dir: str,
    gpr_mode: str = "passive",
    error_scale: str = "linear",
):
    """Create side-by-side comparison of TCI and GPR errors."""
    print(f"[Error Plot] Creating TCI vs GPR error comparison ({error_scale} scale)")

    # Load parameter ranges from config
    config = cfg()
    h_min, h_max = config['H_MIN'], config['H_MAX']
    gamma_min, gamma_max = config['GAMMA_MIN'], config['GAMMA_MAX']

    # -------------------------
    # Reshape and compute errors
    # -------------------------
    N = len(axis)
    Z_true = Z_true.reshape(N, N)
    Z_tci  = Z_tci.reshape(N, N)
    Z_gp   = Z_gp.reshape(N, N)

    err_tci = np.abs(Z_true - Z_tci)
    err_gp  = np.abs(Z_true - Z_gp)

    # -------------------------
    # Physical axes
    # -------------------------
    extent = [h_min, h_max, gamma_min, gamma_max]

    # -------------------------
    # Color scaling
    # -------------------------
    if error_scale == "linear":
        err_vmin = 0.0
        err_vmax = float(np.percentile(np.maximum(err_tci, err_gp), 95.0))
        norm = None
        cmap = "viridis"

    elif error_scale == "log":
        # avoid log(0)
        eps = 1e-10
        err_tci = np.maximum(err_tci, eps)
        err_gp = np.maximum(err_gp, eps)
        err_vmin = np.min(np.maximum(np.minimum(err_tci, err_gp), eps))
        err_vmax = float(np.max(np.maximum(err_tci, err_gp)))
        norm = LogNorm(vmin=err_vmin, vmax=err_vmax)
        cmap = "viridis"

    else:
        raise ValueError("error_scale must be 'linear' or 'log'")

    # -------------------------
    # Layout
    # -------------------------
    fig = plt.figure(figsize=(12.8, 4.8))
    gs = fig.add_gridspec(
        1, 3,
        width_ratios=[1.0, 1.0, 0.05],
        wspace=0.32,
    )

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    cax = fig.add_subplot(gs[0, 2])

    # -------------------------
    # (a) TCI error
    # -------------------------
    im1 = ax1.imshow(
        err_tci,
        origin="lower",
        extent=extent,
        vmin=None if norm else err_vmin,
        vmax=None if norm else err_vmax,
        norm=norm,
        cmap=cmap,
        aspect="auto",
    )
    ax1.set_title("(a) TCI absolute error")
    ax1.set_xlabel("Field $h$")
    ax1.set_ylabel("Anisotropy $\\gamma$")
    ax1.axvline(1.0, ls="--", lw=line_width/2, color="white")

    # Add TCI samples
    if len(tci_samples) > 0:
        # Transform samples from [0,1]x[0,1] to physical parameter space
        h_samples = h_min + (h_max - h_min) * tci_samples[:, 0]
        gamma_samples = gamma_min + (gamma_max - gamma_min) * tci_samples[:, 1]
        ax1.scatter(
            h_samples,
            gamma_samples,
            s=12, color=color_palette['red'],
            edgecolors="white",
            linewidths=0.4,
            zorder=3
        )

    # -------------------------
    # (b) GPR error
    # -------------------------
    im2 = ax2.imshow(
        err_gp,
        origin="lower",
        extent=extent,
        vmin=None if norm else err_vmin,
        vmax=None if norm else err_vmax,
        norm=norm,
        cmap=cmap,
        aspect="auto",
    )
    ax2.set_title(f"(b) GPR ({gpr_mode}) absolute error")
    ax2.set_xlabel("Field $h$")
    ax2.set_ylabel("Anisotropy $\\gamma$")
    ax2.axvline(1.0, ls="--", lw=line_width/2, color="white")

    # Add GPR samples
    if len(gp_samples) > 0:
        # Transform samples from [0,1]x[0,1] to physical parameter space
        h_samples = h_min + (h_max - h_min) * gp_samples[:, 0]
        gamma_samples = gamma_min + (gamma_max - gamma_min) * gp_samples[:, 1]
        ax2.scatter(
            h_samples,
            gamma_samples,
            s=12, color=color_palette['red'],
            edgecolors="white",
            linewidths=0.4,
            zorder=3
        )

    # -------------------------
    # Legend for samples
    # -------------------------
    from matplotlib.lines import Line2D
    if len(tci_samples) > 0 or len(gp_samples) > 0:
        legend_handles = [
            Line2D([0], [0], marker=marker["circle"], color=color_palette['red'], ls="None",
                   markersize=marker_size, markeredgewidth=markeredgewidth,
                   markerfacecolor=color_palette['red'], markeredgecolor="white",
                   label="Training samples"),
        ]
        ax1.legend(
            handles=legend_handles,
            loc="upper left",
            handlelength=handlelength,
        )

    # -------------------------
    # Colorbar (shared)
    # -------------------------
    cb = fig.colorbar(im2, cax=cax)
    cb.set_label("Absolute error $|\\Delta_{\\rm true} - \\Delta_{\\rm pred}|$")
    if error_scale == "log":
        cb.ax.set_yscale("log")

    fig.subplots_adjust(top=0.85)

    suffix = "log" if error_scale == "log" else "linear"
    out_pdf = os.path.join(res_dir, f"error_comparison_{suffix}.pdf")
    print(f"[Error Plot] Saving error comparison plot as {out_pdf}")
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)


# ----------------------------
# 5) Main
# ----------------------------
def main():
    """Main function to run TCI vs GPR benchmark (Figure 2)."""
    # Load configuration
    config = cfg()

    print("=" * 70)
    print("STARTING TCI vs GPR BENCHMARK (Figure 2)")
    print("=" * 70)

    # Output directory
    stamp = time.strftime("%Y%m%d-%H%M%S")
    res_dir = os.path.join(config['RES_DIR'], "xy_benchmark", f"run_{stamp}")
    os.makedirs(res_dir, exist_ok=True)

    # Build exact XY oracles
    gap_fn, xi_fn = build_xy_oracles()

    # Quantics grid for TCI
    qgrid = xfacpy.QuanticsGrid(a=0.0, b=1.0, dim=config['DIM'], nBit=config['N_BIT'])

    # Train predictors
    print("\n[Training] Starting TCI training...")
    tci_pred, tci_samples = run_tci(
        f_noisy=gap_fn,
        qgrid=qgrid,
        budget=config['BUDGET'],
        nBit=config['N_BIT'],
        dim=config['DIM'],
        bondDim=config['BOND_DIM'],
        reltol=config['REL_TOL']
    )

    print(f"\n[Training] Starting GPR training ({config.get('GPR_MODE', 'passive')} mode)...")
    gpr_mode = config.get('GPR_MODE', 'passive')

    if gpr_mode == "active":
        print("[Training] Preparing evaluation grid for active learning...")
        axis_eval, _, _, pts_eval = make_eval_grid(config['N_BIT'])
        Z_true_eval = batch_eval_jax(gap_fn, pts_eval, batch=8192)

        gpr, gp_samples, gpr_losses = run_active_gpr(
            gap_fn=gap_fn,
            eval_pts=pts_eval,
            Z_true_flat=Z_true_eval,
            budget=config['BUDGET'],
            seed=config['GP_SEED']
        )
        print(f"[Training] Active learning completed. Final MAE: {gpr_losses[-1]:.5f}")
    elif gpr_mode == "passive":
        gpr, gp_samples = run_passive_gpr(
            gap_fn=gap_fn,
            budget=config['BUDGET'],
            seed=config['GP_SEED']
        )
        gpr_losses = None
    else:
        raise ValueError(f"Invalid GPR_MODE: {gpr_mode}. Must be 'passive' or 'active'.")

    gp_pred = lambda Xq: gpr.predict(np.asarray(Xq), return_std=False).astype(float)

    # Evaluate on grid
    print("\n[Evaluation] Creating evaluation grid and computing ground truth...")
    axis, H, G, pts = make_eval_grid(config['N_BIT'])

    Z_true = batch_eval_jax(gap_fn, pts, batch=8192)
    xi = batch_eval_jax(xi_fn, pts, batch=8192)

    # Thermodynamic limit: clip extreme xi values for numerical stability
    # Use percentile-based clipping
    print(f"[Evaluation] Raw xi range: [{np.min(xi):.2e}, {np.max(xi):.2e}]")
    xi = np.minimum(xi, np.percentile(xi, 99.5))
    print(f"[Evaluation] Clipped xi range: [{np.min(xi):.2e}, {np.max(xi):.2e}]")

    print("[Evaluation] Computing TCI and GPR predictions...")
    Z_tci = tci_pred(pts)
    Z_gp = gp_pred(pts)

    # Include learning curve data for active learning
    if gpr_mode == "active" and gpr_losses is not None:
        config["gpr_learning_curve"] = gpr_losses.tolist()

    print("\n" + "=" * 72)
    print(f"XY GAP BENCHMARK RESULTS - using Correlation Length xi")
    print(f"GPR Mode: {gpr_mode.upper()}")
    print("=" * 72 + "\n")

    with open(os.path.join(res_dir, "config.yaml"), "w") as f:
        yaml.dump(config, f, default_flow_style=False, indent=2)

    # Publishable figure
    plot_publishable_figure_compact(
        axis=axis,
        Z_true=Z_true, Z_tci=Z_tci, Z_gp=Z_gp,
        xi=xi,
        tci_samples=tci_samples,
        gp_samples=gp_samples,
        res_dir=res_dir,
        gpr_mode=gpr_mode,
    )

    # Error comparison plot
    plot_error_comparison(
        axis=axis,
        Z_true=Z_true, Z_tci=Z_tci, Z_gp=Z_gp,
        tci_samples=tci_samples,
        gp_samples=gp_samples,
        res_dir=res_dir,
        gpr_mode=gpr_mode,
        error_scale="log",
    )

    plot_error_comparison(
        axis=axis,
        Z_true=Z_true, Z_tci=Z_tci, Z_gp=Z_gp,
        tci_samples=tci_samples,
        gp_samples=gp_samples,
        res_dir=res_dir,
        gpr_mode=gpr_mode,
        error_scale="linear",
    )

    print(f"\n[COMPLETED] Results saved to: {res_dir}")
    print(f"            - tci_vs_gpr.pdf")
    print(f"            - error_comparison.pdf")
    print(f"            - config.yaml")


if __name__ == "__main__":
    main()
