from typing import Optional
import matplotlib.pyplot as plt
from cycler import cycler
from pathlib import Path
from matplotlib import font_manager as fm
import matplotlib as mpl
import jax
import jax.numpy as jnp
from matplotlib.colors import LogNorm
from jax.scipy.stats import gaussian_kde
from hfm.simulation.utils import compute_dihedral_batch

SINGLE_COLUMN_WIDTH = 3.25
DOUBLE_COLUMN_WIDTH = 6.75


def import_scientific_palette():
    # Default color palette.
    scientific_palette = [
        "#926DFF",  # Purple (Intense)
        "#C3B0FA",  # Purple (Soft)
        "#EF7175",  # Red (Intense)
        "#F7B8BA",  # Red (Soft)
        "#2876F4",  # Blue (Intense)
        "#8EB5F4",  # Blue (Soft)
        "#24DBFF",  # Cyan (Intense)
        "#91EDFF",  # Cyan (Soft)
        "#FDEABE",  # Yellow (Soft)
        "#787878",  # Grey (Soft)
    ]
    plt.rc("axes", prop_cycle=(cycler(color=scientific_palette)))
    return scientific_palette


def register_fonts(font_dir=Path.home() / "fonts"):
    # Recursively find all .ttf files under ~/fonts
    if font_dir.exists():
        ttf_files = sorted(font_dir.rglob("*.ttf"))

        # Register them with Matplotlib's font manager
        for f in ttf_files:
            try:
                fm.fontManager.addfont(str(f))
            except Exception as e:
                print(f"[warn] could not add {f}: {e}")

        # (Optional but recommended) rebuild the internal lookup table
        fm._load_fontmanager(try_read_cache=False)


def setup_plotting_style():
    mpl.rcParams.update(
        {
            "figure.figsize": (3.25, 2.4),
            "font.size": 7,  # base size
            "axes.labelsize": 7,
            "axes.titlesize": 8,
            "legend.fontsize": 6,
            "xtick.labelsize": 6,
            "ytick.labelsize": 6,
            "font.family": "sans-serif",
            "font.sans-serif": ["Helvetica", "Arial"],
            "lines.linewidth": 1,
            "lines.markersize": 3,
            "axes.linewidth": 0.5,  # thinner black borders around the plot
            "xtick.major.width": 0.4,  # thinner xtick lineticks
            "ytick.major.width": 0.4,  # thinner ytick lineticks
            "xtick.minor.width": 0.3,  # thinner minor xtick lineticks
            "ytick.minor.width": 0.3,  # thinner minor ytick lineticks
            "axes.labelpad": 1,  # set x and y axis labelpad to 0
            "xtick.major.pad": 1,
            "xtick.minor.pad": 1,
            "ytick.major.pad": 1,
            "ytick.minor.pad": 1,
            "patch.linewidth": 0.3,  # default edge linewidth for scatter, bar, etc
            "axes.titlepad": 3,  # reduce the distance from the title to the plot
        }
    )


def setup_plotting():
    register_fonts()
    import_scientific_palette()
    setup_plotting_style()


def compute_phi_psi(xs, phi_indices, psi_indices, subsample=None, key=None):
    if subsample is not None:
        assert key is not None, "key must be provided if subsample is provided"
        xs = xs[jax.random.permutation(key, xs.shape[0])[:subsample]]
    phi = compute_dihedral_batch(
        xs[:, phi_indices[0]],
        xs[:, phi_indices[1]],
        xs[:, phi_indices[2]],
        xs[:, phi_indices[3]],
    )
    psi = compute_dihedral_batch(
        xs[:, psi_indices[0]],
        xs[:, psi_indices[1]],
        xs[:, psi_indices[2]],
        xs[:, psi_indices[3]],
    )

    # Assuming phi and psi are currently in degrees, convert to radians and wrap to (-pi, pi)
    return jnp.deg2rad(phi), jnp.deg2rad(psi)


def compute_histogram_range(
    xs, phi_indices, psi_indices, bins=60, subsample=None, key=None
):
    phi = compute_dihedral_batch(
        xs[:, phi_indices[0]],
        xs[:, phi_indices[1]],
        xs[:, phi_indices[2]],
        xs[:, phi_indices[3]],
    )
    psi = compute_dihedral_batch(
        xs[:, psi_indices[0]],
        xs[:, psi_indices[1]],
        xs[:, psi_indices[2]],
        xs[:, psi_indices[3]],
    )
    phi, psi = jnp.deg2rad(phi), jnp.deg2rad(psi)

    return compute_histogram_range_phi_psi(
        phi, psi, bins=bins, subsample=subsample, key=key
    )


def compute_histogram_range_phi_psi(phi, psi, bins=60, subsample=None, key=None):
    phi, psi = subsample_fn(phi, psi, subsample, key)
    H, _, _ = jnp.histogram2d(
        phi, psi, bins=bins, range=[[-jnp.pi, jnp.pi], [-jnp.pi, jnp.pi]]
    )
    # Update min/max values directly
    nonzero_min = H[H > 0].min() if jnp.any(H > 0) else float("inf")
    return nonzero_min, H.max()


def subsample_fn(x, y, subsample, key):
    assert x.shape[0] == y.shape[0], "x and y must have the same number of samples"
    if subsample is not None:
        assert key is not None, "key must be provided if subsample is provided"
        indices = jax.random.permutation(key, x.shape[0])[:subsample]
        x = x[indices]
        y = y[indices]
    return x, y


