import os
import numpy as np
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, ConstantKernel

from stylesheet import *
from hamiltonian_xy_transverse import get_gap_oracle, get_xi_oracle
from config_loader import cfg


# ============================================================
# 1) Active learning (variance-based GP)
# ============================================================
def run_active_gpr(
    gap_fn,
    eval_pts,
    Z_true_flat,
    budget,
    batch_size=1,
    init_n=20,
    pool_size=5000,
    seed=0,
):
    """
    Run active learning using Gaussian Process with variance-based acquisition.

    Args:
        gap_fn: Function to evaluate true values
        eval_pts: Evaluation points for computing MAE
        Z_true_flat: Ground truth values for MAE computation
        budget: Total number of samples to collect
        batch_size: Number of points to acquire per iteration
        init_n: Initial number of random samples
        pool_size: Size of candidate pool for acquisition
        seed: Random seed

    Returns:
        gpr: Trained GaussianProcessRegressor object
        X: Training points
        losses: Array of MAE values during training
    """
    print(f"[Active GPR] Starting active learning with budget={budget}")

    # Load kernel parameters from config
    config_params = cfg()

    rng = np.random.default_rng(seed)

    X = rng.uniform(0, 1, size=(init_n, 2))
    y = np.array([gap_fn(xi) for xi in X])

    kernel = ConstantKernel() * Matern(length_scale=0.2, nu=config_params['MATERN_NU'])
    gpr = GaussianProcessRegressor(
        kernel=kernel,
        normalize_y=True,
        n_restarts_optimizer=5,
        random_state=seed
    )

    losses = []

    print("[Active GPR] Starting active learning loop")
    while len(X) < budget:
        gpr.fit(X, y)

        Z_pred = gpr.predict(eval_pts)
        mae = np.mean(np.abs(Z_pred - Z_true_flat))
        losses.append(mae)

        print(f"[Active GPR] Samples {len(X):4d}/{budget} | MAE = {mae:.5f}")

        X_pool = rng.uniform(0, 1, size=(pool_size, 2))
        _, sigma = gpr.predict(X_pool, return_std=True)

        idx = np.argsort(sigma)[-batch_size:]
        X_next = X_pool[idx]
        y_next = np.array([gap_fn(X_next[0])])

        X = np.vstack([X, X_next])
        y = np.concatenate([y, y_next])

        if len(X) % 50 == 0:
            print(f"[Active GPR] Progress: {len(X)}/{budget} samples collected")

    gpr.fit(X, y)
    print("[Active GPR] Training finished\n")
    return gpr, X, np.array(losses)


def run_passive_gpr(gap_fn, budget, seed=0):
    """
    Run passive learning using random sampling.

    Args:
        gap_fn: Function to evaluate true values
        budget: Total number of samples to collect
        seed: Random seed

    Returns:
        gpr: Trained GaussianProcessRegressor object
        X: Training points
    """
    print(f"[Passive GPR] Starting passive learning with budget={budget}")

    # Load kernel parameters from config
    config_params = cfg()

    rng = np.random.default_rng(seed)

    print("[Passive GPR] Sampling random points")
    X = rng.uniform(0, 1, size=(budget, 2))
    y = np.array([gap_fn(xi) for xi in X])

    kernel = ConstantKernel() * Matern(length_scale=0.2, nu=config_params['MATERN_NU'])
    gpr = GaussianProcessRegressor(
        kernel=kernel,
        normalize_y=True,
        n_restarts_optimizer=5,
        random_state=seed
    )

    print("[Passive GPR] Fitting GP")
    gpr.fit(X, y)
    print("[Passive GPR] Training finished\n")
    return gpr, X


def binned_mean(x, y, bins):
    """Compute binned mean statistics."""
    out = []
    for i in range(len(bins) - 1):
        m = (x >= bins[i]) & (x < bins[i + 1])
        out.append(np.mean(y[m]) if np.any(m) else np.nan)
    return np.array(out)


