import numpy as np
import torch
import gpytorch
import math
from tqdm import tqdm

from scipy.special import hyp1f1

from utils.transformations import min_max_feature_scaling
from models.ensemble_gp import EnsembleGP  # for QBC


def cross_correlation(data_points, search_space, normalise=False):
    # cross-correlation between each unlabeled x_u and the labeled data points x_l
    n_l = data_points.shape[0]
    rho = []
    for x_u in search_space:
        tmp_rho = 0
        for x_l in data_points:
            tmp_rho += np.sum(np.array(x_u) * np.array(x_l)) / (
                    np.sqrt(np.linalg.norm(x_l) ** 2 * np.linalg.norm(x_u) ** 2) + 1e-8)
        rho.append(tmp_rho / n_l)

    rho = torch.Tensor(rho)

    if normalise:
        rho, _, _ = min_max_feature_scaling(rho)

    return rho


def ensure_change_in_variance(change_in_var, beta):
    """
    Excludes data points, where change in variance is below beta (e.g. 10%)
    Returns indices to selection_array

    :param change_in_var: torch.Tensor with the change in variance
    :param beta: threshold for change in variance
    """
    if torch.sum(change_in_var > beta) > 0:
        return change_in_var < beta, beta
    else:
        return ensure_change_in_variance(change_in_var, beta * 0.9)


