# Test the regression capacity of the various kernels

import argparse
from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
from localglobal.bo.localbo_utils import from_unit_cube, latin_hypercube, to_unit_cube, onehot2ordinal, train_gp
from utils import get_dim_info
import numpy as np
import datetime
import pickle
import torch
import matplotlib.pylab as plt
from utils import spearman, pearson, negative_log_likelihood
from localglobal.test_funcs.random_seed_config import *
import os

parser = argparse.ArgumentParser('Regression Tester')
parser.add_argument('-p', '--problem', type=str, default='pest')
parser.add_argument('--n_trials', type=int, default=1)
parser.add_argument('--n_train', type=int, default=50, help='Number of training points')
parser.add_argument('--n_test', type=int, default=100, help='Number of testing points')
parser.add_argument('-k', '--kernel', type=str, default='type2', help='Type of kernel to use')
parser.add_argument('--lamda', type=float, default=0., help='the noise variance to add. If 0, the observation of the '
                                                            'objective functions will be noiseless.')
parser.add_argument('--ard', action='store_true')
parser.add_argument('--random_seed_objective', default=20, type=int, help='Random seed for the objective function.')
parser.add_argument('--save_dir', default=None, type=str, help='Save path of the log files, if specified.')
parser.add_argument('--plot', action='store_true', help='whether to visualise the results.')
parser.add_argument('--data_dir', default='./data/', help='the data directory (for the datasets that require loading'
                                                          ' from persistent data structures)')
args = parser.parse_args()
options = vars(args)
print(options)

# Time string will be used as the directory name
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')


if args.save_dir is not None:
    save_dir = os.path.join(args.save_dir, args.problem, time_string)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    option_file = open(save_dir + "/command.txt", "w+")
    option_file.write(str(options))
    option_file.close()
else:
    save_dir = None


def warp_discrete(X, dims):
    X_ = np.copy(X)
    cat_dims = get_dim_info(dims)
    # Process the integer dimensions
    for categorical_groups in cat_dims:
        max_col = np.argmax(X[:, categorical_groups], axis=1)
        X_[:, categorical_groups] = 0
        for idx, g in enumerate(max_col):
            X_[idx, categorical_groups[g]] = 1
    return X_


def generate_random(n_rand: int, dim):
    n_onehot = int(np.sum(dim))
    x = latin_hypercube(n_rand, n_onehot)
    lb, ub = np.zeros(n_onehot), np.ones(n_onehot)
    x = from_unit_cube(x, lb, ub)
    x = warp_discrete(x, dim)
    return onehot2ordinal(x, get_dim_info(dim)).numpy()


def generate_random_mixed(n_rand: int, objective_func):
    """For objective functions with both categorical and continuous variables"""
    # cat_info = [list(range(objective_func.config[i])) for i in objective_func.categorical_dims]
    x_cat = generate_random(n_rand, objective_func.config)
    x_cont = latin_hypercube(n_rand, len(objective_func.continuous_dims))
    x_cont = from_unit_cube(x_cont, objective_func.lb, objective_func.ub)
    x = np.hstack((x_cat, x_cont))
    return x


