#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

from email.mime import image
import enum
from operator import itemgetter
import re
from typing import Counter
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import copy
import numpy as np
from torch.optim.lr_scheduler import MultiStepLR
from torch.autograd import Variable
from sklearn.mixture import GaussianMixture
import torch.nn as nn


class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label), self.idxs[item]


class DatasetSplit_DivideMix(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs, prob, pred, mode):
        self.dataset = dataset
        self.mode = mode
        if mode == 'label':
            pred_idx = pred.nonzero()[0]
        elif mode == 'unlabel':
            pred_idx = (1-pred).nonzero()[0]
        print(mode, len(pred_idx))
        if len(pred_idx) == 0:
            self.idxs = [idxs[0]]
            self.prob = [prob[0]]
        else:
            self.idxs = [idxs[i] for i in pred_idx]
            self.prob = [prob[i] for i in pred_idx]
        # self.idxs = [int(i) for i in idxs]
        # self.probs = prob
        # self.preds = pred

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image1, label = self.dataset[self.idxs[item]]
        image2, label = self.dataset[self.idxs[item]]
        prob = self.prob[item]
        if self.mode == 'label':
            return torch.tensor(image1), torch.tensor(image2), torch.tensor(label), torch.tensor(prob)
        else:
            return torch.tensor(image1), torch.tensor(image2)


class SemiLoss(object):
    def linear_rampup(self, cur_e, warm_up, rampup_length=16):
        # print(cur_e)
        current = np.clip((cur_e-warm_up) / rampup_length, 0.0, 1.0)
        return 25 * float(current)

    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x,
                         dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lx, Lu, self.linear_rampup(epoch, warm_up)

# DivideMix dataset
# class Label_dataset(Dataset):
#     def __init__(self) -> None:
#         super().__init__()

#     def __getitem__(self, index: Any) -> T_co:
#         return super().__getitem__(index)

#     def __len__(self):
#         pass


# class Unlabel_datset(Dataset):
#     def __init__(self) -> None:
#         super().__init__()

#     def __getitem__(self, index: Any) -> T_co:
#         return super().__getitem__(index)

#     def __len__(self):
#         pass


