from tqdm import tqdm
import logging
import numpy as np
from sklearn.decomposition import PCA
from sklearn.covariance import MinCovDet as MCD
import torch

logger = logging.getLogger(__name__)


def cov_matrix(X, robust=False):
    """ Compute the covariance matrix of X.
    """
    if robust:
        cov = MCD().fit(X)
        sigma = cov.covariance_
    else:
        sigma = np.cov(X.T)

    return sigma


def standardize(X, robust=False):
    """ Compute the square inverse of the covariance matrix of X.
    """

    sigma = cov_matrix(X, robust)
    n_samples, n_features = X.shape
    rank = np.linalg.matrix_rank(X)

    if rank < n_features:
        pca = PCA(rank)
        pca.fit(X)
        X_transf = pca.transform(X)
        sigma = cov_matrix(X_transf)
    else:
        pca = None
        X_transf = X.copy()

    u, s, _ = np.linalg.svd(sigma)
    square_inv_matrix = u / np.sqrt(s)

    return X_transf @ square_inv_matrix, square_inv_matrix, pca


def sampled_sphere(K, d):
    import torch
    import torch.nn.functional as F

    torch.manual_seed(0)
    mean = torch.zeros(d)
    identity = torch.eye(d)
    dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=mean, scale_tril=identity)
    U = dist.rsample(sample_shape=(K, ))
    return F.normalize(U)


class DataDepth:
    def __init__(self, K):
        self.K = K

    def AI_IRW(
            self, X, AI=False, estimator="classic", method="cholesky", X_test=None, U=None
    ):
        """Compute the score of the average halfspace depth of X_test w.r.t. X
        Parameters
        ----------
        X : Array-like (n_samples, dimension)
                The training set.
        AI: str
            To choose the Affine-Invariant or the original formulation of IRW.
        n_directions : int
            The number of random directions to compute the score.
        X_test : The testing set where the depth is computed.
        U: Array-like (n_directions, dimension)
           If None, it sample directions from the unit sphere.
        Returns
        -------
        Array of float
            Depth score of each delement of X_test.
        """
        if X_test is not None:
            if AI:
                Y, sigma_square_inv, pca = standardize(X, False)
                if pca is not None:
                    Y_test = pca.transform(X_test)
                else:
                    Y_test = X_test

            else:
                Y_test = X_test.copy()
                Y = X.copy()

            # Simulate random directions on the unit sphere.
            n_samples, dim = Y.shape

            if U is not None :
                U_shape = U.shape
                print('Input U shape', U.shape)
                print('Expected U shape [{}, {}]'.format(self.K, dim))
                if (U_shape[0] != self.K) or (U_shape[1] != dim) :
                    print('Mismatch in shape')
                    U = sampled_sphere(self.K, dim)
            
            if U is None:
               U = sampled_sphere(self.K, dim)

        ################################################
        """
        Sigma_square_inv = reduction_matrix(X, estimator, method)


        if AI == True:
            if X_test is None:
                Y = X @ Sigma_square_inv
            else:
                Y = X @ Sigma_square_inv
                Y_test = X_test @ Sigma_square_inv
        else:
            if X_test is None:
                Y = X.copy()
            else:
                Y = X.copy()
                Y_test = X_test.copy()
        """

        # A faster implementation is given if one want to compute the depth of the training set.

        if X_test is None:
            if AI:
                Y, sigma_square_inv, pca = standardize(X, False)
            else:
                Y = X.copy()
                n_samples, dim = Y.shape
            if U is None:
                U = sampled_sphere(self.K, dim)
            ####

            z = np.arange(1, n_samples + 1)
            Depth = np.zeros((n_samples, self.K))

            Z = np.matmul(Y, U.T)
            A = np.matrix.argsort(Z, axis=0)

            for k in tqdm(range(self.K), "Projection", position=0, leave=True, colour='green', ascii=True, ncols=100):
                Depth[A[:, k], k] = z

            Depth = Depth / (n_samples * 1.0)

            Depth_score = np.minimum(Depth, 1 - Depth)

        # The general implementation.
        else:
            n_samples_test, dim_test = Y_test.shape
            if dim_test != dim:
                print("error: dimension of X and X_test must be the same")

            Depth = np.zeros((n_samples_test, self.K))
            z = np.arange(1, n_samples_test + 1)
            A = np.zeros((n_samples_test, self.K), dtype=int)
            Z = np.matmul(Y, U.T)

            Z2 = np.matmul(Y_test, U.T)

            Z.sort(axis=0)

            for k in tqdm(range(self.K), "Projection", position=0, leave=True, colour='green', ncols=100, ascii=True):
                A[:, k] = np.searchsorted(a=Z[:, k], v=Z2[:, k], side="left")
                Depth[:, k] = A[:, k]
            Depth = Depth / (n_samples * 1.0)

            Depth_score = np.minimum(Depth, 1 - Depth)

        return np.mean(Depth_score, axis=1)

    def compute_depths(self, X, X_test, depth_choice):
        logger.info("Choosen depths %s", depth_choice)
        if depth_choice == "int_w_halfs_pace":
            depth_method = self.AI_IRW
        else:
            raise NotImplementedError
        return depth_method(X=X, X_test=X_test)
