import pickle as p
from scipy import stats
import numpy as np
import json
import pandas as pd
from copy import deepcopy
from utils import create_data, load_search_spaces
# from cal_rank import get_scoreboard, get_rankboard
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']

TASKS = [
    'nb101-cifar10',
    'nb201-cifar10', 'nb201-cifar100', 'nb201-ImageNet16-120',
    'nb301-cifar10',
    'tnb101-micro-class_scene', 'tnb101-micro-class_object', 'tnb101-micro-jigsaw', 'tnb101-micro-autoencoder',
    'tnb101-macro-class_scene', 'tnb101-macro-class_object', 'tnb101-macro-jigsaw', 'tnb101-macro-autoencoder',
]

COMPETITORS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip', 'synflow',
               'zen', 'zico', 'meco_opt', 'swap', 'ours']

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


def get_scoreboard(database):
    scoreboard = []
    # for _, info in database.items():
    for task in TASKS:
        scores = [database[task][metric] for metric in COMPETITORS]
        scoreboard.append(scores)
    scoreboard = np.round(np.array(scoreboard), 2)
    scoreboard[scoreboard == -1] = 0.0
    scoreboard[scoreboard == -0] = 0.0
    scoreboard[np.isnan(scoreboard)] = 0.0
    scoreboard = pd.DataFrame(scoreboard, index=TASKS, columns=COMPETITORS)
    return scoreboard


def get_rankboard(scoreboard):
    rankdata = stats.rankdata(-scoreboard.values, method='dense', axis=1)
    sum_scores = np.array([np.sum(rankdata, axis=0)])
    final_rank = np.array([stats.rankdata(sum_scores, method='dense')])
    rankdata = np.concatenate((rankdata, sum_scores, final_rank), axis=0)
    rankboard = pd.DataFrame(rankdata, index=list(scoreboard.index) + ['Sum', 'Final Rank'], columns=COMPETITORS)
    return rankboard


def plot_heatmap(df, figsize=(16, 14), rotation=0, title='', cmap='viridis_r', savetitle='zcp_corr',
                 cbar=True, square=False, fmt='.2f'):
    _, ax = plt.subplots(figsize=figsize)
    _df = df

    mask = np.zeros(_df.values.shape)
    mask[:, -1] = True
    ax = sns.heatmap(_df, alpha=0.0, linecolor='k', cbar=False, annot=True, mask=np.logical_not(mask), square=square,
                     annot_kws={'size': 24, "color": "k"})
    ax.set_facecolor('white')

    columns_text = [item.get_text() for item in ax.get_xticklabels()]

    ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation, fontsize=20, va='top', ha='right',
                       rotation_mode='anchor')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=20)
    ticks = ax.get_xticklabels()
    ticks[-1].set_rotation(0)
    ticks[-1].set_ha('center')
    ax.set(xlabel=None)
    ax.set(ylabel=None)

    for i in [3, 5, 9]:
        ax.axvline(i, color='white', lw=5)
    ax.get_xticklabels()[1].set_color("red")
    ax.get_xticklabels()[3].set_color("red")
    ax.get_xticklabels()[4].set_color("red")
    ax.get_xticklabels()[2].set_color("green")
    ax.get_xticklabels()[5].set_color("blue")
    ax.get_xticklabels()[6].set_color("blue")
    ax.get_xticklabels()[7].set_color("blue")
    ax.get_xticklabels()[8].set_color("blue")

    ax.get_xticklabels()[9].set_color("tab:orange")
    ax.get_xticklabels()[10].set_color("tab:orange")
    ax.get_xticklabels()[11].set_color("tab:orange")
    ax.get_xticklabels()[12].set_color("tab:orange")

    data = _df.values[:, :-1]
    column_max = {task: [] for task in columns_text[:-1]}

    row_text = [item.get_text() for item in ax.get_yticklabels()]
    for j in range(data.shape[1]):
        highest_score = max(data[:, j])
        ids = np.argwhere(data[:, j] == highest_score).reshape(-1)
        for idx in ids:
            column_max[columns_text[j]].append(row_text[idx])
    for col, variable in enumerate(columns_text[:-1]):
        list_best = column_max[variable]
        for best in list_best:
            position = _df.index.get_loc(best)
            ax.add_patch(Rectangle((col, position), 1, 1, fill=False, edgecolor='red', lw=3))

    for i in range(len(row_text)):
        position = _df.index.get_loc(row_text[i])
        ax.add_patch(Rectangle((len(columns_text) - 1, position), 1, 1, fill=False, edgecolor='k', lw=2))
    plt.title(title, fontsize=24)
    plt.tight_layout()
    # plt.show()
    plt.savefig('{}.png'.format(savetitle), bbox_inches='tight', dpi=600)

def calculate_rank(X):
    lo_bound = np.min(np.array(X), axis=0)
    up_bound = np.max(np.array(X), axis=0)
    X = (X - lo_bound) / (up_bound - lo_bound)
    sum_val = np.sum(X, axis=1)
    return sum_val


