import torchvision.transforms as T
from collections import OrderedDict
import copy
from functions import *
from scipy.optimize import linear_sum_assignment

def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad

def test(models, epoch, method, dataloaders, mode='val'):
    assert mode == 'val' or mode == 'test'
    models['backbone'].eval()
    if method == 'lloss':
        models['module'].eval()
    
    total = 0
    correct = 0
    with torch.no_grad():
        for (inputs, labels) in dataloaders[mode]:
            inputs = inputs.cuda()
            labels = labels.cuda()

            scores, _, _ = models['backbone'](inputs)
            _, preds = torch.max(scores.data, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    
    return 100 * correct / total

def GMM_train(models, GMM_model, optimizers, schedulers, dataloaders, num_epochs, **kwargs):
    print('>> Train a Model.')
    with tqdm(total=num_epochs) as pbar:
        for epoch in range(num_epochs):
            models['backbone'].train()
            models['critic'].train()

            lossDict, tr_acc, lr = GMM_Semi_train_epoch_NIID(models, GMM_model, optimizers, dataloaders, **kwargs)
            schedulers['backbone'].step()
            lr = optimizers['backbone'].param_groups[0]['lr']
            lossPostfix = OrderedDict({k: torch.tensor(v).float().mean() for (k, v) in lossDict.items()})
            trAccPostfix = OrderedDict({'trAcc': '{0:.4f}'.format(tr_acc), 'LR': '{0:.4f}'.format(lr)})
            postfix = OrderedDict(list(lossPostfix.items()) + list(trAccPostfix.items()))
            pbar.set_postfix(**postfix)
            pbar.update(1)

            if epoch == num_epochs - 1:
                with torch.no_grad():
                    ############################ TEST ##############################
                    accs = []
                    for i, data in enumerate(dataloaders['test']):
                        inputs, label = data[0].cuda(), data[1].cuda()
                        scores, embedding, features = models['backbone'](inputs)
                        acc = (scores.argmax(dim=1) == label).float().mean()
                        accs.append(acc.item())
                    mean_acc = torch.tensor(accs).float().mean().item()
                    print(f'Mean_acc : {mean_acc}')

    print('>> Finished.')

def MakeClsPreLoader(dataset, data_train, labeled_set):
    if 'im' in dataset:
        return ClsPreLoaderV3(data_train, labeled_set)
    else:
        return ClsPreLoader(data_train, labeled_set)

class ClsPreLoaderV3:
    def __init__(self, dataset, labeled_set, minNum = 100):
        if dataset == 'cifar10':
            self.transform = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(size=32, padding=4),
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-100
            ])
        else:
            self.transform = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(size=32, padding=4),
                T.ToTensor(),
                T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # CIFAR-100
            ])

        self.dataset = dataset
        self.labeled_set = labeled_set
        self.data = dataset.data[labeled_set]
        self.targets = torch.tensor(dataset.targets)[labeled_set]
        self.nC = self.targets.unique().__len__()
        self.PILed = list(map(Image.fromarray, self.data))
        self.transed = torch.stack(list(map(self.transform, self.PILed)))
        self.clsIdxList = [(self.targets == i).nonzero()[:,0] for i in range(self.nC)]
        self.clsRatio = [0 for i in range(self.nC)]

        self.augTransed = {}
        for cls in range(self.nC):
            self.augTransed[cls] = self.transed[self.clsIdxList[cls]]

        for cls in range(self.nC):
            cSize = len(self.augTransed[cls])
            if cSize < minNum:
                diff = minNum - cSize
                repeated = torch.tensor(np.random.choice(self.clsIdxList[cls], diff, replace=True))
                repeatTransed = torch.stack(list(map(self.transform, [self.PILed[x] for x in repeated])))
                self.augTransed[cls] = torch.cat((self.augTransed[cls], repeatTransed)) # Concat at the end??
                # self.clsIdxList[cls] = torch.cat((self.clsIdxList[cls], repeated))

        self.augSize = sum([len(v) for i,v in self.augTransed.items()])
        for cls in range(self.nC):
            self.clsRatio[cls] = len(self.augTransed[cls]) / self.augSize

        augSum = sum([len(v) for i,v in self.augTransed.items()])
        self.augRatio = [len(v) / augSum for i,v in self.augTransed.items()]

    def Sample(self, n):
        imgBuffer, tarBuffer = [], []
        for cls in range(self.nC):
            # clsIdx = np.random.choice(self.clsIdxList[cls], n, replace=False)
            clsIdx = np.random.choice(len(self.augTransed[cls]), n, replace=False)
            img, target = self.augTransed[cls][clsIdx], torch.tensor([cls]*len(clsIdx))
            imgBuffer.append(img)
            tarBuffer.append(target)
        imgOut, tarOut = torch.stack(imgBuffer).view(-1,3,32,32), torch.stack(tarBuffer).view(-1)
        return imgOut, tarOut

    def SummaryInit(self, quotient=1000):
        self.nSample = (torch.tensor(self.clsRatio) * quotient).int()
        self.summaryMaxIter = int(sum([len(v) for i,v in self.augTransed.items()]) / quotient)
        self.curIter = 0
        self.unSampled = [set(range(len(v))) for i,v in self.augTransed.items()]

    def SummarySample(self):
        imgBuffer, tarBuffer = [], []
        for cls in range(self.nC):
            if self.curIter == self.summaryMaxIter - 1:
                clsIdx = np.array(list(self.unSampled[cls]))
                isLast = True
            else:
                clsIdx = np.random.choice(list(self.unSampled[cls]), int(self.nSample[cls]), replace=False)
                self.unSampled[cls] -= set(clsIdx)
                isLast = False
                a = 2
            img, target = self.augTransed[cls][clsIdx], torch.tensor([cls]*len(clsIdx))
            imgBuffer.append(img)
            tarBuffer.append(target)
        self.curIter = self.curIter + 1
        imgOut, tarOut = torch.cat(imgBuffer, 0), torch.cat(tarBuffer ,0)
        return imgOut, tarOut, isLast

