import pandas as pd
import numpy as np
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import os
import copy
import argparse

def load_df(file_name):
    data = pickle.load(open(file_name, 'rb'))
    for k,v in data.items():
        v.update({'id':k})
    data_list = [v for v in data.values() ]
    for item in data_list:
        if item['is_correct']:
            item['correct'] = 1
        else:
            item['correct'] = 0
        item['dep_length_end'] = int(item['dep_length_end'])
        item['seq_length'] = int(item['seq_length'])
    df = pd.DataFrame(data_list)
    return df

def get_model_stats(file_name, info, model_name):
    data = pickle.load(open(file_name, 'rb'))
    overall_accuracy = sum([v['is_correct'] for v in data.values()])/len(data)
    out_dist_acc = sum([v['is_correct'] for v in data.values() if v['ood']])/len(
        [v for v in data.values() if v['ood']])
    in_dist_acc = sum([v['is_correct'] for v in data.values() if not v['ood']]
                      )/len([v for v in data.values() if not v['ood']])
    model_dict = info.copy()
    model_dict['accuracy'] = overall_accuracy
    model_dict['out_dist_acc'] = out_dist_acc
    model_dict['in_dist_acc'] = in_dist_acc
    if 'attention' not in model_dict:
        model_dict['attention'] = False
    for name in ['hidden_size', 'ff_dim', 'heads']:
        if name not in model_dict:
            model_dict[name] = 0
    model_dict['model_name'] = model_name
    return model_dict

def best_params(csv_file):
    results = pd.read_csv(csv_file)
    best_val = {}
    results =results.fillna(0.0)
    for task in set(results.task):
        best_val[task] = {}
        for mode in set(results['mode']):
            best_val[task][mode] = {}
            for model_type in ["lstm", "lru", "transformer"]:
                att_options = [0.0]
                if model_type == "lstm" or model_type == "lru":
                    att_options = [0.0, 1.0]

                for att in att_options:
                    
                    hyper_param_scores = {}
                    limit_df = results.loc[(results['model_type']
                        ==model_type) & (results['task']==task) & (
                        results['attention'] == att)]
                    if len(limit_df) > 0:
                        for i, row in limit_df.iterrows():
                            param_tuple = (row['parameters'], row['embed_size'], 
                                        row['num_layers'], row['lr'], 
                                        row['batch_size'], 
                                        row['hidden_size'], 
                                        row['attention'], 
                                        row['num_heads'], row['ff_dim'])
                            if param_tuple not in hyper_param_scores.keys():
                                hyper_param_scores[param_tuple] = {'count':0, 
                                'score':0, 'scores':[]}
                            hyper_param_scores[param_tuple]['count'] += 1
                            hyper_param_scores[param_tuple]['score'] += row['accuracy']
                            hyper_param_scores[param_tuple]['scores'].append(
                                row['accuracy'])
                        for k,v in hyper_param_scores.items():
                            v['score'] = v['score'] / v['count']
                            v['variance'] = sum([(vi-v['score'])**2 for vi in v['scores']])/(v['count'])
                        sorted_scores = list(hyper_param_scores.keys())
                        sorted_scores.sort(reverse=True, 
                            key=lambda x:hyper_param_scores[x]['score'])
                        sorted_scores = [s for s in sorted_scores if hyper_param_scores[s]['score'] == hyper_param_scores[sorted_scores[0]]['score']]
                        sorted_scores.sort(reverse=False, 
                            key=lambda x:hyper_param_scores[x]['variance'])
                        sorted_scores = [s for s in sorted_scores if hyper_param_scores[s]['variance'] == hyper_param_scores[sorted_scores[0]]['variance']]
                        sorted_scores.sort(reverse=False, 
                            key=lambda x:x[0])
                        w_att = ""
                        if att == 1.0 :
                            w_att = " w/ attention"

                        best_val[task][mode][model_type + w_att] = {
                        'params':sorted_scores[0], 
                        'score':hyper_param_scores[sorted_scores[0]]['score']}
    return best_val

