import numpy as np
import torch
import random

from models.exact_gp import ExactGPModel  # EMOC with multiple outputs
import gpytorch

from utils.transformations import transform
from utils.sampling_strategies import greedy_sampling, cohns_criterion, cross_correlation, ensure_change_in_variance, \
    mutual_information, expected_model_change_output, query_by_committee, mcmc_variance_of_means, mcmc_mean_variance, \
    mcmc_bald, mcmc_query_by_gmm


# Oracles
def oracle(args, query, search_space, oracle_labels, pool_labeled, single_point=False):
    """
    Returns the labels for the data points in pool_unlabeled[x_idx]
    """
    if args.al_type == "pool_based":
        mask = np.full(len(search_space), True, dtype=bool)
        mask[pool_labeled] = False
        candidate_points = search_space[mask]
        tmp_ss = candidate_points.copy()
    else:
        tmp_ss = search_space.copy()
        tmp_ss[pool_labeled] = -1

    # is query a single point?
    if single_point:
        query = [query]

    # Take first element in points to query:
    x_new = query[0]
    idx_tmp = tmp_ss == x_new
    if len(tmp_ss.shape) > 1:
        idx_tmp = np.sum(tmp_ss == x_new, axis=1) == len(x_new) if len(idx_tmp) > 1 else idx_tmp
    ids = np.arange(len(tmp_ss))[idx_tmp]

    # if no elements found, try N closest points
    # TODO: Fix this for initial querying, because here we cannot just take the next data point
    # since that data point is not related to the one we actually wanted to query
    if len(ids) < 1:
        print(f"The data point {x_new} is not found in the search space:\n{[print(x for x in tmp_ss)]}.")
        for i in range(len(query)-1):
            x_new = query[i + 1]
            idx_tmp = tmp_ss == x_new
            idx_tmp = np.sum(tmp_ss == x_new, axis=1) == len(x_new) if len(idx_tmp) > 1 else idx_tmp
            ids = np.arange(len(tmp_ss))[idx_tmp]
            if len(ids) > 0:
                break
        print(f"Using {x_new} instead.")
    assert len(ids) >= 1, "ERROR: No more points to query within assumed range.. See oracle()"

    # choose a random id to mimic a random simulation
    idx = ids[random.randint(0, len(ids) - 1)]
    # idx = np.random.choice(ids, size=repeat_sampling, replace=False).tolist()

    if args.al_type == "pool_based":
        # change the indices of the new points in pool_labeled to match those of the original search_space
        idx = np.where(mask == True)[0][idx]
        #pool_labeled[-1] = np.where(mask == True)[0][pool_labeled[-n_new_points:]]
    pool_labeled.append(idx)
    y_new = oracle_labels[idx]

    return y_new, pool_labeled


def get_query(args, new_points, data, k_samples=1, beta_sampling=0, repeat_sampling=1, seed=0):
    """
    Get the labels to the query from the oracle

    :param args: arguments
    :param new_points: np.array of possible new points (starts querying from left)
    :param search_space: search space (i.e. x)
    :param oracle_labels: oracle labels corresponding to search space (i.e. y)
    :param pool_labeled: list of indices of inputs from the search space, which have already been used (queried)
    :param k_samples: no. of distinct input values
    :param beta_sampling: the distance between distinct input values [0,1) (relative distance, e.g 10% of the range)
    :param repeat_sampling: no. of simulations for each input value
    :param seed: seed
    """

    search_space = data.search_space
    oracle_labels = data.oracle_labels
    pool_labeled = data.pool_labeled

    # Get label from oracle
    if seed:
        random.seed(seed)  # for the data sets
        np.random.seed(seed)  # for the simulator

    if k_samples == 1 and repeat_sampling == 1:
        if oracle_labels is None:
            pool_labeled = new_points[0, :].reshape(-1, new_points.shape[1])
            new_y = data.oracle.query(pool_labeled)
        else:
            new_y, pool_labeled = oracle(args, new_points, search_space, oracle_labels, pool_labeled)
            new_y = [new_y]
    else:
        new_y, pool_labeled, new_points = batch_sampling(args, data, new_points, search_space, oracle_labels, pool_labeled,
                                                         k_samples=k_samples,
                                                         beta_sampling=beta_sampling,
                                                         repeat_sampling=repeat_sampling)

    if pool_labeled is None:
        # Dirty hack to get the order of the new points...
        pool_labeled = new_points
    else:
        pool_labeled = [pool_labeled] if isinstance(pool_labeled, np.int64) else pool_labeled
        pool_labeled = [pool_labeled] if isinstance(pool_labeled, np.float64) else pool_labeled

    # Update data
    data.pool_labeled = pool_labeled
    data = add_queried_datapoints(args, new_y, data, k_samples, repeat_sampling)
    return data