class ClsPreLoader:
    def __init__(self, dataset, labeled_set, minNum = 10):
        if dataset == 'cifar10':
            self.transform = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(size=32, padding=4),
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-100
            ])
        else:
            self.transform = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(size=32, padding=4),
                T.ToTensor(),
                T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # CIFAR-100
            ])

        self.dataset = dataset
        self.labeled_set = labeled_set
        self.data = dataset.data[labeled_set]
        self.targets = torch.tensor(dataset.targets)[labeled_set]
        self.nC = self.targets.unique().__len__()
        self.PILed = list(map(Image.fromarray, self.data))
        self.transed = torch.stack(list(map(self.transform, self.PILed)))
        self.clsIdxList = [(self.targets == i).nonzero()[:,0] for i in range(self.nC)]
        self.clsRatio = [0 for i in range(self.nC)]

        for cls in range(self.nC):
            cSize = len(self.clsIdxList[cls])
            if cSize < minNum:
                diff = minNum - cSize
                repeated = torch.tensor(np.random.choice(self.clsIdxList[cls], diff, replace=True))
                self.clsIdxList[cls] = torch.cat((self.clsIdxList[cls], repeated))

        for cls in range(self.nC):
            self.clsRatio[cls] = len(self.clsIdxList[cls]) / len(labeled_set)

    def Sample(self, n):
        imgBuffer, tarBuffer = [], []
        for cls in range(self.nC):
            clsIdx = np.random.choice(self.clsIdxList[cls], n, replace=False)
            img, target = self.transed[clsIdx], self.targets[clsIdx]
            imgBuffer.append(img)
            tarBuffer.append(target)
        imgOut, tarOut = torch.stack(imgBuffer).view(-1,3,32,32), torch.stack(tarBuffer).view(-1)
        return imgOut, tarOut

    def SummaryInit(self, quotient=1000):
        self.nSample = (torch.tensor(self.clsRatio) * quotient).int()
        self.summaryMaxIter = int(len(self.labeled_set) / quotient)
        self.curIter = 0
        self.unSampled = [set(i.tolist()) for i in copy.deepcopy(self.clsIdxList)]

    def SummarySample(self):
        imgBuffer, tarBuffer = [], []
        for cls in range(self.nC):
            if self.curIter == self.summaryMaxIter - 1:
                clsIdx = np.array(list(self.unSampled[cls]))
                isLast = True
            else:
                clsIdx = np.random.choice(list(self.unSampled[cls]), int(self.nSample[cls]), replace=False)
                self.unSampled[cls] -= set(clsIdx)
                isLast = False
                a = 2
            img, target = self.transed[clsIdx], self.targets[clsIdx]
            imgBuffer.append(img)
            tarBuffer.append(target)
        self.curIter = self.curIter + 1
        imgOut, tarOut = torch.cat(imgBuffer, 0), torch.cat(tarBuffer ,0)
        return imgOut, tarOut, isLast

