import numpy as np

_GOLDEN_RATIO = (1 + 5.0 ** (1 / 2)) / 2.0
_INVERSE_GOLDEN_RATIO = _GOLDEN_RATIO - 1


def matplotlib_config(
    *,
    rel_width=1.0,
    nrows=1,
    ncols=4,
    height_to_width_ratio=_INVERSE_GOLDEN_RATIO,
    dpi=250,
):
    return {
        **font_config(),
        **fontsize_config(),
        **layout_config(rel_width, ncols, nrows, height_to_width_ratio, dpi),
        **style_config(),
    }


def fontsize_config():
    fontsizes_normal = 11 - 1
    fontsizes_small = 11 - 3
    fontsizes_tiny = 11 - 4
    return {
        "font.size": fontsizes_normal,
        "axes.titlesize": fontsizes_normal,
        "axes.labelsize": fontsizes_small,
        "legend.fontsize": fontsizes_small,
        "xtick.labelsize": fontsizes_tiny,
        "ytick.labelsize": fontsizes_tiny,
    }


def layout_config(rel_width, ncols, nrows, height_to_width_ratio, dpi):
    full_width_in = 5.5
    width_in = full_width_in * rel_width
    subplot_width_in = width_in / ncols
    subplot_height_in = height_to_width_ratio * subplot_width_in
    height_in = subplot_height_in * nrows
    return {
        "figure.dpi": dpi,
        "figure.figsize": (width_in, height_in),
        "figure.constrained_layout.use": False,
        "figure.autolayout": False,
        # Padding around axes objects. Float representing inches.
        # Default is 3/72 inches (3 points)
        "figure.constrained_layout.h_pad": (1 / 72),
        "figure.constrained_layout.w_pad": (1 / 72),
        # Space between subplot groups. Float representing
        # a fraction of the subplot widths being separated.
        "figure.constrained_layout.hspace": 0.00,
        "figure.constrained_layout.wspace": 0.00,
    }


def style_config():
    return {
        "axes.labelpad": 2,
        "axes.spines.top": False,
        "axes.spines.right": False,
        "ytick.major.pad": 1,
        "xtick.major.pad": 1,
        "axes.xmargin": 0,
        "axes.ymargin": 0,
        "axes.titlepad": 3,
    }


def font_config():
    return {
        "text.usetex": True,
        "font.family": "serif",
        "text.latex.preamble": "\\usepackage{times} ",
    }


def update_style(
    plt,
    rel_width=1.0,
    nrows=1,
    ncols=4,
    height_to_width_ratio=_INVERSE_GOLDEN_RATIO,
    dpi=250,
):
    plt.rcParams.update(
        matplotlib_config(
            rel_width=rel_width,
            nrows=nrows,
            ncols=ncols,
            height_to_width_ratio=height_to_width_ratio,
            dpi=dpi,
        )
    )


def make_axes(
    plt,
    rel_width=1.0,
    nrows=1,
    ncols=1,
    height_to_width_ratio=_INVERSE_GOLDEN_RATIO,
):
    update_style(
        plt,
        rel_width=rel_width,
        nrows=nrows,
        ncols=ncols,
        height_to_width_ratio=height_to_width_ratio,
    )
    return plt.subplots(nrows=nrows, ncols=ncols, squeeze=(nrows == 1 and ncols == 1))


def fmt_pow10(x):
    if x is None or x == "":
        return ""
    if x == 1:
        return "1"
    if x == 0:
        return "0"
    if x < 0 or (x > 1 and x < 10):
        return str(x)
    if x == 10:
        return "10"
    power = np.log10(x)
    return f"$10^{{{int(power):.2g}}}$"


def normalize_y_axis(*axes):
    """Sets the y limits to be the same for all the axes

    Ensures that ``ax1.get_ylim() == ax2.get_ylim()`` for all ``ax1, ax2``
    """
    miny, maxy = np.inf, -np.inf
    for ax in axes:
        y1, y2 = ax.get_ylim()
        miny = np.min([miny, y1])
        maxy = np.max([maxy, y2])
    for ax in axes:
        ax.set_ylim([miny, maxy])


def hide_frame(*axes, top=True, right=True, left=False, bottom=False):
    """Hides the frame/spine of the axes"""
    for ax in axes:
        ax.spines["top"].set_visible(not top)
        ax.spines["right"].set_visible(not right)
        ax.spines["left"].set_visible(not left)
        ax.spines["bottom"].set_visible(not bottom)