def add_queried_datapoints(args, new_y, data, k_samples, repeat_sampling):
    # Add point and label to train_x and train_y
    if data.oracle_labels is None:
        # Stupid if-statement due to missing batch size on 1d problems
        if len(data.train.x.shape) > 1 and k_samples == 1 and repeat_sampling == 1:
            new_x = np.array(data.pool_labeled).reshape(1, -1)
        else:
            new_x = np.repeat(data.pool_labeled[:k_samples], repeat_sampling, axis=0)
        data.pool_labeled = None
    else:
        new_x = data.search_space[data.pool_labeled[-(k_samples * repeat_sampling):]]
    data.train.x = torch.cat((data.train.x, torch.FloatTensor(new_x)))
    data.train.y = torch.cat((data.train.y, torch.FloatTensor(new_y)))
    return data


def batch_sampling(args, data, new_points, search_space, oracle_labels, pool_labeled,
                   k_samples=1, beta_sampling=0, repeat_sampling=1):
    """
    Sample multiple points at once, i.e. batch sampling

    :param args: arguments
    :param new_points: np.array of possible new points (starts querying from left)
    :param search_space: search space (i.e. x)
    :param oracle_labels: oracle labels corresponding to search space (i.e. y)
    :param pool_labeled: list of indices of inputs from the search space, which have already been used (queried)
    :param k_samples: no. of distinct input values
    :param beta_sampling: the distance between distinct input values
    :param repeat_sampling: no. of simulation for each input value
    """
    # Require beta distance in batch sampling?  NB: This could be vectorized
    if k_samples > 1 and beta_sampling > 0:
        new_point_beta_dist = [new_points[0]]
        min_dist = beta_sampling * np.max(np.unique(search_space, axis=0)) - np.min(np.unique(search_space, axis=0))
        for _ in range(1, k_samples):
            # Loop through possible data points to query
            for p in new_points:
                counter = 0
                # Calculate the distance between the new possible point p and the just acquired points
                for p_new in new_point_beta_dist:
                    if np.min(np.abs(p - p_new)) >= min_dist:
                        counter += 1
                # If the new possible point p is more than beta away from all just acquired points, then add it
                if counter == len(new_point_beta_dist):
                    new_point_beta_dist.append(p)
                    break
        new_points = np.concatenate((np.array(new_point_beta_dist), new_points))

    new_y = []
    for k_sample in range(k_samples):
        tmp_y, pool_labeled = do_repeat_sampling(args, data, repeat_sampling, new_points[k_sample:],
                                                 search_space, oracle_labels, pool_labeled)
        new_y.extend(tmp_y)
        #new_y.append(tmp_y) #.view(repeat_sampling, -1))
        #print("tmp", tmp_y)
        #print("new", new_y)

    #new_y = np.array(new_y)
    #new_y = np.array(new_y).reshape(k_samples*repeat_sampling, -1)
    return new_y, pool_labeled, new_points