def GMM_Semi_train_epoch_NIID(models, GMM_model, optimizers, dataloaders, **kwargs):
    classSampler = kwargs['ClsIdx']
    accList, lossDict = [], {'CEloss':[], 'GMMloss':[], 'KLloss':[], 'GMMACC':[]}
    models['backbone'].train()

    for data in dataloaders['train']:
        labeled_imgs, labels = data[0].cuda(), data[1].cuda()
        batch_size, sample_size = len(labeled_imgs), 1
        optimizers['backbone'].zero_grad()

        SampledImg, SampledLabel = classSampler.Sample(10)
        with torch.no_grad():
            SampledImg, SampledLabel = SampledImg.cuda(), SampledLabel.cuda()
            _, SampledFeat, _ = models['backbone'](SampledImg)
            SampledProto = GMM_model.GetPrototype(SampledFeat, SampledLabel)
            labeled_psi = GMM_model.get_supervised_prior(z=SampledFeat, label=SampledLabel, init_mean=SampledProto,
                                                         fixvar=kwargs['fix_var'], L1var=False)
            labeled_pi, labeled_mean, _labeled_logvar = labeled_psi
            labeled_logvar = _labeled_logvar[None, None, :, :].repeat(batch_size, sample_size, 1, 1)

        scores, labeled_embedding, features = models['backbone'](labeled_imgs)
        labeled_z = labeled_embedding.unsqueeze(dim=1)

        labeled_log_likelihoods = GMM_model.gaussian_log_prob(
            labeled_z[:, :, None, :].repeat(1, 1, GMM_model.component_size, 1),
            labeled_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), labeled_logvar, meanC=False
        )

        log_posteriors = (labeled_log_likelihoods - torch.logsumexp(labeled_log_likelihoods, dim=-1, keepdim=True)).mean(dim=1)
        GMM_loss = F.nll_loss(log_posteriors, labels)
        CE_loss = F.cross_entropy(scores, labels)
        loss = CE_loss + 0.00001* GMM_loss

        if False in torch.isfinite(labeled_log_likelihoods):
            print(f"DURING TRAINING: handle exception when there is -inf, +inf, nan in log_likelihoods.")
            CE_loss.backward()
            optimizers['backbone'].step()
        else:
            loss.backward()
            optimizers['backbone'].step()

        lossDict['CEloss'].append(CE_loss.item())
        lossDict['GMMloss'].append(GMM_loss.item())
        preds = scores.argmax(dim=-1)
        acc = (preds == labels).float().mean()
        GMMacc = (log_posteriors.argmax(dim=-1) == labels).float().mean()
        lossDict['GMMACC'].append(GMMacc.item())
        accList.append(acc)

    return lossDict, torch.tensor(accList).mean() * 100, optimizers['backbone'].param_groups[0]['lr']

