import numpy as np
import json
from scipy import stats
from sklearn.model_selection import train_test_split

METRIC_AGNOSTIC = ['l2_norm', 'params', 'synflow', 'zen', 'val_accuracy']
METRIC_SPECIFIC = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'nwot', 'plain', 'snip', 'swap', 'meco_opt', 'zico', 'val_accuracy']
METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']


def evaluate(X, y, gp_model=None, rank_metric='kendall'):
    for i in range(X.shape[1]):
        if rank_metric == 'spearman':
            res = stats.spearmanr(X[:, i], y)
        else:
            res = stats.kendalltau(X[:, i], y)
        print(METRICS[i], round(res[0], 2))
    our_res = [None, None]
    if gp_model is not None:
        new_feature = gp_model.predict(X)
        if rank_metric == 'spearman':
            our_res = stats.spearmanr(new_feature, y)
        else:
            our_res = stats.kendalltau(new_feature, y)
        print('new_feature', round(our_res[0], 2))
    print()
    return our_res[0]


def modify_y(y, search_space, dataset):
    y_new = []
    for v in y:
        y_new.append(f'{v}+{search_space}_{dataset}')
    return y_new


def create_data(ss, dataset, kind='all'):
    list_arch = list(ss[dataset].keys())
    X = []
    y = []
    if kind == 'agnostic':
        list_metrics = METRIC_AGNOSTIC
    elif kind == 'specific':
        list_metrics = METRIC_SPECIFIC
    else:
        list_metrics = METRICS
    for arch in list_arch:
        x = []
        for metric in list_metrics:
            if metric == 'val_accuracy':
                y.append(ss[dataset][arch][metric])
            else:
                try:
                    x.append(ss[dataset][arch][metric]['score'])
                except KeyError:
                    x.append(-9999999)
                # x.append(ss[dataset][arch][metric]['score'])
        X.append(x)
    X = np.array(X)
    y = np.array(y)
    return X, y


def percentage_split(ss, dataset, kind='all'):
    X, y = create_data(ss, dataset, kind)
    data_top = {}
    length = len(X) // 10
    idx = np.argsort(-y)
    X = X[idx]
    y = y[idx]
    for i in range(1, 11):
        if i == 10:
            _X = X.copy()
            _y = y.copy()
        else:
            _X = X[:length].copy()
            _y = y[:length].copy()
        data_top[f'{(i-1)*10}-{i*10}'] = {'X': _X, 'y': _y}
        X = X[length:].copy()
        y = y[length:].copy()
    return data_top


def create_single_data(ss, dataset, seed, train_size=100, kind='all'):
    X, y = create_data(ss, dataset, kind)
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_size, random_state=seed)
    return X_train, X_test, y_train, y_test


def create_multiple_data(all_ss, list_ss, list_dataset):
    X, y = [], []
    for i, _ss in enumerate(list_ss):
        _X, _y = create_data(all_ss[_ss], list_dataset[i])
        _y = modify_y(_y, _ss, list_dataset[i])
        if i == 0:
            X = _X.copy()
            y = _y.copy()
        else:
            X = np.concatenate((X, _X), axis=0)
            y += _y
    return X, y


def create_multiple_data_2(all_ss, list_ss, list_dataset, seed, train_size=100, kind='all'):
    X_train, X_test, y_train, y_test = [], [], [], []
    for i, _ss in enumerate(list_ss):
        _X, _y = create_data(all_ss[_ss], list_dataset[i], kind)
        _y = modify_y(_y, _ss, list_dataset[i])
        _X_train, _X_test, _y_train, _y_test = train_test_split(_X, _y, train_size=train_size, random_state=seed)
        if i == 0:
            X_train = _X_train.copy()
            X_test = _X_test.copy()
            y_train = _y_train.copy()
            y_test = _y_test.copy()
        else:
            X_train = np.concatenate((X_train, _X_train), axis=0)
            X_test = np.concatenate((X_test, _X_test), axis=0)
            y_train += _y_train
            y_test += _y_test
    return X_train, X_test, y_train, y_test


def load_search_spaces(data_path):
    nb101 = json.load(open(f'{data_path}/NB101/zc_nasbench101.json'))
    nb201 = json.load(open(f'{data_path}/NB201/zc_nasbench201.json'))
    nb301 = json.load(open(f'{data_path}/NB301/zc_nasbench301.json'))
    # nb101 = json.load(open('zc_nasbench101.json'))
    # nb201 = json.load(open('zc_nasbench201.json'))
    # nb301 = json.load(open('zc_nasbench301.json'))
    transnb101_macro = json.load(open(f'{data_path}/TransNAS101/zc_transbench101_macro.json'))
    transnb101_micro = json.load(open(f'{data_path}/TransNAS101/zc_transbench101_micro.json'))

    search_spaces = {
        'nb101': nb101,
        'nb201': nb201,
        'nb301': nb301,
        'transnb101_micro': transnb101_micro,
        'transnb101_macro': transnb101_macro
    }
    return search_spaces