def do_repeat_sampling(args, data, repeat_sampling, new_points, search_space, oracle_labels, pool_labeled):
    """
    Get multiple samples with the same input value

    :param args: arguments
    :param repeat_sampling: the number of simulations to do at the input value
    :param new_points: np.array of possible new points (starts querying from left)
    :param search_space: search space (i.e. x)
    :param oracle_labels: oracle labels corresponding to search space (i.e. y)
    :param pool_labeled: list of indices of inputs from the search space, which have already been used (queried)
    """
    new_y = []
    for _ in range(repeat_sampling):
        if oracle_labels is None:
            # Using a simulator
            tmp_y = data.oracle.query(new_points[0].reshape(-1, new_points.shape[1]))
            new_y = np.concatenate((np.array(new_y), tmp_y))
        else:
            # Using a data set
            tmp_y, pool_labeled = oracle(args, new_points, search_space, oracle_labels, pool_labeled)
            new_y.extend([tmp_y])

    return new_y, pool_labeled


def index_descending(tensor):
    """
    Returns the order of points to query next
    """
    dims = len(tensor.shape)

    if dims == 1:
        ids = torch.argsort(tensor, descending=True)
    else:
        ids = torch.argsort(torch.mean(tensor, dim=1), descending=True)

    return ids


def get_sorted_unlabeled_data_points(args, selection_array, search_space):

    if args.al_type in ['psuedo_population_based', 'population_based']:
        ss_unique = np.unique(search_space, axis=0)
    else:
        ss_unique = search_space

    if args.outputs == 1:
        # Get the sequence of possible new data points
        new_points = ss_unique[index_descending(selection_array)]

        if args.simulator == "motorcycle":
            new_points = np.array([np.where(search_space == tmp_x)[0][0] for tmp_x in new_points]).reshape(-1, 1)
    else:
        if args.selection_criteria in ["variance", "random", "emoc", "cohns", "sequential",
                                       "sequential_relevant_variance", "parallel_relevant_variance"]:
            # One points for each task
            new_points = torch.empty((selection_array.shape[0], selection_array.shape[1], search_space.shape[1]))
            for task_idx in range(selection_array.shape[1]):
                new_points_idx = torch.argsort(selection_array[:, task_idx], descending=True)
                new_points[:, task_idx, :] = torch.tensor(search_space[new_points_idx])
            new_points = torch.transpose(new_points, 0, 1).reshape(-1, 3).numpy()  # Baseline
            # Avoid querying the same data points
            #unique_points, idx_of_unique_points = np.unique(new_points, return_index=True, axis=0)
            #new_points = unique_points[idx_of_unique_points.argsort(), :]
        elif args.selection_criteria in ['iGSr', 'iGS', 'iGSvar', 'GSx']:
            new_points = ss_unique[index_descending(selection_array)]
        else:
            raise NotImplementedError

    return new_points