def GMM_Semi_train(models, GMM_model, method, optimizers, schedulers, dataloaders, num_epochs, **kwargs):
    print('>> Train a Model.')
    labeled_data = read_data2(dataloaders['train'])
    unlabeled_data = read_data2(dataloaders['train_UL'])
    classSampler = kwargs['ClsIdx']
    train_iterations = 10000
    trAccList, trCAccList, trUAccList, labelGList, unlabelGList = [], [], [], [], []

    with tqdm(total=train_iterations) as pbar:
        for iter_count in range(train_iterations):
            if iter_count % 500 == 0 or iter_count == train_iterations - 1:
                ############################ Summarize Labeled Set ############################
                with torch.no_grad():
                    classSampler.SummaryInit(quotient=1000)
                    accs, GMM_pis, GMM_means, GMM_vars = [], [], [], []
                    if 'im' in kwargs['dataset']:
                        while True:
                            SampledImg, SampledLabel, isLast = classSampler.SummarySample()
                            SampledImg, SampledLabel = SampledImg.cuda(), SampledLabel.cuda()
                            scores, SampledFeat, _ = models['backbone'](SampledImg)
                            acc = (scores.argmax(dim=1) == SampledLabel).float().mean()
                            accs.append(acc.item())
                            SampledProto = GMM_model.GetPrototype(SampledFeat, SampledLabel)
                            labeled_psi = GMM_model.get_supervised_prior(z=SampledFeat, label=SampledLabel, init_mean=SampledProto,
                                                                         fixvar=kwargs['fix_var'], L1var=False)
                            GMM_pi, GMM_mean, GMM_var = labeled_psi
                            GMM_pis.append(GMM_pi)
                            GMM_means.append(GMM_mean)
                            GMM_vars.append(GMM_var)
                            if isLast:
                                break
                        mean_acc = torch.tensor(accs).float().mean().item()
                        print(f'\nSummarized labeled set.. Mean_acc : {mean_acc}')
                        L_ALL_GMM_pi = torch.stack(GMM_pis).mean(dim=0)
                        L_ALL_GMM_mean = torch.stack(GMM_means).mean(dim=0)
                        L_ALL_GMM_var = torch.stack(GMM_vars).mean(dim=0)
                    else:
                        classSampler.SummaryInit(quotient=1000)
                        while True:
                            SampledImg, SampledLabel, isLast = classSampler.SummarySample()
                            SampledImg, SampledLabel = SampledImg.cuda(), SampledLabel.cuda()
                            scores, SampledFeat, _ = models['backbone'](SampledImg)
                            acc = (scores.argmax(dim=1) == SampledLabel).float().mean()
                            accs.append(acc.item())
                            SampledProto = GMM_model.GetPrototype(SampledFeat, SampledLabel)
                            labeled_psi = GMM_model.get_supervised_prior(z=SampledFeat, label=SampledLabel, init_mean=SampledProto,
                                                                         fixvar=kwargs['fix_var'], L1var=False)
                            GMM_pi, GMM_mean, GMM_var = labeled_psi
                            GMM_pis.append(GMM_pi)
                            GMM_means.append(GMM_mean)
                            GMM_vars.append(GMM_var)
                            if isLast:
                                break
                        mean_acc = torch.tensor(accs).float().mean().item()
                        print(f'\nSummarized labeled set.. Mean_acc : {mean_acc}')
                        L_ALL_GMM_pi = torch.stack(GMM_pis).mean(dim=0)
                        L_ALL_GMM_mean = torch.stack(GMM_means).mean(dim=0)
                        L_ALL_GMM_var = torch.stack(GMM_vars).mean(dim=0)

                ############################ Summarize Unlabeled Set ############################
                with torch.no_grad():
                    accs, GMM_pis, GMM_means, GMM_vars = [], [], [], []
                    for i, data in enumerate(dataloaders['summary_UL']):
                        inputs = data[0].cuda(); label = data[1].cuda()
                        scores, embedding, features = models['backbone'](inputs)
                        acc = (scores.argmax(dim=1) == label).float().mean()
                        accs.append(acc.item())
                        unlabeled_psi = GMM_model.get_unsupervised_prior(z=embedding, init_mean=L_ALL_GMM_mean,
                                                                         fixvar=kwargs['fix_var'], iter=10, **kwargs)
                        GMM_pi, GMM_mean, GMM_var = unlabeled_psi
                        BDmat = BDMatrix(GMM_mean, GMM_var, L_ALL_GMM_mean, L_ALL_GMM_var)
                        row_match, col_match = linear_sum_assignment(BDmat.cpu().numpy())

                        alignIdx = {col_match[idx]: idx for idx in range(GMM_model.component_size)}
                        aligned_pi = torch.stack([GMM_pi[alignIdx[x]] for x in range(GMM_model.component_size)])
                        aligned_mean = torch.stack([GMM_mean[alignIdx[x]] for x in range(GMM_model.component_size)])
                        aligned_logvar = torch.stack([GMM_var[alignIdx[x]] for x in range(GMM_model.component_size)])

                        GMM_pis.append(aligned_pi)
                        GMM_means.append(aligned_mean)
                        GMM_vars.append(aligned_logvar)
                    mean_acc = torch.tensor(accs).float().mean().item()
                    print(f'Summarized Unlabeled set.. Mean_acc : {mean_acc}')
                    Ul_ALL_GMM_pi = torch.stack(GMM_pis).mean(dim=0)
                    Ul_ALL_GMM_mean = torch.stack(GMM_means).mean(dim=0)
                    Ul_ALL_GMM_var = torch.stack(GMM_vars).mean(dim=0)

                ############################ Evaluate on Test Set ############################
                with torch.no_grad():
                    accs = []
                    for i, data in enumerate(dataloaders['test']):
                        inputs = data[0].cuda(); label = data[1].cuda()
                        scores, embedding, features = models['backbone'](inputs)
                        acc = (scores.argmax(dim=1) == label).float().mean()
                        accs.append(acc.item())
                    mean_acc = torch.tensor(accs).float().mean().item()
                    print(f'Test Mean_acc : {mean_acc}')

                psi_xl = (L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var)
                psi_xul = (Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var)

            if iter_count % 500 == 0:
                trAccList, trCAccList, trUAccList, labelGList, unlabelGList = [], [], [], [], []

            labeled_imgs, labeled_labels = next(labeled_data)
            unlabeled_imgs, unlabeled_labels = next(unlabeled_data)

            schedulers["backbone"].step()
            schedulers["critic"].step()

            lossDict, tr_acc, tr_cAcc, tr_uAcc, label_g, unlabel_g = GMM_Semi_train_epoch(
                models, GMM_model, optimizers, labeled_imgs, labeled_labels, unlabeled_imgs,
                unlabeled_labels, psi_xl, psi_xul, **kwargs)

            trAccList.append(tr_acc); trCAccList.append(tr_cAcc); trUAccList.append(tr_uAcc)
            labelGList.append(label_g); unlabelGList.append(unlabel_g)

            lossPostfix = OrderedDict({k: torch.tensor(v).float().mean() for (k, v) in lossDict.items()})
            trAccPostfix = OrderedDict({'trAcc': Lmean(trAccList), 'trCAcc': Lmean(trCAccList), 'trUAcc': Lmean(trUAccList)})
            gPostfix = OrderedDict({'lb_g': '{0:.4f}'.format(label_g), 'unlb_g': '{0:.4f}'.format(unlabel_g)})
            lrPostfix = OrderedDict({'bkbn':optimizers["backbone"].param_groups[0]["lr"], 'critic':optimizers["critic"].param_groups[0]["lr"]})
            postfix = OrderedDict(list(lossPostfix.items()) + list(trAccPostfix.items()) + list(gPostfix.items()) + list(lrPostfix.items()))
            pbar.set_postfix(**postfix)
            pbar.update(1)

    return L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var, Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var