def greedy_sampling(args, labeled_x, unlabeled_x, labeled_y, unlabeled_pred, criterion='GSx', variance=None):
    """
    The basic greedy sampling is maximum minimum distance in the input space.
    This function implements 4 different versions: GSx, GSy, iGS and iGSr

    GSx: greedy sampling in the feature space [1]
    GSy: greedy sampling in the output space [1]
    iGS: improved greedy sampling. Actually it is the product of GSx and GSy [1]
    iGSr: iGS divided by a representativeness criterion [2]

    These criteria are designed for the sequential-sampling. We have implemented two different versions for
    batch-sampling:
    1) Batch-wise. We compute all distances at ones
    2) Sequential. Every time a point is added, we recompute the distances in the feature space. We use the prediction
    of the y-label when we add a "y-label" in iGS and iGSr
    Which version to used is based on other hyperparameters such as the beta distance and min. change in variance.


    [1] Wu, D., Lin, C. T., & Huang, J. (2019). Active learning for regression using greedy sampling.
    Information Sciences, 474, 90–105. https://doi.org/10.1016/j.ins.2018.09.060
    [2] Liu, Z., & Wu, D. (2020). Integrating Informativeness, Representativeness and Diversity in Pool-Based
    Sequential Active Learning for Regression. 2020 International Joint Conference on Neural Networks (IJCNN), 1–7.
    https://doi.org/10.1109/IJCNN48605.2020.9206845

    The criteria are implemented using the pytorch functions:
    # The PyTorch way
    # Calculate the pairwise distance between all unlabeled and labeled data points
    # s.t. each row contain the distance from one unlabeled data point to all labeled data points
    distance_matrix = torch.cdist(unlabeled_pred, labeled_y, p=2.0)
    # Get the minimum distance for each row
    min_distance = torch.min(torch.pow(distance_matrix, 2), dim=1)
    # Get the indices sorted w.r.t the maximum distance
    _, indices = torch.sort(min_distance[0], descending=True)

    :param labeled_x: input (x) of labeled data points
    :param unlabeled_x: input (x_tilde) of unlabeled data points
    :param labeled_y: label (y) for the labeled data points
    :param unlabeled_pred: prediction (y_hat) of label for the unlabeled data points
    :returns indices: returns indices of points to query next: ACTUALLY IT RIGHT NOW RETURNS THE DISTANCES...
    """

    # cdist want to have a batch size so make sure that 1d are converted to batch
    # if 1-dimensional
    if labeled_x is not None and len(labeled_x.shape) == 1:
        labeled_x = labeled_x.view(-1, 1)
        unlabeled_x = unlabeled_x.view(-1, 1)

    if labeled_y is not None and len(labeled_y.shape) == 1:
        labeled_y = labeled_y.view(-1, 1)
        unlabeled_pred = unlabeled_pred.view(-1, 1)

    distance_matrix = None
    min_distance = None
    if args.beta_sampling > 0 or args.min_change_in_var > 0:
        # Implementation 1: Batch-wise
        # We simply calculate GSx in one go and do not take into account that after the first point is added
        # the distances to the other points should be recalculated.
        # Example: Say we want to points. Let train_x be [0, 3, 6], then the four points [1, 2, 4, 5] will all have the
        # distance. Thus, we might risking getting 1 and 2 instead of 1/2 and 4/5.
        if criterion == 'GSx':
            distance_matrix = torch.cdist(unlabeled_x, labeled_x, p=2.0)
        elif criterion == 'GSy':
            distance_matrix = torch.cdist(unlabeled_pred, labeled_y, p=2.0)
        elif criterion in ['iGS', 'iGSr', 'iGSvar']:
            distance_matrix_x = torch.cdist(unlabeled_x, labeled_x, p=2.0)
            if criterion == 'iGSvar':
                distance_matrix_y = variance.view(-1, 1)
            else:
                distance_matrix_y = torch.cdist(unlabeled_pred, labeled_y, p=2.0)
            distance_matrix = distance_matrix_x * distance_matrix_y
            if criterion == 'iGSr':
                # Compute representativeness for all unlabeled data points
                r_distance_matrix = torch.cdist(unlabeled_x, unlabeled_x, p=2.0)
                r = torch.sum(r_distance_matrix, dim=1)
                # Divide distance matrix with representativeness
                # Divide each row w/ the representativeness for the that data point
                distance_matrix = distance_matrix / r.view(-1, 1)

        min_distance = torch.min(distance_matrix, dim=1)[0]
        #values, indices = torch.sort(min_distance, descending=True)

    else:
        # Implementation 2: Sequential
        # Every time we find a point to query, we add it the label-set and recalculate the distances
        # Example: Say we want two points. Let train_x be [0, 3, 6], then the two points [1, 5] will have the
        # same distance. If we add 1, we will recompute the distance with train_x = [0,1,3,6], where 5 now will be the
        # next point to query.
        if criterion == 'GSx':
            min_distance = torch.zeros(unlabeled_x.shape[0])
            for k_sample in range(args.k_samples):
                # Get distance
                tmp_distance_matrix = torch.cdist(unlabeled_x, labeled_x, p=2.0)
                # Calculate minimum value
                tmp_min_distance = torch.min(tmp_distance_matrix, dim=1)
                # Find the point with largest minimum
                values, indices = torch.sort(tmp_min_distance[0], descending=True)
                new_point = unlabeled_x[indices[0], :]
                if len(new_point.shape) == 1:
                    new_point = new_point.view(1, -1)
                # Add point to the labeled data set
                labeled_x = torch.cat([labeled_x, new_point])
                # Add a descending values to the min_distance array
                min_distance[indices[0]] = args.k_samples - k_sample

        elif criterion == 'GSy':
            distance_matrix = torch.cdist(unlabeled_pred, labeled_y, p=2.0)
            min_distance = torch.min(distance_matrix, dim=1)[0]

        elif criterion in ['iGS', 'iGSr', 'iGSvar']:
            min_distance = torch.zeros(unlabeled_x.shape[0])
            for k_sample in range(args.k_samples):
                # Get distance
                if criterion == 'iGSvar':
                    normalized_variance = variance / torch.sum(variance, dim=0)
                    total_variance = torch.sum(normalized_variance, dim=1)
                    tmp_distance_matrix_y = total_variance.view(-1, 1)
                else:
                    tmp_distance_matrix_y = torch.cdist(unlabeled_pred, labeled_y, p=2.0)
                tmp_distance_matrix_x = torch.cdist(unlabeled_x, labeled_x, p=2.0)
                tmp_distance_matrix = tmp_distance_matrix_x * tmp_distance_matrix_y
                if criterion == 'iGSr':
                    # Compute representativeness for all unlabeled data points
                    r_distance_matrix = torch.cdist(unlabeled_x, unlabeled_x, p=2.0)
                    r = torch.sum(r_distance_matrix, dim=1)
                    # Divide distance matrix with representativeness
                    # Divide each row w/ the representativeness for the that data point
                    tmp_distance_matrix = tmp_distance_matrix / r.view(-1, 1)
                # Calculate minimum value
                tmp_min_distance = torch.min(tmp_distance_matrix, dim=1)
                # Find the point with largest minimum
                values, indices = torch.sort(tmp_min_distance[0], descending=True)
                new_point_x = unlabeled_x[indices[0], :]
                new_point_y = unlabeled_pred[indices[0], :]  # We use the prediction of the model
                if len(new_point_x.shape) == 1:
                    new_point_x = new_point_x.view(1, -1)
                if len(new_point_y.shape) == 1:
                    new_point_y = new_point_y.view(1, -1)
                # Add point to the labeled data set
                labeled_x = torch.cat([labeled_x, new_point_x])
                labeled_y = torch.cat([labeled_y, new_point_y])
                # Add a descending values to the min_distance array
                min_distance[indices[0]] = args.k_samples - k_sample

    return min_distance