def create_gpr_comparison_figure(
    gap_fn, xi_fun, budget=500, res=120
):
    """
    Create Figure 5 - GPR active vs passive learning comparison.

    This function creates a comprehensive comparison between active and passive
    Gaussian Process Regression on the XY model gap prediction task.
    """
    print("[Figure 5] Creating GPR active vs passive comparison")

    # Load parameter ranges from config
    config_params = cfg()
    h_min, h_max = config_params['H_MIN'], config_params['H_MAX']
    gamma_min, gamma_max = config_params['GAMMA_MIN'], config_params['GAMMA_MAX']
    extent = [h_min, h_max, gamma_min, gamma_max]

    # Create evaluation grid
    print("[Figure 5] Creating evaluation grid")
    grid = np.linspace(0, 1, res)
    H, G = np.meshgrid(grid, grid)
    eval_pts = np.stack([H.ravel(), G.ravel()], axis=1)

    print("[Figure 5] Evaluating ground truth")
    Z_true_flat = np.array([gap_fn(x) for x in eval_pts])
    Z_true = Z_true_flat.reshape(res, res)

    # Train models
    print("[Figure 5] Training passive GPR model...")
    gpr_passive, X_passive = run_passive_gpr(gap_fn, budget, seed=0)

    print("[Figure 5] Training active GPR model...")
    gpr_active, X_active, losses = run_active_gpr(
        gap_fn, eval_pts, Z_true_flat, budget, seed=1
    )

    print("[Figure 5] Evaluating predictions")
    Zp_passive = gpr_passive.predict(eval_pts).reshape(res, res)
    Zp_active = gpr_active.predict(eval_pts).reshape(res, res)

    print("[Figure 5] Computing correlation-length diagnostics")
    xi = np.array([xi_fun(x) for x in eval_pts])
    err = np.abs(gpr_active.predict(eval_pts) - Z_true_flat)
    _, sigma = gpr_active.predict(eval_pts, return_std=True)

    xi_clip = np.minimum(xi, np.percentile(xi, 99.5))
    bins = np.logspace(
        np.log10(np.percentile(xi_clip, 1)),
        np.log10(np.percentile(xi_clip, 99)),
        7,
    )
    centers = np.sqrt(bins[:-1] * bins[1:])

    err_b = binned_mean(xi_clip, err, bins)
    sig_b = binned_mean(xi_clip, sigma, bins)

    # ============================================================
    # Create Figure 5 (4-panel layout)
    # ============================================================
    print("[Figure 5] Creating figure layout")
    fig = plt.figure(figsize=(12.8, 7))

    gs = fig.add_gridspec(
        2, 3,
        width_ratios=[1.0, 1.0, 0.05],   # last column = colorbar
        wspace=0.32,
        hspace=0.35
    )

    # Axes
    ax_ul = fig.add_subplot(gs[0, 0])   # passive surface
    ax_ur = fig.add_subplot(gs[0, 1])   # active surface
    cax   = fig.add_subplot(gs[0, 2])   # colorbar

    ax_ll = fig.add_subplot(gs[1, 0])   # learning curve
    ax_lr = fig.add_subplot(gs[1, 1])   # error vs xi

    # -------------------------
    # Upper left: Passive GPR
    # -------------------------
    im = ax_ul.imshow(
        Zp_passive,
        origin="lower",
        extent=extent,
        cmap="viridis",
        vmin=Z_true.min(),
        vmax=Z_true.max(),
        aspect="auto"
    )
    # Transform samples from [0,1]x[0,1] to physical parameter space
    h_passive = h_min + (h_max - h_min) * X_passive[:, 0]
    gamma_passive = gamma_min + (gamma_max - gamma_min) * X_passive[:, 1]
    ax_ul.scatter(
        h_passive,
        gamma_passive,
        s=12, color=color_palette['red'],
        edgecolors="white",
        linewidths=0.4
    )
    ax_ul.set_title("(a) Passive GPR")
    ax_ul.set_xlabel("Field $h$")
    ax_ul.set_ylabel("Anisotropy $\\gamma$")

    # -------------------------
    # Upper right: Active GPR
    # -------------------------
    ax_ur.imshow(
        Zp_active,
        origin="lower",
        extent=extent,
        cmap="viridis",
        vmin=Z_true.min(),
        vmax=Z_true.max(),
        aspect="auto"
    )
    # Transform samples from [0,1]x[0,1] to physical parameter space
    h_active = h_min + (h_max - h_min) * X_active[:, 0]
    gamma_active = gamma_min + (gamma_max - gamma_min) * X_active[:, 1]
    ax_ur.scatter(
        h_active,
        gamma_active,
        s=12, color=color_palette['red'],
        edgecolors="white",
        linewidths=0.4
    )
    ax_ur.set_title("(b) Active GPR")
    ax_ur.set_xlabel("Field $h$")

    # -------------------------
    # Colorbar (shared, external)
    # -------------------------
    cb = fig.colorbar(im, cax=cax)
    cb.set_label("Spectral gap $\\Delta$")

    # -------------------------
    # Lower left: Learning curve
    # -------------------------
    ax_ll.plot(losses, lw=line_width, color=color_palette['blue'], label="Active GPR")
    ax_ll.axhline(
        np.mean(np.abs(gpr_passive.predict(eval_pts) - Z_true_flat)),
        ls="--", lw=line_width, color=color_palette['red'], label="Passive GPR"
    )
    ax_ll.set_title("(c) Learning curve")
    ax_ll.set_xlabel("Active learning iteration")
    ax_ll.set_ylabel("Mean absolute error")
    ax_ll.legend(frameon=False)

    # -------------------------
    # Lower right: Error vs correlation length
    # -------------------------
    ax_lr.plot(centers, err_b, marker["circle"], lw=line_width,
              color=color_palette['blue'], label="Error", markersize=marker_size)
    ax_lr.plot(centers, sig_b, marker["triangle"], lw=line_width,
              color=color_palette['red'], label="Predictive std. dev.", markersize=marker_size)
    ax_lr.set_xscale("log")
    ax_lr.set_title("(d) Active GPR: error vs uncertainty")
    ax_lr.set_xlabel("Correlation length $\\xi$")
    ax_lr.set_ylabel("Gap scale ($\\Delta$ units)")
    ax_lr.grid(True, which="both", alpha=0.3)
    ax_lr.legend(frameon=False)

    # -------------------------
    # Finalize
    # -------------------------
    plt.tight_layout()

    # Ensure results directory exists
    os.makedirs("results", exist_ok=True)

    print("[Figure 5] Saving GPR comparison figure")
    plt.savefig("results/gpr_passive_vs_active.pdf")

    return fig


# ============================================================
# Main function for Figure 5
# ============================================================
def main():
    """
    Main function to create Figure 5 - Active vs Passive GPR comparison.
    """
    print("=" * 70)
    print("CREATING FIGURE 5: GPR ACTIVE VS PASSIVE COMPARISON")
    print("=" * 70)

    print("[Main] Initializing gap_fns")
    gap_fn = get_gap_oracle()
    xi_fun = get_xi_oracle()

    # Load configuration
    config = cfg()
    BUDGET = config['GPR_BUDGET']
    RES = 2 ** config['N_BIT']

    # Create the comparison figure
    fig = create_gpr_comparison_figure(
        gap_fn=gap_fn,
        xi_fun=xi_fun,
        budget=BUDGET,
        res=RES
    )

    print("[Main] Figure 5 creation completed successfully")


if __name__ == "__main__":
    main()