def compute_sample_strategy(args, model, search_space, train_x, train_y, predictions, mu_x, sigma_x,
                     variance, min_change_in_var, iteration):
    """
    Applies a sampling strategy/criteria and returns a tensor with a value for each point in the search space.
    All strategies are defined such that next point to query has the highest value, e.g. argmax(selection_array).

    :param args: arguments
    :param model: gpytorch model
    :param search_space: the unlabeled data set of possible data points to query
    :param train_x: labeled transformed input features
    :param train_y: labeled transformed output features
    :param predictions: dictionary with predictions, means and stddevs
    :param mu_x: mean value of non-transformed input features
    :param sigma_x: standard deviation of non-transformed input features
    :param variance: tensor with previous estimates of variance in last active learning iteration
    :param min_change_in_var: a list of parameters for the minimum change in variance (book keeping)
    :returns: selection_array, variance, min_change_in_var
    """

    """
    # This only works, if candidate points = search_space
    # What happens then args.replicates is bigger than 1?
    if i > 0 and args.min_change_in_var > 0:
        mask = np.zeros_like(data.search_space)
        mask[data.pool_labeled[:-args.k_samples]] = 1  # Previously removed
        mask[data.pool_labeled[-args.k_samples:]] = -1  # Have not been removed yet
        mask_ones, _ = np.where(mask != 1)
        mask_ones = np.unique(mask_ones)
        mask = mask[mask_ones]  # remove 1s from mask
        mask_minus_one, _ = np.where(mask != -1)  # find -1s
        mask_minus_one = np.unique(mask_minus_one)
        var = var[mask_minus_one]
    """

    output = {}

    selection_criteria = args.selection_criteria

    if selection_criteria == "topk5":
        # Only consider top-5 points for each output (yes, there might be replicates!)
        # This should then be used with a non-model-based criteria, e.g., random or GSx
        _, topk_idx = torch.topk(predictions['stddev'], k=10, dim=0)
        search_space = search_space[topk_idx.view(-1)]
        selection_criteria = "random"

    if args.model_type in ['fbgp_mcmc']:
        model_fbgp = model
        model = model.pred_model

    pred_mean = predictions['mean']  # predicted mean for each sample in the search space
    pred_std = predictions['stddev']  # predicted standard deviation for each sample in the search space

    # Get the unique values to query and transform them according to the input transformation
    if args.al_type == 'pool_based' and selection_criteria in ['iGSr', 'iGS', 'iGSvar']:
        ss_unique = search_space
    else:
        ss_unique = np.unique(search_space, axis=0)
    ss_unique_trans, _, _ = transform(torch.Tensor(ss_unique), mu_x, sigma_x, method=args.transformation_x)

    selection_array, last_variance, negative_rho = None, None, None
    if selection_criteria in ['variance', 'combi', "sequential", "sequential_relevant_variance", "parallel_relevant_variance"]:
        last_variance = variance
        variance = pred_std**2

    # Choose sampling strategy
    if selection_criteria in ['variance', "sequential", "sequential_relevant_variance", "parallel_relevant_variance"]:
        selection_array = variance
    elif selection_criteria == 'cross_corr':
        negative_rho = -cross_correlation(train_x, ss_unique_trans, normalise=False)
        selection_array = negative_rho
    elif selection_criteria == 'combi':
        negative_rho = -cross_correlation(train_x, ss_unique_trans, normalise=False)
        selection_array = variance + negative_rho
    elif selection_criteria == "random":
        selection_array = torch.rand([ss_unique.shape[0], args.outputs])
    elif selection_criteria == "GSx":
        selection_array = greedy_sampling(args, labeled_x=train_x, unlabeled_x=ss_unique_trans,
                                          labeled_y=None, unlabeled_pred=None, criterion='GSx')
    elif selection_criteria == "GSy":
        selection_array = greedy_sampling(args, labeled_x=None, unlabeled_x=None,
                                          labeled_y=train_y, unlabeled_pred=pred_mean, criterion='GSy')
    elif selection_criteria == "iGS":
        selection_array = greedy_sampling(args, labeled_x=train_x, unlabeled_x=ss_unique_trans,
                                          labeled_y=train_y, unlabeled_pred=pred_mean, criterion='iGS')
    elif selection_criteria == "iGSvar":
        selection_array = greedy_sampling(args, labeled_x=train_x, unlabeled_x=ss_unique_trans,
                                          labeled_y=train_y, unlabeled_pred=pred_mean, criterion='iGSvar',
                                          variance=pred_std*2)
    elif selection_criteria == "iGSr":
        selection_array = greedy_sampling(args, labeled_x=train_x, unlabeled_x=ss_unique_trans,
                                          labeled_y=train_y, unlabeled_pred=pred_mean, criterion='iGSr')
    elif selection_criteria == "cohns":
        if args.outputs == 1:
            selection_array = cohns_criterion(model, train_x, ss_unique_trans, xis=ss_unique_trans)
        else:
            # Works for indep_exact
            selection_array = []
            for i in range(args.outputs):
                # First fit a model
                tmp_likelihood = gpytorch.likelihoods.GaussianLikelihood(
                    noise_constraint=gpytorch.constraints.GreaterThan(1e-4),
                    noise_prior=gpytorch.priors.NormalPrior(0, 1)
                )
                tmp_kernel = gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1])
                tmp_model = ExactGPModel(train_x, train_y[:, i], tmp_kernel, tmp_likelihood)
                tmp_model.likelihood.noise_covar.noise = model.likelihood.task_noises[i].detach()
                tmp_model.mean_module.constant = torch.nn.Parameter(model.mean_module.constant[i, 0].detach())
                tmp_model.covar_module.outputscale = model.covar_module.outputscale[i].detach()
                tmp_model.covar_module.base_kernel.lengthscale = model.covar_module.base_kernel.lengthscale[i, 0, :].detach()
                tmp_selection_array = cohns_criterion(tmp_model, train_x, ss_unique_trans, xis=ss_unique_trans)
                selection_array.append(tmp_selection_array)
            selection_array = torch.cat(selection_array).view(-1, 6)
    elif selection_criteria == "mi":
        selection_array = mutual_information(model, train_x, ss_unique_trans)
    elif selection_criteria == "emoc":
        if args.outputs == 1:
            selection_array = expected_model_change_output(model, train_x, ss_unique_trans, p=1)
        else:
            # Works for indep_exact
            selection_array = []
            for i in range(args.outputs):
                # First fit a model
                tmp_likelihood = gpytorch.likelihoods.GaussianLikelihood(
                    noise_constraint=gpytorch.constraints.GreaterThan(1e-4),
                    noise_prior=gpytorch.priors.NormalPrior(0, 1)
                )
                tmp_kernel = gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1])
                tmp_model = ExactGPModel(train_x, train_y[:, i], tmp_kernel, tmp_likelihood)
                tmp_model.likelihood.noise_covar.noise = model.likelihood.task_noises[i].detach()
                tmp_model.mean_module.constant = torch.nn.Parameter(model.mean_module.constant[i, 0].detach())
                tmp_model.covar_module.outputscale = model.covar_module.outputscale[i].detach()
                tmp_model.covar_module.base_kernel.lengthscale = model.covar_module.base_kernel.lengthscale[i, 0, :].detach()
                tmp_selection_array = expected_model_change_output(tmp_model, train_x, ss_unique_trans, p=1)
                selection_array.append(tmp_selection_array)
            selection_array = torch.cat(selection_array).view(-1, 6)
    elif selection_criteria == "qbc":
        selection_array, ensemble_pred = query_by_committee(args, model, ss_unique_trans, predictions, train_x, train_y)
        output['ensemble_pred'] = ensemble_pred
    elif selection_criteria == "qbc_emoc":
        selection_array, ensemble_pred = query_by_committee(args, model, ss_unique_trans, predictions, train_x, train_y,
                                                            sampling_crit="emoc")
        output['ensemble_pred'] = ensemble_pred
    elif selection_criteria == "mcmc_qbc":
        if not args.model_type in ['fbgp_mcmc']:
            raise NotImplementedError(f"Trying to use mcmc_mean with {args.model_type}.")
        selection_array, batch_model_output = mcmc_variance_of_means(args, model_fbgp, ss_unique_trans)
        output['batch_model_output'] = batch_model_output
    elif selection_criteria == "mcmc_mean_variance":
        if not args.model_type in ['fbgp_mcmc']:
            raise NotImplementedError(f"Trying to use mcmc_mean with {args.model_type}.")
        selection_array, batch_model_output = mcmc_mean_variance(args, model_fbgp, ss_unique_trans)
        output['batch_model_output'] = batch_model_output
    elif selection_criteria == "mcmc_gmm":
        if not args.model_type in ['fbgp_mcmc']:
            raise NotImplementedError(f"Trying to use mcmc_mean with {args.model_type}.")
        selection_array, batch_model_output = mcmc_query_by_gmm(args, model_fbgp, ss_unique_trans)
        output['batch_model_output'] = batch_model_output
    elif selection_criteria == "mcmc_bald":
        if not args.model_type in ['fbgp_mcmc']:
            raise NotImplementedError(f"Trying to use mcmc_mean with {args.model_type}.")
        selection_array, batch_model_output = mcmc_bald(args, model_fbgp, ss_unique_trans)
        output['batch_model_output'] = batch_model_output
    else:
        raise NotImplementedError(f"The sampling strategy {selection_criteria} is not implemented. Change the "
                                  f"sampling strategy with args.selection_criteria.")

    # Minimum change in variance? Right now, this is only used for 'variance' and 'combi'
    if selection_criteria == 'variance' or selection_criteria == 'combi':
        if args.min_change_in_var > 0:
            change_in_var = torch.abs(last_variance - variance) / last_variance
            ids, tmp_min_change_in_var = ensure_change_in_variance(change_in_var, min_change_in_var[-1])
            min_change_in_var.append(tmp_min_change_in_var)
            selection_array[ids] = 0

    if selection_criteria in ["sequential", "sequential_relevant_variance", "parallel_relevant_variance"]:
        new_points = compute_sample_strategy_multi_output(args, selection_array, model, search_space, predictions,
                                                          iteration)
    else:
        new_points = get_sorted_unlabeled_data_points(args, selection_array, search_space)

    output['selection_array'] = selection_array
    output['variance'] = variance
    output['min_change_in_var'] = min_change_in_var
    output['new_points'] = new_points

    return output