def EXTRACT_GMM_prev(models, GMM_model, dataloaders, **kwargs):
    print('>> EXTRACT GMM PARAMETERS')
    print(f">> update_Lpi = {kwargs['update_Lpi']}, update_ULpi = {kwargs['update_ULpi']}, HintUL = {kwargs['HintUL']}")

    models['backbone'].eval()
    models['critic'].eval()

    with torch.no_grad():
        linear_accs, GMM_pis, GMM_means, GMM_vars = [], [], [], []
        if 'im' in kwargs['dataset']:
            classSampler = kwargs['ClsIdx']
            classSampler.SummaryInit(quotient=1000)
            while True:
                SampledImg, SampledLabel, isLast = classSampler.SummarySample()
                SampledImg, SampledLabel = SampledImg.cuda(), SampledLabel.cuda()
                scores, SampledFeat, _ = models['backbone'](SampledImg)
                linear_acc = (scores.argmax(dim=1) == SampledLabel).float().mean()
                linear_accs.append(linear_acc.item())
                SampledProto = GMM_model.GetPrototype(SampledFeat, SampledLabel)
                labeled_psi = GMM_model.get_supervised_prior(z=SampledFeat, label=SampledLabel, init_mean=SampledProto,
                                                             fixvar=kwargs['fix_var'], L1var=False)
                GMM_pi, GMM_mean, GMM_var = labeled_psi
                GMM_pis.append(GMM_pi)
                GMM_means.append(GMM_mean)
                GMM_vars.append(GMM_var)
                if isLast:
                    break
            mean_linear_acc = torch.tensor(linear_accs).float().mean().item()
            print(f'Summarized Unlabeled set.. Mean_acc : {mean_linear_acc}')

            L_ALL_GMM_pi = torch.stack(GMM_pis).mean(dim=0)
            L_ALL_GMM_mean = torch.stack(GMM_means).mean(dim=0)
            L_ALL_GMM_var = torch.stack(GMM_vars).mean(dim=0)
        else:
            linear_accs, GMM_pis, GMM_means, GMM_vars = [], [], [], []
            for i, data in enumerate(dataloaders['summary_L']):
                inputs = data[0].cuda(); label = data[1].cuda()
                scores, embedding, features = models['backbone'](inputs)
                linear_acc = (scores.argmax(dim=1) == label).float().mean()
                linear_accs.append(linear_acc.item())
                SampledProto = GMM_model.GetPrototype(embedding, label)
                GMM_pi, GMM_mean, GMM_var = GMM_model.get_supervised_prior(z=embedding, label=label, init_mean=SampledProto, fixvar=kwargs['fix_var'])
                GMM_pis.append(GMM_pi)
                GMM_means.append(GMM_mean)
                GMM_vars.append(GMM_var)

            mean_linear_acc = torch.tensor(linear_accs).float().mean().item()
            print(f'Summarized labeled set.. Mean_acc : {mean_linear_acc}')
            L_ALL_GMM_pi = torch.stack(GMM_pis).mean(dim=0)
            L_ALL_GMM_mean = torch.stack(GMM_means).mean(dim=0)
            L_ALL_GMM_var = torch.stack(GMM_vars).mean(dim=0)

        linear_accs, GMM_pis, GMM_means, GMM_vars = [], [], [], []
        for i, data in enumerate(dataloaders['summary_UL']):
            inputs = data[0].cuda(); label = data[1].cuda()
            scores, embedding, features = models['backbone'](inputs)
            pred = scores.argmax(dim=1)
            linear_acc = (pred == label).float().mean()
            linear_accs.append(linear_acc.item())

            hint = L_ALL_GMM_mean if inDict(kwargs, 'HintUL') else None
            update_ULpi = True if inDict(kwargs, 'update_ULpi') else False
            unlabeled_psi = GMM_model.get_unsupervised_prior(z=embedding, init_mean=hint, fixvar=kwargs['fix_var'],
                                                             update_pi = update_ULpi, iter=20, **kwargs)
            GMM_pi, GMM_mean, GMM_var = unlabeled_psi
            BDmat = BDMatrix(GMM_mean, GMM_var, L_ALL_GMM_mean, L_ALL_GMM_var)
            row_match, col_match = linear_sum_assignment(BDmat.cpu().numpy())
            # row_match of UL_batch_psi should be matched to col_match L_whole_psi

            alignIdx = {col_match[idx]: idx for idx in range(GMM_model.component_size)}
            aligned_pi = torch.stack([GMM_pi[alignIdx[x]] for x in range(GMM_model.component_size)])
            aligned_mean = torch.stack([GMM_mean[alignIdx[x]] for x in range(GMM_model.component_size)])
            aligned_logvar = torch.stack([GMM_var[alignIdx[x]] for x in range(GMM_model.component_size)])

            GMM_pis.append(aligned_pi)
            GMM_means.append(aligned_mean)
            GMM_vars.append(aligned_logvar)

        mean_linear_acc = torch.tensor(linear_accs).float().mean().item()
        print(f'Summarized Unlabeled set.. Mean_acc : {mean_linear_acc}')
        Ul_ALL_GMM_pi = torch.stack(GMM_pis).mean(dim=0)
        Ul_ALL_GMM_mean = torch.stack(GMM_means).mean(dim=0)
        Ul_ALL_GMM_var = torch.stack(GMM_vars).mean(dim=0)

    with torch.no_grad():
        linear_accs, L_GMM_ACCs, UL_GMM_ACCs, ALL_GMM_ACCs = [], [], [], []
        for i, data in enumerate(dataloaders['test']):
            inputs = data[0].cuda(); label = data[1].cuda()
            scores, embedding, features = models['backbone'](inputs)
            linear_acc = (scores.argmax(dim=1) == label).float().mean()
            linear_accs.append(linear_acc.item())

        mean_linear_acc = torch.tensor(linear_accs).float().mean().item()
        print(f'Test Mean_acc : {mean_linear_acc}')

    print('>> EXTRACTION FINISHED.')
    return L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var, Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var, mean_linear_acc

