import matplotlib
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import seaborn as sns


# Useful global variables for plotting
marker_ls = ('o', 'v', 'D', '^', '<', 'd', '>')
color_ls = sns.color_palette()
color_ls = color_ls[:5] + color_ls[7:9]    # Get target color


def plot_setup(tick_size, usetex=False, errorbar_size=4):
    # Setup for plotting
    sns.set_context("paper")
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman']
    plt.rcParams['xtick.labelsize'] = tick_size
    plt.rcParams['ytick.labelsize'] = tick_size
    plt.rcParams['font.weight'] = 500
    plt.rcParams['text.usetex'] = usetex
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['errorbar.capsize'] = errorbar_size


def create_legend_handles(legend_ls, color_ls, marker_ls, linewidth,
                          markersize,markerfacecolor, markeredgecolor):
    line_ls = []
    for legend, color, marker in zip(legend_ls, color_ls, marker_ls):
        line = mlines.Line2D([], [], label=legend, color=color, linewidth=linewidth, marker=marker,
                             markersize=markersize,  markerfacecolor=markerfacecolor,
                             markeredgecolor=markeredgecolor, markeredgewidth=linewidth)
        line_ls.append(line)
    return line_ls


def single_line_plot(result_df, hue, x_type, x_label, y_type, y_label,
                     x_logscale=False, y_logscale=False,
                     figsize=(6, 3.3), err_style=None, ci=68,
                     marker_mode='none', markersize=10, linewidth=2.2, label_font=17,
                     use_legend=True, legend_font=16, legend_space=0.0, legend_ncol=None,
                     save_name=None):
    legend_ls = result_df[hue].drop_duplicates().tolist()
    num_models = len(legend_ls)
    if legend_ncol is None:
        legend_ncol = num_models

    # Which mode of marker to use
    if marker_mode == 'none':
        markerfacecolor = 'none'
        markeredgecolor = 'none'
    elif marker_mode == 'solid':
        markerfacecolor = None
        markeredgecolor = 'none'
    elif marker_mode == 'hollow':
        markerfacecolor = 'none'
        markeredgecolor = None
    else:
        raise ValueError("Unknown marker mode.")

    fig = plt.figure(figsize=figsize)
    legend = 'brief' if use_legend else None
    ax = sns.lineplot(data=result_df,
                      x=x_type, y=y_type, hue=hue, style=hue,
                      linewidth=linewidth, palette=color_ls[:num_models],
                      markers=marker_ls[:num_models], dashes=False,
                      markersize=markersize,  markerfacecolor=markerfacecolor,
                      markeredgecolor=markeredgecolor, markeredgewidth=linewidth,
                      err_style=err_style, ci=ci, legend=None)

    ax.grid(linestyle='--', color='lightgrey')
    plt.xlabel(x_label, fontsize=label_font, labelpad=7.5)
    plt.ylabel(y_label, fontsize=label_font, labelpad=3.4)
    if y_logscale:
        plt.yscale('log')
    if x_logscale:
        plt.xscale('log')
        # The five lines below are used to set the x-axis
        # Reference: https://stackoverflow.com/questions/14530113/set-ticks-with-logarithmic-scale
        ax.set_xticks(sorted(result_df[x_type].drop_duplicates().tolist()))
        # Abit of hardcoding
        # To remove this hardcoding, simply remove the if else statement and incorporate both lines
        if result_df[x_type].drop_duplicates().max() < 10**5:
            ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
        else:
            ax.get_xaxis().set_minor_formatter(matplotlib.ticker.NullFormatter())
        ax.get_xaxis().set_tick_params(which='minor', size=0)
        ax.get_xaxis().set_tick_params(which='minor', width=0) 

    if use_legend:
        ax.legend(handles=create_legend_handles(legend_ls, color_ls, marker_ls, linewidth,
                                                 markersize, markerfacecolor, markeredgecolor),
                  loc='upper center', bbox_to_anchor=(0.5, -0.25-legend_space),
                  ncol=legend_ncol, fontsize=legend_font)

    if save_name is not None:
        fig.savefig(save_name, bbox_inches='tight')