import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pdb
from models.GMM_module import *
from functions import *
from scipy.optimize import linear_sum_assignment

class Critic(nn.Module):

    def __init__(self, latent_size = 512, mid_size = 256):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(latent_size, mid_size)
        self.fc2 = nn.Linear(mid_size, 1)
        self.bn1 = nn.BatchNorm2d(mid_size)


class GMM_Model(nn.Module):
    def __init__(self, input_shape=[1,32,32], unsupervised_em_iters=5, semisupervised_em_iters=5,  fix_pi=False,
                 hidden_size=64, component_size=20, latent_size=64, train_mc_sample_size=10, test_mc_sample_size=10):
        super(GMM_Model, self).__init__()
        self.input_shape = input_shape
        self.unsupervised_em_iters = unsupervised_em_iters
        self.semisupervised_em_iters = semisupervised_em_iters
        self.fix_pi = fix_pi
        self.hidden_size = hidden_size
        self.last_hidden_size = 2*2*hidden_size
        self.component_size = component_size
        self.latent_size = latent_size
        self.train_mc_sample_size = train_mc_sample_size
        self.test_mc_sample_size = test_mc_sample_size

        self.q_z_given_x_net = nn.Sequential(
            SAB(dim_in=self.last_hidden_size, dim_out=self.last_hidden_size, num_heads=4, ln=False),
            SAB(dim_in=self.last_hidden_size, dim_out=self.last_hidden_size, num_heads=4, ln=False),
            nn.Linear(self.last_hidden_size, 2 * self.hidden_size)
        )

        self.proj = nn.Sequential(
            nn.Linear(latent_size, self.last_hidden_size),
            nn.ELU(inplace=True),
            nn.Linear(self.last_hidden_size, self.last_hidden_size),
            nn.ELU(inplace=True),
            nn.Linear(self.last_hidden_size, self.last_hidden_size),
            nn.ELU(inplace=True),
        )

        self.decoder = CIFAR10Decoder(hidden_size=hidden_size)
        # self.rec_criterion = nn.BCELoss(reduction='sum')
        self.rec_criterion = nn.MSELoss
        self.register_buffer('log_norm_constant', torch.tensor(-0.5 * np.log(2 * np.pi)))
        self.register_buffer('uniform_pi', torch.ones(self.component_size)/self.component_size)

    def reparametrize(self, mean, logvar, S=1):
        mean = mean.unsqueeze(1).repeat(1, S, 1)
        logvar = logvar.unsqueeze(1).repeat(1, S, 1)
        std = logvar.mul(0.5).exp()
        eps = torch.randn_like(mean)
        return eps.mul(std).add(mean)

    def Easy_reparametrize(self, mean, logvar, S=1):
        mean = mean.unsqueeze(1).repeat(1, S, 1)
        logvar = logvar.unsqueeze(1).repeat(1, S, 1)
        std = logvar.mul(0.5).exp()
        eps = torch.randn_like(mean)
        eps = eps / (eps.max()*2)
        return eps.mul(std).add(mean)

    def Same_reparametrize(self, mean, logvar, S=1):
        mean = mean.unsqueeze(1).repeat(1, S, 1)
        return mean

    def gaussian_log_prob(self, x, mean, logvar=None, pi=None, **kwargs):
        if logvar is None:
            logvar = torch.zeros_like(mean)
        a = (x - mean).pow(2)
        # log_p = -0.5 * (logvar + a / (logvar.exp() + 1e-9))
        log_p = -0.5 * (logvar + a / (logvar.exp()))
        log_p = log_p + self.log_norm_constant

        # if False in torch.isfinite(log_p):
        #     # pdb.set_trace()
        #     print('log_p becomes None')
        #     return None
        if 'meanC' in kwargs and kwargs['meanC']:
            return log_p.mean(dim=-1)
        else:
            return log_p.sum(dim=-1)

    def gaussian_log_prob_safe(self, x, mean, logvar=None, pi=None):
        if logvar is None:
            logvar = torch.zeros_like(mean)
        a = (x - mean).pow(2)
        # log_p = -0.5 * (logvar + a / (logvar.exp() + 1e-9))
        log_p = -0.5 * (logvar + a / (logvar.exp()))
        log_p = log_p + self.log_norm_constant

        if False in torch.isfinite(log_p):
           return None

        return log_p.sum(dim=-1)

        # return log_p.mean(dim=-1)

    def get_posterior(self, H, mc_sample_size=10):
        ## q(z|x) ##
        q_z_given_x_mean, q_z_given_x_logvar = self.q_z_given_x_net(H).split(self.latent_size, dim=-1)
        q_z_given_x = self.reparametrize(mean=q_z_given_x_mean, logvar=q_z_given_x_logvar, S=mc_sample_size)
        return q_z_given_x_mean, q_z_given_x_logvar, q_z_given_x

    def get_supervised_onecls_prior(self, z):
        mean = z.mean(dim=0)
        L2norm = (z - mean[None, :]).pow(2)
        var = L2norm.mean(dim=0)
        logvar = torch.log(var)

        return mean, logvar

    def get_supervised_prior(self, z, label, init_mean = None, fixvar=True, iter = None, **kwargs):
        initial_pi = self.uniform_pi
        initial_mean = init_mean

        initial_logvar = torch.zeros_like(initial_mean)
        psi = (initial_pi, initial_mean, initial_logvar)

        tmp = self.get_supervised_params(X=z, label=label, psi=psi, **kwargs)
        psi = [param.detach() for param in tmp]
        return psi

    def get_unsupervised_prior(self, z, init_mean = None, fixvar=True, iter = None, **kwargs):
        sample_size = z.shape[0]
        initial_pi = self.uniform_pi
        idxs = torch.from_numpy(np.random.choice(sample_size, self.component_size, replace=False)).to(z.device)
        if init_mean is not None:
            initial_mean = init_mean
        else:
            initial_mean = torch.index_select(z, dim=0, index=idxs)

        if fixvar: # Fix covariance matrix to identity matrix #
            initial_logvar = torch.zeros_like(initial_mean)
            psi = (initial_pi, initial_mean, initial_logvar)
            for _ in range(self.unsupervised_em_iters):
                psi = self.get_unsupervised_params(X=z, psi=psi)
            psi = [param.detach() for param in psi]
            return psi
        else: # Does not fix covariance matrix #
            initial_logvar = torch.zeros_like(initial_mean)
            psi = (initial_pi, initial_mean, initial_logvar)

            iterNum = iter if iter else self.unsupervised_em_iters
            for em_idx in range(iterNum):
                tmp = self.get_unsupervised_params(X=z, psi=psi, fixvar=False, **kwargs)
                if inDict(kwargs, 'safeUpdate'):
                    diversedModes = torch.where(torch.isfinite(tmp[2]) == False)[0].unique()
                    for diversedMode in diversedModes:
                        tmp[1][diversedMode] = psi[1][diversedMode]
                        tmp[2][diversedMode] = psi[2][diversedMode]
                if False in torch.isfinite(tmp[2]):
                    print(f"BREAK: handle exception when there is -inf, +inf, nan in logvar at em_iter {em_idx}.")
                    break
                psi = tmp
                # source = '_UL_' if 'UL' in kwargs['source'] else 'L'
                # PlotGMM(psi, z, f'{source}_{em_idx}iter')

            psi = [param.detach() for param in psi]
            return psi

    def get_unsupervised_params(self, X, psi, fixvar=True, **kwargs):
        sample_size = X.shape[0]

        if fixvar: # Fix covariance matrix to identity matrix #
            pi, mean, logvar = psi
            log_likelihoods = self.gaussian_log_prob(
                X[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(sample_size, 1, 1)
            ) + torch.log(pi[None, :].repeat(sample_size, 1))

            posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
            N = torch.sum(posteriors, dim=0)
            if not self.fix_pi:
                pi = N / N.sum(dim=-1, keepdim=True)
            if self.fix_pi and inDict(kwargs, 'update_pi'):
                pi = N / N.sum(dim=-1, keepdim=True)

            denominator = N[:, None].repeat(1, self.latent_size)
            mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
            return pi, mean, torch.zeros_like(mean)

        else: # Does not fix covariance matrix #
            pi, mean, logvar = psi
            log_likelihoods = self.gaussian_log_prob( # In original Meta-GMVAE, var is fixed in here..
                X[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(sample_size, 1, 1),
                logvar[None, :, :].repeat(sample_size, 1, 1), pi = pi[None, :].repeat(sample_size, 1), meanC=False
            )
            if 'meanScaler' in kwargs and kwargs['meanScaler']:
                log_likelihoods = log_likelihoods * kwargs['meanScaler']
            if log_likelihoods == None:
                return pi, mean, logvar
            else:
                log_likelihoods = log_likelihoods + torch.log(pi[None, :].repeat(sample_size, 1))

            posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
            N = torch.sum(posteriors, dim=0)
            if not self.fix_pi:
                pi = N / N.sum(dim=-1, keepdim=True)
            if self.fix_pi and inDict(kwargs, 'update_pi'):
                pi = N / N.sum(dim=-1, keepdim=True)

            denominator = N[:, None].repeat(1, self.latent_size) # [10,512]
            mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / (denominator + 1e-9)
            L2norm = (X[:,None,:] - mean[None,:,:]).pow(2) + 1e-9 # [4096,10,512]
            # L2norm = (X.repeat(100, 1, 1) - mean.repeat(100, 1, 1)).pow(2)
            weighted_L2norm = posteriors.unsqueeze(dim=-1) * L2norm # [4096,10,512]
            var = weighted_L2norm.sum(dim=0) / denominator # [10,512]
            # X2 = torch.matmul(posteriors.permute([1, 0]).contiguous(), X.pow(2.0))
            # X_mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) * mean
            # mean2 = N[:, None].repeat(1, self.latent_size) * mean.pow(2.0)
            # L2norm = X2 - 2 * X_mean + mean2
            # var = L2norm / (denominator + 1e-9)
            logvar = torch.log(var)

            return pi, mean, logvar

    def get_supervised_params(self, X, label, psi, **kwargs):
        _pi, _mean, _logvar = psi
        pi = _pi
        if inDict(kwargs, 'PIMOVE'):
            pi = label.bincount() / len(label)

        Label = torch.stack([(label==i).float() for i in range(len(pi))])  # [100,1000]
        denominator = Label.sum(dim=1, keepdim=True)                       # [100,1]
        mean = (Label @ X) / denominator                                   # [100,1000] @ [1000,512]
        # L2norm = (X[:, None, :] - _mean[None, :, :]).pow(2)                # [1000,100,512]
        if inDict(kwargs, 'L1var'):
            L2norm = (X[:, None, :] - _mean[None, :, :]).abs()                # [1000,100,512]
        else:
            L2norm = (X[:, None, :] - _mean[None, :, :]).pow(2)            # [1000,100,512]
        weighted_L2norm = Label.permute([1,0]).unsqueeze(dim=-1) * L2norm  # [1000,100,512]
        var = weighted_L2norm.sum(dim=0) / denominator                     # [100,512]
        logvar = torch.log(var)
        if False in torch.isfinite(logvar):
            print(f'weighted_L2norm min : {weighted_L2norm.min()}')
            print(f'denominator min : {denominator.min()}')
            print(f"ALERT: handle exception when there is -inf, +inf, nan in logvar.")
        return pi, mean, logvar

    def get_unsupervised_params_uniform(self, X, psi, fixvar=True, fixpi = True):
        sample_size = X.shape[0]

        if fixvar: # Fix covariance matrix to identity matrix #
            pi, mean = psi
            log_likelihoods = self.gaussian_log_prob(
                X[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(sample_size, 1, 1)
            ) + torch.log(pi[None, :].repeat(sample_size, 1))

            posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
            N = torch.sum(posteriors, dim=0)
            if not fixpi:
                pi = N / N.sum(dim=-1, keepdim=True)

            denominator = N[:, None].repeat(1, self.latent_size)
            mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
            return pi, mean
        else: # Does not fix covariance matrix #
            pi, mean, logvar = psi
            log_likelihoods = self.gaussian_log_prob(
                X[:, None, :].repeat(1, self.component_size, 1),
                mean[None, :, :].repeat(sample_size, 1, 1),
                logvar[None, :, :].repeat(sample_size, 1, 1)
            )
            if log_likelihoods == None:
                return pi, mean, logvar
            else:
                log_likelihoods = log_likelihoods + torch.log(pi[None, :].repeat(sample_size, 1))

            posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
            N = torch.sum(posteriors, dim=0)
            if not fixpi:
                pi = N / N.sum(dim=-1, keepdim=True)

            denominator = N[:, None].repeat(1, self.latent_size) # [10,512]
            mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
            L2norm = (X[:,None,:] - mean[None,:,:]).pow(2) # [4096,10,512]
            weighted_L2norm = posteriors.unsqueeze(dim=-1) * L2norm # [4096,10,512]
            var = weighted_L2norm.sum(dim=0) / denominator # [10,512]

            logvar = torch.log(var)
            return pi, mean, logvar

    def NoHint_GMM_test(self, embedding, label, sample_size, fixvar=True, _proto = None, iter = 50):
        batch_size, latent_size = embedding.shape
        q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
        # q_z_given_x = self.Easy_reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
        all_z = q_z_given_x.view(-1, latent_size)
        p_z_given_psi = self.get_unsupervised_prior_uniform(z=all_z, fixvar=fixvar, iter = iter)
        if fixvar:
            p_y_given_psi_pi, p_z_given_y_psi_mean = p_z_given_psi
            p_z_given_y_psi_logvar = torch.zeros_like(p_z_given_y_psi_mean)
            logvar = None
        else:
            p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar = p_z_given_psi
            logvar = p_z_given_y_psi_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)

        log_likelihoods = self.gaussian_log_prob(
            q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
            p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
        ) + torch.log(p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))

        posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
        preds = posteriors.mean(dim=-2).argmax(dim=-1)
        if _proto is None:
            prototype = self.GetPrototype(embedding, label)
        else:
            prototype = _proto
        GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label, draw=False)
        GMM_preds = torch.tensor(list(map(lambda x: GMM2cls[x.item()], preds)))
        # GMM_preds = predss

        sortedMean, sortedLogvar = torch.zeros_like(p_z_given_y_psi_mean), torch.zeros_like(p_z_given_y_psi_logvar)
        for gmmIdx, protoIdx in GMM2cls.items():
            sortedMean[protoIdx] = p_z_given_y_psi_mean[gmmIdx]
            sortedLogvar[protoIdx] = p_z_given_y_psi_logvar[gmmIdx]

        return GMM_preds, p_y_given_psi_pi, sortedMean, sortedLogvar

    def Hint_GMM_test(self, embedding, label, sample_size, fixvar=True, _prototype=None):
        batch_size, latent_size = embedding.shape
        if _prototype == None:
            prototype = self.GetPrototype(embedding, label)
        else:
            prototype = _prototype
        # q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
        q_z_given_x = self.Easy_reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
        all_z = q_z_given_x.view(-1, latent_size)
        p_z_given_psi = self.get_unsupervised_prior(z=all_z, init_mean = prototype, fixvar=fixvar)
        if fixvar:
            p_y_given_psi_pi, p_z_given_y_psi_mean = p_z_given_psi
            p_z_given_y_psi_logvar = logvar = None
        else:
            p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar = p_z_given_psi
            logvar = p_z_given_y_psi_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)

        log_likelihoods = self.gaussian_log_prob(
            q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
            p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
        ) + torch.log(p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))

        posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
        preds = posteriors.mean(dim=-2).argmax(dim=-1)
        GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label, _prototype=_prototype)
        # GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label)
        GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))
        GMM_mean_label = torch.tensor(list(map(lambda x:GMM2cls[x], np.arange(10))))

        return GMM_preds, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, GMM_mean_label, prototype

    def Hint_GMM_test2(self, embedding, label, sample_size, fixvar=True, _prototype=None, **kwargs):

        batch_size, latent_size = embedding.shape
        if _prototype == None:
            if label != None:
                prototype = self.GetPrototype(embedding, label)
            else:
                prototype = None
        else:
            prototype = _prototype
        # q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
        q_z_given_x = embedding.unsqueeze(dim=1)
        all_z = q_z_given_x.view(-1, latent_size)
        iter = inDict(kwargs,'ITER') if inDict(kwargs,'ITER') else 10
        p_z_given_psi = self.get_unsupervised_prior(z=all_z, init_mean = prototype, fixvar=fixvar, iter=iter, **kwargs) # 100 --> best
        # p_z_given_psi = self.get_unsupervised_prior(z=all_z, init_mean = prototype, fixvar=fixvar, meanScaler=100, iter=10, **kwargs) # 100 --> best

        if fixvar:
            p_y_given_psi_pi, p_z_given_y_psi_mean = p_z_given_psi
            p_z_given_y_psi_logvar = logvar = torch.zeros_like(p_z_given_y_psi_mean)
        else:
            p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar = p_z_given_psi
            logvar = p_z_given_y_psi_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)

        BDmat = BDMatrix(p_z_given_y_psi_mean, p_z_given_y_psi_logvar, kwargs['L_proto'], kwargs['L_logvar'])
        row_match, col_match = linear_sum_assignment(BDmat.cpu().numpy())
        # row_match of UL_batch_psi should be matched to col_match L_whole_psi

        alignIdx = {col_match[idx]:idx for idx in range(self.component_size)}
        aligned_pi = torch.stack([p_y_given_psi_pi[alignIdx[x]] for x in range(self.component_size)])
        aligned_mean = torch.stack([p_z_given_y_psi_mean[alignIdx[x]] for x in range(self.component_size)])
        aligned_logvar = torch.stack([p_z_given_y_psi_logvar[alignIdx[x]] for x in range(self.component_size)])

        log_likelihoods = self.gaussian_log_prob(
            q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
            p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
        )

        if log_likelihoods == None:
            return torch.zeros(batch_size), p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, None, prototype
        else:
            log_likelihoods = log_likelihoods + torch.log(p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))

        posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
        preds = posteriors.mean(dim=-2).argmax(dim=-1)
        if "L_proto" in kwargs:
            GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, kwargs['L_proto'], embedding=embedding, label=label, _prototype=_prototype, draw=False)
        else:
            GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label, _prototype=_prototype, draw=False)
        cls2GMM = {v:k for k,v in GMM2cls.items()}
        # GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label)
        GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))
        GMM_mean_label = torch.tensor(list(map(lambda x:GMM2cls[x], np.arange(10))))

        # return GMM_preds, p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, GMM_mean_label, prototype

        # aligned_pi = torch.stack([p_y_given_psi_pi[cls2GMM[x]] for x in range(self.component_size)])
        # aligned_mean = torch.stack([p_z_given_y_psi_mean[cls2GMM[x]] for x in range(self.component_size)])
        # aligned_logvar = torch.stack([p_z_given_y_psi_logvar[cls2GMM[x]] for x in range(self.component_size)])

        algin_LL = self.gaussian_log_prob(
            q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
            aligned_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), aligned_logvar
        )
        algin_LL = algin_LL + torch.log(aligned_pi[None, None, :].repeat(batch_size, sample_size, 1))
        posteriors = torch.exp(algin_LL - torch.logsumexp(algin_LL, dim=-1, keepdim=True))
        align_preds = posteriors.mean(dim=-2).argmax(dim=-1)

        return align_preds, aligned_pi, aligned_mean, aligned_logvar, GMM_mean_label, prototype

    def Given_GMM_test(self, embedding, label, sample_size, _pi, _mean, _logvar):
        batch_size, latent_size = embedding.shape
        # nC = len(_mean)
        # pi, mean, logvar = torch.ones(nC).to(embedding.device) / nC, _mean, _logvar
        pi, mean, logvar = _pi, _mean, _logvar
        q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)

        # log_likelihoods = self.gaussian_log_prob(
        #     q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
        #     p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
        # )
        #
        # if log_likelihoods == None:
        #     return torch.zeros(
        #         batch_size), p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, None, prototype
        # else:
        #     log_likelihoods = log_likelihoods + torch.log(
        #         p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))
        #
        # posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
        # preds = posteriors.mean(dim=-2).argmax(dim=-1)



        log_likelihoods = self.gaussian_log_prob(
            q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
            mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
        ) + torch.log(pi[None, None, :].repeat(batch_size, sample_size, 1))

        posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
        preds = posteriors.mean(dim=-2).argmax(dim=-1)
        GMM2cls = self.Dist2Proto(mean, mean, embedding=embedding, label=label, _prototype=mean)
        GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))

        return GMM_preds

    def Given_GMM_test2(self, embedding, label, sample_size, _pi, _mean, _logvar):
        batch_size, latent_size = embedding.shape
        nC = len(_mean)
        pi, mean, logvar = torch.ones(nC).to(embedding.device) / nC, _mean, _logvar
        pi = _pi

        log_likelihoods = self.gaussian_log_prob(
            embedding[:, None, :].repeat(1, self.component_size, 1),
            mean[None, :, :].repeat(batch_size, 1, 1),
            logvar[None, :, :].repeat(batch_size, 1, 1)
        ) + torch.log(pi[None, :].repeat(batch_size, 1))

        log_likelihoods3 = self.gaussian_log_prob(
            embedding[:, None, :].repeat(1, self.component_size, 1),
            mean[None, :, :].repeat(batch_size, 1, 1),
            logvar[None, :, :].repeat(batch_size, 1, 1), meanC=True
        )*4 + torch.log(pi[None, :].repeat(batch_size, 1))

        posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
        posteriors3 = torch.exp(log_likelihoods3 - torch.logsumexp(log_likelihoods3, dim=-1, keepdim=True))
        preds = posteriors.argmax(dim=-1)
        GMM2cls = self.Dist2Proto(mean, mean, embedding=embedding, label=label, _prototype=mean)
        GMM_preds = torch.tensor(list(map(lambda x: GMM2cls[x.item()], preds)))

        return GMM_preds, posteriors

    def GetPrototype(self, embedding, label):
        def supp_idxs(c):
            return label.eq(c).nonzero().flatten()

        classes = torch.unique(label)
        support_idxs = list(map(supp_idxs, classes))
        prototypes = torch.stack([embedding[idx_list].mean(0) for idx_list in support_idxs])

        return prototypes

    def Dist2Proto(self, GMM_mean, proto, embedding=None, label=None, _prototype=None, draw=False):
        GMM2cls, nCls = {}, len(proto)
        if _prototype != None:
            for i in range(nCls):
                GMM2cls[i] = i
            return GMM2cls

        L2dist = torch.zeros(nCls, nCls).to(proto.device)
        for gmmIdx in range(nCls):
            for proIdx in range(nCls):
                gmm_mean = GMM_mean[gmmIdx].repeat(nCls, 1)
                dist = (gmm_mean - proto).pow(2).mean(dim=-1)
                L2dist[gmmIdx] = dist
        # L2dist = (GMM_mean[:,None] - proto[None,:]).pow(2).mean(dim=-1) # [10,10,512] --> [10,10]
        closest = L2dist.argmin(dim=1)
        for GMMidx, c in enumerate(closest):
            GMM2cls[GMMidx] = c.item()

        sortedGMM = torch.zeros_like(proto)
        for gmmIdx, protoIdx in GMM2cls.items():
            sortedGMM[protoIdx] = GMM_mean[gmmIdx]

        if draw:
            print("DrawTSNE in Dist2Proto")
            # drawTSNE2(GMM_mean, proto, name1='GMM_mean', name2='prototype')
            drawTSNE3(sortedGMM, proto, embedding, label1 = torch.arange(10), label2 = torch.arange(10), label3=label,
                      name1='GMM_mean', name2='prototype', name3='embedding')

        return GMM2cls

    def GMM_prediction(self, embedding, pi, mean, logvar):
        batch_size = embedding.size(0)
        LL = self.gaussian_log_prob(
            embedding[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(batch_size, 1, 1), logvar)
        LL = LL + torch.log(pi[None, :].repeat(batch_size, 1))
        posteriors = torch.exp(LL - torch.logsumexp(LL, dim=-1, keepdim=True))
        GMM_pred = posteriors.argmax(dim=-1)

        return GMM_pred


def Make_GMM_Model(args):
    GMM_model = GMM_Model(unsupervised_em_iters = args.unsupervised_em_iters,
                          semisupervised_em_iters = args.semisupervised_em_iters,
                          fix_pi = args.fix_pi,
                          hidden_size = args.hidden_size,
                          component_size = args.component_size,
                          latent_size = args.latent_size,
                          train_mc_sample_size = args.train_mc_sample_size,
                          test_mc_sample_size = args.test_mc_sample_size)
    return GMM_model

