import shutil
from itertools import cycle
from typing import List, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LogNorm, Normalize, SymLogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.path as mpath
import matplotlib.patches as mpatches


def checkdep_usetex(s):
    """From matplotlib"""
    if not s:
        return False
    if not shutil.which("tex"):
        # _log.warning("usetex mode requires TeX.")
        return False
    try:
        mpl._get_executable_info("dvipng")
    except mpl.ExecutableNotFoundError:
        # _log.warning("usetex mode requires dvipng.")
        return False
    try:
        mpl._get_executable_info("gs")
    except mpl.ExecutableNotFoundError:
        # _log.warning("usetex mode requires ghostscript.")
        return False
    return True


def set_latex_font(math=True, normal=True, extra_preamble=[]):
    if math:
        plt.rcParams["mathtext.fontset"] = "stix"
    else:
        plt.rcParams["mathtext.fontset"] = plt.rcParamsDefault["mathtext.fontset"]

    if normal:
        plt.rcParams["font.family"] = "serif"  # "STIXGeneral"
        # plt.rcParams['text.usetex'] = True
    else:
        plt.rcParams["font.family"] = plt.rcParamsDefault["font.family"]

    # if math or normal:
    #     plt.rcParams['text.latex.preamble'] = r'\usepackage{amsfonts}',
    # [
    #     r'\usepackage{amsfonts}',
    #     # r'\usepackage{amsmath}',
    # ]
    # matplotlib.rc('text',usetex=True)
    # matplotlib.rc('text.latex', preamble=r'\usepackage{color}')

    # plt.rcParams['text.usetex'] = True

    usetex = checkdep_usetex(True)
    # print(f'usetex: {usetex}')
    plt.rc("text", usetex=usetex)
    default_preamble = [
        r"\usepackage{amsfonts}",
        r"\usepackage{amsmath}",
    ]
    preamble = "".join(default_preamble + extra_preamble)
    plt.rc("text.latex", preamble=preamble)
    # mpl.verbose.level = 'debug-annoying'


# def get_twoslope_log(vim=0, vmax=1, center=0.5):
#     from matplotlib.colors import SymLogNorm

#     SymLogNorm

from matplotlib import colors

# from utils.math import affine

# class TwoSlopeLogNorm(colors.SymLogNorm):
#     """A two slope LogNorm with high resolution towards the two ends of the
#     range (default [0, 1]). Low resolution in the center of the range
#     (default 0.5)."""
#     def __init__(self, thresh, vcenter=0.5, lower=0, upper=1, clip=False) -> None:
#         assert lower <= vcenter <= upper
#         self.vcenter = vcenter
#         self.lower = lower
#         self.upper = upper
#         vmin = -vcenter
#         vmax = vcenter
#         super().__init__(thresh, vmin=vmin, vmax=vmax, clip=clip, linscale=1e-10)

#     def __call__(self, value, clip: bool | None = None):
#         idx_up = value >= self.vcenter
#         idx_lo = value < self.vcenter
#         val_up = value[idx_up]
#         val_lo = value[idx_lo]

#         value[idx_up] = affine(val_up, self.vcenter, self.upper, -self.vcenter, 0)
#         value[idx_lo] = affine(val_lo, self.lower, self.vcenter, 0, self.vcenter)
#         # if value >= self.vcenter:
#         #     # Map [vcenter, upper] to [-vcenter, 0]
#         #     value = affine(value, self.vcenter, self.upper, -self.vcenter, 0)
#         # else:
#         #     # Map [lower, vcenter[ to [0, vcenter[
#         #     value = affine(value, self.lower, self.vcenter, 0, self.center)
#         normalized = super().__call__(value, clip)

#         normalized[idx_lo] -= 0.5
#         normalized[idx_up] += 0.5

#         return normalized


