import torch
import torch.nn as nn
import torch.nn.functional as F

class contrastive_loss(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    def forward(self,x,labels):
        #this function assums that positive logit is always the first element.
        #Which is true here
        loss = -x[:,0] + torch.logsumexp(x[:,1:],dim=1)
        return loss.mean()


class contrastive_loss_2(nn.Module):
    def __init__(self, n_view=2):
        super().__init__()
        self.n_view = n_view
        pass
    def forward(self, x,labels):
        pos = torch.logsumexp(x[:, :(self.n_view-1)], 1)
        # neg = torch.logsumexp(x[:, (self.n_view-1):], 1)
        all = torch.logsumexp(x, 1)
        loss = -pos + all
        return loss.mean()


class SimCLR(nn.Module):
    """
    只支持n_views=2，否则loss会计算错误
    """
    def __init__(self,temperature=0.5,n_views=2,contrastive=False):
        super(SimCLR,self).__init__()
        self.temp = temperature
        self.n_views = n_views
        
        if contrastive:
            # self.criterion = contrastive_loss()
            self.criterion = contrastive_loss_2(n_view=self.n_views)
        else:
            self.criterion = torch.nn.CrossEntropyLoss()
        
    def info_nce_loss(self,X):
        
        bs, n_dim = X.shape    # bs 2000 
        bs = int(bs/self.n_views) # bs = 2000 / 20 = 100
        device = X.device
        
        
        labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        similarity_matrix = torch.matmul(X, X.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)   # torch.Size([2000, 19])

        # select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)   # torch.Size([2000, 1980])

        logits = torch.cat([positives, negatives], dim=1)  # torch.Size([2000, 1999])  看作分类问题 
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)  # torch.Size([2000])
        
        logits = logits / self.temp
        return logits, labels
        
    def forward(self,X):
        logits, labels = self.info_nce_loss(X) # X torch.Size([2000, 1024])
        loss = self.criterion(logits, labels)
        return loss

class Z_loss(nn.Module):
    def __init__(self,):
        super().__init__()
        pass
        
    def forward(self,z):
        z_list = z.chunk(2,dim=0)
        z_sim = F.cosine_similarity(z_list[0],z_list[1],dim=1).mean()
        z_sim_out = z_sim.clone().detach()
        return -z_sim, z_sim_out

class TotalCodingRate(nn.Module):
    def __init__(self, eps=0.01):
        super(TotalCodingRate, self).__init__()
        self.eps = eps
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]  torch.Size([1024, 10])
        I = torch.eye(p,device=W.device)  # torch.Size([1024, 1024])
        scalar = p / (m * self.eps)  # 512
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def forward(self,X):
        return - self.compute_discrimn_loss(X.T)  # torch.Size([10, 1024])
    

class TotalCodingRate_MI(nn.Module):
    def __init__(self, eps=0.01):
        super(TotalCodingRate_MI, self).__init__()
        self.eps = eps
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]
        I = torch.eye(p,device=W.device)
        scalar = self.eps / p
        logdet = torch.logdet( scalar * I + W.matmul(W.T) / m)
        return logdet / 2.
    
    def forward(self,X):
        return - self.compute_discrimn_loss(X.T)