def plot_2d(
    samples_x,
    samples_y,
    title=None,
    range=None,
    highlight=None,
    bins=60,
    vmin=None,
    vmax=None,
    cmap=None,
    free_energy_bar=False,
    attach_cbar=True,
    fig=None,
    ax=None,
    subsample=None,
    key=None,
    **kwargs,
):
    assert (fig is None and ax is None) or (fig is not None and ax is not None), (
        "Both fig and ax must be either both None or both provided."
    )
    if fig is None:
        fig, ax = plt.subplots()

    samples_x, samples_y = subsample_fn(samples_x, samples_y, subsample, key)

    if title is not None:
        ax.set_title(title)
    H, _, _, _ = ax.hist2d(
        samples_x,
        samples_y,
        bins=bins,
        norm=LogNorm(vmin=vmin, vmax=vmax),
        range=range,
        rasterized=True,
        cmap=cmap,
        **kwargs,
    )

    if vmin is None:
        vmin = H[H > 0].min() if jnp.any(H > 0) else vmin
    if vmax is None:
        vmax = H.max()

    cbar_obj = None
    if free_energy_bar:
        # we assume a positive energy
        e_min, e_max = 0, jnp.log(vmax / H.sum()) - jnp.log(vmin / H.sum())
        sm = mpl.cm.ScalarMappable(
            norm=mpl.colors.Normalize(vmin=e_min, vmax=e_max),
            cmap=mpl.colormaps[cmap].reversed(),
        )

        if attach_cbar:
            # attach the colorbar to the axes
            cbar_obj = fig.colorbar(sm, ax=ax, extend="max")
            cbar_obj.set_label(r"Energy / $k_BT$")
        else:
            cbar_obj = sm

    if range is not None:
        ax.set_xlim(*range[0])
        ax.set_ylim(*range[1])

    ax.set_box_aspect(1)

    if highlight is not None:
        # plot a star for each highlight
        for h in highlight:
            ax.plot(samples_x[h], samples_y[h], "*", markersize=8, alpha=0.8)

    return fig, ax, cbar_obj


def plot_phi_psi(
    xs,
    phi_indices,
    psi_indices,
    title=None,
    highlight=None,
    bins=100,
    vmin=None,
    vmax=None,
    subsample=None,
    key=None,
    cmap="turbo_r",
    fig=None,
    ax=None,
    **kwargs,
):
    range = [-jnp.pi, jnp.pi]
    assert xs.ndim == 3, "xs must be a Batch x N x 3 array"
    assert xs.shape[-1] == 3, "xs must have 3 coordinates per atom"

    phi, psi = compute_phi_psi(xs, phi_indices, psi_indices)

    fig = fig if fig is not None else plt.gcf()
    ax = ax if ax is not None else plt.gca()

    ret = plot_2d(
        phi,
        psi,
        title,
        [range, range],
        highlight,
        bins,
        vmin,
        vmax,
        fig=fig,
        ax=ax,
        cmap=cmap,
        subsample=subsample,
        key=key,
        **kwargs,
    )

    ax.set_xticks([-jnp.pi, -jnp.pi / 2, 0, jnp.pi / 2, jnp.pi])
    ax.set_xticklabels(
        [r"$-\pi$", r"$-\frac{\pi}{2}$", "0", r"$\frac{\pi}{2}$", r"$\pi$"]
    )
    ax.set_yticks([-jnp.pi, -jnp.pi / 2, 0, jnp.pi / 2, jnp.pi])
    ax.set_yticklabels(
        [r"$-\pi$", r"$-\frac{\pi}{2}$", "0", r"$\frac{\pi}{2}$", r"$\pi$"]
    )

    ax.set_xlabel(r"$\varphi$")
    ax.set_ylabel(r"$\psi$")

    return ret


def plot_fes(
    samples: jnp.ndarray,
    kBT: float,
    grid: Optional[jnp.ndarray] = None,
    weights: Optional[jnp.ndarray] = None,
    bw_method: float = 0.05,
    bins: float = 100,
    subsample: Optional[int] = None,
    key: Optional[jax.random.PRNGKey] = None,
    fig=None,
    ax=None,
    *args,
    **kwargs,
):
    if subsample is not None:
        assert key is not None, "key must be provided if subsample is provided"
        samples = samples[jax.random.permutation(key, samples.shape[0])[:subsample]]
    if grid is None:
        grid = jnp.linspace(samples.min(), samples.max(), bins)

    fig = fig if fig is not None else plt.gcf()
    ax = ax if ax is not None else plt.gca()

    fes = -kBT * gaussian_kde(samples, bw_method, weights).logpdf(grid)
    fes -= fes.min()

    ax.plot(grid, fes, *args, **kwargs)
    ax.set_xlim(grid.min(), grid.max())
    ax.set_ylabel(r"Energy")  # / $k_BT$
    ax.set_ylim(0, None)

    return grid, fes


def plot_fes_angles(
    angles: jnp.ndarray,
    kBT: float,
    weights: Optional[jnp.ndarray] = None,
    bw_method: float = 0.05,
    bins: float = 100,
    fig=None,
    ax=None,
    *args,
    **kwargs,
):
    """
    Specify 1 for kbT if you want to plot Energy / kBT
    """
    fig = fig if fig is not None else plt.gcf()
    ax = ax if ax is not None else plt.gca()

    grid = jnp.linspace(-jnp.pi, jnp.pi, bins)
    # we add the periodic boundary conditions so that the plot is continuous
    extended_angles = jnp.concatenate(
        [angles, angles + 2 * jnp.pi, angles - 2 * jnp.pi]
    )
    # extended_angles = angles
    grid, fes = plot_fes(
        extended_angles, kBT, grid, weights, bw_method, fig=fig, ax=ax, *args, **kwargs
    )
    ax.set_xticks(
        [-jnp.pi, -jnp.pi / 2, 0, jnp.pi / 2, jnp.pi],
        [r"$-\pi$", r"$-\frac{\pi}{2}$", "0", r"$\frac{\pi}{2}$", r"$\pi$"],
    )
    return grid, fes