def cohns_criterion(model, train_x, new_x, xis):
    """
    Cohn's active learning criterion [1]

    Returns the "change" for each value in the new_x

    [1] Seo, S., Wallat, M., Graepel, T., & Obermayer, K. (2000).
    Gaussian process regression: Active data selection and test point rejection.
    Proceedings of the International Joint Conference on Neural Networks, 3(1), 241–246.
    https://doi.org/10.1109/ijcnn.2000.861310

    :param model: GPyTorch model
    :param train_x: training data (torch.Tensor)
    :param new_x: transformed search space (torch.Tensor)
    :param xis: reference points (torch.Tensor), (could be the search_space)
    """
    # Len of new_x
    n = new_x.shape[0]

    # Compute the inverse of C(x,x)
    model.train()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        predictions_cn = model.likelihood(model(train_x))
    inv_CN = torch.inverse(predictions_cn.covariance_matrix).detach()

    # Compute m and C(x_tilde, x_tilde)
    model.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        predictions_cn1 = model.likelihood(model(torch.cat([train_x, new_x])))
    C_N1 = predictions_cn1.covariance_matrix.detach()
    m = C_N1[:-n, -n:]
    C_tilde = C_N1[-n:, -n:]

    # Compute the Cohn's criterion over multiple xis
    # Get denominator
    inv_C_m = torch.matmul(inv_CN, m)
    den = C_tilde - torch.matmul(torch.transpose(m, 0, 1), inv_C_m)
    den = den.diag()

    # Get k_N
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        tmp_pred = model.likelihood(model(torch.cat([train_x, xis])))
    kN1 = tmp_pred.covariance_matrix[-xis.shape[0]:, :train_x.shape[0], ].detach()
    kN = kN1

    # Get C(x_tilde, xi)
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        tmp_pred = model.likelihood(model(torch.cat([new_x, xis])))
    C_tilde_xi = tmp_pred.covariance_matrix[-xis.shape[0]:, :n].detach()
    num = (torch.matmul(kN, inv_C_m) - C_tilde_xi) ** 2
    num = torch.sum(num, dim=0)

    criterion = num / den

    # Return the average
    return criterion / len(xis)


def mutual_information(model, train_x, new_x, search_space=None):
    """
    Calculate the mutual information (MI) [1]

    It is implemented by only doing one forward pass (model prediction).
    The covariance matrix covar_ss can be divided into

    covar_ss = [ cov(train_x, train_x), cov(train_x, search_space);
                 cov(search_space, train_x), cov(search_space, search_space)]

    The covariance matrix cov(search_space, search_space) is denoted covar_uu.

    [1] Guestrin, C., Krause, A., & Singh, A. P. (2005). Near-optimal sensor placements in Gaussian processes.
    Proceedings of the 22nd International Conference on Machine Learning - ICML ’05, 1, 265–272.
    https://doi.org/10.1145/1102351.1102385

    :param model: gpytorch model
    :param train_x: labeled data
    :param new_x: must be a subspace of the search_space
    :param search_space: the search space, if none the new_x is also used as the search space.
    """

    if search_space is None:
        search_space = new_x

    # Calculate the covariance matrix between all labeled and unlabeled data points,
    # i.e. the training set and the search space
    model.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        predictions_ss = model.likelihood(model(torch.cat([train_x, search_space])))
    covar_ss = predictions_ss.covariance_matrix

    # Extract covariance matrix for train_x
    n = train_x.shape[0]
    covar_aa = covar_ss[:n, :n]

    # Extract covariance matrix for search space (unlabeled data)
    covar_uu = covar_ss[n:, n:]

    # Get indices for the covariance matrix to costruct A_bar
    # We must exclude all point in the labeled space from the unlabeled space
    # NB: IS THIS REALLY NECESSARY? We can have multiple points at the same location
    #u = [round(x.item() * 100) / 100 for x in search_space]
    #l = [round(x.item() * 100) / 100 for x in train_x]
    #indices = [False if u_x in l else True for u_x in u]

    indices = []
    for ux in search_space:
        indices.append(True)
        for lx in train_x:
            if (ux == lx).all():
                indices[-1] = False  # Do not use point if it is in the training set
                continue

    # Get variance of the new data points
    var_y = torch.diag(covar_uu)

    # Mutual Information (MI) for each data point in the candidate set (could be regular search space)
    mis = []
    for idx, xx in enumerate(new_x):
        # nominator
        covar_ya_vec = covar_ss[(n + idx), :n].unsqueeze(0)
        covar_ay_vec = covar_ss[(n + idx), :n].unsqueeze(1)
        covar_aa_inv = torch.inverse(covar_aa)
        nominator = var_y[idx] - torch.matmul(covar_ya_vec, torch.matmul(covar_aa_inv, covar_ay_vec))

        # denominator
        tmp_indices = indices.copy()
        tmp_indices[idx] = False  # Exclude current new point
        covar_ya_vec2 = covar_uu[idx, :][tmp_indices].unsqueeze(0)
        covar_ay_vec2 = covar_uu[idx, :][tmp_indices].unsqueeze(1)
        covar_abarabar = covar_uu[tmp_indices, :][:, tmp_indices]
        covar_aa_inv2 = torch.inverse(covar_abarabar)
        denominator = var_y[idx] - torch.matmul(covar_ya_vec2, torch.matmul(covar_aa_inv2, covar_ay_vec2))

        mi = nominator / denominator
        mis.append(mi[0, 0].item())

    return torch.tensor(mis)