class MI_LogDet_Loss(nn.Module):
    def __init__(self, beta=1e-2, num_patches=20):
        super(MI_LogDet_Loss, self).__init__()
        self.beta = beta
        self.num_patches = num_patches
        
    @staticmethod
    def covariance2entropy_singleGaussian(x, beta=0.01):
        d = x.shape[1]
        cov= torch.cov(x.T)  # torch.Size([1024, 1024])
        Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)).cuda() + 1) + 0.5 * torch.logdet(beta * torch.eye(d).cuda() + cov)
        # Ent =  0.5 * torch.logdet(beta * torch.eye(d).cuda() + cov)
        return Ent

    @staticmethod
    def covariance2entropy_singleGaussian_diag(x, beta=0.0):
        d = x.shape[1]
        cov= torch.diag(torch.diag(torch.cov(x.T)))
        Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)).cuda() + 1) + 0.5 * torch.logdet(beta * torch.eye(d).cuda() + cov)
        return Ent

    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]  t2torch.Size([1024, 10])
        I = torch.eye(p,device=W.device)  # torch.Size([1024, 1024])
        scalar = p / (m * self.beta)  # 512
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2

    def marginal_entropy(self, X): # 负样本
        z_list = X.chunk(self.num_patches,dim=0)
        loss = 0 
        for i in range(self.num_patches):
            # zi = self.pair(z_list[i])
            zi = z_list[i]
            loss += self.covariance2entropy_singleGaussian(zi, self.beta)
            # loss += self.compute_discrimn_loss(zi.T)
        loss = loss/self.num_patches
        return loss

    def conditional_entropy(self, X):  # 正样本
        z_list = X.chunk(self.num_patches,dim=0)
        z_list = torch.stack(list(z_list), dim=0) 
        
        loss = 0
        for i in range(z_list.shape[1]):
            z_pos = z_list[:,i,:]
            # z_pos = self.pair(z_pos)
            # loss += self.covariance2entropy_singleGaussian(z_pos, self.beta)
            loss += self.covariance2entropy_singleGaussian_diag(z_pos, self.beta)
            
            # loss += self.compute_discrimn_loss(z_pos.T)
        loss = loss/z_list.shape[1]
        return loss

    def pair(self,X):
        
        bs, n_dim = X.shape    
        device = X.device

        mask = torch.eye(bs, dtype=torch.bool).to(device)  
        
        X_matrix = torch.cat([torch.stack([X] * X.shape[0], dim=0), torch.stack([X] * X.shape[0], dim=1)], dim=2)
        
        X_matrix = X_matrix[~mask].view(X_matrix.shape[0], -1, X_matrix.shape[2])  
        
        pair = X_matrix.view(-1, X_matrix.shape[-1])
        return pair
    
    def forward(self, X):
        marginal_loss = self.marginal_entropy(X)
        conditional_loss = self.conditional_entropy(X)
        return marginal_loss, conditional_loss
    
    