class LocalUpdate(object):
    def __init__(self, args, dataset, idxs, num_classes, label_cache_dist=None):
        self.args = args
        # np.random.seed(args.seed)
        if args.dataset == 'clothing1m':
            targets = np.array(dataset.targets)[idxs]
            idxs_ = []
            idxs_np = np.array(idxs)
            targets_dist = Counter(targets)
            min_key, min_count = min(targets_dist.items(), key=itemgetter(1))
            for i in targets_dist.keys():
                i_index = np.where(targets == i)[0]  # index in idxs
                i_idxs = idxs_np[i_index]
                np.random.shuffle(i_idxs)
                idxs_.extend(i_idxs[:min_count])
            idxs = idxs_
        self.trainloader = DataLoader(DatasetSplit(dataset, idxs),
                                      batch_size=self.args.local_bs, shuffle=True, num_workers=4)
        # print(targets[np.array(idxs)])
        self.total_epoch = args.epochs
        self.num_classes = num_classes
        self.method = args.method
        # self.T = None

        # co-teaching initilization
        self.mom1 = 0.9
        self.mom2 = 0.1
        self.alpha_plan = [self.args.lr] * self.total_epoch
        self.beta1_plan = [self.mom1] * self.total_epoch
        for i in range(80, self.total_epoch):
            self.alpha_plan[i] = float(
                self.total_epoch - i) / (self.total_epoch - 80) * args.lr
            self.beta1_plan[i] = self.mom2
        self.forget_rate = (args.noise_ratio + 0.2) % 1
        self.rate_schedule = np.ones(self.total_epoch) * self.forget_rate
        self.rate_schedule[:10] = np.linspace(0, self.forget_rate ** 1, 10)
        self.noise_or_not = dataset.noise_or_not

        # peer
        self.label_cache_dist = label_cache_dist
        self.alpha = self.args.alpha

        # T-revision
        self.revision_lr = 5e-7
        self.n_epoch_estimate = 1
        self.total_epoch = 4

        # dividemix
        self.p_thre = 0.5
        self.CE = nn.CrossEntropyLoss(reduction='none')
        self.CEloss = nn.CrossEntropyLoss()
        self.dataset = dataset
        if self.num_classes == 10:
            self.warmup = 10
        else:
            self.warmup = 30

        # self.index_num = int(len(self.trainloader.dataset) / args.local_bs)
        # self.total_index = self.index_num + 1
        # self.A = torch.zeros((self.n_epoch_estimate, len(idxs), num_classes))

    def update_weights(self, model, model2, cur_epoch, T):
        model.train()
        model2.train()
        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                    momentum=0.9)
        if self.method == 'peer':
            assert self.label_cache_dist is not None
            if cur_epoch < 50:
                self.alpha = 0
            elif cur_epoch >= 50 and cur_epoch < 100:
                self.alpha = (cur_epoch - 50) * 0.02
            print('alpha', self.alpha)
            for iter in range(self.args.local_ep):
                for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                    images, labels = images.cuda(), labels.cuda()
                    model.zero_grad()
                    output = model(images)
                    # peer_item = self.label_cache_dist[0][0] * \
                    #     self.smooth_softmax(output, torch.tensor(0).cuda())
                    # for i in range(1, self.num_classes):
                    #     peer_item += self.label_cache_dist[0][i] * self.smooth_softmax(
                    #         output, torch.tensor(i).cuda())
                    peer_label = np.random.choice(
                        self.num_classes, len(images), p=self.label_cache_dist[0])
                    peer_label = torch.from_numpy(peer_label).cuda()
                    loss = self.smooth_softmax(
                        output, labels) - self.alpha * self.smooth_softmax(output, peer_label)
                    loss.backward()
                    optimizer.step()
                    # print(f'Loss: {loss.item()}')
            return model.state_dict(), model2.state_dict()
        elif self.method == 'correction':
            if cur_epoch < int(self.total_epoch / 2):
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        model.zero_grad()
                        output = model(images)
                        loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer.step()
            # elif cur_epoch == int(self.total_epoch / 2):
            #     T = self.estimate_T(model)
            if cur_epoch >= int(self.total_epoch / 2):
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        model.zero_grad()
                        output = model(images)
                        loss = self.smooth_softmax_correction(
                            output, labels, T)
                        # loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer.step()
            return model.state_dict(), model2.state_dict()
        elif self.method == 'coteaching':
            # model2 = copy.deepcopy(model)
            # Co-teaching uses Adam as optimizer
            optimizer1 = torch.optim.Adam(model.parameters(), lr=self.args.lr)
            optimizer2 = torch.optim.Adam(model2.parameters(), lr=self.args.lr)
            for local_epoch in range(self.args.local_ep):
                model.train()
                self.adjust_learning_rate(optimizer1, cur_epoch)
                model2.train()
                self.adjust_learning_rate(optimizer2, cur_epoch)
                model, model2 = self.co_teaching_train(
                    model, optimizer1, model2, optimizer2, cur_epoch)
            return model.state_dict(), model2.state_dict()
        elif self.method == 'revision':
            optimizer_es = torch.optim.SGD(
                model.parameters(), lr=self.args.lr, weight_decay=1e-4)
            optimizer = torch.optim.SGD(
                model.parameters(), lr=self.args.lr, weight_decay=1e-4, momentum=0.9)
            optimizer_revision = torch.optim.Adam(
                model.parameters(), lr=self.revision_lr, weight_decay=1e-4)
            # scheduler = MultiStepLR(optimizer, milestones=[40, 80], gamma=0.1)
            if cur_epoch < 150:
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        optimizer_es.zero_grad()
                        output = model(images)
                        loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer_es.step()
            # elif cur_epoch == self.n_epoch_estimate:
            #     T = self.revision_fit(model, True)
            #     T = self.revieion_norm(T)
            elif cur_epoch >= 150 and cur_epoch < 290:
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        optimizer.zero_grad()
                        output = model(images)
                        loss = self.loss_func_reweight(output, T, labels)
                        loss.backward()
                        optimizer.step()
            elif cur_epoch >= 290:
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        optimizer_revision.zero_grad()
                        output, correction = model(images, revision=True)
                        loss = self.loss_function_revision(
                            output, T, correction, labels)
                        loss.backward()
                        optimizer_revision.step()
            return model.state_dict(), model2.state_dict()
        elif self.method == 'baseline':
            for iter in range(self.args.local_ep):
                for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                    images, labels = images.cuda(), labels.cuda()
                    model.zero_grad()
                    output = model(images)
                    loss = F.cross_entropy(output, labels)
                    loss.backward()
                    optimizer.step()
            return model.state_dict(), model2.state_dict()
        elif self.method == 'dividemix':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                        momentum=0.9, weight_decay=5e-4)
            optimizer2 = torch.optim.SGD(model2.parameters(), lr=self.args.lr,
                                         momentum=0.9, weight_decay=5e-4)
            self.criterion = SemiLoss()
            if cur_epoch < self.warmup:
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        model.zero_grad()
                        output = model(images)
                        loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer.step()

                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        model2.zero_grad()
                        output = model2(images)
                        loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer2.step()
                return model.state_dict(), model2.state_dict()
            else:
                for iter in range(self.args.local_ep):
                    size = len(self.trainloader.dataset)
                    losses1 = torch.zeros(size)
                    losses2 = torch.zeros(size)
                    model.eval()
                    model2.eval()
                    i = 0
                    idxs_ = []
                    with torch.no_grad():
                        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                            images, labels = images.cuda(), labels.cuda()
                            output = model(images)
                            loss = F.cross_entropy(
                                output, labels, reduction="none")
                            idxs_.extend(indexes)
                            for l in loss:
                                losses1[i] = l
                                i += 1
                    losses1 = (losses1 - losses1.min()) / \
                        (losses1.max() - losses1.min())
                    input_loss1 = losses1.reshape(-1, 1)
                    gmm1 = GaussianMixture(
                        n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
                    gmm1.fit(input_loss1)
                    prob1 = gmm1.predict_proba(input_loss1)
                    prob1 = prob1[:, gmm1.means_.argmin()]

                    i = 0
                    with torch.no_grad():
                        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                            images, labels = images.cuda(), labels.cuda()
                            output = model2(images)
                            loss = F.cross_entropy(
                                output, labels, reduction="none")
                            for l in loss:
                                losses2[i] = l
                                i += 1
                    losses2 = (losses2 - losses2.min()) / \
                        (losses2.max() - losses2.min())
                    input_loss2 = losses2.reshape(-1, 1)
                    gmm2 = GaussianMixture(
                        n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4)
                    gmm2.fit(input_loss2)
                    prob2 = gmm2.predict_proba(input_loss2)
                    prob2 = prob2[:, gmm2.means_.argmin()]
                    # prob1, prob2

                    pred1 = (prob1 > self.p_thre)
                    pred2 = (prob2 > self.p_thre)

                    labeled_trainloader = DataLoader(DatasetSplit_DivideMix(
                        self.dataset, idxs_, prob2, pred2, 'label'), batch_size=self.args.local_bs, shuffle=True, num_workers=4)
                    unlabeled_trainloader = DataLoader(DatasetSplit_DivideMix(
                        self.dataset, idxs_, prob2, pred2, 'unlabel'), batch_size=self.args.local_bs, shuffle=True, num_workers=4)
                    self.dm_train(cur_epoch, model, model2, optimizer,
                                labeled_trainloader, unlabeled_trainloader)
                    # print(len(self.dataset), len(idxs_), prob1.shape, pred1.shape)
                    labeled_trainloader = DataLoader(DatasetSplit_DivideMix(
                        self.dataset, idxs_, prob1, pred1, 'label'), batch_size=self.args.local_bs, shuffle=True, num_workers=4)
                    unlabeled_trainloader = DataLoader(DatasetSplit_DivideMix(
                        self.dataset, idxs_, prob1, pred1, 'unlabel'), batch_size=self.args.local_bs, shuffle=True, num_workers=4)
                    self.dm_train(cur_epoch, model, model2, optimizer2,
                                labeled_trainloader, unlabeled_trainloader)
                    # for batch_idx, (images, images1, labels) in enumerate(labeled_trainloader):
                    #     print(images.shape)
                return model.state_dict(), model2.state_dict()

    def dm_train(self, epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader):
        net.train()
        net2.eval()  # fix one network and train the other

        unlabeled_train_iter = iter(unlabeled_trainloader)
        num_iter = (len(labeled_trainloader.dataset)//self.args.local_bs)+1
        for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
            try:
                inputs_u, inputs_u2 = unlabeled_train_iter.next()
            except:
                unlabeled_train_iter = iter(unlabeled_trainloader)
                inputs_u, inputs_u2 = unlabeled_train_iter.next()
            batch_size = inputs_x.size(0)

            # Transform label to one-hot
            labels_x = torch.zeros(batch_size, self.num_classes).scatter_(
                1, labels_x.view(-1, 1), 1)
            w_x = w_x.view(-1, 1).type(torch.FloatTensor)

            inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(
            ), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
            inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()

            with torch.no_grad():
                # label co-guessing of unlabeled samples
                outputs_u11 = net(inputs_u)
                outputs_u12 = net(inputs_u2)
                outputs_u21 = net2(inputs_u)
                outputs_u22 = net2(inputs_u2)

                pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) +
                      torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
                ptu = pu**(1/0.5)  # temparature sharpening

                targets_u = ptu / ptu.sum(dim=1, keepdim=True)  # normalize
                targets_u = targets_u.detach()

                # label refinement of labeled samples
                outputs_x = net(inputs_x)
                outputs_x2 = net(inputs_x2)

                px = (torch.softmax(outputs_x, dim=1) +
                      torch.softmax(outputs_x2, dim=1)) / 2
                px = w_x*labels_x + (1-w_x)*px
                ptx = px**(1/0.5)  # temparature sharpening

                targets_x = ptx / ptx.sum(dim=1, keepdim=True)  # normalize
                targets_x = targets_x.detach()

            # mixmatch
            l = np.random.beta(4, 4)
            l = max(l, 1-l)

            all_inputs = torch.cat(
                [inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat(
                [targets_x, targets_x, targets_u, targets_u], dim=0)

            idx = torch.randperm(all_inputs.size(0))

            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]

            mixed_input = l * input_a + (1 - l) * input_b
            mixed_target = l * target_a + (1 - l) * target_b

            logits = net(mixed_input)
            logits_x = logits[:batch_size*2]
            logits_u = logits[batch_size*2:]

            Lx, Lu, lamb = self.criterion(
                logits_x, mixed_target[:batch_size*2], logits_u, mixed_target[batch_size*2:], epoch+batch_idx/num_iter, self.warmup)

            # regularization
            prior = torch.ones(self.num_classes)/self.num_classes
            prior = prior.cuda()
            pred_mean = torch.softmax(logits, dim=1).mean(0)
            penalty = torch.sum(prior*torch.log(prior/pred_mean))

            loss = Lx + lamb * Lu + penalty
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    def adjust_learning_rate(self, optimizer, epoch):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.alpha_plan[epoch]
            param_group['betas'] = (
                self.beta1_plan[epoch], 0.999)  # Only change beta1

    def co_teaching_accuracy(self, logit, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        output = F.softmax(logit, dim=1)
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous(
            ).view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def smooth_softmax(self, pred, label):
        softmax_pred = F.softmax(pred) + 1e-7
        # return F.nll_loss(softmax_pred, label)
        # one_hot = F.one_hot(label, num_classes=self.num_classes)
        soft_log_pred = torch.log(softmax_pred) * 1.0
        # loss = one_hot * soft_log_pred
        return F.nll_loss(soft_log_pred, label)

    # TODO: F.nll_loss
    def smooth_softmax_correction(self, pred, label, T):
        # T = torch.from_numpy(T)
        softmax_pred = F.softmax(pred)
        softmax_pred = torch.mm(softmax_pred, T.cuda()) + 1e-7
        # one_hot = F.one_hot(label, num_classes=self.num_classes)
        soft_log_pred = torch.log(softmax_pred) * 1.0
        # loss = one_hot * soft_log_pred
        # return torch.mean(loss)
        return F.nll_loss(soft_log_pred, label)

        # Loss functions
    def loss_coteaching(self, y_1, y_2, t, forget_rate, ind, noise_or_not):
        loss_1 = F.cross_entropy(y_1, t, reduce=False)
        ind_1_sorted = torch.argsort(loss_1.data).cuda()
        loss_1_sorted = loss_1[ind_1_sorted]

        loss_2 = F.cross_entropy(y_2, t, reduce=False)
        ind_2_sorted = torch.argsort(loss_2.data).cuda()
        loss_2_sorted = loss_2[ind_2_sorted]

        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * len(loss_1_sorted))

        # pure_ratio_1 = torch.sum(noise_or_not[ind[ind_1_sorted[:num_remember].cpu()]])/float(num_remember)
        # pure_ratio_2 = torch.sum(noise_or_not[ind[ind_2_sorted[:num_remember].cpu()]])/float(num_remember)

        pure_ratio_1 = np.sum(
            noise_or_not[ind[ind_1_sorted[:num_remember].cpu()]])/float(num_remember)
        pure_ratio_2 = np.sum(
            noise_or_not[ind[ind_2_sorted[:num_remember].cpu()]])/float(num_remember)

        ind_1_update = ind_1_sorted[:num_remember]
        ind_2_update = ind_2_sorted[:num_remember]
        # exchange
        loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
        loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])

        return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2

    def co_teaching_train(self, model1, optimizer1, model2, optimizer2, epoch):
        # print('Training %s...' % model_str)
        pure_ratio_list = []
        pure_ratio_1_list = []
        pure_ratio_2_list = []

        train_total = 0
        train_correct = 0
        train_total2 = 0
        train_correct2 = 0

        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
            ind = indexes.cpu().numpy().transpose()
            # if i > args.num_iter_per_epoch:
            #     break

            images = images.cuda()
            labels = labels.cuda()

            # Forward + Backward + Optimize
            logits1 = model1(images)
            prec1, _ = self.co_teaching_accuracy(logits1, labels, topk=(1, 5))
            train_total += 1
            train_correct += prec1

            logits2 = model2(images)
            prec2, _ = self.co_teaching_accuracy(logits2, labels, topk=(1, 5))
            train_total2 += 1
            train_correct2 += prec2
            loss_1, loss_2, pure_ratio_1, pure_ratio_2 = self.loss_coteaching(
                logits1, logits2, labels, self.rate_schedule[epoch], ind, self.noise_or_not)
            pure_ratio_1_list.append(100*pure_ratio_1)
            pure_ratio_2_list.append(100*pure_ratio_2)

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            loss_1.backward(retain_graph=True)
            loss_2.backward()
            optimizer1.step()
            optimizer2.step()
            if (batch_idx + 1) % 50 == 0:
                print('Epoch [%d/%d], Iter [%d/%d] Training Accuracy1: %.4F, Training Accuracy2: %.4f, Loss1: %.4f, Loss2: %.4f, Pure Ratio1: %.4f, Pure Ratio2 %.4f'
                      % (epoch+1, self.total_epoch, batch_idx + 1, len(self.trainloader.dataset)//self.args.local_bs, prec1, prec2, loss_1.item(), loss_2.item(), np.sum(pure_ratio_1_list)/len(pure_ratio_1_list), np.sum(pure_ratio_2_list)/len(pure_ratio_2_list)))

        train_acc1 = float(train_correct)/float(train_total)
        train_acc2 = float(train_correct2)/float(train_total2)
        if train_acc1 > train_acc2:
            return model1, model2
        else:
            return model2, model1

    def estimate_T(self, model):
        conf_score_ = None
        pred_ = None
        targets_ = None
        pred_gt = []
        error = False
        cor_error = torch.zeros((self.num_classes, self.num_classes))
        cls_flag = np.array([0] * self.num_classes)
        # Forward
        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
            images, labels = images.cuda(), labels.cuda()
            output = model(images)
            outputs = F.softmax(output)
            conf_score, predicted = outputs.max(1)

            if conf_score_ is None:
                conf_score_ = conf_score.detach().cpu()
            else:
                conf_score_ = torch.hstack(
                    (conf_score_, conf_score.detach().cpu()))

            # pred_, 50000x1
            if pred_ is None:
                pred_ = predicted.detach().cpu()
            else:
                pred_ = torch.hstack((pred_, predicted.detach().cpu()))

            # targets_, 50000x1
            if targets_ is None:
                targets_ = labels.detach().cpu()
            else:
                targets_ = torch.hstack((targets_, labels.detach().cpu()))

        for i, k in enumerate(conf_score_):
            if k > 0.95:
                pred_gt.append((pred_[i], targets_[i]))
                row = pred_[i]
                col = targets_[i]
                cor_error[row][col] += 1
                cls_flag[row] += 1
                if pred_[i] != targets_[i]:
                    error = True

        for i in range(self.num_classes):
            if cls_flag[i] == 0:
                cor_error[i][i] = 1.0
            else:
                cor_error[i, :] /= cls_flag[i]

        if error:
            print('T', cor_error)
        return cor_error

    def revision_fit(self, model, filter_outlier=False):
        T = np.empty((self.num_classes, self.num_classes))
        eta_corr = None
        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
            images, labels = images.cuda(), labels.cuda()
            output = model(images)
            outputs = F.softmax(output)
            outputs = outputs.detach().cpu().numpy()
            if eta_corr is None:
                eta_corr = outputs
            else:
                eta_corr = np.vstack(
                    (eta_corr, outputs)
                )
        for i in np.arange(self.num_classes):
            if not filter_outlier:
                idx_best = np.argmax(eta_corr[:, i])
            else:
                eta_thresh = np.percentile(
                    eta_corr[:, i], 97, interpolation='higher')
                robust_eta = eta_corr[:, i]
                robust_eta[robust_eta >= eta_thresh] = 0.0
                idx_best = np.argmax(robust_eta)
            for j in np.arange(self.num_classes):
                T[i, j] = eta_corr[idx_best, j]
        return T

    def revieion_norm(self, T):
        row_sum = np.sum(T, 1)
        T_norm = T / row_sum
        return T_norm

    def loss_function_revision(self, out, T, correction, target):
        loss = 0.
        out_softmax = F.softmax(out, dim=1)
        if 'torch' not in str(type(T)):
            T = torch.from_numpy(T)
        T = T.cuda()
        for i in range(len(target)):
            temp_softmax = out_softmax[i]
            temp = out[i]
            temp = torch.unsqueeze(temp, 0)
            temp_softmax = torch.unsqueeze(temp_softmax, 0)
            temp_target = target[i]
            temp_target = torch.unsqueeze(temp_target, 0)
            pro1 = temp_softmax[:, target[i]]
            T = T + correction
            T_result = T
            out_T = torch.matmul(T_result.t(), temp_softmax.t().double())
            out_T = out_T.t()
            pro2 = out_T[:, target[i]]
            beta = (pro1 / pro2)
            cross_loss = F.cross_entropy(temp, temp_target)
            _loss = beta * cross_loss
            loss += _loss
        return loss / len(target)

    def loss_func_reweight(self, out, T, target):
        loss = 0.
        out_softmax = F.softmax(out, dim=1)
        if 'torch' not in str(type(T)):
            T = torch.from_numpy(T)
        T = T.cuda()
        for i in range(len(target)):
            temp_softmax = out_softmax[i]
            temp = out[i]
            temp = torch.unsqueeze(temp, 0)
            temp_softmax = torch.unsqueeze(temp_softmax, 0)
            temp_target = target[i]
            temp_target = torch.unsqueeze(temp_target, 0)
            pro1 = temp_softmax[:, target[i]]
            out_T = torch.matmul(T.t(), temp_softmax.t().double())
            out_T = out_T.t()
            pro2 = out_T[:, target[i]]
            beta = pro1 / pro2
            beta = Variable(beta, requires_grad=True)
            cross_loss = F.cross_entropy(temp, temp_target)
            _loss = beta * cross_loss
            loss += _loss
        return loss / len(target)
