# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import math
# import pdb
# from models.GMM_module import *
# from functions import *
#
# class Critic(nn.Module):
#
#     def __init__(self, latent_size = 512, mid_size = 256):
#         super(Critic, self).__init__()
#         self.fc1 = nn.Linear(latent_size, mid_size)
#         self.fc2 = nn.Linear(mid_size, 1)
#         self.bn1 = nn.BatchNorm2d(mid_size)
#
#
# class GMM_Model(nn.Module):
#     def __init__(self, input_shape=[1,32,32], unsupervised_em_iters=5, semisupervised_em_iters=5,  fix_pi=False,
#                  hidden_size=64, component_size=20, latent_size=64, train_mc_sample_size=10, test_mc_sample_size=10):
#         super(GMM_Model, self).__init__()
#         self.input_shape = input_shape
#         self.unsupervised_em_iters = unsupervised_em_iters
#         self.semisupervised_em_iters = semisupervised_em_iters
#         self.fix_pi = fix_pi
#         self.hidden_size = hidden_size
#         self.last_hidden_size = 2*2*hidden_size
#         self.component_size = component_size
#         self.latent_size = latent_size
#         self.train_mc_sample_size = train_mc_sample_size
#         self.test_mc_sample_size = test_mc_sample_size
#
#         self.q_z_given_x_net = nn.Sequential(
#             SAB(dim_in=self.last_hidden_size, dim_out=self.last_hidden_size, num_heads=4, ln=False),
#             SAB(dim_in=self.last_hidden_size, dim_out=self.last_hidden_size, num_heads=4, ln=False),
#             nn.Linear(self.last_hidden_size, 2 * self.hidden_size)
#         )
#
#         self.proj = nn.Sequential(
#             nn.Linear(latent_size, self.last_hidden_size),
#             nn.ELU(inplace=True),
#             nn.Linear(self.last_hidden_size, self.last_hidden_size),
#             nn.ELU(inplace=True),
#             nn.Linear(self.last_hidden_size, self.last_hidden_size),
#             nn.ELU(inplace=True),
#         )
#
#         self.decoder = CIFAR10Decoder(hidden_size=hidden_size)
#         # self.rec_criterion = nn.BCELoss(reduction='sum')
#         self.rec_criterion = nn.MSELoss
#         self.register_buffer('log_norm_constant', torch.tensor(-0.5 * np.log(2 * np.pi)))
#         self.register_buffer('uniform_pi', torch.ones(self.component_size)/self.component_size)
#
#     def reparametrize(self, mean, logvar, S=1):
#         mean = mean.unsqueeze(1).repeat(1, S, 1)
#         logvar = logvar.unsqueeze(1).repeat(1, S, 1)
#         std = logvar.mul(0.5).exp()
#         eps = torch.randn_like(mean)
#         return eps.mul(std).add(mean)
#
#     def Easy_reparametrize(self, mean, logvar, S=1):
#         mean = mean.unsqueeze(1).repeat(1, S, 1)
#         logvar = logvar.unsqueeze(1).repeat(1, S, 1)
#         std = logvar.mul(0.5).exp()
#         eps = torch.randn_like(mean)
#         eps = eps / (eps.max()*2)
#         return eps.mul(std).add(mean)
#
#     def Same_reparametrize(self, mean, logvar, S=1):
#         mean = mean.unsqueeze(1).repeat(1, S, 1)
#         return mean
#
#     def gaussian_log_prob(self, x, mean, logvar=None, pi=None, **kwargs):
#         if logvar is None:
#             logvar = torch.zeros_like(mean)
#         a = (x - mean).pow(2)
#         # log_p = -0.5 * (logvar + a / (logvar.exp() + 1e-9))
#         log_p = -0.5 * (logvar + a / (logvar.exp()))
#         log_p = log_p + self.log_norm_constant
#
#         # if False in torch.isfinite(log_p):
#         #     # pdb.set_trace()
#         #     print('log_p becomes None')
#         #     return None
#         if 'meanC' in kwargs and kwargs['meanC']:
#             return log_p.mean(dim=-1)
#         else:
#             return log_p.sum(dim=-1)
#
#     def gaussian_log_prob_safe(self, x, mean, logvar=None, pi=None):
#         if logvar is None:
#             logvar = torch.zeros_like(mean)
#         a = (x - mean).pow(2)
#         # log_p = -0.5 * (logvar + a / (logvar.exp() + 1e-9))
#         log_p = -0.5 * (logvar + a / (logvar.exp()))
#         log_p = log_p + self.log_norm_constant
#
#         if False in torch.isfinite(log_p):
#            return None
#
#         return log_p.sum(dim=-1)
#
#         # return log_p.mean(dim=-1)
#
#     def get_posterior(self, H, mc_sample_size=10):
#         ## q(z|x) ##
#         q_z_given_x_mean, q_z_given_x_logvar = self.q_z_given_x_net(H).split(self.latent_size, dim=-1)
#         q_z_given_x = self.reparametrize(mean=q_z_given_x_mean, logvar=q_z_given_x_logvar, S=mc_sample_size)
#         return q_z_given_x_mean, q_z_given_x_logvar, q_z_given_x
#
#     def get_unsupervised_prior(self, z, init_mean = None, fixvar=True, iter = None):
#         sample_size = z.shape[0]
#         initial_pi = self.uniform_pi
#         idxs = torch.from_numpy(np.random.choice(sample_size, self.component_size, replace=False)).to(z.device)
#         if init_mean is not None:
#             initial_mean = init_mean
#         else:
#             initial_mean = torch.index_select(z, dim=0, index=idxs)
#
#         if fixvar: # Fix covariance matrix to identity matrix #
#             psi = (initial_pi, initial_mean)
#             for _ in range(self.unsupervised_em_iters):
#                 psi = self.get_unsupervised_params(X=z, psi=psi)
#             psi = (param.detach() for param in psi)
#             return psi
#         else: # Does not fix covariance matrix #
#             initial_logvar = torch.zeros_like(initial_mean)
#             psi = (initial_pi, initial_mean, initial_logvar)
#
#             iterNum = iter if iter else self.unsupervised_em_iters
#             for em_idx in range(iterNum):
#                 tmp = self.get_unsupervised_params(X=z, psi=psi, fixvar=False)
#                 if False in torch.isfinite(tmp[2]):
#                     print(f"BREAK: handle exception when there is -inf, +inf, nan in logvar at em_iter {em_idx}.")
#                     break
#                 psi = tmp
#             psi = (param.detach() for param in psi)
#             return psi
#
#     def get_unsupervised_prior_uniform(self, z, init_mean=None, fixvar=True, iter=None):
#         sample_size = z.shape[0]
#         initial_pi = self.uniform_pi
#         idxs = torch.from_numpy(np.random.choice(sample_size, self.component_size, replace=False)).to(z.device)
#         if init_mean is not None:
#             initial_mean = init_mean
#         else:
#             initial_mean = torch.index_select(z, dim=0, index=idxs)
#
#         iterNum = iter if iter else self.unsupervised_em_iters
#         if fixvar:  # Fix covariance matrix to identity matrix #
#             psi = (initial_pi, initial_mean)
#             for _ in range(iterNum):
#                 psi = self.get_unsupervised_params_uniform(X=z, psi=psi)
#             psi = (param.detach() for param in psi)
#             return psi
#         else:  # Does not fix covariance matrix #
#             initial_logvar = torch.zeros_like(initial_mean)
#             psi = (initial_pi, initial_mean, initial_logvar)
#             for em_idx in range(iterNum):
#                 tmp = self.get_unsupervised_params_uniform(X=z, psi=psi, fixvar=False)
#                 if False in torch.isfinite(tmp[2]):  # handle exception when there is -inf, +inf, nan in logvar.
#                     print("BREAK: handle exception when there is -inf, +inf, nan in logvar.")
#                     break
#                 psi = tmp
#             psi = (param.detach() for param in psi)
#             return psi
#
#     def get_unsupervised_params(self, X, psi, fixvar=True):
#         sample_size = X.shape[0]
#
#         if fixvar: # Fix covariance matrix to identity matrix #
#             pi, mean = psi
#             log_likelihoods = self.gaussian_log_prob(
#                 X[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(sample_size, 1, 1)
#             ) + torch.log(pi[None, :].repeat(sample_size, 1))
#
#             posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#             N = torch.sum(posteriors, dim=0)
#             if not self.fix_pi:
#                 pi = N / N.sum(dim=-1, keepdim=True)
#
#             denominator = N[:, None].repeat(1, self.latent_size)
#             mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
#             return pi, mean
#         else: # Does not fix covariance matrix #
#             pi, mean, logvar = psi
#             log_likelihoods = self.gaussian_log_prob( # In original Meta-GMVAE, var is fixed in here..
#                 X[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(sample_size, 1, 1),
#                 logvar[None, :, :].repeat(sample_size, 1, 1), pi = pi[None, :].repeat(sample_size, 1)
#             )
#             if log_likelihoods == None:
#                 return pi, mean, logvar
#             else:
#                 log_likelihoods = log_likelihoods + torch.log(pi[None, :].repeat(sample_size, 1))
#
#             posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#             N = torch.sum(posteriors, dim=0)
#             if not self.fix_pi:
#                 pi = N / N.sum(dim=-1, keepdim=True)
#
#             denominator = N[:, None].repeat(1, self.latent_size) # [10,512]
#             mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
#             L2norm = (X[:,None,:] - mean[None,:,:]).pow(2) # [4096,10,512]
#             weighted_L2norm = posteriors.unsqueeze(dim=-1) * L2norm # [4096,10,512]
#             var = weighted_L2norm.sum(dim=0) / denominator # [10,512]
#             # X2 = torch.matmul(posteriors.permute([1, 0]).contiguous(), X.pow(2.0))
#             # X_mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) * mean
#             # mean2 = N[:, None].repeat(1, self.latent_size) * mean.pow(2.0)
#             # L2norm = X2 - 2 * X_mean + mean2
#             # var = L2norm / (denominator + 1e-9)
#             logvar = torch.log(var)
#             return pi, mean, logvar
#
#     def get_unsupervised_params_uniform(self, X, psi, fixvar=True, fixpi = True):
#         sample_size = X.shape[0]
#
#         if fixvar: # Fix covariance matrix to identity matrix #
#             pi, mean = psi
#             log_likelihoods = self.gaussian_log_prob(
#                 X[:, None, :].repeat(1, self.component_size, 1), mean[None, :, :].repeat(sample_size, 1, 1)
#             ) + torch.log(pi[None, :].repeat(sample_size, 1))
#
#             posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#             N = torch.sum(posteriors, dim=0)
#             if not fixpi:
#                 pi = N / N.sum(dim=-1, keepdim=True)
#
#             denominator = N[:, None].repeat(1, self.latent_size)
#             mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
#             return pi, mean
#         else: # Does not fix covariance matrix #
#             pi, mean, logvar = psi
#             log_likelihoods = self.gaussian_log_prob(
#                 X[:, None, :].repeat(1, self.component_size, 1),
#                 mean[None, :, :].repeat(sample_size, 1, 1),
#                 logvar[None, :, :].repeat(sample_size, 1, 1)
#             )
#             if log_likelihoods == None:
#                 return pi, mean, logvar
#             else:
#                 log_likelihoods = log_likelihoods + torch.log(pi[None, :].repeat(sample_size, 1))
#
#             posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#             N = torch.sum(posteriors, dim=0)
#             if not fixpi:
#                 pi = N / N.sum(dim=-1, keepdim=True)
#
#             denominator = N[:, None].repeat(1, self.latent_size) # [10,512]
#             mean = torch.matmul(posteriors.permute([1, 0]).contiguous(), X) / denominator
#             L2norm = (X[:,None,:] - mean[None,:,:]).pow(2) # [4096,10,512]
#             weighted_L2norm = posteriors.unsqueeze(dim=-1) * L2norm # [4096,10,512]
#             var = weighted_L2norm.sum(dim=0) / denominator # [10,512]
#
#             logvar = torch.log(var)
#             return pi, mean, logvar
#
#     def NoHint_GMM_test(self, embedding, label, sample_size, fixvar=True, _proto = None, iter = 50):
#         batch_size, latent_size = embedding.shape
#         q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#         # q_z_given_x = self.Easy_reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#         all_z = q_z_given_x.view(-1, latent_size)
#         p_z_given_psi = self.get_unsupervised_prior_uniform(z=all_z, fixvar=fixvar, iter = iter)
#         if fixvar:
#             p_y_given_psi_pi, p_z_given_y_psi_mean = p_z_given_psi
#             p_z_given_y_psi_logvar = torch.zeros_like(p_z_given_y_psi_mean)
#             logvar = None
#         else:
#             p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar = p_z_given_psi
#             logvar = p_z_given_y_psi_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)
#
#         log_likelihoods = self.gaussian_log_prob(
#             q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
#             p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
#         ) + torch.log(p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))
#
#         posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#         preds = posteriors.mean(dim=-2).argmax(dim=-1)
#         if _proto is None:
#             prototype = self.GetPrototype(embedding, label)
#         else:
#             prototype = _proto
#         GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label, draw=False)
#         GMM_preds = torch.tensor(list(map(lambda x: GMM2cls[x.item()], preds)))
#         # GMM_preds = predss
#
#         sortedMean, sortedLogvar = torch.zeros_like(p_z_given_y_psi_mean), torch.zeros_like(p_z_given_y_psi_logvar)
#         for gmmIdx, protoIdx in GMM2cls.items():
#             sortedMean[protoIdx] = p_z_given_y_psi_mean[gmmIdx]
#             sortedLogvar[protoIdx] = p_z_given_y_psi_logvar[gmmIdx]
#
#         return GMM_preds, p_y_given_psi_pi, sortedMean, sortedLogvar
#
#     def Hint_GMM_test(self, embedding, label, sample_size, fixvar=True, _prototype=None):
#         batch_size, latent_size = embedding.shape
#         if _prototype == None:
#             prototype = self.GetPrototype(embedding, label)
#         else:
#             prototype = _prototype
#         # q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#         q_z_given_x = self.Easy_reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#         all_z = q_z_given_x.view(-1, latent_size)
#         p_z_given_psi = self.get_unsupervised_prior(z=all_z, init_mean = prototype, fixvar=fixvar)
#         if fixvar:
#             p_y_given_psi_pi, p_z_given_y_psi_mean = p_z_given_psi
#             p_z_given_y_psi_logvar = logvar = None
#         else:
#             p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar = p_z_given_psi
#             logvar = p_z_given_y_psi_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)
#
#         log_likelihoods = self.gaussian_log_prob(
#             q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
#             p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
#         ) + torch.log(p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))
#
#         posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#         preds = posteriors.mean(dim=-2).argmax(dim=-1)
#         GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label, _prototype=_prototype)
#         # GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label)
#         GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))
#         GMM_mean_label = torch.tensor(list(map(lambda x:GMM2cls[x], np.arange(10))))
#
#         return GMM_preds, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, GMM_mean_label, prototype
#
#     def Hint_GMM_test2(self, embedding, label, sample_size, fixvar=True, _prototype=None):
#
#         batch_size, latent_size = embedding.shape
#         if _prototype == None:
#             prototype = self.GetPrototype(embedding, label)
#         else:
#             prototype = _prototype
#         q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#         # q_z_given_x = self.Easy_reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#         all_z = q_z_given_x.view(-1, latent_size)
#         p_z_given_psi = self.get_unsupervised_prior(z=all_z, init_mean = prototype, fixvar=fixvar)
#         if fixvar:
#             p_y_given_psi_pi, p_z_given_y_psi_mean = p_z_given_psi
#             p_z_given_y_psi_logvar = logvar = None
#         else:
#             p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar = p_z_given_psi
#             logvar = p_z_given_y_psi_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)
#
#         log_likelihoods = self.gaussian_log_prob(
#             q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
#             p_z_given_y_psi_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
#         )
#
#         if log_likelihoods == None:
#             return torch.zeros(batch_size), p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, None, prototype
#         else:
#             log_likelihoods = log_likelihoods + torch.log(p_y_given_psi_pi[None, None, :].repeat(batch_size, sample_size, 1))
#
#         posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#         preds = posteriors.mean(dim=-2).argmax(dim=-1)
#         GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label,
#                                   _prototype=_prototype, draw=False)
#         # GMM2cls = self.Dist2Proto(p_z_given_y_psi_mean, prototype, embedding=embedding, label=label)
#         GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))
#         GMM_mean_label = torch.tensor(list(map(lambda x:GMM2cls[x], np.arange(10))))
#
#         return GMM_preds, p_y_given_psi_pi, p_z_given_y_psi_mean, p_z_given_y_psi_logvar, GMM_mean_label, prototype
#
#     def Given_GMM_test(self, embedding, label, sample_size, _pi, _mean, _logvar):
#         batch_size, latent_size = embedding.shape
#         # nC = len(_mean)
#         # pi, mean, logvar = torch.ones(nC).to(embedding.device) / nC, _mean, _logvar
#         pi, mean, logvar = _pi, _mean, _logvar
#         q_z_given_x = self.reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#
#         log_likelihoods = self.gaussian_log_prob(
#             q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
#             mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
#         ) + torch.log(pi[None, None, :].repeat(batch_size, sample_size, 1))
#
#         posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#         preds = posteriors.mean(dim=-2).argmax(dim=-1)
#         GMM2cls = self.Dist2Proto(mean, mean, embedding=embedding, label=label, _prototype=mean)
#         GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))
#
#         return GMM_preds
#
#     def Given_GMM_test2(self, embedding, label, sample_size, _pi, _mean, _logvar):
#         batch_size, latent_size = embedding.shape
#         nC = len(_mean)
#         pi, mean, logvar = torch.ones(nC).to(embedding.device) / nC, _mean, _logvar
#         pi = _pi
#
#         log_likelihoods = self.gaussian_log_prob(
#             embedding[:, None, :].repeat(1, self.component_size, 1),
#             mean[None, :, :].repeat(batch_size, 1, 1),
#             logvar[None, :, :].repeat(batch_size, 1, 1)
#         ) + torch.log(pi[None, :].repeat(batch_size, 1))
#
#         posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#         preds = posteriors.argmax(dim=-1)
#         GMM2cls = self.Dist2Proto(mean, mean, embedding=embedding, label=label, _prototype=mean)
#         GMM_preds = torch.tensor(list(map(lambda x: GMM2cls[x.item()], preds)))
#
#         return GMM_preds, posteriors
#
#
#     def Given_GMM_test_V2(self, embedding, label, sample_size, _pi, _mean, _logvar):
#         batch_size, latent_size = embedding.shape
#         nC = len(_mean)
#         pi, mean, logvar = torch.ones(nC).to(embedding.device) / nC, _mean, _logvar
#         q_z_given_x = self.Same_reparametrize(mean=embedding, logvar=torch.ones_like(embedding), S=sample_size)
#
#         log_likelihoods = self.gaussian_log_prob(
#             q_z_given_x[:, :, None, :].repeat(1, 1, self.component_size, 1),
#             mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), logvar
#         ) + torch.log(pi[None, None, :].repeat(batch_size, sample_size, 1))
#
#         posteriors = torch.exp(log_likelihoods - torch.logsumexp(log_likelihoods, dim=-1, keepdim=True))
#         preds = posteriors.mean(dim=-2).argmax(dim=-1)
#         GMM2cls = self.Dist2Proto(mean, mean, embedding=embedding, label=label, _prototype=mean)
#         GMM_preds = torch.tensor(list(map(lambda x:GMM2cls[x.item()], preds)))
#
#         return GMM_preds
#
#     def GetPrototype(self, embedding, label):
#         def supp_idxs(c):
#             return label.eq(c).nonzero().flatten()
#
#         classes = torch.unique(label)
#         support_idxs = list(map(supp_idxs, classes))
#         prototypes = torch.stack([embedding[idx_list].mean(0) for idx_list in support_idxs])
#
#         return prototypes
#
#     def Dist2Proto(self, GMM_mean, proto, embedding=None, label=None, _prototype=None, draw=False):
#         GMM2cls, nCls = {}, len(proto)
#         if _prototype != None:
#             for i in range(nCls):
#                 GMM2cls[i] = i
#             return GMM2cls
#
#         L2dist = torch.zeros(nCls, nCls).to(proto.device)
#         for gmmIdx in range(nCls):
#             for proIdx in range(nCls):
#                 gmm_mean = GMM_mean[gmmIdx].repeat(nCls, 1)
#                 dist = (gmm_mean - proto).pow(2).mean(dim=-1)
#                 L2dist[gmmIdx] = dist
#         # L2dist = (GMM_mean[:,None] - proto[None,:]).pow(2).mean(dim=-1) # [10,10,512] --> [10,10]
#         closest = L2dist.argmin(dim=1)
#         for GMMidx, c in enumerate(closest):
#             GMM2cls[GMMidx] = c.item()
#
#         sortedGMM = torch.zeros_like(proto)
#         for gmmIdx, protoIdx in GMM2cls.items():
#             sortedGMM[protoIdx] = GMM_mean[gmmIdx]
#
#         if draw:
#             print("DrawTSNE in Dist2Proto")
#             # drawTSNE2(GMM_mean, proto, name1='GMM_mean', name2='prototype')
#             drawTSNE3(sortedGMM, proto, embedding, label1 = torch.arange(10), label2 = torch.arange(10), label3=label,
#                       name1='GMM_mean', name2='prototype', name3='embedding')
#
#         return GMM2cls
#
#
#     def L2_GMM(self, pi1, mean1, logvar1, pi2, mean2, logvar2):
#         nC = len(pi1)
#
#         power = 0
#         for i in range(nC):
#             for j in range(nC):
#                 pi_1i, mean_1i, logvar_1i = pi1[i], mean1[i], logvar1[i]
#                 pi_1j, mean_1j, logvar_1j = pi1[j], mean1[j], logvar1[j]
#                 pi_2i, mean_2i, logvar_2i = pi2[i], mean2[i], logvar2[i]
#                 pi_2j, mean_2j, logvar_2j = pi2[j], mean2[j], logvar2[j]
#                 logvar_1i_1j = torch.stack([logvar_1i, logvar_1j], dim=0).logsumexp(dim=0)
#                 logvar_2i_2j = torch.stack([logvar_2i, logvar_2j], dim=0).logsumexp(dim=0)
#                 logvar_1i_2i = torch.stack([logvar_1i, logvar_2i], dim=0).logsumexp(dim=0)
#                 log_likelihoods_1i_1j = self.gaussian_log_prob(mean_1i, mean_1j, logvar_1i_1j)
#                 log_likelihoods_2i_2j = self.gaussian_log_prob(mean_2i, mean_2j, logvar_2i_2j)
#                 log_likelihoods_1i_2i = self.gaussian_log_prob(mean_1i, mean_2i, logvar_1i_2i)
#                 power += pi_1i * pi_1j * log_likelihoods_1i_1j.exp()
#                 power += pi_2i * pi_2j * log_likelihoods_2i_2j.exp()
#                 power -= 2 * (pi_1i * pi_2i * log_likelihoods_1i_2i.exp())
#
#         out = power.sqrt().log()
#         return out
#
#
#
# def Make_GMM_Model(args):
#     GMM_model = GMM_Model(unsupervised_em_iters = args.unsupervised_em_iters,
#                           semisupervised_em_iters = args.semisupervised_em_iters,
#                           fix_pi = args.fix_pi,
#                           hidden_size = args.hidden_size,
#                           component_size = args.component_size,
#                           latent_size = args.latent_size,
#                           train_mc_sample_size = args.train_mc_sample_size,
#                           test_mc_sample_size = args.test_mc_sample_size)
#     return GMM_model
#