class MaximalCodingRateReduction(torch.nn.Module):
    def __init__(self, eps=0.01, gamma=1):
        super(MaximalCodingRateReduction, self).__init__()
        self.eps = eps
        self.gamma = gamma
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape
        I = torch.eye(p,device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def compute_compress_loss(self, W, Pi):
        p, m = W.shape
        k, _, _ = Pi.shape
        I = torch.eye(p,device=W.device).expand((k,p,p))
        trPi = Pi.sum(2) + 1e-8
        scale = (p/(trPi*self.eps)).view(k,1,1)
        
        W = W.view((1,p,m))
        log_det = torch.logdet(I + scale*W.mul(Pi).matmul(W.transpose(1,2)))
        compress_loss = (trPi.squeeze()*log_det/(2*m)).sum()
        return compress_loss
        
    def forward(self, X, Y, num_classes=None):
        #This function support Y as label integer or membership probablity.
        if len(Y.shape)==1:
            #if Y is a label vector
            if num_classes is None:
                num_classes = Y.max() + 1
            Pi = torch.zeros((num_classes,1,Y.shape[0]),device=Y.device)
            for indx, label in enumerate(Y):
                Pi[label,0,indx] = 1
        else:
            #if Y is a probility matrix
            if num_classes is None:
                num_classes = Y.shape[1]
            Pi = Y.T.reshape((num_classes,1,-1))
            
        W = X.T
        discrimn_loss = self.compute_discrimn_loss(W)
        compress_loss = self.compute_compress_loss(W, Pi)
 
        total_loss = - discrimn_loss + self.gamma*compress_loss
        return total_loss, [discrimn_loss.item(), compress_loss.item()]
    
    
import math
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import numpy as np


def MNorm_log_pdf(x, mu, cov):
    Mnorm = torch.distributions.multivariate_normal.MultivariateNormal(mu, cov)
    # shape of log_pdf is [Ndata,1]
    return Mnorm.log_prob(x).view(-1, 1)


def soft_assignment(x,w,mu,cov):
    softmax = torch.nn.Softmax(dim=1)
    logpdf = torch.zeros((x.shape[0], w.shape[0]))
    for k in range(w.shape[0]):
        logpdf[:, k] = w[k].log() + MNorm_log_pdf(x, mu[k, :], cov[k, :, :]).squeeze()
    p = softmax(logpdf)  # p is the ownership/ soft assignment
    return p


class MI_LogDet_RobustEstimator(torch.nn.Module):
    def __init__(self, beta,Kmax,method='kmeans',n_views=20):
        super(MI_LogDet_RobustEstimator, self).__init__()

        self.beta=beta
        self.Kmax = Kmax
        self.method=method
        self.n_views=n_views

    @staticmethod
    def covariance2entropy_singleGaussian(x):
        d = x.shape[1]
        # Sigma = torch.cov(x.T)
        Sigma = torch_cov(x)
        Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1) + 0.5 * torch.logdet(Sigma)
        return Ent

    @staticmethod
    def H_lower(w, mu, cov):
        H_l = torch.tensor([0.0])
        K = w.shape[0]
        for i in range(K):
            pdf_mix = torch.tensor([0.0])
            for j in range(K):
                pdf_mix += w[j] * torch.exp(MNorm_log_pdf(mu[i, :], mu[j, :], cov[i, :, :] + cov[j, :, :]).squeeze())
            H_l += w[i] * torch.log(pdf_mix)
        return -H_l

    @staticmethod
    def H_upper(w, mu, cov, beta):
        d = cov.shape[-1]
        K = w.shape[0]
        H_u = torch.tensor([0.0])
        for j in range(K):
            H_u += w[j] * (
                    -torch.log(w[j]) + d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1) + 0.5 * torch.logdet(
                beta * torch.eye(d) + cov[j, :, :]))
        return H_u

    def covariance2entropy_estimator(self, x):
        # x tensor
        # w k dim
        # Sigma k * d* d
        d = x.shape[-1]
        Ns = x.shape[0]
        K=self.Kmax
        if K > 1:
            cluster = KMeans(n_clusters=K, random_state=0).fit(x.cpu().detach().numpy())
            y_pred = cluster.predict(x.cpu().detach().numpy())
            w = torch.zeros(K)
            mu = torch.zeros((K, d))
            cov = torch.zeros((K, d, d))
            for j in range(K):
                indx = np.where(y_pred == j)
                w[j] = torch.from_numpy(np.array(np.sum(y_pred == j) / Ns)).float()
                mu[j, :] = torch.mean(x[indx[0], :], dim=0)
                # cov[j, :, :] = torch.cov(x[indx[0], :].T)
                cov[j, :, :] = torch_cov(x[indx[0], :])
            H_u = self.H_upper(w,mu,cov,self.beta)
            Ent = H_u
        else:
            # cov = torch.cov(x.T)
            cov = torch_cov(x)
            Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1).cuda() + 0.5 * torch.logdet(self.beta * torch.eye(d).cuda() + cov)
        return Ent

    def covariance2entropy_estimator_SoftAssign(self, x):
        # w k dim
        # Sigma k * d* d
        K = self.Kmax
        GMM = GaussianMixture(n_components=K, init_params='kmeans')
        GMM.fit(X=x.detach().numpy())
        w = torch.from_numpy(GMM.weights_).float()
        mu = torch.from_numpy(GMM.means_).float()
        Sigma = torch.from_numpy(GMM.covariances_).float()
        Pi = soft_assignment(x, w, mu, Sigma)
        d = mu.shape[1]
        Ns = x.shape[0]
        if K > 1:
            Ent = 0.0
            for j in range(K):
                Sigma_est = 1 / torch.sum(Pi[:, j]) * torch.matmul(
                    torch.matmul((x - mu[j, :].view(1, -1).repeat(Ns, 1)).T, torch.diag(Pi[:, j])),
                    x - mu[j, :].view(1, -1).repeat(Ns, 1))
                Ent += w[j] * (-torch.log(w[j]) + d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1) + 0.5 * torch.logdet(
                    self.beta * torch.eye(d)+Sigma_est))
        else:
            # Sigma = torch.cov(x.T)
            Sigma = torch_cov(x)
            Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1) + 0.5 * torch.logdet(self.beta * torch.eye(d) + Sigma)
        return Ent

    def MILE_estimate(self, x, y):
        x_n=torch.randn_like(x)
        mix_est=self.covariance2entropy_estimator(x_n)
        bias_est_1=self.covariance2entropy_singleGaussian(x_n)-mix_est

        x_n=torch.randn_like(torch.cat((x, y), dim=1))
        mix_est=self.covariance2entropy_estimator(x_n)
        bias_est_2=self.covariance2entropy_singleGaussian(x_n)-mix_est

        if self.method == 'Kmeans':
            Hx= self.covariance2entropy_estimator(x)
            Hy = self.covariance2entropy_estimator(y)
            Hxy = self.covariance2entropy_estimator(torch.cat((x, y), dim=1))
        else:
            Hx = self.covariance2entropy_estimator_SoftAssign(x)
            Hy = self.covariance2entropy_estimator_SoftAssign(y)
            Hxy = self.covariance2entropy_estimator_SoftAssign(torch.cat((x, y), dim=1))
        # use H(x|y)=H(x,y)-H(x)
        # use MI(x,y)=H(x)+H(y)-H(x,y)
        MIx_y_unbias = Hx + Hy - Hxy #+ 2*bias_est_1-bias_est_2

        return MIx_y_unbias


    def MILE_estimate_pairs(self, x_pos, x_neg):
        dim= x_pos.shape[1]//2
        # x1_margin_sample=torch.concat((x_pos[:,0:dim],x_neg[:,0:dim]),dim=0)
        # x2_margin_sample=torch.concat((x_pos[:,dim:],x_neg[:,dim:]),dim=0)
        x1_margin_sample =x_neg[:, 0:dim]
        x2_margin_sample =x_neg[:, dim:]
        x1x2_joint_sample = x_pos

        # x_n=torch.randn_like(x1_margin_sample)
        # mix_est=self.covariance2entropy_estimator(x_n)
        # bias_est_1a=self.covariance2entropy_singleGaussian(x_n)-mix_est

        # x_n=torch.randn_like(x2_margin_sample)
        # mix_est=self.covariance2entropy_estimator(x_n)
        # bias_est_1b=self.covariance2entropy_singleGaussian(x_n)-mix_est
        # #
        # x_n=torch.randn_like(x1x2_joint_sample)
        # mix_est=self.covariance2entropy_estimator(x_n)
        # bias_est_2=self.covariance2entropy_singleGaussian(x_n)-mix_est

        if self.method == 'Kmeans':
            Hx= self.covariance2entropy_estimator(x1_margin_sample)
            Hy = self.covariance2entropy_estimator(x2_margin_sample)
            Hxy = self.covariance2entropy_estimator(x1x2_joint_sample)
        else:
            Hx = self.covariance2entropy_estimator_SoftAssign(x1_margin_sample)
            Hy = self.covariance2entropy_estimator_SoftAssign(x2_margin_sample)
            Hxy = self.covariance2entropy_estimator_SoftAssign(x1x2_joint_sample)
        # use H(x|y)=H(x,y)-H(x)
        # use MI(x,y)=H(x)+H(y)-H(x,y)
        MIx_y_unbias = Hx + Hy - Hxy #+ bias_est_1a+bias_est_1b-bias_est_2

        return MIx_y_unbias

    def pos_neg(self,X):
        
        bs, n_dim = X.shape    # bs 2000 
        bs = int(bs/self.n_views) # bs = 2000 / 20 = 100
        device = X.device
        
        
        labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        # similarity_matrix = torch.matmul(X, X.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)  # torch.Size([30, 30])
        labels = labels[~mask].view(labels.shape[0], -1) # torch.Size([30, 29])
        
        # x_tiled = torch.stack([X] * X.shape[0], dim=0)
        # y_tiled = torch.stack([X] * X.shape[0], dim=1) # torch.Size([30, 30, 1024])
        # X_matrix = torch.cat([x_tiled, y_tiled], dim=2) # torch.Size([30, 30, 2048])
        X_matrix = torch.cat([torch.stack([X] * X.shape[0], dim=0), torch.stack([X] * X.shape[0], dim=1)], dim=2)
        
        X_matrix = X_matrix[~mask].view(X_matrix.shape[0], -1, X_matrix.shape[2])  # torch.Size([30, 29, 2048])
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = X_matrix[labels.bool()].view(labels.shape[0],-1, X_matrix.shape[2])  # torch.Size([30, 2, 2048])

        # select only the negatives
        negatives = X_matrix[~labels.bool()].view(X_matrix.shape[0], -1, X_matrix.shape[2]) # torch.Size([30, 27, 2048])

        return positives, negatives
    
    def forward(self, X:torch.tensor):
        pos, neg = self.pos_neg(X)
        loss = self.MILE_estimate_pairs(pos.view(-1, pos.shape[-1]), neg.view(-1, neg.shape[-1]))
        return loss

