import collections
import copy
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F

# import wandb
import os
from omegaconf import DictConfig
from torch import nn
from torch.utils.data import DataLoader, Subset, TensorDataset, Dataset

# Variable import
from torch.autograd import Variable
from matplotlib import pyplot as plt
from tqdm import tqdm
from torchvision.utils import make_grid

import model
import wgan
import dcgan
import model_from_DF
import model_from_aux
from utils import SubsetWithTargets

# import for generator train
# from GAN.model.wgan import Generator, Discriminator
from functions import loop, cal_grad_penalty

MODEL = {
    "cnnmnist": model.CNNMnist,
    "cnnmnist2": model.CNNMnist2,
    "cnncifar": model.CNNCifar,
    "cifar100resnet18": model.cifar100ResNet18,
    "tinyimagenetresnet18": model.tinyimagenetResNet18,
    "resnet18": model.ResNet18,
    "resnet50": model.ResNet50,
    "imagenetresnet18": model.ImagenetResNet18,
    "cnnfmnist": model.CNNFashion_Mnist,
    "iris": model.Iris,
    "resnet8_from_aux": model_from_aux.resnet8,
    "resnet8_from_DF": model_from_DF.ResNet8,
}
GEN_MODEL = {
    # "wgan": model.WGANGenerator,
    # "wgan_res": model.GenServerResNet8,
    "dcgan": dcgan.Generator,
    "wgan": wgan.Generator,
}
SERVER_DISC_MODEL = {
    # "wgan": model.WGANDiscriminator,
    # "wgan_res": model.DiscServerResNet8,
    "dcgan": dcgan.Discriminator,
    "wgan": wgan.Discriminator,
}
Client_DISC_MODEL = {
    # "resnet18": model.DiscResNet18,
    # "wgan_res": model.DiscServerResNet8,
    "dcgan": dcgan.Discriminator,
    "wgan": wgan.Discriminator,
    "resnet8": model.ResNet8_disc,
    "resnet18": model.ResNet18_disc,
}

EM_SOFT_SET = {"em_soft", "em_entropy_soft", "df", "gan", "et", "logit_var", "df_gkd", "gan_dafkd"}
EM_SET = EM_SOFT_SET | {"em"}


class Device:
    def __init__(self, cfg: DictConfig, dset: SubsetWithTargets):
        self.cfg: DictConfig = cfg
        self.dset: SubsetWithTargets = dset
        _gpu = f"cuda:{cfg.simul.gpu}" if cfg.simul.gpu != -1 else "cpu"
        self.gpu = torch.device(_gpu)

        self.net = MODEL[self.cfg.model.model]()
        if not cfg.simul.gpu_on_off:
            self.net.to(self.gpu)

            self.gan_gen = GEN_MODEL[self.cfg.gan.model]().to(self.gpu)
            if cfg.dset.name == "fashionmnist":
                nc = 1
            else:
                nc = 3
            self.disc_net = Client_DISC_MODEL[self.cfg.gan.c.model](
                act=self.cfg.gan.c.act,
                eps=self.cfg.gan.c.eps,
                norm=self.cfg.gan.c.norm,
                bn=self.cfg.gan.c.bn,
                clip=self.cfg.gan.c.clip,
                nc=nc,
            ).to(self.gpu)
            if cfg.gan.timing in ["s", "a"]:
                # net and disc_net share params upto layer2
                self.disc_net.conv1 = self.net.conv1
                self.disc_net.bn1 = self.net.bn1
                self.disc_net.layer1 = self.net.layer1
                self.disc_net.layer2 = self.net.layer2

    def update(self, w: collections.OrderedDict):
        self.net.load_state_dict(copy.deepcopy(w))


