import numpy as np
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jax import jit, vmap, config

# Enable 64-bit precision for physical accuracy
config.update("jax_enable_x64", True)

from stylesheet import *
from config_loader import cfg

def get_gap_oracle():
    """
    Returns JAX-compiled function for computing the XY model gap.

    Returns:
        Function that takes normalized coordinates [0,1]^2 and returns gap
    """
    print(f"[XY Model] Initializing gap oracle")

    # Load parameter ranges from config
    config_params = cfg()
    h_min, h_max = config_params['H_MIN'], config_params['H_MAX']
    gamma_min, gamma_max = config_params['GAMMA_MIN'], config_params['GAMMA_MAX']

    @jit
    def gap_fn(arg):
        h = h_min + (h_max - h_min) * arg[0]  # [0,1] -> [h_min, h_max]
        gamma = gamma_min + (gamma_max - gamma_min) * arg[1]  # [0,1] -> [gamma_min, gamma_max]
        gap_sq_0 = (h - 1.0)**2
        gap_sq_pi = (h + 1.0)**2
        min_boundary_gap_sq = jnp.minimum(gap_sq_0, gap_sq_pi)

        denom = 1.0 - gamma**2
        safe_denom = jnp.where(jnp.abs(denom) < 1e-10, 1.0, denom)
        x_vertex = h / safe_denom

        valid_vertex = (jnp.abs(gamma) < 1.0) & (jnp.abs(x_vertex) <= 1.0)

        gap_sq_vertex = (h - x_vertex)**2 + gamma**2 * (1.0 - x_vertex**2)

        final_gap_sq = jnp.where(valid_vertex, 
                                 jnp.minimum(min_boundary_gap_sq, gap_sq_vertex), 
                                 min_boundary_gap_sq)

        return 2.0 * jnp.sqrt(jnp.maximum(0.0, final_gap_sq))

    return gap_fn


def get_xi_oracle():
    """
    Returns JAX-compiled function for computing the XY model correlation length.

    Returns:
        Function that takes normalized coordinates [0,1]^2 and returns correlation length
    """
    print(f"[XY Model] Initializing correlation length oracle")

    # Load parameter ranges from config
    config_params = cfg()
    h_min, h_max = config_params['H_MIN'], config_params['H_MAX']
    gamma_min, gamma_max = config_params['GAMMA_MIN'], config_params['GAMMA_MAX']

    @jit
    def xi_fn(arg):
        h = h_min + (h_max - h_min) * arg[0]  # [0,1] -> [h_min, h_max]
        gamma = gamma_min + (gamma_max - gamma_min) * arg[1]  # [0,1] -> [gamma_min, gamma_max]

        A = 1.0 + gamma
        B = -2.0 * h
        C = 1.0 - gamma

        Dis = B * B - 4.0 * A * C
        sqrtD = jnp.sqrt(Dis + 0j)

        denom = jnp.where(jnp.abs(2.0 * A) < 1e-16, 1e-16, 2.0 * A)
        z1 = (-B + sqrtD) / denom
        z2 = (-B - sqrtD) / denom

        mod1 = jnp.maximum(jnp.abs(z1), 1e-16)
        mod2 = jnp.maximum(jnp.abs(z2), 1e-16)

        min_log = jnp.minimum(jnp.abs(jnp.log(mod1)), jnp.abs(jnp.log(mod2)))
        xi = 1.0 / (min_log + 1e-9)

        return xi

    return xi_fn

def build_xy_oracles():
    """
    Build exact XY model oracles for gap and correlation length.

    Returns:
        Tuple of (gap_fn, xi_fn) functions
    """
    print(f"[XY Model] Building oracles")

    gap_fn = get_gap_oracle()
    xi_fn = get_xi_oracle()

    return gap_fn, xi_fn


def plot_correlation_length_xy_2d(res=300):
    """
    Plot the analytic correlation length xi(h, gamma) as a 2D phase diagram.
    This creates Figure 3 from the paper.
    """
    print(f"[Figure 3] Creating 2D correlation length plot with resolution {res}")

    # Load parameter ranges from config
    config_params = cfg()
    h_min, h_max = config_params['H_MIN'], config_params['H_MAX']
    gamma_min, gamma_max = config_params['GAMMA_MIN'], config_params['GAMMA_MAX']

    h_vals = np.linspace(h_min + 0.01, h_max, res)
    g_vals = np.linspace(gamma_min + 0.01, gamma_max, res)
    H, G = np.meshgrid(h_vals, g_vals)

    A = 1 + G
    B = -2 * H
    C = 1 - G

    Discriminant = B**2 - 4*A*C
    sqrt_D = np.sqrt(Discriminant.astype(complex))

    z1 = (-B + sqrt_D) / (2*A)
    z2 = (-B - sqrt_D) / (2*A)

    log_mod_z1 = np.abs(np.log(np.abs(z1)))
    log_mod_z2 = np.abs(np.log(np.abs(z2)))

    min_log_mod = np.minimum(log_mod_z1, log_mod_z2)
    xi = 1.0 / (min_log_mod + 1e-10)

    plt.figure(figsize=(6.4, 4.8))
    im = plt.imshow(
        np.log10(xi),
        origin="lower",
        extent=[h_min, h_max, gamma_min, gamma_max],
        cmap="viridis",
        aspect="auto",
    )
    plt.colorbar(im, label="$\\log_{10} \\, \\xi$")
    plt.xlabel("Transverse field $h$")
    plt.ylabel("Anisotropy $\\gamma$")
    plt.title("XY model correlation length")

    # Critical lines
    plt.axvline(1.0, color="black", linestyle="--", alpha=0.5, label="$h=1$")
    theta = np.linspace(0, np.pi / 2, 200)
    plt.plot(np.cos(theta), np.sin(theta), "w:", alpha=0.5, label="$h^2 + \\gamma^2 = 1$")
    plt.legend(loc="upper right")
    plt.tight_layout()

    # Ensure results directory exists
    import os
    os.makedirs("results", exist_ok=True)

    print("[Figure 3] Saving correlation length plot in results/")
    plt.savefig("results/xy_correlation_length.pdf")


def batch_eval_jax(f_jit, pts: np.ndarray, batch: int = 8192):
    """
    Vectorized evaluation of a jitted JAX scalar function on many points.
    Uses batching to avoid huge JAX compilation overheads for very large arrays.
    """
    print(f"[Batch Eval] Processing {len(pts)} points in batches of {batch}")

    f_vm = jit(vmap(f_jit))
    out = []
    for i in range(0, len(pts), batch):
        chunk = jnp.asarray(pts[i:i+batch])
        out.append(np.asarray(f_vm(chunk)))
    return np.concatenate(out, axis=0)

if __name__ == "__main__":
    print("[Main] Starting XY model analysis")
    plot_correlation_length_xy_2d()
    print("[Main] XY model analysis completed")