def create_database():
    list_metrics = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain',
                    'snip', 'synflow', 'zen', 'zico', 'meco_opt', 'swap', 'ours']
    all_corr = json.load(open(f'result/{filename}.json'))
    x_labels = ['nb101-cifar10', 'nb201-cifar10', 'nb301-cifar10',
                'nb201-cifar100', 'nb201-ImageNet16-120',
                'tnb101-micro-class_scene', 'tnb101-micro-class_object', 'tnb101-micro-autoencoder', 'tnb101-micro-jigsaw',
                'tnb101-macro-class_scene', 'tnb101-macro-class_object', 'tnb101-macro-autoencoder', 'tnb101-macro-jigsaw']
    database = []
    for metric in list_metrics:
        list_corr = []
        for ss_dataset in x_labels:
            try:
                list_corr.append(all_corr[ss_dataset][metric])
            except KeyError:
                list_corr.append(np.nan)
        database.append(list_corr)
    database = np.array(database)
    database[database == -1] = 0.0
    database[database == -0] = 0.0
    database[np.isnan(database)] = 0.0
    x_labels = map(lambda x:x.upper(), x_labels)
    x_labels = [x.replace('CIFAR10', 'CF10') for x in x_labels]
    x_labels = [x.replace('IMAGENET16-120', 'IMGNT') for x in x_labels]
    x_labels = [x.replace('AUTOENCODER', 'AUTOENC') for x in x_labels]
    x_labels = [x.replace('CLASS_', "") for x in x_labels]
    _list_metrics = ['EPE-NAS', 'Fisher', 'FLOPs', 'Grad-norm', 'Grasp', 'Jacov',
                     'L2-norm', 'NWOT', 'Params', 'Plain',' Snip',' Synflow', 'Zen', 'ZiCo', 'MeCo', 'SWAP', r'$\bf{SR}$' + '-' + r'$\bf{NAS}$']
    
    scoreboard = get_scoreboard(all_corr)
    rank = get_rankboard(scoreboard).values[-1]
    database = np.round(database, 2)
    database = np.concatenate((database, np.array([rank]).T), axis=1)
    df = pd.DataFrame(database,
                      columns=x_labels + [r'$\bf{Rank}$'],
                      index=_list_metrics)
    df = df.sort_values(by=r'$\bf{Rank}$', ascending=False)
    return df

def evaluate():
    all_ss = load_search_spaces('database')
    all_corr = json.load(open(f'result/{filename}.json'))

    res = json.load(open('exp/multiple_results.json'))

    X = []
    for rid, info in res.items():
        X.append(info['full'])
    X = np.array(X)
    scores = calculate_rank(X)
    idx_best = np.argmax(scores)
    print(res[f'{idx_best + 1}'])

    gp_model = p.load(open(f'exp/GP-Model_multiple_run{idx_best + 1}.p', 'rb'))
    model = deepcopy(str(gp_model.our_program['program']))

    for i in range(len(METRICS) - 1, -1, -1):
        model = np.char.replace(model, f'X{i}', f'{METRICS[i]}')
    print('+ Model:', model)

    for ss in [
        'nb101',
        'nb201',
        'nb301',
        'transnb101_micro',
        'transnb101_macro'
    ]:
        if ss == 'nb101':
            list_dataset = ['cifar10']
        elif ss == 'nb201':
            list_dataset = ['cifar10', 'cifar100', 'ImageNet16-120']
        elif ss == 'nb301':
            list_dataset = ['cifar10']
        elif ss == 'transnb101_micro':
            list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
        elif ss == 'transnb101_macro':
            list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
        else:
            raise NotImplementedError

        for dataset in list_dataset:
            if ss == 'transnb101_micro':
                ss_ = 'tnb101-micro'
            elif ss == 'transnb101_macro':
                ss_ = 'tnb101-macro'
            else:
                ss_ = ss
            X, y = create_data(all_ss[ss], dataset)
            y_pred = gp_model.predict(X)
            if measure == 'kendall':
                corr = stats.kendalltau(y, y_pred)[0]
            elif measure == 'spearman':
                corr = stats.spearmanr(y, y_pred)[0]
            else:
                raise NotImplementedError

            problem = f'{ss_}-{dataset}'
            all_corr[problem]['ours'] = corr
            print(f'{problem}: {corr:.2f}')

    # with open(f'result/{filename}.json', 'w') as fp:
    #     json.dump(all_corr, fp, indent=4, cls=NumpyEncoder)
    #     print('- Results are saved on result/')
    #     fp.close()

def visualize():
    df = create_database()
    plot_heatmap(df, figsize=(16, 10), rotation=20, title=f'{measure.upper()}', fmt='.2f', cbar=False,
                 savetitle=f'fig/{filename}')
    print('- Figures are saved on fig/')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    ''' PROBLEM '''
    parser.add_argument('--measure', type=str, default='kendall', help='the rank correlation measure',
    choices=['kendall', 'spearman'])
    args = parser.parse_args()

    measure = args.measure
    filename = f'{measure}-correlation'

    evaluate()
    visualize()
