import os
import pickle as p
import json
import argparse
import numpy as np

from GPModel import CloneSymbolicRegressor, get_fitness_function
from utils import modify_y
from utils import load_search_spaces
from sklearn.model_selection import train_test_split
from copy import deepcopy

TRAIN_SIZE = 0.7

METRICS = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'swap', 'meco_opt', 'zico', 'val_accuracy']

METRICS_v0 = ['epe_nas', 'fisher', 'flops', 'grad_norm', 'grasp', 'jacov', 'l2_norm', 'nwot', 'params', 'plain', 'snip',
           'synflow', 'zen', 'val_accuracy']

MEASURE = 'kendall'

FUNCTION_SET = ['add', 'sub', 'mul', 'neg', 'div', 'log', 'sqrt']

DATA_PATH = 'database'

POP_SIZE = 100
TOURNAMENT_SIZE = 4
N_GENS = 50
P_CROSSOVER = 0.7
P_SUBTREE_MUTATION = 0.1
P_HOIST_MUTATION = 0.05
P_POINT_MUTATION = 0.1

N_RUNS = 31
VERBOSE = 1
LOGGING = 1


def create_data(ss, dataset, list_metrics):
    list_arch = list(ss[dataset].keys())
    X = []
    y = []
    for arch in list_arch:
        x = []
        for metric in list_metrics:
            if metric == 'val_accuracy':
                y.append(ss[dataset][arch][metric])
            else:
                try:
                    x.append(ss[dataset][arch][metric]['score'])
                except KeyError:
                    x.append(-9999999)
        X.append(x)
    X = np.array(X)
    y = np.array(y)
    return X, y

def run_evolution(X_train, y_train,
                  list_metrics, fitness_function,
                  run_id, name, model_path,
                  multiple=True):
    gp_model = CloneSymbolicRegressor(
        population_size=POP_SIZE,
        tournament_size=TOURNAMENT_SIZE,
        generations=N_GENS,
        p_crossover=P_CROSSOVER, p_subtree_mutation=P_SUBTREE_MUTATION,
        p_hoist_mutation=P_HOIST_MUTATION, p_point_mutation=P_POINT_MUTATION,
        metric=fitness_function,
        function_set=FUNCTION_SET, const_range=None,
        stopping_criteria=9999999999,
        init_depth=(2, 10), verbose=VERBOSE,
        parsimony_coefficient=0.01, random_state=run_id, low_memory=True, multiple=multiple
    )
    gp_model.fit(X_train, y_train)

    model = deepcopy(str(gp_model.our_program['program']))
    for i in range(len(list_metrics) - 1, -1, -1):
        model = np.char.replace(model, f'X{i}', f'{list_metrics[i]}')
    print('+ Best model:', model)
    print('+ Depth:', gp_model.our_program['program'].depth_)
    print()

    if LOGGING:
        delattr(gp_model, '_programs')
        p.dump(gp_model, open(f'{model_path}/GP-Model_{name}.p', 'wb'))
    return gp_model


