#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

from email.mime import image
import enum
from operator import itemgetter
import re
from typing import Counter
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import copy
import numpy as np
from torch.optim.lr_scheduler import MultiStepLR
from torch.autograd import Variable


class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label), self.idxs[item]


class LocalUpdate(object):
    def __init__(self, args, dataset, idxs, num_classes, label_cache_dist=None):
        self.args = args
        np.random.seed(args.seed)
        if args.dataset == 'clothing1m':
            targets = np.array(dataset.targets)[idxs]
            idxs_ = []
            idxs_np = np.array(idxs)
            targets_dist = Counter(targets)
            min_key, min_count = min(targets_dist.items(), key=itemgetter(1))
            for i in targets_dist.keys():
                i_index = np.where(targets == i)[0]  # index in idxs
                i_idxs = idxs_np[i_index]
                np.random.shuffle(i_idxs)
                idxs_.extend(i_idxs[:min_count])
            idxs = idxs_
        self.trainloader = DataLoader(DatasetSplit(dataset, idxs),
                                      batch_size=self.args.local_bs, shuffle=True, num_workers=4)
        # print(targets[np.array(idxs)])
        self.total_epoch = args.epochs
        self.num_classes = num_classes
        self.method = args.method
        # self.T = None

        # co-teaching initilization
        self.mom1 = 0.9
        self.mom2 = 0.1
        self.alpha_plan = [self.args.lr] * self.total_epoch
        self.beta1_plan = [self.mom1] * self.total_epoch
        for i in range(80, self.total_epoch):
            self.alpha_plan[i] = float(
                self.total_epoch - i) / (self.total_epoch - 80) * args.lr
            self.beta1_plan[i] = self.mom2
        self.forget_rate = args.noise_ratio
        self.rate_schedule = np.ones(self.total_epoch) * self.forget_rate
        self.rate_schedule[:10] = np.linspace(0, self.forget_rate ** 1, 10)
        self.noise_or_not = dataset.noise_or_not

        # peer
        self.label_cache_dist = label_cache_dist
        self.alpha = self.args.alpha

        # T-revision
        self.revision_lr = 5e-7
        self.n_epoch_estimate = 1
        self.total_epoch = 4

        # self.index_num = int(len(self.trainloader.dataset) / args.local_bs)
        # self.total_index = self.index_num + 1
        # self.A = torch.zeros((self.n_epoch_estimate, len(idxs), num_classes))

    def update_weights(self, model, model2, cur_epoch, T):
        model.train()
        model2.train()
        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                    momentum=0.9)
        if self.method == 'peer':
            assert self.label_cache_dist is not None
            if cur_epoch < 50:
                self.alpha = 0
            elif cur_epoch >= 50 and cur_epoch < 100:
                self.alpha = (cur_epoch - 50) * 0.02
            print('alpha', self.alpha)
            for iter in range(self.args.local_ep):
                for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                    images, labels = images.cuda(), labels.cuda()
                    model.zero_grad()
                    output = model(images)
                    peer_label = np.random.choice(
                        self.num_classes, len(images), p=self.label_cache_dist[0])
                    peer_label = torch.from_numpy(peer_label).cuda()
                    loss = self.smooth_softmax(
                        output, labels) - self.alpha * self.smooth_softmax(output, peer_label)
                    loss.backward()
                    optimizer.step()
                    # print(f'Loss: {loss.item()}')
            return model.state_dict(), model2.state_dict()
        elif self.method == 'correction':
            if cur_epoch < int(self.total_epoch / 2):
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        model.zero_grad()
                        output = model(images)
                        loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer.step()
            # elif cur_epoch == int(self.total_epoch / 2):
            #     T = self.estimate_T(model)
            if cur_epoch >= int(self.total_epoch / 2):
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        model.zero_grad()
                        output = model(images)
                        loss = self.smooth_softmax_correction(
                            output, labels, T)
                        # loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer.step()
            return model.state_dict(), model2.state_dict()
        elif self.method == 'coteaching':
            # model2 = copy.deepcopy(model)
            # Co-teaching uses Adam as optimizer
            optimizer1 = torch.optim.Adam(model.parameters(), lr=self.args.lr)
            optimizer2 = torch.optim.Adam(model2.parameters(), lr=self.args.lr)
            for local_epoch in range(self.args.local_ep):
                model.train()
                self.adjust_learning_rate(optimizer1, cur_epoch)
                model2.train()
                self.adjust_learning_rate(optimizer2, cur_epoch)
                model, model2 = self.co_teaching_train(
                    model, optimizer1, model2, optimizer2, cur_epoch)
            return model.state_dict(), model2.state_dict()
        elif self.method == 'revision':
            optimizer_es = torch.optim.SGD(
                model.parameters(), lr=self.args.lr, weight_decay=1e-4)
            optimizer = torch.optim.SGD(
                model.parameters(), lr=self.args.lr, weight_decay=1e-4, momentum=0.9)
            optimizer_revision = torch.optim.Adam(
                model.parameters(), lr=self.revision_lr, weight_decay=1e-4)
            # scheduler = MultiStepLR(optimizer, milestones=[40, 80], gamma=0.1)
            if cur_epoch < int(self.total_epoch / 2):
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        optimizer_es.zero_grad()
                        output = model(images)
                        loss = F.cross_entropy(output, labels)
                        loss.backward()
                        optimizer_es.step()
            # elif cur_epoch == self.n_epoch_estimate:
            #     T = self.revision_fit(model, True)
            #     T = self.revieion_norm(T)
            elif cur_epoch >= self.total_epoch and cur_epoch < 290:
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        optimizer.zero_grad()
                        output = model(images)
                        loss = self.loss_func_reweight(output, T, labels)
                        loss.backward()
                        optimizer.step()
            elif cur_epoch >= 290:
                for iter in range(self.args.local_ep):
                    for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                        images, labels = images.cuda(), labels.cuda()
                        optimizer_revision.zero_grad()
                        output, correction = model(images, revision=True)
                        loss = self.loss_function_revision(
                            output, T, correction, labels)
                        loss.backward()
                        optimizer_revision.step()
            return model.state_dict(), model2.state_dict()
        elif self.method == 'baseline':
            for iter in range(self.args.local_ep):
                for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
                    images, labels = images.cuda(), labels.cuda()
                    model.zero_grad()
                    output = model(images)
                    loss = F.cross_entropy(output, labels)
                    loss.backward()
                    optimizer.step()
            return model.state_dict(), model2.state_dict()

    def adjust_learning_rate(self, optimizer, epoch):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.alpha_plan[epoch]
            param_group['betas'] = (
                self.beta1_plan[epoch], 0.999)  # Only change beta1

    def co_teaching_accuracy(self, logit, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        output = F.softmax(logit, dim=1)
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous(
            ).view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def smooth_softmax(self, pred, label):
        softmax_pred = F.softmax(pred) + 1e-7
        # return F.nll_loss(softmax_pred, label)
        # one_hot = F.one_hot(label, num_classes=self.num_classes)
        soft_log_pred = torch.log(softmax_pred) * 1.0
        # loss = one_hot * soft_log_pred
        return F.nll_loss(soft_log_pred, label)

    # TODO: F.nll_loss
    def smooth_softmax_correction(self, pred, label, T):
        # T = torch.from_numpy(T)
        softmax_pred = F.softmax(pred)
        softmax_pred = torch.mm(softmax_pred, T.cuda()) + 1e-7
        # one_hot = F.one_hot(label, num_classes=self.num_classes)
        soft_log_pred = torch.log(softmax_pred) * 1.0
        # loss = one_hot * soft_log_pred
        # return torch.mean(loss)
        return F.nll_loss(soft_log_pred, label)

        # Loss functions
    def loss_coteaching(self, y_1, y_2, t, forget_rate, ind, noise_or_not):
        loss_1 = F.cross_entropy(y_1, t, reduce=False)
        ind_1_sorted = torch.argsort(loss_1.data).cuda()
        loss_1_sorted = loss_1[ind_1_sorted]

        loss_2 = F.cross_entropy(y_2, t, reduce=False)
        ind_2_sorted = torch.argsort(loss_2.data).cuda()
        loss_2_sorted = loss_2[ind_2_sorted]

        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * len(loss_1_sorted))

        # pure_ratio_1 = torch.sum(noise_or_not[ind[ind_1_sorted[:num_remember].cpu()]])/float(num_remember)
        # pure_ratio_2 = torch.sum(noise_or_not[ind[ind_2_sorted[:num_remember].cpu()]])/float(num_remember)

        # pure_ratio_1 = np.sum(
        #     noise_or_not[ind[ind_1_sorted[:num_remember].cpu()]])/float(num_remember)
        # pure_ratio_2 = np.sum(
        #     noise_or_not[ind[ind_2_sorted[:num_remember].cpu()]])/float(num_remember)

        ind_1_update = ind_1_sorted[:num_remember]
        ind_2_update = ind_2_sorted[:num_remember]
        # exchange
        loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
        loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])

        return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, 0, 0

    def co_teaching_train(self, model1, optimizer1, model2, optimizer2, epoch):
        # print('Training %s...' % model_str)
        pure_ratio_list = []
        pure_ratio_1_list = []
        pure_ratio_2_list = []

        train_total = 0
        train_correct = 0
        train_total2 = 0
        train_correct2 = 0

        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
            ind = indexes.cpu().numpy().transpose()
            # if i > args.num_iter_per_epoch:
            #     break

            images = images.cuda()
            labels = labels.cuda()

            # Forward + Backward + Optimize
            logits1 = model1(images)
            prec1, _ = self.co_teaching_accuracy(logits1, labels, topk=(1, 5))
            train_total += 1
            train_correct += prec1

            logits2 = model2(images)
            prec2, _ = self.co_teaching_accuracy(logits2, labels, topk=(1, 5))
            train_total2 += 1
            train_correct2 += prec2
            loss_1, loss_2, pure_ratio_1, pure_ratio_2 = self.loss_coteaching(
                logits1, logits2, labels, self.rate_schedule[epoch], ind, self.noise_or_not)
            pure_ratio_1_list.append(100*pure_ratio_1)
            pure_ratio_2_list.append(100*pure_ratio_2)

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            loss_1.backward(retain_graph=True)
            loss_2.backward()
            optimizer1.step()
            optimizer2.step()
            if (batch_idx + 1) % 50 == 0:
                print('Epoch [%d/%d], Iter [%d/%d] Training Accuracy1: %.4F, Training Accuracy2: %.4f, Loss1: %.4f, Loss2: %.4f, Pure Ratio1: %.4f, Pure Ratio2 %.4f'
                      % (epoch+1, self.total_epoch, batch_idx + 1, len(self.trainloader.dataset)//self.args.local_bs, prec1, prec2, loss_1.item(), loss_2.item(), np.sum(pure_ratio_1_list)/len(pure_ratio_1_list), np.sum(pure_ratio_2_list)/len(pure_ratio_2_list)))

        train_acc1 = float(train_correct)/float(train_total)
        train_acc2 = float(train_correct2)/float(train_total2)
        if train_acc1 > train_acc2:
            return model1, model2
        else:
            return model2, model1

    def estimate_T(self, model):
        conf_score_ = None
        pred_ = None
        targets_ = None
        pred_gt = []
        error = False
        cor_error = torch.zeros((self.num_classes, self.num_classes))
        cls_flag = np.array([0] * self.num_classes)
        # Forward
        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
            images, labels = images.cuda(), labels.cuda()
            output = model(images)
            outputs = F.softmax(output)
            conf_score, predicted = outputs.max(1)

            if conf_score_ is None:
                conf_score_ = conf_score.detach().cpu()
            else:
                conf_score_ = torch.hstack(
                    (conf_score_, conf_score.detach().cpu()))

            # pred_, 50000x1
            if pred_ is None:
                pred_ = predicted.detach().cpu()
            else:
                pred_ = torch.hstack((pred_, predicted.detach().cpu()))

            # targets_, 50000x1
            if targets_ is None:
                targets_ = labels.detach().cpu()
            else:
                targets_ = torch.hstack((targets_, labels.detach().cpu()))

        for i, k in enumerate(conf_score_):
            if k > 0.95:
                pred_gt.append((pred_[i], targets_[i]))
                row = pred_[i]
                col = targets_[i]
                cor_error[row][col] += 1
                cls_flag[row] += 1
                if pred_[i] != targets_[i]:
                    error = True

        for i in range(self.num_classes):
            if cls_flag[i] == 0:
                cor_error[i][i] = 1.0
            else:
                cor_error[i, :] /= cls_flag[i]

        if error:
            print('T', cor_error)
        # T = cor_error.detach().cpu().numpy()
        return T

    def revision_fit(self, model, filter_outlier=False):
        T = np.empty((self.num_classes, self.num_classes))
        eta_corr = None
        for batch_idx, (images, labels, indexes) in enumerate(self.trainloader):
            images, labels = images.cuda(), labels.cuda()
            output = model(images)
            outputs = F.softmax(output)
            outputs = outputs.detach().cpu().numpy()
            if eta_corr is None:
                eta_corr = outputs
            else:
                eta_corr = np.vstack(
                    (eta_corr, outputs)
                )
        for i in np.arange(self.num_classes):
            if not filter_outlier:
                idx_best = np.argmax(eta_corr[:, i])
            else:
                eta_thresh = np.percentile(
                    eta_corr[:, i], 97, interpolation='higher')
                robust_eta = eta_corr[:, i]
                robust_eta[robust_eta >= eta_thresh] = 0.0
                idx_best = np.argmax(robust_eta)
            for j in np.arange(self.num_classes):
                T[i, j] = eta_corr[idx_best, j]
        return T

    def revieion_norm(self, T):
        row_sum = np.sum(T, 1)
        T_norm = T / row_sum
        return T_norm

    def loss_function_revision(self, out, T, correction, target):
        loss = 0.
        out_softmax = F.softmax(out, dim=1)
        # T = torch.from_numpy(T).cuda()
        if 'torch' not in str(type(T)):
            T = torch.from_numpy(T)
        T = T.cuda()
        for i in range(len(target)):
            temp_softmax = out_softmax[i]
            temp = out[i]
            temp = torch.unsqueeze(temp, 0)
            temp_softmax = torch.unsqueeze(temp_softmax, 0)
            temp_target = target[i]
            temp_target = torch.unsqueeze(temp_target, 0)
            pro1 = temp_softmax[:, target[i]]
            T = T + correction
            T_result = T
            out_T = torch.matmul(T_result.t(), temp_softmax.t().double())
            out_T = out_T.t()
            pro2 = out_T[:, target[i]]
            beta = (pro1 / pro2)
            cross_loss = F.cross_entropy(temp, temp_target)
            _loss = beta * cross_loss
            loss += _loss
        return loss / len(target)

    def loss_func_reweight(self, out, T, target):
        loss = 0.
        out_softmax = F.softmax(out, dim=1)
        if 'torch' not in str(type(T)):
            T = torch.from_numpy(T)
        T = T.cuda()
        for i in range(len(target)):
            temp_softmax = out_softmax[i]
            temp = out[i]
            temp = torch.unsqueeze(temp, 0)
            temp_softmax = torch.unsqueeze(temp_softmax, 0)
            temp_target = target[i]
            temp_target = torch.unsqueeze(temp_target, 0)
            pro1 = temp_softmax[:, target[i]]
            out_T = torch.matmul(T.t(), temp_softmax.t().double())
            out_T = out_T.t()
            pro2 = out_T[:, target[i]]
            beta = pro1 / pro2
            beta = Variable(beta, requires_grad=True)
            cross_loss = F.cross_entropy(temp, temp_target)
            _loss = beta * cross_loss
            loss += _loss
        return loss / len(target)