class TwoSlopeLogNorm(colors.LogNorm):
    """A two slope LogNorm with high resolution towards the two ends of the
    range (default [0, 1]). Low resolution in the center of the range
    (default 0.5)."""

    def __init__(self, thresh, vcenter=0.5, lower=0, upper=1, clip=False) -> None:
        assert lower <= vcenter <= upper
        self.vcenter = vcenter
        self.lower = lower
        self.upper = upper
        self.thresh = thresh
        vmin = thresh
        vmax = vcenter
        super().__init__(vmin=vmin, vmax=vmax, clip=clip)

    def __call__(self, value, clip: bool | None = None):
        value = np.copy(value)
        idx_up = value >= self.vcenter
        idx_lo = value < self.vcenter
        val_up = value[idx_up]
        val_lo = value[idx_lo]

        value[idx_up] = affine(
            val_up, self.vcenter, self.upper - self.thresh, self.vcenter, self.thresh
        )
        value[idx_lo] = affine(
            val_lo, self.lower + self.thresh, self.vcenter, self.thresh, self.vcenter
        )
        # if value >= self.vcenter:
        #     # Map [vcenter, upper] to [-vcenter, 0]
        #     value = affine(value, self.vcenter, self.upper, -self.vcenter, 0)
        # else:
        #     # Map [lower, vcenter[ to [0, vcenter[
        #     value = affine(value, self.lower, self.vcenter, 0, self.center)
        normalized = super().__call__(value, clip)

        # normalized[idx_lo] -= 0.5
        normalized[idx_lo] = normalized[idx_lo] / 2
        normalized[idx_up] = 1 - normalized[idx_up] / 2

        return normalized

    def inverse(self, value):
        value = np.copy(value)
        idx_up = value >= 0.5
        idx_lo = value < 0.5
        val_up = value[idx_up]
        val_lo = value[idx_lo]

        value[idx_lo] = 2 * value[idx_lo]
        value[idx_up] = 2 * (1 - value[idx_up])

        print(value)

        value = super().inverse(value)

        value[idx_up] = affine(
            val_up, self.vcenter, self.thresh, self.vcenter, self.upper - self.thresh
        )
        value[idx_lo] = affine(
            val_lo, self.thresh, self.vcenter, self.lower + self.thresh, self.vcenter
        )
        # if value >= self.vcenter:
        #     # Map [vcenter, upper] to [-vcenter, 0]
        #     value = affine(value, self.vcenter, self.upper, -self.vcenter, 0)
        # else:
        #     # Map [lower, vcenter[ to [0, vcenter[
        #     value = affine(value, self.lower, self.vcenter, 0, self.center)
        # normalized = super().inverse(value, clip)

        # # normalized[idx_lo] -= 0.5
        # normalized[idx_lo] = normalized[idx_lo]/2
        # normalized[idx_up] = 1 - normalized[idx_up]/2

        return value


def _get_ax(ax: plt.Axes = None):
    if ax is None:
        return plt.gca()
    return ax


def set_lim_bounds(ax, axis, vmin=None, vmax=None):
    assert axis in ["x", "y"]
    zmin, zmax = ax.get_xlim() if axis == "x" else ax.get_ylim()
    if vmin is not None:
        zmin = np.maximum(zmin, vmin)
    if vmax is not None:
        zmax = np.minimum(zmax, vmax)
    zlim = zmin, zmax
    ax.set_xlim(zlim) if axis == "x" else ax.set_ylim(zlim)


def plot_2D_hyperplane(w, b, ax: plt.Axes = None, plot_vector: bool = True, **kwargs):
    """Plot the line of equation wTx + b = 0."""
    w = np.squeeze(np.array(w))
    assert w.shape == (2,)

    ax = _get_ax(ax)

    w1, w2 = w

    if w1 == 0 and w2 == 0:
        assert ValueError("Both w0 and w1 cannot be 0.")

    if w1 == 0:
        ax.axhline(-b / w2, **kwargs)
        return ax

    if w2 == 0:
        ax.axvline(-b / w1, **kwargs)
        return ax

    ymin, ymax = ax.get_ylim()
    xmin, xmax = ax.get_xlim()

    xlow = -b / w1 - w2 / w1 * ymin
    xup = -b / w1 - w2 / w1 * ymax

    xlow = np.clip(xlow, xmin, xmax)
    xup = np.clip(xup, xmin, xmax)

    (p,) = ax.plot(
        [xlow, xup], [-b / w2 - w1 / w2 * xlow, -b / w2 - w1 / w2 * xup], **kwargs
    )
    if plot_vector:
        xy_start = -b * w / np.square(np.linalg.norm(w))
        ax.arrow(
            xy_start[0],
            xy_start[1],
            dx=w1,
            dy=w2,
            head_width=0.03,
            length_includes_head=True,
            zorder=1,
            color=p.get_color(),
            lw=p.get_linewidth(),
        )

    ax.set_xlim((xmin, xmax))
    ax.set_ylim((ymin, ymax))

    return p