def GMM_Semi_train_epoch(models, GMM_model, optimizers, labeled_imgs, labels, unlabeled_imgs,
                               unlabeled_labels, psi_xl, psi_xul, **kwargs):

    accList, cAccList, uAccList, lossDict = [], [], [], {'CEloss': [], 'Wloss': []}
    label_gList, unlabel_gList = [], []
    models['backbone'].train()
    models['critic'].train()

    label_x, label_y = labeled_imgs.cuda(), labels.cuda()
    unlabel_x, unlabel_y = unlabeled_imgs.cuda(), unlabeled_labels.cuda()

    # Update step for critic #
    set_requires_grad(models['backbone'], requires_grad=False)
    set_requires_grad(models['critic'], requires_grad=True)
    optimizers['critic'].zero_grad()

    label_score, labeled_embedding, features_l = models['backbone'](label_x)
    unlabel_score, unlabeled_embedding, feature_u = models['backbone'](unlabel_x)

    labeled_z = labeled_embedding.unsqueeze(dim=1)
    unlabeled_z = unlabeled_embedding.unsqueeze(dim=1)
    label_cat = GetDistInfo(psi_xl, psi_xul, labeled_z, GMM_model, labeled_embedding, unlabeled_embedding)
    unlabel_cat = GetDistInfo(psi_xl, psi_xul, unlabeled_z, GMM_model, labeled_embedding, unlabeled_embedding)

    label_g, unlabel_g = models['critic'](label_cat), models['critic'](unlabel_cat)
    W_loss = label_g.mean() - unlabel_g.mean()
    GP = gradient_penalty(models['critic'], label_cat, unlabel_cat)
    loss = W_loss + GP

    loss.backward()
    optimizers['critic'].step()

    # Update step for feature extractor #
    set_requires_grad(models['backbone'], requires_grad=True)
    set_requires_grad(models['critic'], requires_grad=False)
    optimizers['backbone'].zero_grad()

    label_score, labeled_embedding, features_l = models['backbone'](label_x)
    unlabel_score, unlabeled_embedding, feature_u = models['backbone'](unlabel_x)

    label_cat = GetDistInfo(psi_xl, psi_xul, labeled_z, GMM_model, labeled_embedding, unlabeled_embedding)
    unlabel_cat = GetDistInfo(psi_xl, psi_xul, unlabeled_z, GMM_model, labeled_embedding, unlabeled_embedding)
    label_g, unlabel_g = models['critic'](label_cat), models['critic'](unlabel_cat)
    target_loss = F.cross_entropy(label_score, label_y)
    W_loss1 = label_g.mean() - unlabel_g.mean()
    loss = target_loss - W_loss1
    loss.backward()
    optimizers['backbone'].step()

    # Record Informations #
    lossDict['CEloss'].append(target_loss.item())
    lossDict['Wloss'].append(W_loss.item())

    preds = label_score.argmax(dim=-1)
    acc = (preds == label_y).float().mean()
    uAcc = (unlabel_score.argmax(dim=-1) == unlabel_y).float().mean()
    cAcc = (len((label_g < 0.5).nonzero()) + len((unlabel_g >= 0.5).nonzero())) / (len(label_g) + len(unlabel_g))
    accList.append(acc)
    cAccList.append(cAcc)
    uAccList.append(uAcc)
    label_gList.append(label_g.mean().item())
    unlabel_gList.append(unlabel_g.mean().item())

    return lossDict, torch.tensor(accList).mean() * 100, torch.tensor(cAccList).mean() * 100, \
           torch.tensor(uAccList).mean() * 100, torch.tensor(label_gList).mean(), torch.tensor(unlabel_gList).mean()