def get_all_from_folder(folder_name, best_params):
    task_dict = {}
    model_level_df_dict = []
    best_model_df_dict = []
    for i,filename in enumerate(os.listdir(folder_name)):
        info = parse_model_params(filename)
        task = info['task']
        mode = info['mode']
        model_name = info['model']
        hidden_size = 0
        ff_dim = 0
        num_heads = 1
        attention = 0
        if 'hidden_size' in info:
            hidden_size = info['hidden_size']
        if 'attention' in info:
            attention = int(info['attention'])
        if 'ff_dim' in info:
            ff_dim = info['ff_dim']
        if 'heads' in info:
            num_heads = info['heads']
        param_tuple = (info['parameters'], info['embed'], 
                       info['layers'], info['lr'], info['batch_size'], 
                       hidden_size, attention, num_heads, ff_dim)
        if model_name in ["lstm", "lru"]:
            model_name = model_name.upper()
            if info['attention']:
                model_name += " w/ attention"
        else:
            model_name = model_name.capitalize()

        best_params_tuple = best_params[task][mode][model_name.lower()]['params']
        is_best = True
        for p1, p2 in zip(param_tuple, best_params_tuple):
            if abs(float(p1) - float(p2)) > 0.001:
                is_best = False
        if is_best:
            if task not in task_dict:
                task_dict[task] = {}
            if mode not in task_dict[task]:
                task_dict[task][mode] = {}      

            df = load_df(os.path.join(folder_name, filename))
            if model_name not in task_dict[task][mode]:
                task_dict[task][mode][model_name] = df
            else:
                task_dict[task][mode][model_name] = pd.concat(
                    [task_dict[task][mode][model_name], df], axis=0)
            best_model_df_dict.append(get_model_stats(os.path.join(folder_name, 
                                                                filename), 
                                                   info, model_name))
        model_level_df_dict.append(get_model_stats(os.path.join(folder_name, 
                                                                filename), 
                                                   info, model_name))
        os.remove(os.path.join(folder_name, filename))
    model_df = pd.DataFrame(model_level_df_dict)
    best_model_df = pd.DataFrame(best_model_df_dict)
    return task_dict, model_df, best_model_df

def parse_model_params(filename):
    name_only = '.'.join(filename.split(".")[:-2])
    params = name_only.split("___")
    hyperparam_dict = {}
    for h in params:
        split_info = h.split("_")
        name = '_'.join(split_info[:-1])
        val = split_info[-1]
        if val.isnumeric():
            val = int(val)
        if name == '':
            name = 'model'
        elif val == 'attention':
            if name == 'no':
                val = False
            elif name == 'with':
                val = True
            name = 'attention'
        elif val in ['distract', 'pad']:
            hyperparam_dict['task'] = name
            name = 'mode'
        hyperparam_dict[name] = val
    return hyperparam_dict

def plot_tasks_restrict(tasks, x, xlabel, c, plot_name, restrict_field, 
                        threshold_value, axline=0, plot_type="point", 
                        width=50):
    simp_tasks = set()
    for task in tasks.keys():
        if 'long' in task:
            task = task[:-4]
        elif 'short' in task:
            task = task[:-5]
        simp_tasks.add(task)
    num_tasks = len(simp_tasks)
    fig, axes = plt.subplots(num_tasks, 2, sharey=True, 
                             figsize=(width,12 * num_tasks))
    index = {'LSTM':0, 'Transformer':1, 'LRU':3,'LSTM w/ attention':2, 
             "LRU w/ attention":4}
    for i, task in enumerate(simp_tasks):
        for j, length in enumerate(['short','long']):
            model_dict = tasks[task+length]["distract"]
            ax = axes[i,j]
            if axline > 0:
                ax.axvline(axline, 0,1.1, color='gray', linestyle='--', 
                        label="Max length in training")
            for model_idx, model in enumerate(model_dict.items()):
                label, df = model
                if df is not None:
                    df = df.loc[df[restrict_field]<=threshold_value]
                    color_idx = index[label]
                    if plot_type=="point":
                        sns.pointplot(ax=ax, data=df, 
                                    x=x, y="correct", estimator=np.mean,
                                    color=c[color_idx], label=label)
                    elif plot_type =="reg":
                        sns.regplot(ax=ax, data=df, 
                                x=x, y="correct", x_estimator=np.mean,
                                color=c[color_idx], label=label, x_bins=32)

            ax.set_ylim(-0.1, 1.1)
        
            ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel("Accuracy")
            else:
                ax.set_ylabel('')
            ax.legend()
            ax.set_title(
                "Task: {}".format(task+length))
    fig.subplots_adjust(hspace=0.25,wspace=0.05)
    fig.savefig(plot_name, dpi=300,bbox_inches = 'tight')

def plot_models(model_df, x, y, xlabel, ylabel, c, plot_name):
    tasks = set(model_df.task)
    simp_tasks = set()
    for task in tasks:
        if 'long' in task:
            task = task[:-4]
        elif 'short' in task:
            task = task[:-5]
        simp_tasks.add(task)
    num_tasks = len(simp_tasks)
    fig, axes = plt.subplots(num_tasks, 2, figsize=(18,8 * num_tasks), 
        sharey=True)
    index = {'LSTM':0, 'Transformer':1, 'LRU':3,'LSTM w/ attention':2, 
             "LRU w/ attention":4}
    palette = {m:c[idx] for m, idx in index.items()}
    tasks = set(model_df['task'])
    max_xticks = 0
    xticks = None
    labels = None
    for i, task in enumerate(simp_tasks):
        for j, length in enumerate(["short", "long"]):
            ax = axes[i, j]
         
            task_mode_df = model_df.loc[model_df['task']==task+length]
            if len(task_mode_df) > 0:
                sns.pointplot(ax=ax, data=task_mode_df, x=x, y=y, 
                    estimator=np.mean, hue='model_name', palette=palette)
                   

            ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            else:
                ax.set_ylabel('')
            ax.set_ylim(-0.1, 1.1)
            ax.set_xticklabels(ax.get_xticklabels(), rotation=20)
            ax.legend()
            ax.set_title(
                "Task: {}".format(task+length))
    fig.subplots_adjust(hspace=0.3,wspace=0.05)
    fig.savefig(plot_name, dpi=300,bbox_inches = 'tight')