def compute_sample_strategy_multi_output(args, selection_array, model, search_space, predictions, iteration):
    """
    Active learning with multiple outputs
    NB: hard coded to six outputs
    """
    if args.selection_criteria == "sequential":
        # Sequential with(out) unique
        selection_array = selection_array[:, iteration % 6]  # loop over the outputs
        args.outputs = 1
        new_points = get_sorted_unlabeled_data_points(args, selection_array, search_space)
        #unique_points, idx_of_unique_points = np.unique(new_points, return_index=True, axis=0)
        #new_points = unique_points[idx_of_unique_points.argsort(), :]
        args.outputs = 6
    elif args.selection_criteria == "sequential_relevant_variance":
        # Choose a single output based on "relevant variance"
        rel_var = torch.mean(predictions['stddev'], dim=0) * torch.std(predictions['stddev'], dim=0)
        rel_var = rel_var / torch.sum(rel_var)
        chosen_outputs = random.choices(np.arange(args.outputs), weights=rel_var, k=1)
        #lst_chosen_outputs.append(chosen_outputs)
        selection_array = selection_array[:, chosen_outputs]  # loop over the outputs
        args.outputs = 1
        new_points = get_sorted_unlabeled_data_points(args, selection_array, search_space)
        unique_points, idx_of_unique_points = np.unique(new_points, return_index=True, axis=0)
        new_points = unique_points[idx_of_unique_points.argsort(), :]
        args.outputs = 6
    elif args.selection_criteria == "parallel_relevant_variance":
        # Choose outputs based on "relevant variance"
        #rel_var = torch.mean(predict_output['stddev'], dim=0) * torch.std(predict_output['stddev'], dim=0)
        #rel_var = torch.mean(predict_output['stddev'], dim=0)
        #rel_var = 1 / model.likelihood.task_noises
        rel_var = torch.mean(predictions['stddev'], dim=0) / model.likelihood.task_noises
        #_, rel_var = torch.max(predict_output['stddev'], dim=0)
        #rel_var = torch.mean(predict_output['stddev'], dim=0) / torch.std(predict_output['stddev'], dim=0)
        rel_var = rel_var / torch.sum(rel_var)
        chosen_outputs = random.choices(np.arange(args.outputs), weights=rel_var, k=6)
        #lst_chosen_outputs.append(chosen_outputs)
        selection_array = selection_array[:, chosen_outputs]  # loop over the outputs
        new_points = get_sorted_unlabeled_data_points(args, selection_array, search_space)
        unique_points, idx_of_unique_points = np.unique(new_points, return_index=True, axis=0)
        new_points = unique_points[idx_of_unique_points.argsort(), :]

    return new_points
