import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import itertools

# function to get first epoch when averaged window key value reaches target value
def get_first_epoch(df, target_value, window_size, key):
    window_average = df[key].rolling(window=window_size).mean()
    first_epochs_indexes = window_average[window_average >= target_value].index
    if len(first_epochs_indexes) == 0:
        return None
    first_epoch_index = first_epochs_indexes[0]
    first_epoch = df.iloc[first_epoch_index]['Epoch']
    return first_epoch

# Plotting styling
DEFAULT_COLORS_LIST = colorS = [u'#1f77b4', 
                                u'#ff7f0e', 
                                u'#2ca02c', 
                                u'#d62728', 
                                u'#9467bd', 
                                u'#8c564b', 
                                u'#e377c2', 
                                u'#7f7f7f', 
                                u'#bcbd22', 
                                u'#17becf']

DEFAULT_MARKERS_LIST = ['x', '.', '+', '1', 'p','*', 'D' , '.',  's']

def setup_plotting():
    plt.style.use('fast')
    mpl.rcParams['mathtext.fontset'] = 'cm'
    # mpl.rcParams['mathtext.fontset'] = 'dejavusans'
    mpl.rcParams['pdf.fonttype'] = 42
    mpl.rcParams['ps.fonttype'] = 42
    mpl.rcParams['lines.linewidth'] = 2.0
    mpl.rcParams['legend.fontsize'] = 'large'
    mpl.rcParams['axes.titlesize'] = 'xx-large'
    mpl.rcParams['xtick.labelsize'] = 'x-large'
    mpl.rcParams['ytick.labelsize'] = 'xx-large'
    mpl.rcParams['axes.labelsize'] = 'x-large'
    mpl.rcParams['figure.titlesize'] = 'xx-large'
    mpl.rcParams['axes.grid'] = True

def read_exp_logs(file_name):
    df = pd.read_csv(file_name)
    first_epoch = df[df['Epoch'] == 0].index[-1]
    # print(df[df['Epoch'] == 0].index)
    df = df.iloc[first_epoch:]
    return df

def smoothing_average(df, key, window_size):
    window_average = df[key].rolling(window=window_size).mean()
    return window_average

def plot_exp_logs(file_names,
                  labels=None,
                  colors_list=DEFAULT_COLORS_LIST,
                  markers_list=DEFAULT_MARKERS_LIST,
                  x_axis='Epoch',
                  x_axis_name='Global Communication Rounds',
                  plot_names=['Train Loss', 'Train Accuracy', 
                              'Test Loss', 'Test Accuracy'],
                  plot_params=[None, None, None, None],
                  plot_layout=(2,2),
                  figure_size=(10,10),
                  figure_save_location=None,
                  title=None,
                  smoothing_average_window_size=1,
                 ):
    setup_plotting()
    if labels == None:
        labels = file_names
    assert len(file_names) == len(labels)
    assert len(plot_names) == len(plot_params)
    assert plot_layout[0] * plot_layout[1] >= len(plot_names)
    fig, axs = plt.subplots(plot_layout[0], plot_layout[1], figsize=figure_size)
    colors = itertools.cycle(colors_list)
    markers = itertools.cycle(markers_list)
    for file_name, label in zip(file_names, labels):
        df = read_exp_logs(file_name)
        color = next(colors)
        marker = next(markers)
        for i, (plot_name, param) in enumerate(zip(plot_names, plot_params)):
            ax = axs[i // plot_layout[1], i % plot_layout[1]]
            df[plot_name] = smoothing_average(df, plot_name, smoothing_average_window_size)
            ax.plot(df[x_axis], df[plot_name], label=label, color=color, marker=marker)
            ax.set_xlabel(x_axis_name)
            ax.set_ylabel(plot_name)
            ax.legend()
            # set params
            if param:
                for key, value in param.items():
                    ax.set(**{key: value}) 
    if title:
        fig.suptitle(title)
    plt.tight_layout()
    if figure_save_location:
        plt.savefig(figure_save_location)