def add_legend(
    ax: plt.Axes = None,
    pos: str = "top",
    ax_extra: List[plt.Axes] = None,
    ncol: int = 2,
    dx: float = 0,
    dy: float = 0,
    handles: list = None,
    labels: list = None,
    handles_extra: list = None,
    labels_extra: list = None,
    framealpha: float = 0,
    **kwargs,
):
    ax = _get_ax(ax)
    if handles is None or labels is None:
        handles, labels = ax.get_legend_handles_labels()
    ax_extra = [] if ax_extra is None else ax_extra
    for _ax in ax_extra:
        _handles_extra, _labels_extra = _ax.get_legend_handles_labels()
        handles.extend(_handles_extra)
        labels.extend(_labels_extra)
    if handles_extra is not None and labels_extra is not None:
        handles.extend(handles_extra)
        labels.extend(labels_extra)

    if pos == "top":
        loc = "lower center"
        bbox_to_anchor = (0.5 + dx, 1.0 + dy)

    elif pos == "right":
        loc = "center left"
        bbox_to_anchor = (1 + dx, 0.5 + dy)

    elif pos == "center":
        loc = "center"
        bbox_to_anchor = (0.5 + dx, 0.5 + dy)

    elif pos == "topdown":
        loc = "upper center"
        bbox_to_anchor = (0.5 + dx, 1.0 + dy)

    elif pos == "leftright":
        loc = "center left"
        bbox_to_anchor = (0 + dx, 0.5 + dy)

    else:
        raise ValueError(f'Unknown pos "{pos}".')

    ax.legend(
        framealpha=framealpha,
        ncol=ncol,
        bbox_to_anchor=bbox_to_anchor,
        loc=loc,
        handles=handles,
        labels=labels,
        **kwargs,
    )


def add_axis(
    position,
    ax: plt.Axes = None,
    size: str = "10%",
    pad: float = 0.15,
    align_lim: bool = True,
    transfer: bool = True,
) -> plt.Axes:
    ax = _get_ax(ax)
    divider = make_axes_locatable(ax)
    new_ax = divider.append_axes(position, size=size, pad=pad)

    if align_lim:
        x = position in ["top", "bottom"]
        align_axes(new_ax, ax_ref=ax, align_x=x, align_y=(not x))

        if transfer:
            d = dict(xticks=[]) if x else dict(yticks=[])
            if position in ["bottom", "left"]:
                ax.set(**d)
            else:
                new_ax.set(**d)

    return new_ax


def align_axes(
    *axes: List[plt.Axes],
    ax_ref: plt.Axes = None,
    align_x: bool = True,
    align_y: bool = True,
):
    def common_lim(*lims):
        lims = np.array(lims)
        return (np.min(lims[:, 0]), np.max(lims[:, 1]))

    if align_x:
        lims = [ax.get_xlim() for ax in axes]
        lim = common_lim(lims) if ax_ref is None else ax_ref.get_xlim()
        [ax.set_xlim(lim) for ax in axes]

    if align_y:
        lims = [ax.get_ylim() for ax in axes]
        lim = common_lim(lims) if ax_ref is None else ax_ref.get_ylim()
        [ax.set_ylim(lim) for ax in axes]


def get_scalar_mappable(
    norm: str = "linear",
    cmap: str = "Reds",
    values: List[float] = None,
    vmin: float = None,
    vmax: float = None,
) -> ScalarMappable:
    if norm == "linear":
        norm = Normalize(vmin, vmax)
    elif norm == "log":
        norm = LogNorm(vmin, vmax)
    elif norm == "symlog":
        norm = SymLogNorm(vmin, vmax)
    else:
        raise ValueError(f'norm must be "linear" or "log". Got "{norm}".')

    sm = ScalarMappable(norm=norm, cmap=cmap)

    if values is not None:
        sm.set_array(values)
        sm.autoscale()

    return sm


def add_colorbar(
    sm: ScalarMappable,
    ax: plt.Axes = None,
    pad: float = 0.05,
    cax: plt.Axes = None,
) -> plt.Axes:
    ax = _get_ax(ax)
    if cax is None:
        cax = add_axis("right", ax, size="5%", pad=pad, align_lim=False, transfer=False)
    fig = ax.figure
    fig.colorbar(sm, cax=cax)
    return cax


def axvline_err(x, ax: plt.Axes = None, alpha: float = 0.3, n_std: int = 2, **kwargs):
    ax = _get_ax(ax)
    mean = np.mean(x)
    std = np.std(x, ddof=1)
    delta = n_std * std

    ymin, ymax = ax.get_ylim()
    v = ax.axvline(mean, **kwargs)
    ax.fill_betweenx(
        [ymin, ymax],
        mean - delta,
        mean + delta,
        alpha=alpha,
        color=v.get_color(),
        edgecolor="none",
        zorder=v.get_zorder(),
    )

    return v


