import pickle as p

from scipy import stats
import numpy as np
import json
import pandas as pd
from copy import deepcopy
from utils import create_data, load_search_spaces


METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']

METRICS_v0 = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'val_accuracy']

def create_data_clone(ss, dataset, list_metrics):
    list_arch = list(ss[dataset].keys())
    X = []
    y = []
    for arch in list_arch:
        x = []
        for metric in list_metrics:
            if metric == 'val_accuracy':
                y.append(ss[dataset][arch][metric])
            else:
                try:
                    x.append(ss[dataset][arch][metric]['score'])
                except KeyError:
                    x.append(-9999999)
        X.append(x)
    X = np.array(X)
    y = np.array(y)
    return X, y

def calculate_rank(X):
    lo_bound = np.min(np.array(X), axis=0)
    up_bound = np.max(np.array(X), axis=0)
    X = (X - lo_bound) / (up_bound - lo_bound)
    sum_val = np.sum(X, axis=1)
    return sum_val


def run_A1():
    all_ss = load_search_spaces('database')

    res = json.load(open('exp/multiple_D+_results.json'))

    X = []
    for rid, info in res.items():
        X.append(info['full'])
    X = np.array(X)
    scores = calculate_rank(X)
    idx_best = np.argmax(scores)
    print(res[f'{idx_best + 1}'])

    gp_model_Dp = p.load(open(f'exp/GP-Model_multiple_D+_run{idx_best + 1}.p', 'rb'))
    model = deepcopy(str(gp_model_Dp.our_program['program']))

    for i in range(len(METRICS) - 1, -1, -1):
        model = np.char.replace(model, f'X{i}', f'{METRICS[i]}')
    print('+ Model (Dataset D+):', model, '\n')

    res = json.load(open('exp/multiple_D-_results.json'))

    X = []
    for rid, info in res.items():
        X.append(info['full'])
    X = np.array(X)
    scores = calculate_rank(X)
    idx_best = np.argmax(scores)
    print(res[f'{idx_best + 1}'])

    gp_model_Ds = p.load(open(f'exp/GP-Model_multiple_D-_run{idx_best + 1}.p', 'rb'))
    model = deepcopy(str(gp_model_Ds.our_program['program']))

    for i in range(len(METRICS) - 1, -1, -1):
        model = np.char.replace(model, f'X{i}', f'{METRICS[i]}')
    print('+ Model (Dataset D-):', model, '\n')

    gp_model_D = p.load(open(f'exp/GP-Model_multiple_run12.p', 'rb'))

    list_problems = []
    data = []
    for ss in [
        'nb101',
        'nb201',
        'nb301',
        'transnb101_micro',
        'transnb101_macro'
    ]:
        if ss == 'nb101':
            list_dataset = ['cifar10']
        elif ss == 'nb201':
            # list_dataset = ['cifar10', 'cifar100', 'ImageNet16-120']
            list_dataset = ['cifar10']
        elif ss == 'nb301':
            list_dataset = ['cifar10']
        elif ss == 'transnb101_micro':
            # list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
            list_dataset = ['class_scene']
        elif ss == 'transnb101_macro':
            # list_dataset = ['class_scene', 'class_object', 'jigsaw', 'autoencoder']
            list_dataset = ['class_scene']
        else:
            raise NotImplementedError

        for dataset in list_dataset:
            if ss == 'transnb101_micro':
                ss_ = 'tnb101-micro'
            elif ss == 'transnb101_macro':
                ss_ = 'tnb101-macro'
            else:
                ss_ = ss
            problem = f'{ss_}-{dataset}'
            list_problems.append(problem)

            X, y = create_data(all_ss[ss], dataset)
            y_pred_Dp = gp_model_Dp.predict(X)
            y_pred_D = gp_model_D.predict(X)

            _X, _ = create_data_clone(all_ss[ss], dataset, list_metrics=METRICS_v0)
            y_pred_Ds = gp_model_Ds.predict(_X)

            if measure == 'kendall':
                corr_D = stats.kendalltau(y, y_pred_D)[0]
                corr_Dp = stats.kendalltau(y, y_pred_Dp)[0]
                corr_Ds = stats.kendalltau(y, y_pred_Ds)[0]
            elif measure == 'spearman':
                corr_D = stats.spearmanr(y, y_pred_D)[0]
                corr_Dp = stats.spearmanr(y, y_pred_Dp)[0]
                corr_Ds = stats.spearmanr(y, y_pred_Ds)[0]
            else:
                raise NotImplementedError
            data.append([corr_Ds, corr_Dp, corr_D])
    data = np.round(np.array(data), 2)
    df = pd.DataFrame(data, index=list_problems, columns=['D-', 'D+', 'D'])
    print(df)

def run_A2():
    all_ss = load_search_spaces('database')

    res_nb101 = json.load(open('exp/nb101-cifar10_results.json'))
    X = [info['full'] for _, info in res_nb101.items()]
    idx_best = np.argmax(X)
    gp_model_nb101 = p.load(open(f'exp/GP-Model_nb101-cifar10_run{idx_best + 1}.p', 'rb'))

    res_nb201 = json.load(open('exp/nb201-cifar10_results.json'))
    X = [info['full'] for _, info in res_nb201.items()]
    idx_best = np.argmax(X)
    gp_model_nb201 = p.load(open(f'exp/GP-Model_nb201-cifar10_run{idx_best + 1}.p', 'rb'))

    res_nb301 = json.load(open('exp/nb301-cifar10_results.json'))
    X = [info['full'] for _, info in res_nb301.items()]
    idx_best = np.argmax(X)
    gp_model_nb301 = p.load(open(f'exp/GP-Model_nb301-cifar10_run{idx_best + 1}.p', 'rb'))

    list_problems = []
    data = []
    for ss in [
        'nb101',
        'nb201',
        'nb301',
    ]:
        if ss == 'nb101':
            list_dataset = ['cifar10']
        elif ss == 'nb201':
            list_dataset = ['cifar10']
        elif ss == 'nb301':
            list_dataset = ['cifar10']
        else:
            raise NotImplementedError

        for dataset in list_dataset:
            ss_ = ss
            problem = f'{ss_}-{dataset}'
            list_problems.append(problem)

            X, y = create_data(all_ss[ss], dataset)
            y_pred_nb101 = gp_model_nb101.predict(X)
            y_pred_nb201 = gp_model_nb201.predict(X)
            y_pred_nb301 = gp_model_nb301.predict(X)

            if measure == 'kendall':
                corr_nb101 = stats.kendalltau(y, y_pred_nb101)[0]
                corr_nb201 = stats.kendalltau(y, y_pred_nb201)[0]
                corr_nb301 = stats.kendalltau(y, y_pred_nb301)[0]
            elif measure == 'spearman':
                corr_nb101 = stats.spearmanr(y, y_pred_nb101)[0]
                corr_nb201 = stats.spearmanr(y, y_pred_nb201)[0]
                corr_nb301 = stats.spearmanr(y, y_pred_nb301)[0]
            else:
                raise NotImplementedError
            data.append([corr_nb101, corr_nb201, corr_nb301])
    data = np.round(np.array(data), 2).T
    data = np.concatenate((data, np.array([[0.61, 0.76, 0.40]])), axis=0)

    df = pd.DataFrame(data, index=['NB101-CF10', 'NB201-CF10', 'NB301-CF10', 'All 3 problems'], columns=['NB101-CF10', 'NB201-CF10', 'NB301-CF10'])
    print(df)

if __name__ == '__main__':
    measure = 'kendall'
    run_A1()
    # run_A2()