def search(all_ss, list_ss_dataset=(None, None), train_ratio=0.7,
           list_metrics=None, fitness_function=None, model_path='exp', name=None):
    print('List Metrics:', list_metrics)
    res_json = {}
    all_scores_train, all_scores_test, all_scores_full = [], [], []
    for run_id in range(1, N_RUNS + 1):
        res_json[f'{run_id}'] = {'seed': run_id}
        print('# Run ID:', run_id)
        X, y, X_train, y_train, X_test, y_test = [], [], [], [], [], []
        unique_label = []
        for i, (search_space, dataset) in enumerate(list_ss_dataset):
            ss = all_ss[search_space]
            _X, _y = create_data(ss, dataset, list_metrics=list_metrics)
            _y = modify_y(_y, search_space, dataset)
            _X_train, _X_test, _y_train, _y_test = train_test_split(_X, _y, train_size=train_ratio, random_state=run_id)
            if i == 0:
                X = _X.copy()
                y = _y.copy()
                X_train = _X_train.copy()
                X_test = _X_test.copy()
                y_train = _y_train.copy()
                y_test = _y_test.copy()
            else:
                X = np.concatenate((X, _X), axis=0)
                y += _y
                X_train = np.concatenate((X_train, _X_train), axis=0)
                X_test = np.concatenate((X_test, _X_test), axis=0)
                y_train += _y_train
                y_test += _y_test
            unique_label.append(f'({search_space} {dataset})')
        if name is not None:
            _name = f'multiple_{name}_run{run_id}'
        else:
            _name = f'multiple_run{run_id}'
        gp_model = run_evolution(X_train, y_train, list_metrics, run_id=run_id,
                                 fitness_function=fitness_function,
                                 model_path=model_path,
                                 name=_name, multiple=True)

        y_pred = gp_model.predict(X_train)
        print('+ Correlation on Training Set:')
        corr = fitness_function(y_train, y_pred)
        for i, label in enumerate(unique_label):
            print(f'{label}: {corr[i]}')
        print('Train - Mean:', np.round(np.mean(corr), 6))

        all_scores_train.append(np.round(corr, 4))
        res_json[f'{run_id}']['train'] = list(np.round(corr, 4))
        print()

        print('+ Correlation on Test Set:')
        y_pred = gp_model.predict(X_test)
        corr = fitness_function(y_test, y_pred)
        for i, label in enumerate(unique_label):
            print(f'{label}: {corr[i]}')
        print('Test - Mean:', np.round(np.mean(corr), 6))

        all_scores_test.append(np.round(corr, 4))
        res_json[f'{run_id}']['test'] = list(np.round(corr, 4))
        print()

        print('+ Correlation on Full Data:')
        y_pred = gp_model.predict(X)
        corr = fitness_function(y, y_pred)
        for i, label in enumerate(unique_label):
            print(f'{label}: {corr[i]}')
        print('Full - Mean:', np.round(np.mean(corr), 6))

        all_scores_full.append(np.round(corr, 4))
        res_json[f'{run_id}']['full'] = list(np.round(corr, 4))
        print()

        print('---' * 40)
    if LOGGING:
        if name is not None:
            _name = f'{model_path}/multiple_{name}_results.json'
        else:
            _name = f'{model_path}/multiple_results.json'
        json.dump(res_json,
                  open(_name, 'w'), indent=6)
    print(
        f'- Train - Mean: {np.round(np.mean(all_scores_train, axis=0), 4)}, Std: {np.round(np.std(all_scores_train, axis=0), 4)}')
    print(
        f'- Test - Mean: {np.round(np.mean(all_scores_test, axis=0), 4)}, Std: {np.round(np.std(all_scores_test, axis=0), 4)}')
    print(
        f'- Full - Mean: {np.round(np.mean(all_scores_full, axis=0), 4)}, Std: {np.round(np.std(all_scores_full, axis=0), 4)}')
    print()

