import os
import json
import pickle as p
import numpy as np
import argparse

from sklearn.model_selection import train_test_split
from copy import deepcopy

from utils import create_data, load_search_spaces
from utils import METRICS
from GPModel import CloneSymbolicRegressor, get_fitness_function

TRAIN_SIZE = 0.7

FUNCTION_SET = ['add', 'sub', 'mul', 'neg', 'div', 'log', 'sqrt']

DATA_PATH = 'database'

POP_SIZE = 100
TOURNAMENT_SIZE = 4
N_GENS = 50
P_CROSSOVER = 0.7
P_SUBTREE_MUTATION = 0.1
P_HOIST_MUTATION = 0.05
P_POINT_MUTATION = 0.1

N_RUNS = 31
VERBOSE = 1
LOGGING = 1


def run_evolution(X_train, y_train,
                  list_metrics, fitness_function,
                  run_id, name, model_path, multiple=False):
    gp_model = CloneSymbolicRegressor(
        population_size=POP_SIZE,
        tournament_size=TOURNAMENT_SIZE,
        generations=N_GENS,
        p_crossover=P_CROSSOVER, p_subtree_mutation=P_SUBTREE_MUTATION,
        p_hoist_mutation=P_HOIST_MUTATION, p_point_mutation=P_POINT_MUTATION,
        metric=fitness_function,
        function_set=FUNCTION_SET, const_range=None,
        stopping_criteria=9999999999,
        init_depth=(2, 10), verbose=VERBOSE,
        parsimony_coefficient=0.01, random_state=run_id, low_memory=True, multiple=multiple
    )
    gp_model.fit(X_train, y_train)

    model = deepcopy(str(gp_model.our_program['program']))
    for i in range(len(list_metrics) - 1, -1, -1):
        model = np.char.replace(model, f'X{i}', f'{list_metrics[i]}')
    print('+ Best model:', model)
    print('+ Depth:', gp_model.our_program['program'].depth_)
    print()

    if LOGGING:
        delattr(gp_model, '_programs')
        p.dump(gp_model, open(f'{model_path}/GP-Model_{name}.p', 'wb'))
    return gp_model


def search(all_ss, ss_dataset=(None, None), train_ratio=0.7, fitness_function=None, model_path='exp'):
    list_metrics = METRICS

    ss = all_ss[ss_dataset[0]]
    all_scores_train, all_scores_test, all_scores_full = [], [], []
    res_json = {}
    for run_id in range(1, N_RUNS + 1):
        res_json[f'{run_id}'] = {'seed': run_id}
        print('# Run ID:', run_id)
        X, y = create_data(ss, ss_dataset[1])
        X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_ratio, random_state=run_id)
        gp_model = run_evolution(X_train, y_train, list_metrics, fitness_function=fitness_function,
                                 run_id=run_id,
                                 model_path=model_path,
                                 name=f'{ss_dataset[0]}-{ss_dataset[1]}_run{run_id}')

        y_pred = gp_model.predict(X_train)
        corr = fitness_function(y_train, y_pred)
        all_scores_train.append(round(corr, 4))
        print('+ Correlation on Training Set:', round(corr, 4))
        res_json[f'{run_id}']['train'] = round(corr, 4)

        y_pred = gp_model.predict(X_test)
        corr = fitness_function(y_test, y_pred)
        print('+ Correlation on Testing Set:', round(corr, 4))
        all_scores_test.append(corr)
        res_json[f'{run_id}']['test'] = round(corr, 4)

        y_pred = gp_model.predict(X)
        corr = fitness_function(y, y_pred)
        print('+ Correlation on Full data:', round(corr, 4))
        all_scores_full.append(corr)
        res_json[f'{run_id}']['full'] = round(corr, 4)
        print('---' * 40)
        print()
    if LOGGING:
        json.dump(res_json,
                  open(f'{model_path}/{ss_dataset[0]}-{ss_dataset[1]}_results.json',
                       'w'), indent=6)
    print(f'- Train - Mean: {round(np.mean(all_scores_train), 4)}, Std: {round(np.std(all_scores_train), 4)}')
    print(f'- Test - Mean: {round(np.mean(all_scores_test), 4)}, Std: {round(np.std(all_scores_test), 4)}')
    print(f'- Full - Mean: {round(np.mean(all_scores_full), 4)}, Std: {round(np.std(all_scores_full), 4)}')


def run(kwargs):
    MEASURE = 'kendall'
    FITNESS_FUNCTION = get_fitness_function(measure=MEASURE, multiple_dataset=False)

    search_spaces = load_search_spaces(DATA_PATH)
    print('- Training ratio:', TRAIN_SIZE)
    print('- Measure:', MEASURE)

    os.makedirs(f'exp', exist_ok=True)

    ss_dataset = (kwargs.ss, kwargs.dataset)
    print('- Search Space - Dataset:', ss_dataset)
    print('----' * 40)

    search(search_spaces, ss_dataset, TRAIN_SIZE, fitness_function=FITNESS_FUNCTION, model_path=f'exp')
    print('====' * 40)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    ''' PROBLEM '''
    parser.add_argument('--ss', type=str, default='nb201', help='the search space',
                        choices=['nb101', 'nb201', 'nb301', 'transnb101_macro', 'transnb101_micro'])
    parser.add_argument('--dataset', type=str, default='cifar10')

    args = parser.parse_args()
    run(args)