for t in range(args.n_trials):

    if args.random_seed_objective is not None:
        assert 1 <= int(args.random_seed_objective) <= 25
        args.random_seed_objective -= 1

    if args.problem == 'contamination':
        random_seed_pair_ = generate_random_seed_pair_contamination()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Contamination(lamda=args.lamda,
                          random_seed_pair=(case_seed_, init_seed_))
    elif args.problem == 'pest':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControl(random_seed=random_seed_)
    elif args.problem == 'branin':
        f = Branin(normalize=False)
    elif args.problem == 'ising':
        random_seed_pair_ = generate_random_seed_pair_ising()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Ising(lamda=args.lamda, random_seed_pair=(case_seed_, init_seed_))
    elif args.problem == 'func2C':
        f = Func2C(lamda=args.lamda)
    elif args.problem == 'func3C':
        f = Func3C(lamda=args.lamda)
    elif args.problem == 'ackley53':
        f = Ackley53(lamda=args.lamda)
    elif args.problem == 'MaxSAT60':
        f = MaxSAT60()
    elif args.problem == 'xgboost-mnist':
        f = XGBoostOptTask(lamda=args.lamda, task='mnist')
    elif args.problem == 'svm-boston':
        f = SVMOptTask(lamda=args.lamda, task='boston')
    elif args.problem == 'nasbench101':
        # f = NASBench101(data_dir=args.data_dir)
        # pickle.dump(f, open(args.data_dir + 'nasbench101.pickle', 'wb'))
        try:
            f = pickle.load(open(args.data_dir + 'nasbench101.pickle', 'rb'))
        except:
            f = NASBench101(data_dir=args.data_dir)
    else:
        raise ValueError

    n_categories = f.config
    problem_type = f.problem_type
    if problem_type == 'mixed':
        x_train = generate_random_mixed(args.n_train, f)
        x_test = generate_random_mixed(args.n_test, f)
    else:
        x_train = generate_random(args.n_train, n_categories)
        x_test = generate_random(args.n_test, n_categories)

    y_train = np.array([f.compute(x) for x in x_train]).flatten()
    y_test = np.array([f.compute(x) for x in x_test]).flatten()

    # Initialise a GP
    x_train_torch = torch.tensor(x_train)
    y_train_torch = torch.tensor(y_train)
    x_test_torch = torch.tensor(x_test)
    if problem_type == 'mixed':
        gp = train_gp(train_x=x_train_torch, train_y=y_train_torch,
                      kern=args.kernel,
                      use_ard=args.ard, num_steps=300, hypers={},
                      int_constrained_dims=f.int_constrained_dims,
                      cat_dims=f.categorical_dims, cont_dims=f.continuous_dims)
    else:
        gp = train_gp(train_x=x_train_torch, train_y=y_train_torch,
                      kern=args.kernel,
                      use_ard=args.ard, num_steps=300, hypers={}, )

    # Use the GP to predict
    # pred = gp.likelihood(gp(x_test_torch))
    pred = gp(x_test_torch)
    y_test_pred = pred.mean.detach().numpy()
    y_test_pred_std = pred.stddev.detach().numpy()
    train_pred = gp.likelihood(gp(x_train_torch))
    y_train_pred = train_pred.mean.detach().numpy()
    y_train_std = train_pred.stddev.detach().numpy()

    # Plot
    if args.plot:
        plt.figure(figsize=(8, 4))
        plt.subplot(121)
        if f.normalize:
            y_test = y_test * f.std + f.mean
            y_test_pred = y_test_pred * f.std + f.mean
            y_test_pred_std *= f.std
        plt.plot(y_test, y_test, ".")
        plt.errorbar(y_test, y_test_pred, fmt='.', yerr=y_test_pred_std, capsize=2, color='gray', markerfacecolor='blue',
                     markersize=8, alpha=0.3)
        plt.subplot(122)
        if f.normalize:
            y_train = y_train * f.std + f.mean
            y_train_pred = y_train_pred * f.std + f.mean
            y_train_std *= f.std
        plt.plot(y_train, y_train, ".")
        plt.errorbar(y_train, y_train_pred, fmt='.', yerr=y_train_std, capsize=2, color='gray', markerfacecolor='blue',
                     markersize=8, alpha=0.3)
        plt.show()

    print('-----Train-----')
    print('Spearman coefficient: ', spearman(y_train, y_train_pred))
    print('Pearson coefficient: ', pearson(y_train, y_train_pred))
    print('Negative Log likelihood: ', negative_log_likelihood(y_train_pred, y_train_std, y_train))

    print('-----Validation-----')
    print('Spearman coefficient: ', spearman(y_test, y_test_pred))
    print('Pearson coefficient: ', pearson(y_test, y_test_pred))
    print('Negative Log likelihood: ', negative_log_likelihood(y_test_pred, y_test_pred_std, y_test))
    if save_dir is not None:
        res = {
            'y_train': y_train,
            'y_train_pred': y_train_pred,
            'y_train_pred_std': y_train_std,
            'y_test': y_test,
            'y_test_pred': y_test_pred,
            'y_test_pred_std': y_test_pred_std,
            'spearman_train': spearman(y_train, y_train_pred),
            'pearson_train': pearson(y_train, y_train_pred),
            'nll_train': negative_log_likelihood(y_train_pred, y_train_std, y_train),
            'spearman_test': spearman(y_test, y_test_pred),
            'pearson_test': pearson(y_test, y_test_pred),
            'nll_test': negative_log_likelihood(y_test_pred, y_test_pred_std, y_test),
        }
        pickle.dump(res,
                    open(os.path.join(save_dir, 'trial-%d.pickle' % t), 'wb'),
        )