class TrainableDevice(Device):
    def __init__(self, cfg: DictConfig, dset: SubsetWithTargets):
        super().__init__(cfg, dset)

        # hyper params
        cfg_ = cfg.s if isinstance(self, TrainableServer) else cfg.c

        self.ep = cfg_.ep
        self.bs = cfg_.bs
        self.optimizer = self._get_optimizer(cfg_, self.net)
        if cfg.gan.timing:
            self.disc_optimizer = self._get_optimizer(cfg.gan, self.disc_net)

        # loss input example: logit(unnormalized value) tensor(64, 10), label int tensor(64)
        self.loss_fn = nn.CrossEntropyLoss()
        if not isinstance(self, TrainableServer):
            if self.cfg.env.ratio_c_noise:
                self.dl = self.create_noisy_dl(self.cfg.env.ratio_c_noise)
            else:
                self.dl = DataLoader(self.dset, batch_size=self.bs, shuffle=True)

    def get_state_dicts(self) -> Dict:
        states = {
            "cfg": self.cfg,
            "dset": self.dset,
            "net": self.net.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        return states

    def create_noisy_dl(self, ratio_c_noise) -> DataLoader:
        self.noisy_labels = self._get_noisy_labels(ratio_c_noise)
        dl = DataLoader(self.dset, batch_size=len(self.dset), shuffle=False)
        data, _ = iter(dl).next()
        noisy_dset = TensorDataset(data, torch.tensor(self.noisy_labels))
        return DataLoader(noisy_dset, batch_size=self.bs, shuffle=True)

    def _get_noisy_labels(self, ratio_c_noise) -> np.ndarray:
        dset = self.dset
        noisy_labels = copy.copy(dset.targets)
        before_noise = 0
        for cls in range(self.cfg.dset.n_cls):
            idxs = np.where(dset.targets == cls)[0]
            n_noisy_idxs = int(len(idxs) * ratio_c_noise)
            noisy_idxs = idxs[:n_noisy_idxs]
            noisy_clses = self._get_clses_except(cls, self.cfg.dset.n_cls, before_noise + 1)
            noisy_cls = before_noise  # to avoid NameError when n_noisy_idxs == 0
            for noisy_idx, noisy_cls in zip(noisy_idxs, noisy_clses):
                noisy_labels[noisy_idx] = noisy_cls
            before_noise = noisy_cls
        return np.array(noisy_labels)

    @staticmethod
    def _get_clses_except(cls: int, n_cls: int, start_idx: int):
        """Generator for repeats 0 ~ n_cls except cls.

        Examples:
            cls=1, n_cls=10
            (0, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, ...)
        """
        i = start_idx
        while True:
            if i == cls:
                i += 1
            elif i == n_cls:
                i = 0
            else:
                yield i
                i += 1

    def _get_gan_loss(self, real, fake, disc_net):
        criterion = nn.BCEWithLogitsLoss()
        real_loss = criterion(disc_net(real).squeeze(), torch.ones(real.shape[0]).to(self.gpu).squeeze())
        fake_loss = criterion(disc_net(fake).squeeze(), torch.zeros(fake.shape[0]).to(self.gpu).squeeze())
        # print(real_loss, real_loss.shape, fake_loss, fake_loss.shape)
        # print(disc_net(real).squeeze().shape, torch.ones(real.shape[0]).shape)
        # print(disc_net(fake).squeeze().shape, torch.zeros(fake.shape[0]).shape)
        return real_loss + fake_loss

    def train(self) -> float:
        self.net.train()
        if self.cfg.gan.timing:
            self.disc_net.train()
        tot_loss = 0
        tot_gan_loss = 0

        loss_cnt = 0
        if self.cfg.fl.combine == "prox" or self.cfg.gkd.is_gkd == True:
            global_model = copy.deepcopy(self.net)
        for _ in range(self.ep):
            for images, labels in self.dl:
                images, labels = images.to(self.gpu), labels.to(self.gpu)
                self.optimizer.zero_grad()
                log_probs = self.net(images)
                loss = torch.nan_to_num(self.loss_fn(log_probs, labels))

                if self.cfg.fl.combine == "prox":
                    proximal_term = 0.0
                    for w, w_t in zip(self.net.parameters(), global_model.parameters()):
                        proximal_term += (w - w_t).norm(2)
                    loss += (self.cfg.fl.mu / 2) * proximal_term

                if self.cfg.gkd.is_gkd:
                    global_log_probs = global_model(images)
                    tau = self.cfg.gkd.tau
                    gkd_loss = F.kl_div(
                        F.log_softmax(log_probs / tau, dim=1),
                        F.softmax(global_log_probs / tau, dim=1),
                        reduction="batchmean",
                    )
                    loss += self.cfg.gkd.gamma / 2 * gkd_loss
                loss.backward()
                self.optimizer.step()
                tot_loss += loss.item()
                loss_cnt += 1
        return tot_loss / loss_cnt

    def disc_train(self, dset: SubsetWithTargets) -> float:
        self.disc_net.train()
        tot_loss = 0
        loss_cnt = 0
        dl = DataLoader(dset, self.cfg.gan.bs)
        # for _ in range(self.cfg.gan.ep):
        for images, _ in dl:
            real = images.to(self.gpu)
            fake = self.gan_gen(torch.randn(images.shape[0], 128).to(self.gpu))
            # fake = (fake * self.gen_stds + self.gen_means - self.means) / self.stds
            loss = self._get_gan_loss(real, fake, self.disc_net)
            self.disc_optimizer.zero_grad()
            loss.backward()
            self.disc_optimizer.step()
            tot_loss += loss.item()
            loss_cnt += 1
        return tot_loss / loss_cnt

    @torch.no_grad()
    def disc_val(self, dset: SubsetWithTargets) -> float:
        self.disc_net.eval()
        correct = 0
        tot_loss = 0
        dl = DataLoader(dset, self.cfg.gan.bs)
        for images, _ in dl:
            real = images.to(self.gpu)
            fake = self.gan_gen(torch.randn(images.shape[0], 128).to(self.gpu))

            real_predict = self.disc_net(real)
            fake_predict = self.disc_net(fake)

            loss = self._get_gan_loss(real, fake, self.disc_net)

            correct += (real_predict > 0.5).sum().cpu().item() + (fake_predict < 0.5).sum().cpu().item()
            tot_loss += loss.item()

        accuracy = 100.00 * correct / len(dset) / 2

        return accuracy, tot_loss / len(dset)

    @torch.no_grad()
    def val(self, dset_test: SubsetWithTargets) -> float:
        """Return test accuracy and predictions for each cls

        Args:
            dset_test (SubsetWithTargets): test set

        Returns:
            Tuple[int, Dict[int, Counter]]: Given true lable is c, Counter that shows the net's predicted label
                This is only used to log pred_counter_for_cls plot.
        """
        self.net.eval()
        if self.cfg.simul.gpu_on_off:
            self.net.to(self.gpu)
        correct = 0
        dl = DataLoader(dset_test, batch_size=self.cfg.s.bs)
        for images, labels in dl:
            images, labels = images.to(self.gpu), labels.to(self.gpu)
            log_probs = self.net(images)
            y_pred = log_probs.data.argmax(dim=1)
            correct += labels.data.eq(y_pred).sum().cpu().item()
            accuracy = 100.00 * correct / len(dset_test)
        if self.cfg.simul.gpu_on_off:
            self.net.to(self.gpu)
        return accuracy

    @staticmethod
    def _get_optimizer(
        cfg_,
        net: nn.Module,
        rot_net: Optional[nn.Module] = None,
        disc_net: Optional[nn.Module] = None,
    ) -> torch.optim.Optimizer:
        """Return optimizer.

        If both net exists, then opt params: net.params and backend of rot_net.params.
        Because frontend of rot_net is shared with frontend of net.
        """
        if rot_net:
            if isinstance(rot_net, model.ResNet) or isinstance(rot_net, model.ImagenetResNet):
                params = [
                    {"params": net.parameters()},
                    {"params": rot_net.layer3.parameters()},
                    {"params": rot_net.layer4.parameters()},
                    {"params": rot_net.linear.parameters()},
                ]
            else:
                params = [
                    {"params": net.parameters()},
                    {"params": rot_net.layer3.parameters()},
                    {"params": rot_net.classifier.parameters()},
                ]
        elif disc_net:
            if isinstance(disc_net, model.ResNet):
                params = [
                    {"params": net.parameters()},
                    {"params": disc_net.layer3.parameters()},
                    {"params": disc_net.layer4.parameters()},
                    {"params": disc_net.linear.parameters()},
                ]
        else:
            params = net.parameters()

        if cfg_.optim == "sgd":
            return torch.optim.SGD(
                params=params,
                lr=cfg_.lr,
                momentum=cfg_.momentum,
                weight_decay=cfg_.weight_decay,
            )
        elif cfg_.optim == "adam":
            if cfg_.beta1 and cfg_.beta2:
                return torch.optim.Adam(
                    params=params,
                    lr=cfg_.lr,
                    weight_decay=cfg_.weight_decay,
                    betas=(cfg_.beta1, cfg_.beta2),
                )
            return torch.optim.Adam(params=params, lr=cfg_.lr, weight_decay=cfg_.weight_decay)


# used in TrainableServer-combine_em
class CustomDataset(Dataset):
    def __init__(self, data, label, reg_logit):
        self.data = data
        self.label = label
        self.reg_logit = reg_logit

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label1 = self.label[idx]
        label2 = self.reg_logit[idx]

        return image, label1, label2


class TrainableServer(TrainableDevice):
    """Server Device

    Attributes:
        combine (self, devices): Update the server network using input devices.
    """

    def __init__(self, cfg: DictConfig, dset: SubsetWithTargets):
        super().__init__(cfg, dset)
        self.round = 0
        if self.cfg.env.ratio_labeled:
            labeled_subset = self._create_labeled_subset()
            self.labeled_iter = iter(self.labeled_data_generator(labeled_subset, self.cfg.env.labeled_bs))
        self.images = torch.stack([data[0] for data in self.dset])
        if self.cfg.fl.model_save:
            self.c_models = {}

        if cfg.gan.timing:
            if self.cfg.dset_s.name == "fashionmnist":
                nc = 1
            else:
                nc = 3
            self.gan_disc_net = SERVER_DISC_MODEL[self.cfg.gan.model](nc=nc).to(self.gpu)
            self.gen_optimizer = self._get_optimizer(cfg.gan, self.gan_gen)
            self.s_disc_optimizer = self._get_optimizer(cfg.gan, self.gan_disc_net)
            self.nc = nc
            if cfg.gan.dp.is_dp:
                pass

        if cfg.gan.s.timing:
            self.disc_net.conv1 = self.net.conv1
            self.disc_net.bn1 = self.net.bn1
            self.disc_net.layer1 = self.net.layer1
            self.disc_net.layer2 = self.net.layer2
            self.optimizer = self._get_optimizer(cfg.s, self.net, self.disc_net)

        if self.cfg.gkd.is_gkd:
            self.global_model = MODEL[self.cfg.model.model]().to(self.gpu)
            self.global_model_buffer = []

        if self.cfg.s.anneal:
            self.max_anneal = self.cfg.simul.max_round * self.cfg.s.ep
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.max_anneal)

    def compute_gradient_penalty(self, real, fake, disc_net):
        alpha = torch.rand(real.shape[0], 1, 1, 1).to(self.gpu)
        interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
        d_interpolates = disc_net(interpolates)
        fake = torch.zeros(real.shape[0]).to(self.gpu)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def _get_wgan_loss(self, real, fake, disc_net):
        real_loss = disc_net(real).mean()
        fake_loss = disc_net(fake).mean()
        gradient_penalty = self.compute_gradient_penalty(real, fake, disc_net)
        loss = fake_loss - real_loss + self.cfg.gan.gp_lambda * gradient_penalty
        return loss

    def gen_train(self, dset: SubsetWithTargets) -> float:

        dataloader = DataLoader(dset, batch_size=self.cfg.gan.bs, shuffle=True)
        loop_dataloader = loop(dataloader)

        optimG = torch.optim.Adam(self.gan_gen.parameters(), lr=0.0002, betas=(0, 0.9))
        optimD = torch.optim.Adam(self.gan_disc_net.parameters(), lr=0.0002, betas=(0, 0.9))

        if self.cfg.gan.lr_decay:

            def lr_lambda(current_iter):
                return 1 - (current_iter / self.cfg.gan.load_ep)

            schedulerD = torch.optim.lr_scheduler.LambdaLR(optimD, lr_lambda)
            schedulerG = torch.optim.lr_scheduler.LambdaLR(optimG, lr_lambda)

        step = 0
        pbar = tqdm(range(self.cfg.gan.load_ep))

        for gen_ep in pbar:
            # update discriminator
            for p in self.gan_disc_net.parameters():
                p.requires_grad_(True)

            for i in range(self.cfg.gan.d_iter):
                self.gan_disc_net.zero_grad()
                batch = next(loop_dataloader)
                real = batch[0].cuda()
                batch_size = real.size(0)
                out_real = self.gan_disc_net(real)

                noise = torch.randn(batch_size, 128).cuda()
                fake = self.gan_gen(noise)
                out_fake = self.gan_disc_net(fake.detach())

                # gradient penalty
                gp = cal_grad_penalty(self.gan_disc_net, real.data, fake.data, batch_size, nc=self.nc)

                errD = out_fake.mean() - out_real.mean() + gp
                errD.backward()
                optimD.step()

            # update generator
            for p in self.gan_disc_net.parameters():
                p.requires_grad_(False)

            self.gan_gen.zero_grad()
            noise = torch.randn(self.cfg.gan.bs, 128).cuda()
            fake = self.gan_gen(noise)
            out_fake = self.gan_disc_net(fake)

            errG = -out_fake.mean()
            errG.backward()
            optimG.step()

            step += 1
            if self.cfg.gan.lr_decay:
                schedulerD.step()
                schedulerG.step()

            if (step + 1) % 25000 == 0:
                torch.save(
                    self.gan_gen.state_dict(),
                    f"/workspace/model/{self.cfg.dset.name}_{self.cfg.fl.n_c_ratio}_{self.cfg.dset_s.name}_{self.cfg.fl.n_s_ratio}_diri{self.cfg.fl.diri_alpha}_gan_gen_{step+1}_s{self.cfg.simul.seed}.pt",
                )

    def _create_labeled_subset(self) -> Subset:
        """Create subset of server training set consists of x% of each cls of the training set."""
        tot_labeled_idxs = []
        for cls in range(self.cfg.dset.n_cls):
            idxs = np.where(self.dset.targets == cls)[0]
            n_labeled_idxs = int(len(idxs) * self.cfg.env.ratio_labeled)
            labeled_idxs = idxs[:n_labeled_idxs]
            tot_labeled_idxs.append(labeled_idxs)
        tot_labeled_idxs = np.concatenate(tot_labeled_idxs)
        return Subset(self.dset, tot_labeled_idxs)

    def combine(self, devices: TrainableDevice):
        if self.cfg.fl.combine == "avg" or self.cfg.fl.combine == "prox":
            self.combine_avg(devices.values())
        elif self.cfg.fl.combine in EM_SET:
            if self.cfg.fl.combine_from == "avg":
                self.combine_avg(devices.values())

            if self.cfg.fl.model_save:
                for c in devices:
                    self.c_models[c] = copy.deepcopy(devices[c])
                devices = self.c_models.values()
            else:
                devices = devices.values()
            self.combine_em(devices)
        else:
            raise ValueError(f"combine must be avg, em, or avg_em  but got {self.cfg.fl.combine}")
        self.round += 1

    def combine_avg(self, devices: TrainableDevice):
        """Update the server network as an average of input devices' network."""
        weights = [copy.deepcopy(d.net.state_dict()) for d in devices]
        w_avg = weights.pop()

        for k in w_avg.keys():
            for w in weights:
                w_avg[k] += w[k]
            w_avg[k] = torch.div(w_avg[k], float(len(weights) + 1))

        self.net.load_state_dict(w_avg)

    def labeled_data_generator(self, labeled_subset: Subset, bs: int):
        while True:
            labeled_dl = DataLoader(labeled_subset, batch_size=bs, shuffle=True)
            for images, labels in labeled_dl:
                yield images, labels

    def combine_em(self, devices: TrainableDevice):
        dl = DataLoader(self.dset, batch_size=self.bs, shuffle=False)

        if self.cfg.gkd.is_gkd and self.cfg.gkd.avg_first:
            self.global_model_buffer.append(copy.deepcopy(self.net.state_dict()))
            if len(self.global_model_buffer) > self.cfg.gkd.M:
                del self.global_model_buffer[0]

                w_avg = self.global_model_buffer[-1]
                for k in w_avg.keys():
                    for w in self.global_model_buffer:
                        if w_avg[k].type() == "torch.cuda.LongTensor":
                            w[k] = w[k].long()
                        w_avg[k] += w[k]
                    w_avg[k] = torch.div(w_avg[k], float(len(self.global_model_buffer)))

                self.global_model.load_state_dict(w_avg)
            else:
                self.global_model.load_state_dict(copy.deepcopy(self.net.state_dict()))

        if self.cfg.fl.combine != "et":
            new_labels = self._create_ensemble_labels(dl=dl, devices=devices)
        if self.cfg.fl.combine == "em":
            ensemble_acc = ((new_labels == torch.tensor(self.dset.targets)).sum() / len(new_labels) * 100).item()
            # wandb.log({"ensemble_acc": ensemble_acc}, step=self.round)
        elif self.cfg.fl.combine != "et":
            train_dl = DataLoader(
                TensorDataset(self.images, new_labels),
                batch_size=self.cfg.s.bs,
                shuffle=True,
            )

        if self.cfg.fl.combine == "et":
            new_labels, reg_en_logit_arr = self._create_ensemble_labels(dl=dl, devices=devices)
            train_dl = DataLoader(
                CustomDataset(self.images, new_labels, reg_en_logit_arr),
                batch_size=self.cfg.s.bs,
                shuffle=True,
            )
        loss_func = nn.CrossEntropyLoss()

        self.net.train()
        ep = max(1, self.cfg.s.ep - self.cfg.s.ep_decay * self.round)

        tot_class_loss = 0
        cnt = 0

        if self.cfg.fl.combine == "et":
            kl_loss_func = nn.KLDivLoss(reduction="batchmean")
            for _ in range(ep):
                for images, labels, reg_logits in train_dl:

                    # classification loss
                    images, labels, reg_logits = (
                        images.to(self.gpu),
                        labels.to(self.gpu),
                        reg_logits.to(self.gpu),
                    )
                    self.optimizer.zero_grad()
                    log_probs = self.net(images)
                    loss = loss_func(log_probs, labels)

                    # regularization term
                    if reg_logits.sum() != 0:
                        log_probs = log_probs[reg_logits.sum(dim=1) != 0]
                        reg_logits = reg_logits[reg_logits.sum(dim=1) != 0]
                        log_probs = F.softmax(log_probs, dim=1)
                        reg_labels = F.log_softmax(reg_logits, dim=1)
                        reg_loss = kl_loss_func(reg_labels, log_probs)
                        loss += self.cfg.fl.et_lambda * reg_loss

                    loss.backward()
                    self.optimizer.step()

                    tot_class_loss += loss.item()
                    cnt += 1
                if self.cfg.s.anneal:
                    self.scheduler.step()

        else:
            for _ in range(ep):
                for images, labels in train_dl:
                    if self.cfg.env.ratio_labeled:
                        images, labels = self._concat_with_labeled(images, labels)
                    images, labels = images.to(self.cfg.simul.gpu), labels.to(self.cfg.simul.gpu)
                    self.optimizer.zero_grad()
                    log_probs = self.net(images)
                    loss = torch.nan_to_num(loss_func(log_probs, labels))
                    tot_class_loss += loss.item()
                    loss.backward()
                    self.optimizer.step()
                    cnt += 1
                if self.cfg.s.anneal:
                    self.scheduler.step()
        # wandb.log({"classification_loss": tot_class_loss / cnt}, step=self.round)

        if self.cfg.gkd.is_gkd and not self.cfg.gkd.avg_first:
            self.global_model_buffer.append(copy.deepcopy(self.net.state_dict()))
            if len(self.global_model_buffer) > self.cfg.gkd.M:
                del self.global_model_buffer[0]

                w_avg = self.global_model_buffer[-1]
                for k in w_avg.keys():
                    for w in self.global_model_buffer:
                        if w_avg[k].type() == "torch.cuda.LongTensor":
                            w[k] = w[k].long()
                        w_avg[k] += w[k]
                    w_avg[k] = torch.div(w_avg[k], float(len(self.global_model_buffer)))

                self.global_model.load_state_dict(w_avg)
            else:
                self.global_model.load_state_dict(copy.deepcopy(self.net.state_dict()))

    def _concat_with_labeled(self, images: torch.tensor, labels: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
        labeled_images, labeled_labels = next(self.labeled_iter)
        if labels.ndim == 2 and labeled_labels.ndim == 1:
            # in case of EM_SOFT_SET
            labeled_labels = F.one_hot(labeled_labels, num_classes=self.cfg.dset.n_cls)
        images = torch.cat([images, labeled_images], dim=0)
        labels = torch.cat([labels, labeled_labels], dim=0)
        return images, labels

    @torch.no_grad()
    def _create_ensemble_labels(self, dl: DataLoader, devices: TrainableDevice) -> torch.tensor:
        label_arr = torch.zeros((len(self.dset), self.cfg.dset.n_cls))
        odds_arr = torch.zeros((len(self.dset), 1))
        loss_func = nn.CrossEntropyLoss(reduction="none")
        if self.cfg.fl.combine == "logit_var":
            vars_arr = torch.zeros((len(self.dset), 1))
        if self.cfg.fl.combine == "et":
            var_arr = []  # expected: torch.Size([n_device, l_dset])
            var_times_logit_arr = []  # expected: torch.Size([n_device, l_dset, 10])
            probable_label_arr = []  # expected: torch.Size([n_device, l_dset])
        true_label_arr = []

        losses_arr = []  # expected: torch.Size([n_device, l_dset])
        labels_arr = []  # expected: torch.Size([n_device, l_dset, self.cfg.dset.n_cls])

        ensemble_criterion = nn.NLLLoss()
        ensemble_loss = 0

        for c, d in enumerate(devices):
            if self.cfg.fl.combine == "true_ens":
                loss_list = []  # torch.Size([l_dset])

            if self.cfg.fl.combine == "et":
                var_list = []
                var_times_logit_list = []
                probable_label_list = []

            if self.cfg.fl.combine[:3] == "gan":
                disc_odds_arr = []
                cls_loss_arr = []

            if self.cfg.fl.combine == "logit_var":
                var_list = []
            if self.cfg.simul.gpu_on_off:
                d.net.to(self.gpu)
            d.net.eval()
            tmp_label_arr = []
            for images, true in dl:
                if c == len(devices) - 1:
                    true = true.to(self.gpu)
                    true_label_arr.append(true)
                images = images.to(self.gpu)
                logits = d.net(images).detach()  # torch.Size([bs, 10])
                if self.cfg.fl.combine == "em_entropy_soft":
                    softmax = F.softmax(logits, dim=1)  # torch.Size([bs, 10])
                    log_softmax = F.log_softmax(logits, dim=-1)  # torch.Size([bs, 10])
                    H = -(softmax * log_softmax).sum(dim=-1)  # torch.Size([bs])
                    w = torch.exp(-self.cfg.ood.entropy.temp * H)  # torch.Size([bs])
                    labels = w.unsqueeze(-1) * softmax  # torch.Size([bs, 10])
                elif self.cfg.fl.combine[:2] == "df":
                    labels = logits  # torch.Size([bs, 10])
                elif self.cfg.fl.combine == "logit_var":
                    w = torch.var(logits, dim=1)  # torch.Size([bs])
                    var_list.append(w)
                    labels = w.unsqueeze(-1) * logits  # torch.Size([bs, 10])
                elif self.cfg.fl.combine[:3] == "gan":
                    softmax = F.softmax(logits, dim=1)  # torch.Size([bs, 10])
                    disc_outputs = d.disc_net(images)  # torch.Size([bs, 1])
                    if self.cfg.fl.combine == "gan":
                        disc_odds = torch.exp(disc_outputs)

                        if self.cfg.fl.norm_odd == "std":
                            disc_odds = torch.exp((torch.log(disc_odds + 1e-12) - d.shift) / d.std)
                        elif self.cfg.fl.norm_odd == "mean":
                            disc_odds = disc_odds / d.mean
                        disc_odds_arr.append(disc_odds)
                    else:
                        disc_odds = F.sigmoid(disc_outputs)
                        disc_odds_arr.append(disc_odds)

                    if self.cfg.fl.logit_combine:
                        labels = logits * disc_odds.view(-1, 1)  # torch.Size([bs, 10])
                    else:
                        labels = softmax * disc_odds.view(-1, 1)  # torch.Size([bs, 10])

                elif self.cfg.fl.combine == "et":
                    w = torch.var(logits, dim=1)  # torch.Size([bs])
                    var_list.append(w)

                    wl = w.unsqueeze(-1) * logits  # torch.Size([bs, 10])
                    var_times_logit_list.append(wl)

                    label = torch.argmax(logits, dim=1)  # torch.Size([bs])
                    probable_label_list.append(label)

                elif self.cfg.fl.combine == "true_ens":
                    loss_list.append(loss_func(logits, true))
                if self.cfg.fl.combine != "et":
                    tmp_label_arr.append(labels)

            if self.cfg.fl.combine == "logit_var":
                var_arr = torch.cat(var_list, dim=0).to("cpu")
            if self.cfg.fl.combine == "true_ens":
                losses_arr.append(torch.cat(loss_list, dim=0).to("cpu"))  # torch.Size([l_dset])
                labels_arr.append(torch.cat(tmp_label_arr, dim=0).to("cpu"))  # torch.Size([l_dset, 10]
            if self.cfg.fl.combine[:3] == "gan":
                disc_odds_arr = torch.cat(disc_odds_arr, dim=0).to("cpu")  # .tolist()
                odds_arr += disc_odds_arr.view(-1, 1)

            if self.cfg.fl.combine == "et":
                var = torch.cat(var_list, dim=0)  # torch.Size([l_dset])
                var_times_logit = torch.cat(var_times_logit_list, dim=0)  # torch.Size([l_dset, 10])
                probable_label = torch.cat(probable_label_list, dim=0)  # torch.Size([l_dset])

                var_arr.append(var.unsqueeze(0))
                var_times_logit_arr.append(var_times_logit.unsqueeze(0))
                probable_label_arr.append(probable_label.unsqueeze(0))
            if self.cfg.fl.combine != "et":
                tmp_label_arr = torch.cat(tmp_label_arr, dim=0).to("cpu")
                label_arr += tmp_label_arr
            if self.cfg.fl.combine == "logit_var":
                vars_arr += var_arr.view(-1, 1)
            torch.cuda.empty_cache()
            if self.cfg.simul.gpu_on_off:
                d.net.to("cpu")
        if self.cfg.fl.combine == "em_entropy_soft":
            label_arr = F.normalize(label_arr, p=1, dim=1)
        elif self.cfg.fl.combine[:2] == "df":
            label_arr = F.softmax(label_arr / len(devices), dim=1)
        elif self.cfg.fl.combine[:3] == "gan":
            if self.cfg.fl.norm_odd == "one_norm":
                label_arr = label_arr / odds_arr * len(devices)
            if self.cfg.fl.logit_combine:
                label_arr = F.softmax(label_arr / len(devices), dim=1)
            else:
                label_arr = F.normalize(label_arr, p=1, dim=1)
        elif self.cfg.fl.combine == "logit_var":
            label_arr = F.softmax(label_arr / vars_arr)

        elif self.cfg.fl.combine == "et":
            var_arr = torch.cat(var_arr, dim=0)  # torch.Size([n_device, l_dset])
            var_times_logit_arr = torch.cat(var_times_logit_arr, dim=0)  # torch.Size([n_device, l_dset, 10])
            probable_label_arr = torch.cat(probable_label_arr, dim=0)  # torch.Size([n_device, l_dset])

            # ensemble logit
            var_sum = torch.sum(var_arr, dim=0).unsqueeze(1)  # torch.Size([l_dset, 1])
            var_times_logit_sum = torch.sum(var_times_logit_arr, dim=0)  # torch.Size([l_dset, 10])
            en_logit_arr = var_times_logit_sum / var_sum  # torch.Size([l_dset, 10])

            # ensemble label
            en_label_arr = torch.argmax(en_logit_arr, dim=1)  # torch.Size([l_dset, 10])

            # ensemble logit for regularization term
            reg_arr = probable_label_arr != en_label_arr  # torch.Size([n_device, l_dset])
            reg_var_sum = (var_arr * reg_arr).sum(dim=0).unsqueeze(1)  # torch.Size([l_dset, 1])
            reg_var_sum[reg_var_sum == 0] = 1
            reg_var_times_logit_sum = (var_times_logit_arr * reg_arr.unsqueeze(2)).sum(
                dim=0
            )  # torch.Size([l_dset, 10])
            reg_en_logit_arr = reg_var_times_logit_sum / reg_var_sum  # torch.Size([l_dset, 10])

            label_arr = F.softmax(en_logit_arr, dim=1)

        # calculate ensembloe loss between true label and ensemble label
        true_label_arr = torch.cat(true_label_arr, dim=0).cpu()
        if self.cfg.dset.name == self.cfg.dset_s.name:
            ensemble_loss = ensemble_criterion(torch.log(label_arr), true_label_arr)
            # wandb.log({"ensemble_loss": ensemble_loss}, step=self.round)

            acc = true_label_arr.eq(label_arr.argmax(dim=1)).sum().item() / len(true_label_arr) * 100
            # wandb.log({"ensemble_acc": acc}, step=self.round)

        if self.cfg.fl.combine == "et":
            return en_label_arr, reg_en_logit_arr

        return label_arr

    @torch.no_grad()
    def _create_augmented_dataset(self, devices: TrainableDevice) -> SubsetWithTargets:
        """Using generator and train devices, create augmented dataset."""
        dset_len = len(self.dset)
        batch_size = self.cfg.gan.bs
        images_list = []
        labels_list = []
        for i in range(dset_len // batch_size + 1):
            if i == dset_len // batch_size:
                batch_size = dset_len % batch_size
            images = self.gan_gen(torch.randn(batch_size, 128).to(self.gpu))
            labels = torch.zeros(batch_size, self.cfg.dset.n_cls).to(self.gpu)
            odds = torch.zeros(batch_size, 1).to(self.gpu)
            # tmp_labels_list = []
            for d in devices:
                d.net.eval()
                d.disc_net.eval()
                logits = d.net(images)
                sigmoid = d.disc_net(images)
                odd = sigmoid / (1 - sigmoid + 1e-12)
                odds += odd.view(-1, 1)
                labels += logits * odd.view(-1, 1)
            if self.cfg.fl.norm_odd == "one_norm":
                labels = labels / odds
            labels = F.softmax(labels, dim=1)
            images_list.append(images)
            labels_list.append(labels)
        images = torch.cat(images_list, dim=0)
        labels = torch.cat(labels_list, dim=0)
        new_dset = TensorDataset(images, labels)
        return new_dset

    @torch.no_grad()
    def val_em(self, dset, devices):
        dl = DataLoader(dset, batch_size=self.cfg.s.bs)
        label_arr = torch.zeros((len(dset), self.cfg.dset.n_cls))
        odds_arr = torch.zeros((len(dset), 1))
        vars_arr = torch.zeros((len(dset), 1))
        loss_func = nn.CrossEntropyLoss(reduction="none")

        true_label_arr = []

        ensemble_criterion = nn.NLLLoss()
        ensemble_loss = 0

        for c, d in enumerate(devices):
            # 클라이언트 c의 disc_odds와 cross entropy loss plot
            if self.cfg.fl.combine == "logit_var":
                var_list = []
            if self.cfg.fl.combine == "et":
                var_list = []
            if self.cfg.fl.combine[:3] == "gan":
                disc_odds_arr = []
                cls_loss_arr = []

            if self.cfg.simul.gpu_on_off:
                d.net.to(self.gpu)
            d.net.eval()
            tmp_label_arr = []
            for images, true in dl:
                if c == len(devices) - 1:
                    true_label_arr.append(true)
                images = images.to(self.gpu)
                logits = d.net(images).detach()  # torch.Size([bs, 10])
                if self.cfg.fl.combine == "em_entropy_soft":
                    softmax = F.softmax(logits, dim=1)  # torch.Size([bs, 10])
                    log_softmax = F.log_softmax(logits, dim=-1)  # torch.Size([bs, 10])
                    H = -(softmax * log_softmax).sum(dim=-1)  # torch.Size([bs])
                    w = torch.exp(-self.cfg.ood.entropy.temp * H)  # torch.Size([bs])
                    labels = w.unsqueeze(-1) * softmax  # torch.Size([bs, 10])
                elif self.cfg.fl.combine[:2] == "df":
                    labels = logits  # torch.Size([bs, 10])
                elif self.cfg.fl.combine == "logit_var":
                    w = torch.var(logits, dim=1)
                    labels = w.unsqueeze(-1) * logits  # torch.Size([bs, 10])
                    var_list.append(w)
                elif self.cfg.fl.combine == "et":
                    w = torch.var(logits, dim=1)
                    labels = w.unsqueeze(-1) * logits  # torch.Size([bs, 10])
                    var_list.append(w)

                elif self.cfg.fl.combine[:3] == "gan":
                    softmax = F.softmax(logits, dim=1)  # torch.Size([bs, 10])
                    disc_outputs = d.disc_net(images)  # torch.Size([bs, 1])
                    if self.cfg.fl.combine == "gan":
                        disc_odds = torch.exp(disc_outputs)

                        if self.cfg.fl.norm_odd == "std":
                            disc_odds = torch.exp((torch.log(disc_odds + 1e-12) - d.shift) / d.std)
                        elif self.cfg.fl.norm_odd == "mean":
                            disc_odds = disc_odds / d.mean
                        disc_odds_arr.append(disc_odds)
                    else:
                        disc_odds = F.sigmoid(disc_outputs)
                        disc_odds_arr.append(disc_odds)

                    if self.cfg.fl.combine == "gan":
                        disc_odds_arr.append(disc_odds)
                    else:
                        disc_odds_arr.append(F.sigmoid(disc_outputs))

                    if self.cfg.fl.logit_combine:
                        labels = logits * disc_odds.view(-1, 1)  # torch.Size([bs, 10])
                    else:
                        labels = softmax * disc_odds.view(-1, 1)  # torch.Size([bs, 10])

                tmp_label_arr.append(labels)
            if self.cfg.fl.combine == "logit_var":
                var_arr = torch.cat(var_list, dim=0).to("cpu")
                vars_arr += var_arr.view(-1, 1)
            if self.cfg.fl.combine == "et":
                var_arr = torch.cat(var_list, dim=0).to("cpu")
                vars_arr += var_arr.view(-1, 1)
            if self.cfg.fl.combine[:3] == "gan":
                disc_odds_arr = torch.cat(disc_odds_arr, dim=0).to("cpu")  # .tolist()
                odds_arr += disc_odds_arr.view(-1, 1)
            tmp_label_arr = torch.cat(tmp_label_arr, dim=0).to("cpu")
            label_arr += tmp_label_arr
            torch.cuda.empty_cache()
            if self.cfg.simul.gpu_on_off:
                d.net.to("cpu")
        if self.cfg.fl.combine == "em_entropy_soft":
            label_arr = F.normalize(label_arr, p=1, dim=1)
        elif self.cfg.fl.combine[:2] == "df":
            label_arr = F.softmax(label_arr / len(devices), dim=1)
        elif self.cfg.fl.combine == "logit_var":
            label_arr = F.softmax(label_arr / vars_arr)
        elif self.cfg.fl.combine == "et":
            label_arr = F.softmax(label_arr / vars_arr)
        elif self.cfg.fl.combine[:3] == "gan":
            if self.cfg.fl.norm_odd == "one_norm":
                label_arr = label_arr / odds_arr * len(devices)
            if self.cfg.fl.logit_combine:
                label_arr = F.softmax(label_arr / len(devices), dim=1)
            else:
                label_arr = F.normalize(label_arr, p=1, dim=1)

        # calculate ensembloe loss between true label and ensemble label
        true_label_arr = torch.cat(true_label_arr, dim=0)
        ensemble_loss = ensemble_criterion(torch.log(label_arr), true_label_arr)
        # wandb.log({"test_ensemble_loss": ensemble_loss}, step=self.round - 1)

        acc = true_label_arr.eq(label_arr.argmax(dim=1)).sum().item() / len(true_label_arr) * 100
        # wandb.log({"test_acc_s_ens": acc}, step=self.round - 1)
