import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

from matplotlib import font_manager as fm, rcParams

def main():

    fnames = os.listdir('../results')
    dfs = [pd.read_pickle(os.path.join('../results', fname)) for fname in fnames if fname.endswith('.pkl')]
    names = [fname.replace('-results.pkl', '') for fname in fnames]
    names.append('ADDA')
    df = pd.concat(dfs, ignore_index=True)

    epochs = range(400)
    columns = ['Epoch', 'Top1-Accuracy', 'Top5-Accuracy', 'Algorithm']
    top1, top5 = [57.108 for _ in range(400)], [0 for _ in range(400)]
    data = list(zip(
        epochs, top1, top5, ['ADDA' for _ in range(len(epochs))]
    ))
    domain_apt_df = pd.DataFrame(data, columns=columns)
    df = pd.concat([df, domain_apt_df], ignore_index='True')

    sns.set(style='whitegrid')
    ax = sns.lineplot('Epoch', 'Top1-Accuracy', hue='Algorithm', lw=3, alpha=0.9, data=df, palette=get_palette(names))
    ax.set_ylabel('Test accuracy')
    set_labels(ax, names, add_legend=True, ncols=2)
    plt.tight_layout()
    plt.show()

def get_palette(unique_names):
    colors = [
        '#F97F1F', '#A83271', '#EA3546', '#662E9B', '#5187BB', '#43BCCD', '#1BA06D'
    ]

    colors = colors[:len(unique_names)]
    palette = dict(zip(unique_names, colors))
    return palette

def set_labels(ax, names, add_legend=True, ncols=3):
    label_kwargs = {'fontproperties': _get_font(), 'fontsize': 20}
    tick_kwargs = {'fontproperties': _get_font(), 'fontsize': 15}

    ax.set_xlabel(ax.get_xlabel(), **label_kwargs)
    ax.set_ylabel(ax.get_ylabel(), **label_kwargs)
    ax.set_xticklabels(_to_int(ax.get_xticks()), **tick_kwargs)
    ax.set_yticklabels(_to_int(ax.get_yticks()), **tick_kwargs)
    ax.set_title(ax.get_title(), **label_kwargs)

    if add_legend is True:
        ax.legend(labels=names, prop=_get_font(), ncol=ncols, loc='lower right')
        plt.setp(ax.get_legend().get_texts(), fontsize=10)

def _get_font():
    fpath = os.path.join(rcParams["datapath"], "fonts/ttf/Palatino.ttf")
    prop = fm.FontProperties(fname=fpath)
    return prop

def _to_int(ls):
    return [int(k) for k in ls]

if __name__ == '__main__':
    main()