def expected_model_change_output(model, train_x, new_x, p=1):
    """
    Expected Model Change Output (EMOC) [1]

    The official numpy implementation [2] has here been modified to work with gpytorch.

    [1] Kading, C., Rodner, E., Freytag, A., Mothes, O., Barz, B., & Denzler, J. (2019)
    Active learning for regression tasks with expected model output changes.
    British Machine Vision Conference 2018, BMVC 2018.
    [2] http://triton.inf-cv.uni-jena.de/LifelongLearning/gpEMOCreg/src/master/activeLearning/activeLearningGPemoc.py

    :param model: gpytorch model
    :param train_x: labeled data
    :param new_x: must be a subspace of the search_space
    :param p: using the p-th non-central moment
    """

    # Modifications:
    # self.norm = p
    # self.X = train_x
    # self.sigmaN = model.likelihood.noise.detach()
    # self.K = kAll[:train_x.shape[0], :train_x.shape[0]]

    def gaussianAbsoluteMoment(muTilde, predVar):
        f11 = hyp1f1(-0.5 * p, 0.5, -0.5 * np.divide(muTilde ** 2, predVar))
        prefactors = ((2 * predVar ** 2) ** (p / 2.0) * math.gamma((1 + p) / 2.0)) / np.sqrt(np.pi)
        return np.multiply(prefactors, f11)

    # Get distributions from GP
    model.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        predictions = model.likelihood(model(torch.cat([train_x, new_x])))

    # convert from PyTorch to numpy
    x = new_x.numpy()

    # Modified official implementation
    emocScores = np.asmatrix(np.empty([x.shape[0], 1], dtype=np.float))
    # muTilde =np.asmatrix(np.zeros([x.shape[0],1], dtype=np.float))
    muTilde = predictions.mean[train_x.shape[0]:].detach().numpy()

    # kAll = self.kernelFunc(np.vstack([train_x, x]))
    kAll = predictions.lazy_covariance_matrix.detach()
    k = kAll[0:train_x.shape[0], train_x.shape[0]:].numpy()
    #selfKdiag = np.asmatrix(np.diag(kAll[train_x.shape[0]:, train_x.shape[0]:].numpy())).T

    # sigmaF = self.calcSigmaF(x, k, selfKdiag)
    #sigmaF = np.diag(kAll[train_x.shape[0]:, train_x.shape[0]:])
    sigmaF = predictions.stddev[train_x.shape[0]:].detach().numpy()
    moments = np.asmatrix(gaussianAbsoluteMoment(np.asarray(muTilde), np.asarray(sigmaF)))

    self_sigmaN = model.likelihood.noise.detach().numpy()
    term1 = 1.0 / (sigmaF + self_sigmaN)

    term2 = np.asmatrix(np.ones((train_x.shape[0] + 1, x.shape[0])), dtype=np.float) * (-1.0)
    self_K = kAll[:train_x.shape[0], :train_x.shape[0]].numpy()
    term2[0:train_x.shape[0], :] = np.linalg.solve(self_K + np.identity(train_x.shape[0], dtype=np.float) * self_sigmaN,
                                                   k)

    #preCalcMult = np.dot(term2[:-1, :].T, kAll[0:train_x.shape[0], :].numpy())

    if len(term1.shape) == 1:
        term1 = term1.reshape(-1, 1)

    tmp_term = kAll[0:train_x.shape[0], :].numpy().astype(float)
    tmp_term2 = kAll[train_x.shape[0]:, :]
    for idx in range(x.shape[0]):
        #vAll = term1[idx, :] * (preCalcMult[idx, :] + np.dot(term2[-1, idx].T, kAll[train_x.shape[0] + idx, :].numpy()))
        preCalcMult = np.dot(term2[:-1, idx].T.astype(float), tmp_term)
        vAll = term1[idx, :] * (preCalcMult + np.dot(term2[-1, idx].T.astype(float), tmp_term2[idx, :].detach().numpy().astype(float)))
        emocScores[idx, :] = np.mean(np.power(np.abs(vAll), p))

    output = np.multiply(emocScores, moments)

    # NB: Why do I need to take the diagonal? That was not a part of the original implementation
    return torch.diag(torch.tensor(output))