def get_ls_cycler():
    linestyle_tuple = [  # from matplotlib
        ("loosely dotted", (0, (1, 10))),
        ("dotted", (0, (1, 1))),
        ("long dash with offset", (5, (10, 3))),
        ("loosely dashed", (0, (5, 10))),
        ("dashed", (0, (5, 5))),
        ("densely dashed", (0, (5, 1))),
        ("loosely dashdotted", (0, (3, 10, 1, 10))),
        ("dashdotted", (0, (3, 5, 1, 5))),
        ("densely dashdotted", (0, (3, 1, 1, 1))),
        ("dashdotdotted", (0, (3, 5, 1, 5, 1, 5))),
        ("loosely dashdotdotted", (0, (3, 10, 1, 10, 1, 10))),
        ("densely dashdotdotted", (0, (3, 1, 1, 1, 1, 1))),
    ]
    lines = [ls[1] for ls in linestyle_tuple]
    return cycle(lines)


class KeepLim:
    def __init__(self, ax: plt.Axes, axis="both") -> None:
        assert axis in ["x", "y", "both"]
        self.ax = ax
        self.axis = axis

    def __enter__(self):
        self.xlim = self.ax.get_xlim()
        self.ylim = self.ax.get_ylim()

    def __exit__(self, exc_type, exc_value, traceback):
        if self.axis in ["x", "both"]:
            self.ax.set_xlim(self.xlim)
        if self.axis in ["y", "both"]:
            self.ax.set_ylim(self.ylim)


def get_grid_axes(
    n_axis: int,
    ncols: int = 3,
    flatten_axes: bool = True,
    **kwargs,
) -> Tuple[plt.Figure, List[plt.Axes]]:
    nrows = np.ceil(n_axis / ncols).astype(int)
    fig, axes = plt.subplots(nrows, ncols, **kwargs)

    if n_axis % ncols != 0:
        for j in range(n_axis % ncols, ncols):
            axes[-1][j].set_axis_off()

    if ncols > 1 and flatten_axes:
        axes = [axes[i][j] for i in range(nrows) for j in range(ncols)]

    return fig, axes


def remove_inner_labels(
    axes: List[plt.Axes],
    ncols: int,
    hide_tick_labels: str = None,
):
    # Detect if axes is flatten or not
    try:
        axes[0][0]
        is_flat = False
    except TypeError:
        is_flat = True

    n_axis = len(axes)
    nrows = np.ceil(n_axis / ncols).astype(int)
    for i in range(nrows):
        for j in range(ncols):
            ax = axes[i * ncols + j] if is_flat else axes[i][j]
            if j > 0:  # turn off ylabels
                ax.set(ylabel=None)
                if hide_tick_labels in ["y", "both"]:
                    ax.set(yticklabels=[])
            if i < nrows - 1:  # turn off xlabels
                ax.set(xlabel=None)
                if hide_tick_labels in ["x", "both"]:
                    ax.set(xticklabels=[])


def apply_scientific_notation(a: List[float]):
    return [f"{x:.1e}" for x in a]


def add_label_band(
    ax: plt.Axes, top, bottom, label, spine_pos=-0.05, tip_pos=-0.02, **kwargs
):
    """
    Helper function to add bracket around y-tick labels.
    From https://stackoverflow.com/a/67237650/18429836

    Parameters
    ----------
    ax : matplotlib.Axes
        The axes to add the bracket to

    top, bottom : floats
        The positions in *data* space to bracket on the y-axis

    label : str
        The label to add to the bracket

    spine_pos, tip_pos : float, optional
        The position in *axes fraction* of the spine and tips of the bracket.
        These will typically be negative

    Returns
    -------
    bracket : matplotlib.patches.PathPatch
        The "bracket" Aritst.  Modify this Artist to change the color etc of
        the bracket from the defaults.

    txt : matplotlib.text.Text
        The label Artist.  Modify this to change the color etc of the label
        from the defaults.

    """
    # grab the yaxis blended transform
    transform = ax.get_yaxis_transform()

    # add the bracket
    bracket = mpatches.PathPatch(
        mpath.Path(
            [
                [tip_pos, top],
                [spine_pos, top],
                [spine_pos, bottom],
                [tip_pos, bottom],
            ]
        ),
        transform=transform,
        clip_on=False,
        facecolor="none",
        edgecolor="k",
        linewidth=0.5,
    )
    ax.add_artist(bracket)

    # add the label
    txt = ax.text(
        spine_pos,
        (top + bottom) / 2,
        label,
        ha="right",
        va="center",
        rotation="vertical",
        clip_on=False,
        transform=transform,
        **kwargs,
    )

    return bracket, txt


def add_horizontal_bands(ax: plt.Axes, n: int):
    with KeepLim(ax):
        for i in range(0, n - 1, 2):
            ax.axhspan(i - 0.5, i + 0.5, color=".93", zorder=-1, edgecolor="none", lw=0)
    ax.set_ylim(-0.5, n - 0.5)