def torch_cov(input_vec:torch.tensor):    
    x = input_vec- torch.mean(input_vec,axis=0)
    cov_matrix = torch.matmul(x.T, x) / (x.shape[0]-1)
    return cov_matrix


class MI_LogDet_RobustEstimator_improved(torch.nn.Module):
    def __init__(self, beta, Kmax, method='GMM', n_views=20):
        super(MI_LogDet_RobustEstimator_improved, self).__init__()

        self.beta = beta
        self.Kmax = Kmax
        self.method = method
        self.n_views = n_views

    @staticmethod
    def covariance2entropy_singleGaussian(x):
        d = x.shape[1]
        # Sigma = torch.cov(x.T)
        Sigma = torch_cov(x)
        Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1) + 0.5 * torch.logdet(Sigma)
        return Ent

    @staticmethod
    def H_upper(w, mu, cov, beta):
        d = cov.shape[-1]
        K = w.shape[0]
        H_u = torch.tensor([0.0])
        for j in range(K):
            H_u += w[j] * (
                    -torch.log(w[j]) + d / 2 * (torch.log(torch.tensor(2 * math.pi)) + 1) + 0.5 * torch.logdet(
                beta * torch.eye(d) + cov[j, :, :]))
        return H_u

    def covariance2entropy_estimator(self, x):
        # x tensor
        # w k dim
        # Sigma k * d* d
        d = x.shape[-1]
        Ns = x.shape[0]
        K = self.Kmax
        if K > 1:
            cluster = KMeans(n_clusters=K, random_state=0).fit(x.cpu().detach().numpy())
            y_pred = cluster.predict(x.cpu().detach().numpy())
            w = torch.zeros(K)
            mu = torch.zeros((K, d))
            cov = torch.zeros((K, d, d))
            for j in range(K):
                indx = np.where(y_pred == j)
                w[j] = torch.from_numpy(np.array(np.sum(y_pred == j) / Ns)).float()
                mu[j, :] = torch.mean(x[indx[0], :], dim=0)
                # cov[j, :, :] = torch.cov(x[indx[0], :].T)
                cov[j, :, :] = torch_cov(x[indx[0], :])
            H_u = self.H_upper(w, mu, cov, self.beta)
            Ent = H_u
        else:
            # cov = torch.cov(x.T)
            cov = torch_cov(x)
            Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)).cuda() + 1) + 0.5 * torch.logdet(
                self.beta * torch.eye(d).cuda() + cov)
        return Ent

    def covariance2entropy_estimator_GMM(self, x, cov_type='full'):
        # w k dim
        # Sigma k * d* d
        K = self.Kmax
        GMM = GaussianMixture(n_components=K, init_params='kmeans', covariance_type=cov_type)
        GMM.fit(X=x.detach().numpy())
        w = torch.from_numpy(GMM.weights_).float()
        mu = torch.from_numpy(GMM.means_).float()
        if cov_type == 'full':
            cov = torch.from_numpy(GMM.covariances_).float()
        else:
            cov_diag = torch.from_numpy(GMM.covariances_).float()
            cov = torch.zeros((K, x.shape[1], x.shape[1]))
            for k in range(K):
                cov[k, :, :] = torch.diag(cov_diag[k, :])
        d = mu.shape[1]
        if K > 1:
            Ent = self.H_upper(w, mu, cov, self.beta)
        else:
            # Sigma = torch.cov(x.T)
            Sigma = torch_cov(x)
            Ent = d / 2 * (torch.log(torch.tensor(2 * math.pi)).cuda() + 1) + 0.5 * torch.logdet(
                self.beta * torch.eye(d).cuda() + Sigma)
        return Ent

    def MILE_estimate(self, x, y):
        if self.method == 'Kmeans':
            x_n = torch.randn_like(x)
            mix_est = self.covariance2entropy_estimator(x_n)
            bias_est_1x = self.covariance2entropy_singleGaussian(x_n) - mix_est

            y_n = torch.randn_like(y)
            mix_est = self.covariance2entropy_estimator(y_n)
            bias_est_1y = self.covariance2entropy_singleGaussian(y_n) - mix_est

            x_n = torch.randn_like(torch.concatenate((x, y), dim=1))
            mix_est = self.covariance2entropy_estimator(x_n)
            bias_est_2xy = self.covariance2entropy_singleGaussian(x_n) - mix_est

            Hx = self.covariance2entropy_estimator(x)
            Hy = self.covariance2entropy_estimator(y)
            Hxy = self.covariance2entropy_estimator(torch.concatenate((x, y), dim=1))

        elif self.method == 'GMM':

            x_n = torch.randn_like(x)
            mix_est = self.covariance2entropy_estimator_GMM(x_n)
            bias_est_1x = self.covariance2entropy_singleGaussian(x_n) - mix_est

            y_n = torch.randn_like(y)
            mix_est = self.covariance2entropy_estimator_GMM(y_n)
            bias_est_1y = self.covariance2entropy_singleGaussian(y_n) - mix_est

            x_n = torch.randn_like(torch.concatenate((x, y), dim=1))
            mix_est = self.covariance2entropy_estimator_GMM(x_n)
            bias_est_2xy = self.covariance2entropy_singleGaussian(x_n) - mix_est

            Hx = self.covariance2entropy_estimator_GMM(x)
            Hy = self.covariance2entropy_estimator_GMM(y)
            Hxy = self.covariance2entropy_estimator_GMM(torch.concatenate((x, y), dim=1))

        # use H(x|y)=H(x,y)-H(x)
        # use MI(x,y)=H(x)+H(y)-H(x,y)
        MIx_y_unbias = Hx + Hy - Hxy + bias_est_1x + bias_est_1y - bias_est_2xy

        return MIx_y_unbias

    def MILE_estimate_pairs(self, x_pos, x_neg):
        dim = x_pos.shape[1] // 2
        # x1_margin_sample=torch.concat((x_pos[:,0:dim],x_neg[:,0:dim]),dim=0)
        # x2_margin_sample=torch.concat((x_pos[:,dim:],x_neg[:,dim:]),dim=0)
        x1_margin_sample = x_neg[:, 0:dim]
        x2_margin_sample = x_neg[:, dim:]
        x1x2_joint_sample = x_pos

        if self.method == 'Kmeans':
            x_n = torch.randn_like(x1_margin_sample)
            mix_est = self.covariance2entropy_estimator(x_n)
            bias_est_1a = self.covariance2entropy_singleGaussian(x_n) - mix_est.cuda()

            x_n = torch.randn_like(x2_margin_sample)
            mix_est = self.covariance2entropy_estimator(x_n)
            bias_est_1b = self.covariance2entropy_singleGaussian(x_n) - mix_est.cuda()
            #
            x_n = torch.randn_like(x1x2_joint_sample)
            mix_est = self.covariance2entropy_estimator(x_n)
            bias_est_2 = self.covariance2entropy_singleGaussian(x_n) - mix_est.cuda()

            Hx = self.covariance2entropy_estimator(x1_margin_sample).cuda()
            Hy = self.covariance2entropy_estimator(x2_margin_sample).cuda()
            Hxy = self.covariance2entropy_estimator(x1x2_joint_sample).cuda()

        elif self.method == 'GMM':

            x_n = torch.randn_like(x1_margin_sample)
            mix_est = self.covariance2entropy_estimator_GMM(x_n)
            bias_est_1a = self.covariance2entropy_singleGaussian(x_n) - mix_est

            x_n = torch.randn_like(x2_margin_sample)
            mix_est = self.covariance2entropy_estimator_GMM(x_n)
            bias_est_1b = self.covariance2entropy_singleGaussian(x_n) - mix_est
            #
            x_n = torch.randn_like(x1x2_joint_sample)
            mix_est = self.covariance2entropy_estimator_GMM(x_n)
            bias_est_2 = self.covariance2entropy_singleGaussian(x_n) - mix_est

            Hx = self.covariance2entropy_estimator_GMM(x1_margin_sample)
            Hy = self.covariance2entropy_estimator_GMM(x2_margin_sample)
            Hxy = self.covariance2entropy_estimator_GMM(x1x2_joint_sample)

        # use H(x|y)=H(x,y)-H(x)
        # use MI(x,y)=H(x)+H(y)-H(x,y)
        MIx_y_unbias = Hx + Hy - Hxy  + bias_est_1a+bias_est_1b-bias_est_2

        return MIx_y_unbias
    

    def pos_neg(self,X):
        
        bs, n_dim = X.shape    # bs 2000 
        bs = int(bs/self.n_views) # bs = 2000 / 20 = 100
        device = X.device
        
        
        labels = torch.cat([torch.arange(bs) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        # similarity_matrix = torch.matmul(X, X.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)  # torch.Size([30, 30])
        labels = labels[~mask].view(labels.shape[0], -1) # torch.Size([30, 29])
        
        # x_tiled = torch.stack([X] * X.shape[0], dim=0)
        # y_tiled = torch.stack([X] * X.shape[0], dim=1) # torch.Size([30, 30, 1024])
        # X_matrix = torch.cat([x_tiled, y_tiled], dim=2) # torch.Size([30, 30, 2048])
        X_matrix = torch.cat([torch.stack([X] * X.shape[0], dim=0), torch.stack([X] * X.shape[0], dim=1)], dim=2)
        
        X_matrix = X_matrix[~mask].view(X_matrix.shape[0], -1, X_matrix.shape[2])  # torch.Size([30, 29, 2048])
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = X_matrix[labels.bool()].view(labels.shape[0],-1, X_matrix.shape[2])  # torch.Size([30, 2, 2048])

        # select only the negatives
        negatives = X_matrix[~labels.bool()].view(X_matrix.shape[0], -1, X_matrix.shape[2]) # torch.Size([30, 27, 2048])

        return positives, negatives
    
    def forward(self, X:torch.tensor):
        pos, neg = self.pos_neg(X)
        loss = self.MILE_estimate_pairs(pos.view(-1, pos.shape[-1]), neg.view(-1, neg.shape[-1]))
        return loss