def query_by_committee(args, model, search_space, predictions, train_x, train_y, sampling_crit=None):
    """
    Query-by-committee (QBC)

    :param args: arguments
    :param model: gpytorch model
    :param search_space: unique search space in the transformed space
    :param predictions: dictionary with predictions, means, stddevs and individual model predictions
    :param train_x: labeled transformed input features
    :param train_y: labeled transformed output features
    :param sampling_crit: do we want to add an extra sample criteria
    """

    # if we not already have a committee, we must make one
    if 'individual_preds' not in predictions.keys():
        print("Making an ensemble to use Query-by-Committee")
        model = EnsembleGP(train_x, train_y,
                           kernel=gpytorch.kernels.RBFKernel(),
                           likelihood=gpytorch.likelihoods.GaussianLikelihood())
        model.fit(train_data=(train_x, train_y), args=args)
        predictions = model.predict(dataloader=(search_space, None))
    n_models = len(predictions['individual_preds'])

    # Regular QBC (standard deviation between the predictions)
    if sampling_crit is None:
        ensemble_pred = torch.zeros(n_models, search_space.shape[0])
        for idx in range(n_models):
            ensemble_pred[idx, :] = predictions['individual_preds'][idx].mean
        qbc_selection_array = torch.std(ensemble_pred, dim=0)

    elif sampling_crit == "emoc":
        ensemble_emoc = torch.zeros(n_models, search_space.shape[0])
        for idx in range(n_models):
            ensemble_emoc[idx, :] = expected_model_change_output(model.models[idx], train_x, search_space, p=1)
        emoc_mean = torch.mean(ensemble_emoc, dim=0)
        emoc_stddev = torch.std(ensemble_emoc, dim=0)

        # take the mean emoc and weight with the uncertainty
        qbc_selection_array = emoc_mean / emoc_stddev

    return qbc_selection_array, predictions


def mcmc_variance_of_means(args, model, search_space):
    if model.batch_model is None:
        model.set_batch_model()

    search_space = search_space.view(-1, 1) if len(search_space.shape) == 1 else search_space
    expanded_test_x = search_space.unsqueeze(0).repeat(args.num_samples * args.num_chains, 1, 1)
    output = model.batch_model(expanded_test_x)
    variance_of_means = torch.pow(torch.std(output.mean.detach(), axis=0), 2)
    return variance_of_means, output


def mcmc_mean_variance(args, model, search_space):
    if model.batch_model is None:
        model.set_batch_model()

    search_space = search_space.view(-1, 1) if len(search_space.shape) == 1 else search_space
    expanded_test_x = search_space.unsqueeze(0).repeat(args.num_samples * args.num_chains, 1, 1)
    output = model.batch_model(expanded_test_x)
    mean_variance = torch.mean(torch.pow(output.stddev.detach(), 2), axis=0)
    return mean_variance, output


def mcmc_query_by_gmm(args, model, search_space):
    if model.batch_model is None:
        model.set_batch_model()

    search_space = search_space.view(-1, 1) if len(search_space.shape) == 1 else search_space
    expanded_test_x = search_space.unsqueeze(0).repeat(args.num_samples * args.num_chains, 1, 1)
    output = model.batch_model(expanded_test_x)
    mean_variance = torch.mean(torch.pow(output.stddev.detach(), 2), axis=0)
    variance_of_means = torch.pow(torch.std(output.mean.detach(), axis=0), 2)
    return mean_variance + variance_of_means, output


def mcmc_bald(args, model, search_space):
    if model.batch_model is None:
        model.set_batch_model()

    search_space = search_space.view(-1, 1) if len(search_space.shape) == 1 else search_space
    expanded_test_x = search_space.unsqueeze(0).repeat(args.num_samples * args.num_chains, 1, 1)
    output = model.batch_model(expanded_test_x)
    mean_stddev_all = torch.mean(output.stddev.detach(), axis=0) / 1500
    entropy_expected = torch.log(mean_stddev_all)
    expected_entropy = torch.mean(torch.log(output.stddev.detach()), axis=0)
    bald = entropy_expected - expected_entropy
    return bald, output
