from typing import Union

import matplotlib.pyplot as plt
import seaborn as sns

# Use \showthe\textwidth in the latex document to get the width of the document.
# Each venue has their own latex format so the width may change.
DOCUMENT_WIDTHS = {"neurips": 397}


def set_style_params(
    font_size: int,
    label_font_size: int,
    legend_font_size: int,
    tick_font_size: int,
    title_size: int,
    axes_line_width: int,
):
    """
    Stylizes all the plots.

    :param font_size: global size of all fonts.
    :param label_font_size: size of the font in the axes labels.
    :param legend_font_size: size of the font in the legends.
    :param tick_font_size: size of the font in the ticks.
    :param title_size: size of the font in the titles.
    :param axes_line_width: the width of the line in each axis.
    """
    sns.set_style("white")
    tex_fonts = {
        "axes.labelsize": label_font_size,
        "font.size": font_size,
        "legend.fontsize": legend_font_size,
        "xtick.labelsize": tick_font_size,
        "ytick.labelsize": tick_font_size,
        "axes.titlesize": title_size,
        "axes.linewidth": axes_line_width,
    }
    plt.rcParams.update(tex_fonts)
    plt.rc("pdf", fonttype=42)
    plt.rcParams["text.usetex"] = True


def calculate_best_figure_dimensions(
    document_width: Union[str, float] = "neurips", scale=1, subplots=(1, 1)
) -> tuple[float, float]:
    """
    Sets figure dimensions to avoid scaling in LaTeX.
    Based on the code in https://jwalton.info/Embed-Publication-Matplotlib-Latex/

    :param document_width: the document textwidth or columnwidth in pts. One of the following
        strings are also acceptable: "neurips".
    :param scale: the fraction of the width which we wish the figure to occupy.
    :param subplots: an optional number of rows and columns of subplots.

    :return: a tuple containing the width and height of the figure in inches.
    """
    if isinstance(document_width, str):
        width_pt = DOCUMENT_WIDTHS[document_width]
    else:
        width_pt = document_width

    # Width of figure (in pts)
    fig_width_pt = width_pt * scale

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**0.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt

    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return fig_width_in, fig_height_in


def save_plot(filepath: str, fig: plt.Figure, format: str = "pdf"):
    """
    Saves a figure to a file.

    :param filepath: filepath to save the image.
    :param fig: the figure.
    :param format: the format of the file to save.
    """
    fig.savefig(filepath, format=format, bbox_inches="tight", pad_inches=0)
