import matplotlib.pyplot as plt

import numpy as np
import seaborn as sns
import gpytorch
import torch
import jsonargparse
import pickle
import time
import pandas as pd
import random


from gpytorch.kernels import LinearKernel, RBFKernel
from utils.data_handler import MyDataLoader, sample_initial_dataset
from simulators.simulators2 import oracle_simulator
from utils.active_learning import get_query, compute_sample_strategy
from utils.plotting import do_all_my_plots
from utils.transformations import transform
from utils.gp_utils import get_model, get_likelihood, get_loss, compute_loss

sns.set_style("darkgrid")

import os
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

# Set directory name (path)
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

if __name__ == '__main__':
    parser = jsonargparse.ArgumentParser(default_config_files=[dname + '/config.yaml'])
    parser.add_argument('--cfg', action=jsonargparse.ActionConfigFile)
    parser.add_argument('--path_train_data', type=str)
    parser.add_argument('--path_test_data', type=str)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--model_type', type=str)
    parser.add_argument('--set_prior', type=bool)
    parser.add_argument('--outputs', type=int)
    parser.add_argument('--inducing_points', type=int)
    parser.add_argument('--hidden_dim', type=int)
    parser.add_argument('--n_samples', type=int)
    parser.add_argument('--initial_samples', type=int)
    parser.add_argument('--space_filling_design', type=str)
    parser.add_argument('--test_samples', type=int)
    parser.add_argument('--transformation_x', type=str)
    parser.add_argument('--transformation_y', type=str)
    parser.add_argument('--selection_criteria', type=str)
    parser.add_argument('--min_change_in_var', type=float)
    parser.add_argument('--plot', type=bool)
    parser.add_argument('--active_learning_steps', type=int)
    parser.add_argument('--k_samples', type=int)
    parser.add_argument('--beta_sampling', type=float)
    parser.add_argument('--repeat_sampling', type=int)
    parser.add_argument('--milestones')
    parser.add_argument('--initial_lr', type=float)
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--n_runs', type=int)
    parser.add_argument('--output_file', type=str)
    parser.add_argument('--simulator', type=str)
    parser.add_argument('--al_type', type=str)
    parser.add_argument('--dataset', type=str)
    parser.add_argument('num_chains', type=int)
    parser.add_argument('num_samples', type=int)
    parser.add_argument('warmup_steps', type=int)
    parser.add_argument('predict_mcmc', type=str)
    args = parser.parse_args()

    args.path_train_data = dname + args.path_train_data
    args.path_test_data = dname + args.path_test_data

    if args.simulator == 'None':
        args.simulator = None
        if args.al_type == "population_based":
            raise RuntimeError("You cannot run the population based case without a simulator."
                               "\nEither specify a simulator with args.simulator or change the active learning"
                               " type with args.al_type to be pool_based or pseudo_population_based.")

    #if (args.simulator is None) and (args.space_filling_design != 'random'):
    #    raise RuntimeError("The space filling designs only works with simulators.. Change it to be 'random'.")

    if args.al_type == 'population_based':
        # MODE: Population-based (data comes from a simulator)
        # Setup up search space, oracle labels, train_x, train_y, test_x, test_y, pool_labeled
        # The search space is just the possible input values
        pass
    else:
        # Two kinds of pool-based settings:
        # 1) Pseudo population-based:
        # -- Data comes from a data set but will be used as was it coming from a simulator
        # -- Here we "simulate" a simulator, or what?
        # -- For the Mercury data set, we will among all possible data points,
        # -- and then just take the closest one. So this not really pool-based
        # 2) Pool-based
        # -- Traditional pool-based. Thus, we need to remove the data points from the search space s.t the
        # -- search space only consists of the unlabeled data
        pass

    if args.simulator is not None:
        oracle = oracle_simulator(args)
    else:
        oracle = None
    data = MyDataLoader(args=args, oracle=oracle)
    data.get_initial_data()
    if args.al_type == 'population_based':
        data.compute_true_mean_and_stddev()
    toy_simulator = True


    # Remove any seed
    np.random.seed()

    # Kernel for GP
    #kernel = RBFKernel()
    kernel = RBFKernel(ard_num_dims=data.train.x.shape[1])

    # Active learning scheme
    nmll_losses = []
    nmll_losses_valid = []
    rmse_losses_valid, rrse_losses_valid, mae_losses_valid = [], [], []
    rmse_std_losses_valid = []
    variance_test, variance_test_std = [], []
    # Variance
    var = 0.001
    # Threshold of variance (minimum change in variance)
    var_t = [args.min_change_in_var]
    tic = time.time()
    noises, lengthscales = [], []
    #outputscales = []
    no_more_data = False
    print(f"Starting active learning for: {args.output_file}")
    for i in range(args.active_learning_steps):
        if i % 10:
            print("Iteration:", i, "\t", time.time() - tic)
        if no_more_data:
            break

        # Prepare data
        data.transform()
        data.make_dataloader()

        # Setup model
        model = get_model(args, data, kernel, likelihood=get_likelihood(args))

        # Fit model
        nmll_loss, fit_losses, outs = model.fit(data.gp_train_loader, args)
        nmll_losses.append(nmll_loss)
        #for param_name, param in model.named_parameters():
        #    print(f'Parameter name: {param_name:42} value = {param}')
        #noises.append(fit_output['noises'][-1])
        #lengthscales.append(fit_output['lengthscales'][-1])
        #outputscales.append(fit_output['outputscales'][-1])

        # Make predictions on the test set
        # Validate with: marginal log likelihood (MLL)
        predict_output = model.predict(data.gp_test_loader)
        predictions_mll = predict_output['predictions']
        #model.pred_model.covar_module.base_kernel.lengthscale.shape
        #model.pred_model.covar_module.base_kernel.lengthscale.shape
        #model.pred_model.likelihood.noise_covar.noise.shape

        if args.model_type in ['gp_prior', 'exact']:
            #lengthscales.append(model.covar_module.base_kernel.lengthscale.tolist()[0])
            lengthscales.append(model.covar_module.lengthscale.tolist()[0])
            noises.append(model.likelihood.noise_covar.noise.item())
        else:
            #lengthscales.append(model.pred_model.covar_module.base_kernel.lengthscale.tolist()[0])
            lengthscales.append(model.pred_model.covar_module.lengthscale.tolist()[0])
            noises.append(model.pred_model.likelihood.noise_covar.noise.item())

        #print(noises[-1]) #, lengthscales[-1])

        # Maximum likelihood is scale invariant, so we don't have to transform it!
        #stddev_ori = predictions_mll.stddev * data.y_sigma
        #R_trans = predictions_mll.covariance_matrix
        #R_ori = torch.mm(torch.diag(stddev_ori), torch.mm(R_trans, torch.diag(stddev_ori)))
        #mean_ori = predictions_mll.mean * data.y_sigma + data.y_mu
        #pred_ori = gpytorch.distributions.MultivariateNormal(mean=mean_ori, covariance_matrix=R_trans)
        #losses_valid = compute_loss(args, data.gp_test_loader_ori, predictions_mll, lst_metrics=['mll'],
        #                      mll=get_loss(args, model, num_data=data.train_trans.y.numel()))
        losses_valid = compute_loss(args, data.gp_test_loader, predictions_mll, lst_metrics=['mll'],
                              mll=get_loss(args, model, num_data=data.train_trans.y.numel()))
        nmll_losses_valid.append(losses_valid['nmll'])

        # Points of interest for querying
        # If the discretization of the population-based data consists of two many points, we take a random subsample.
        # In the pool-based setting, we should exclude all the labeled data points from the search space
        # this is also applied in the oracle itself
        candidate_points = data.get_candidate_points()

        # Make predictions for validation, querying and plotting
        if args.al_type == "population_based":
            # Make predictions on the search space
            # Also predictions for PLOTTING and for QUERYING the next label
            # Validation: RMSE(mean), RMSE(std)
            # Get the unique values to query and transform them according to the input transformation
            ss_unique = np.unique(data.candidate_points, axis=0)
            pred_x_trans, _, _ = transform(torch.Tensor(ss_unique), data.x_mu, data.x_sigma, method=args.transformation_x)
            predict_output = model.predict(dataloader=(pred_x_trans, None))
            predict_output_querying = predict_output
            #if not data.candidate_points.shape == ss_unique.shape:
            #    cp_trans, _, _ = transform(torch.Tensor(data.candidate_points), data.x_mu, data.x_sigma,
            #                                   method=args.transformation_x)
            #    predict_output_querying = model.predict(dataloader=(cp_trans, None))

            # Transform pred_x, mean and standard deviation back
            pred_x = transform(pred_x_trans, data.x_mu, data.x_sigma, method=args.transformation_x, inverse=True)
            pred_mean = transform(predict_output['mean'], data.y_mu, data.y_sigma, method=args.transformation_y, inverse=True)
            pred_std = transform(predict_output['stddev'], 0, data.y_sigma, method=args.transformation_y, inverse=True)
            # Compute loses
            if toy_simulator:
                losses = compute_loss(args, dataloader=(None, data.true_mean_cp), predictions=pred_mean, lst_metrics=['rmse'])
                rmse_losses_valid.append(losses['rmse'])
                losses = compute_loss(args, dataloader=(None, data.true_std_cp), predictions=pred_std, lst_metrics=['rmse'])
                rmse_std_losses_valid.append(losses['rmse'])
            else:
                raise NotImplementedError("Evaluation for a 'true' simulator has not been implemented.")
        else:
            # Validate: RMSE
            #pred_mean_trans = predictions_mll.mean.mean(0) if args.model_type in ['deepgp'] else predictions_mll.mean
            pred_mean_trans = predict_output['mean']
            test_pred_mean = transform(pred_mean_trans, data.y_mu, data.y_sigma, method=args.transformation_y, inverse=True)
            losses = compute_loss(args, dataloader=(None, data.test.y), predictions=test_pred_mean,
                                  lst_metrics=['rmse', 'rrse', 'mae'])
            rmse_losses_valid.append(losses['rmse'].numpy())
            rrse_losses_valid.append(losses['rrse'].numpy())
            mae_losses_valid.append(losses['mae'].numpy())

            # Make predictions on the search space
            # Predictions for PLOTTING and for QUERYING the next label
            # Get the unique values to query and transform them according to the input transformation
            ss_unique = np.unique(data.search_space, axis=0)
            pred_x_trans, _, _ = transform(torch.Tensor(ss_unique), data.x_mu, data.x_sigma, method=args.transformation_x)
            predict_output = model.predict(dataloader=(pred_x_trans, None))

            # Transform pred_x, mean and standard deviation back
            pred_x = transform(pred_x_trans, data.x_mu, data.x_sigma, method=args.transformation_x, inverse=True)
            pred_mean = transform(predict_output['mean'], data.y_mu, data.y_sigma, method=args.transformation_y, inverse=True)
            pred_std = transform(predict_output['stddev'], 0, data.y_sigma, method=args.transformation_y, inverse=True)

            # predictions for querying in pool-based
            # PLOTTING and QUERYING predictions must only be the same if we are using a pseudo simulator
            # If we are in the pool-based domain, we should split the two such that we only get predictions for the
            # query set.
            x_querying = pred_x
            predict_output_querying = predict_output
            if args.al_type == "pool_based":
                x_querying, _, _ = transform(torch.Tensor(data.candidate_points), data.x_mu, data.x_sigma, method=args.transformation_x)
                predict_output_querying = model.predict(dataloader=(x_querying, None))

        # Query the next point
        if len(data.candidate_points) == 0:
            print(f"No more data points to query. Using all {data.train.x.shape[0]} data points.")
            no_more_data = True
        elif args.selection_criteria == 'lhs':
            data.train.x, data.train.y, _, _, _ = sample_initial_dataset(args, data.search_space, data.oracle_labels,
                                                                         path_test_data=args.path_test_data,
                                                                         initial_samples=data.train.x.shape[0]+args.k_samples*i)
            selection_array = 0
        else:
            # Compute sampling strategy (evaluate acqusition function)
            sample_strategy_output = compute_sample_strategy(args, model, data.candidate_points,
                                                             train_x=data.train_trans.x, train_y=data.train_trans.y,
                                                             predictions=predict_output_querying,
                                                             mu_x=data.x_mu, sigma_x=data.x_sigma,
                                                             variance=var, min_change_in_var=var_t,
                                                             iteration=i)
            var = sample_strategy_output['variance']
            var_t = sample_strategy_output['min_change_in_var']
            selection_array = sample_strategy_output['selection_array']
            new_points = sample_strategy_output['new_points']

            # Get query and add it to data
            data = get_query(args, new_points, data,
                             k_samples=args.k_samples,
                             beta_sampling=args.beta_sampling,
                             repeat_sampling=args.repeat_sampling,
                             seed=False)  # No seed!!! The dummy simulators will have no variation...

        # Plot results
        if args.plot:
            p_dct = {}
            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                p_dct['train_x'] = data.train.x
                p_dct['train_y'] = data.train.y
                p_dct['test_x'] = data.test.x
                p_dct['test_y'] = data.test.y
                p_dct['pred_x'] = pred_x
                p_dct['mean'] = pred_mean
                p_dct['pred_std'] = pred_std
                p_dct['lower'] = p_dct['mean'] - 2 * p_dct['pred_std']
                p_dct['upper'] = p_dct['mean'] + 2 * p_dct['pred_std']
                p_dct['variance'] = p_dct['pred_std'] ** 2
                p_dct['nmll_losses'] = nmll_losses
                p_dct['nmll_losses_valid'] = nmll_losses_valid
                p_dct['rmse_losses_valid'] = rmse_losses_valid
                p_dct['rrse_losses_valid'] = rrse_losses_valid
                p_dct['mae_losses_valid'] = mae_losses_valid
                p_dct['fit_losses'] = fit_losses
                p_dct['selection_array'] = selection_array
                p_dct['sample_strategy_output'] = sample_strategy_output
                p_dct['rmse_std_losses_valid'] = rmse_std_losses_valid

            do_all_my_plots(args, p_dct, data, title=f'output/test/{args.simulator}_{args.model_type}_{i}.pdf')

        # Get model hyperparameters --> use them as priors in next iteration
        if args.set_prior:
            opt_hypers = model.state_dict()

        # print("%d/%d: Training NMLL: %.3f,    Valid NMLL: %.3f,    Valid RMSE: %.3f" %
        #      (i + 1, args.active_learning_steps, nmll_losses[-1], nmll_losses_valid[-1], rmse_losses_valid[-1]))
        #torch.save(model.state_dict(), f'best_state_dict_{args.model_type}_{args.k_samples}_iter{i}.pth')

    # Save things to file
    p_dct = {}
    p_dct['data'] = data
    p_dct['train_x'] = data.train.x
    p_dct['train_y'] = data.train.y
    p_dct['test_x'] = data.test.x
    p_dct['test_y'] = data.test.y
    p_dct['pred_x'] = pred_x
    p_dct['mean'] = pred_mean
    p_dct['pred_std'] = pred_std
    p_dct['lower'] = p_dct['mean'] - 2 * p_dct['pred_std']
    p_dct['upper'] = p_dct['mean'] + 2 * p_dct['pred_std']
    p_dct['variance'] = p_dct['pred_std']**2
    p_dct['variance'] = p_dct['variance'] / torch.sum(p_dct['variance'], dim=0)
    p_dct['nmll_losses'] = nmll_losses
    p_dct['nmll_losses_valid'] = nmll_losses_valid
    p_dct['rmse_losses_valid'] = rmse_losses_valid
    p_dct['rrse_losses_valid'] = rrse_losses_valid
    p_dct['mae_losses_valid'] = mae_losses_valid
    p_dct['rmse_std_losses_valid'] = rmse_std_losses_valid
    p_dct['new_points'] = args.k_samples * args.repeat_sampling
    p_dct['running_time'] = time.time() - tic
    p_dct['selection_array'] = selection_array
    if args.model_type in ['ridge_reg', 'xgboost']:
        p_dct['model_state_dict'] = model
    elif args.model_type in ['fbgp_mcmc']:
        p_dct['mcmc_samples'] = model.mcmc_samples
        p_dct['model_state_dict'] = None  # TODO: Find a way to save the model
    else:
        p_dct['model_state_dict'] = [m.state_dict() for m in model] if isinstance(model, list) else model.state_dict()
    p_dct['args'] = args
    p_dct['kernel'] = kernel
    p_dct['fit_losses'] = fit_losses
    p_dct['min_change_in_var'] = var_t
    p_dct['noises'] = noises
    p_dct['lengthscales'] = lengthscales

    # Convert from tensor to list
    for k in p_dct.keys():
        if torch.is_tensor(p_dct[k]):
            p_dct[k] = p_dct[k].tolist()

    with open(args.output_file, 'wb') as fp:
        fp.write(pickle.dumps(p_dct))
    print(f"Wrote results to {args.output_file}. The experiment took {p_dct['running_time']} sec.")