def search_2(all_ss, list_ss_dataset=(None, None), train_ratio=0.7,
             fitness_function=None, model_path='exp'):
    print('List Metrics:', METRICS)
    res_json = {}
    all_scores_train, all_scores_test, all_scores_full = [], [], []
    for run_id in range(1, N_RUNS + 1):
        res_json[f'{run_id}'] = {'seed': run_id}
        print('# Run ID:', run_id)
        X, y, X_train, y_train, X_test, y_test = [], [], [], [], [], []
        unique_label = []
        for i, (search_space, dataset) in enumerate(list_ss_dataset):
            ss = all_ss[search_space]
            _X, _y = create_data(ss, dataset, list_metrics=METRICS)
            _y = modify_y(_y, search_space, dataset)
            _X_train, _X_test, _y_train, _y_test = train_test_split(_X, _y, train_size=train_ratio, random_state=run_id)
            if i == 0:
                X = _X.copy()
                y = _y.copy()
                X_train = _X_train.copy()
                X_test = _X_test.copy()
                y_train = _y_train.copy()
                y_test = _y_test.copy()
            else:
                X = np.concatenate((X, _X), axis=0)
                y += _y
                X_train = np.concatenate((X_train, _X_train), axis=0)
                X_test = np.concatenate((X_test, _X_test), axis=0)
                y_train += _y_train
                y_test += _y_test
            unique_label.append(f'({search_space} {dataset})')
        gp_model = run_evolution(X_train, y_train, METRICS, run_id=run_id,
                                 fitness_function=fitness_function,
                                 model_path=model_path,
                                 name=f'multiple_D+_run{run_id}', multiple=True)

        y_pred = gp_model.predict(X_train)
        print('+ Correlation on Training Set:')
        corr = fitness_function(y_train, y_pred)
        for i, label in enumerate(unique_label):
            print(f'{label}: {corr[i]}')
        print('Train - Mean:', np.round(np.mean(corr), 6))

        all_scores_train.append(np.round(corr, 4))
        res_json[f'{run_id}']['train'] = list(np.round(corr, 4))
        print()

        print('+ Correlation on Test Set:')
        y_pred = gp_model.predict(X_test)
        corr = fitness_function(y_test, y_pred)
        for i, label in enumerate(unique_label):
            print(f'{label}: {corr[i]}')
        print('Test - Mean:', np.round(np.mean(corr), 6))

        all_scores_test.append(np.round(corr, 4))
        res_json[f'{run_id}']['test'] = list(np.round(corr, 4))
        print()

        print('+ Correlation on Full Data:')
        y_pred = gp_model.predict(X)
        corr = fitness_function(y, y_pred)
        for i, label in enumerate(unique_label):
            print(f'{label}: {corr[i]}')
        print('Full - Mean:', np.round(np.mean(corr), 6))

        all_scores_full.append(np.round(corr, 4))
        res_json[f'{run_id}']['full'] = list(np.round(corr, 4))
        print()

        print('---' * 40)
    if LOGGING:
        json.dump(res_json,
                  open(f'{model_path}/multiple_D+_results.json', 'w'), indent=6)
    print(
        f'- Train - Mean: {np.round(np.mean(all_scores_train, axis=0), 4)}, Std: {np.round(np.std(all_scores_train, axis=0), 4)}')
    print(
        f'- Test - Mean: {np.round(np.mean(all_scores_test, axis=0), 4)}, Std: {np.round(np.std(all_scores_test, axis=0), 4)}')
    print(
        f'- Full - Mean: {np.round(np.mean(all_scores_full, axis=0), 4)}, Std: {np.round(np.std(all_scores_full, axis=0), 4)}')
    print()


def run(kwargs):
    FITNESS_FUNCTION = get_fitness_function(measure=MEASURE, multiple_dataset=True)

    search_spaces = load_search_spaces(DATA_PATH)

    print('- Training ratio:', TRAIN_SIZE)
    print('- Measure:', MEASURE)

    os.makedirs(f'exp', exist_ok=True)

    SR_dataset = kwargs.SR_dataset
    if SR_dataset == 'D':
        list_ss_dataset = [('nb101', 'cifar10'), ('nb201', 'cifar10'), ('nb301', 'cifar10')]
        print('- Search Space - Dataset:', list_ss_dataset)
        print('----' * 40)
        search(all_ss=search_spaces, list_ss_dataset=list_ss_dataset,
               train_ratio=TRAIN_SIZE, list_metrics=METRICS, fitness_function=FITNESS_FUNCTION,
               model_path=f'exp', name=None)

    elif SR_dataset == 'D+':
        list_ss_dataset = [('nb101', 'cifar10'), ('nb201', 'cifar10'), ('nb301', 'cifar10'),
                           ('transnb101_micro', 'class_scene'), ('transnb101_macro', 'class_scene')]
        print('- Search Space - Dataset:', list_ss_dataset)
        print('----' * 40)
        search(all_ss=search_spaces, list_ss_dataset=list_ss_dataset,
               train_ratio=TRAIN_SIZE, list_metrics=METRICS, fitness_function=FITNESS_FUNCTION,
               model_path=f'exp', name='D+')

    elif SR_dataset == 'D-':
        list_ss_dataset = [('nb101', 'cifar10'), ('nb201', 'cifar10'), ('nb301', 'cifar10')]
        print('- Search Space - Dataset:', list_ss_dataset)
        print('----' * 40)
        search(all_ss=search_spaces, list_ss_dataset=list_ss_dataset,
               train_ratio=TRAIN_SIZE, list_metrics=METRICS_v0, fitness_function=FITNESS_FUNCTION,
               model_path=f'exp', name='D-')
    else:
        pass
    print('====' * 40)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    ''' PROBLEM '''
    parser.add_argument('--SR_dataset', type=str, default='D', help='dataset for SR',
    choices=['D', 'D+', 'D-'])
    args = parser.parse_args()
    run(args)