def plot_in_out(model_df, c, plot_name):
    tasks = set(model_df.task)
    simp_tasks = set()
    for task in tasks:
        if 'long' in task:
            task = task[:-4]
        elif 'short' in task:
            task = task[:-5]
        simp_tasks.add(task)
    fig, axes = plt.subplots(len(simp_tasks), 2, figsize=(20,
        10 * len(simp_tasks)), sharey=True)
    
    index = {'LSTM':0, 'Transformer':1, 'LRU':3,'LSTM w/ attention':2, 
             "LRU w/ attention":4}
    palette = {m:c[idx] for m, idx in index.items()}
    order=["LSTM", "LSTM w/ attention", "Transformer", "LRU","LRU w/ attention"]
    for i, task in enumerate(simp_tasks):
        for j, length in enumerate(["short", "long"]):
            ax = axes[i,j]
            df = model_df.loc[model_df['task'] == task+length]
            sns.pointplot(ax=ax,data=df, x='model_name', y='in_dist_acc', 
                          estimator=np.mean, palette=palette, hue='model_name',
                          markers='o', label='In-Distribution Accuracy',
                          join=False,dodge=True, scale=2, order=order)
            sns.pointplot(ax=ax, data=df, x='model_name', y='out_dist_acc',
                          estimator=np.mean, palette=palette, hue='model_name',
                          markers='*', label='Out-of-Distribution Accuracy',
                          join=False, dodge=True, scale=2, order=order)
            ax.set_xlabel('')
            if j == 0:
                ax.set_ylabel('Accuracy')
            else:
                ax.set_ylabel('')
            leg_handles = ax.get_legend_handles_labels()
            in_dist = copy.copy(leg_handles[0][0])
            out_dist = copy.copy(leg_handles[0][-1])
            in_dist.set_color("gray")
            out_dist.set_color("gray")
            ax.set_xticklabels(ax.get_xticklabels(), rotation=20)
            ax.set_ylim(-0.1, 1.1)
            if j == 0:
                ax.legend((in_dist, out_dist), ['In-Distribution Accuracy', 
                    'Out-of-Distribution Accuracy'], loc='lower left')
            else:
                ax.get_legend().remove()
            ax.set_title("{}".format(task+length))

    fig.savefig(plot_name, dpi=300,bbox_inches = 'tight')

def heat_map(model_df, ax, cbar):
    to_plot = np.ones((65,33))
    for seq_len in range(1,33):
        for dep_len in range(65):
            relevant = model_df.loc[(model_df['seq_length'] == seq_len) & (model_df['dep_length_end']==dep_len)]
            if len(relevant) > 0:
                mean = relevant.correct.mean()
                to_plot[dep_len,seq_len] = mean
            else:
                to_plot[dep_len, seq_len] = 0.5 * (to_plot[dep_len, seq_len - 1] + to_plot[dep_len -1, seq_len])
    if not cbar:
        sns.heatmap(to_plot,ax=ax, cbar=None)
    else:
        sns.heatmap(to_plot,ax=ax)

def all_heat_maps(data):
    fig, axes = plt.subplots(1, len(data), figsize=(len(data)*10,15), sharey=True, layout='constrained')
    for i,item in enumerate(data.items()):
        model_name, df = item
        if i + 1 == len(data):
            cbar = True
        else:
            cbar = False
        heat_map(df, axes[i], cbar)
        axes[i].set_title(model_name)
        axes[i].set_xlabel('Target Sequence Length')
        
    axes[0].set_ylabel('Dependency Length')

parser = argparse.ArgumentParser()
parser.add_argument('--result_folder', type=str, 
    help="Folder containing results", default='')
parser.add_argument('--results_csv', type=str, 
    help="CSV with all model results", default='')
args = parser.parse_args()

sns.set(rc = {'figure.figsize':(20,15)})
sns.set_style('whitegrid')
sns.set_context("paper", font_scale=2)

tasks, model_df, best_model_df =get_all_from_folder(args.result_folder, 
    best_params(args.results_csv))

plot_in_out(best_model_df, sns.color_palette("pastel"), "in_out.png")
plot_models(model_df, 'parameters', 'accuracy', 'Number of Parameters', 
    'Accuracy', sns.color_palette("pastel"), "params.png")
plot_tasks_restrict(tasks, 'seq_length', 'Target Sequence Length', 
    sns.color_palette("pastel"), 'seq_length.png', 'dep_length_end',32, 15, 
    plot_type="point",width=32)
plot_tasks_restrict(tasks, 'dep_length_end', 'Dependency Length', 
    sns.color_palette("pastel"), 'dep_length.png', 'seq_length', 16, 32,"reg", 
    width=30)