import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from pylab import rcParams
import seaborn as sns
from pathlib import Path

from .utils import create_dir

# Global plotting settings
rcParams['figure.figsize'] = 9, 8
sns.set_style("whitegrid")
sns.set_context("talk", font_scale=1.6)
# Important for paper compatibility
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['ps.useafm'] = True
matplotlib.rcParams['pdf.use14corefonts'] = True
matplotlib.rcParams['text.usetex'] = True
legend_bool = "brief"

# TODO: think about regex for experiment selection / might not be user-friendly
# TODO: need function to download all plots from server
# TODO: Add learning curve visualization that plots training and validation loss across epochs

# TODO: set this only once
DATAPATH = Path("/media/sdb/path_learning/plots")


class MathTextSciFormatter(mticker.Formatter):
        def __init__(self, fmt="%1.2e"):
            self.fmt = fmt

        def __call__(self, x, pos=None):
            s = self.fmt % x
            decimal_point = '.'
            positive_sign = '+'
            tup = s.split('e')
            significand = tup[0].rstrip(decimal_point)
            sign = tup[1][0].replace(positive_sign, '')
            exponent = tup[1][1:].lstrip('0')
            if exponent:
                exponent = '10^{%s%s}' % (sign, exponent)
            if significand and exponent:
                s = r'%s{\times}%s' % (significand, exponent)
            else:
                s = r'%s%s' % (significand, exponent)
            return "${}$".format(s)


def visualize(title: str, exp_names: list, plot_type: str):
    """
    Visualize from a selected set of experiments
    :param title:
    :param exp_names:
    :param plot_type:
    :return:
    """
    df = gather_results(exp_names)
    print(f"Example row: {df.loc[0]}")
    create_dir(DATAPATH)

    # TODO: Pick final results of path

    # TODO: Add plot that visualizes path

    if plot_type == "barplot":
        plt.figure()
        sns.barplot(x="Experiment", y="Loss-test", hue="Dataset", data=df,
                        palette="muted")
        plt.title(title)
        plt.savefig(str(DATAPATH / title.replace(" ", "")) + ".pdf",
                    bbox_inches='tight', pad_inches=0)
        plt.close()
    else:
        raise NotImplementedError(f"The plot type {plot_type} is not implemented")


if __name__ == "__main__":
    experiment_names = ["Blur N=2", "Blur N=4"]
    title = "Test plot"
    visualize(title, experiment_names, plot_type="barplot")
