import collections
import math
import random
from copy import deepcopy

import numpy as np
from tqdm import tqdm
import torch
import os
import wandb
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn

from utils.models import MultiRandomCrop, RandomMixup, RandomCutmix
from torch.utils.data import DataLoader, default_collate, TensorDataset
from fedsd2c.fedsd2c_utils import *
from fedsd2c.img_mix import feat_match_comp, mse_sim, cos_sim, feat_mean_comp, feat_extract, OutputHook
from utils.logger import Logger
from collections import defaultdict
from utils import AverageMeter
from utils.fed_utils import assign_dataset, init_model
from utils.models_gan import LargeGenerator
from utils.models import ConvNet
from torchvision.models import ResNet
from diffusers import AutoencoderKL


def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()
        m.weight.requires_grad = False
        m.bias.requires_grad = False


def distance_wb(gwr, gws):
    shape = gwr.shape
    if len(shape) == 4:  # conv, out*in*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
        gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
    elif len(shape) == 3:  # layernorm, C*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2])
        gws = gws.reshape(shape[0], shape[1] * shape[2])
    elif len(shape) == 2:  # linear, out*in
        tmp = 'do nothing'
    elif len(shape) == 1:  # batchnorm/instancenorm, C; groupnorm x, bias
        gwr = gwr.reshape(1, shape[0])
        gws = gws.reshape(1, shape[0])
        return torch.tensor(0, dtype=torch.float, device=gwr.device)

    dis_weight = torch.sum(
        1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
    dis = dis_weight
    return dis


def match_loss(gw_syn, gw_real, metric="cos"):
    dis = 0

    if metric == 'ours':
        for ig in range(len(gw_real)):
            gwr = gw_real[ig]
            gws = gw_syn[ig]
            dis += distance_wb(gwr, gws)

    elif metric == 'mse':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = torch.sum((gw_syn_vec - gw_real_vec) ** 2)

    elif metric == 'cos':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (
                torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)

    else:
        exit('unknown distance function: %s' % metric)

    return dis


class FedClient(object):

    def __init__(self, args, client_id, dataset_id='MNIST'):
        """
        Client in the federated learning for FedD3
        :param client_id: Id of the client
        :param dataset_id: Dataset name for the application scenario
        """
        # Metadata
        self._id = client_id
        self._dataset_id = dataset_id
        self.args = args

        # Following private parameters are defined by dataset.
        self._image_length = -1
        self._image_width = -1
        self._image_channel = -1
        self._n_class = assign_dataset(dataset_id)[0]

        if self._dataset_id == 'MNIST':
            self._image_length = 28
            self._image_width = 28
            self._image_channel = 1

        elif self._dataset_id == 'FashionMNIST':
            self._image_length = 28
            self._image_width = 28
            self._image_channel = 1

        elif self._dataset_id == 'CIFAR10':
            self._image_length = 32
            self._image_width = 32
            self._image_channel = 3

        elif self._dataset_id == 'CIFAR100':
            self._image_length = 32
            self._image_width = 32
            self._image_channel = 3

        elif self._dataset_id == 'SVHN':
            self._image_length = 32
            self._image_width = 32
            self._image_channel = 3

        elif self._dataset_id == 'TINYIMAGENET':
            self._image_length = 64
            self._image_width = 64
            self._image_channel = 3

        elif self._dataset_id == 'Imagenette':
            self._image_length = 128
            self._image_width = 128
            self._image_channel = 3

        elif self._dataset_id == 'openImg':
            self._image_length = 256
            self._image_width = 256
            self._image_channel = 3

        else:
            print('unexpected dataset!')
            exit(0)

        # Initialize the parameters in the local client
        self._epoch = args.client_instance_n_epoch
        self._batch_size = args.client_instance_bs
        self._lr = args.client_instance_lr
        self._momentum = 0.9
        self.num_workers = 2
        self.loss_rec = []
        self.n_data = 0
        self.mixup_alpha = args.client_instance_mixup_alpha
        self.cutmix_alpha = args.client_instance_cutmix_alpha
        # self.mixaug = args.client_instance_mixup_aug

        # Local dataset
        self._train_data = None
        self._test_data = None
        self._sd_data = None

        # Local distilled dataset
        self._distill_data = {'x': [], 'y': []}
        self._rest_data = {'x': [], 'y': [], 'dist': [], 'pred': []}
        self.shuffle_distill_data = True

        # FedSD2C parameters
        self.input_size = self._image_width
        self.num_crop = self.args.fedsd2c_num_crop
        self.factor = self.args.fedsd2c_factor
        self.mipc = self.args.fedsd2c_mipc
        self.ipc = self.args.fedsd2c_ipc
        self.temperature = self.args.fedsd2c_temperature
        self.use_ld = self.args.fedsd2c_use_ld
        self.iter_mode = self.args.fedsd2c_iter_mode

        self.iterations_per_layer = self.args.fedsd2c_iteration
        self.jitter = self.args.fedsd2c_jitter
        self.sre2l_lr = self.args.fedsd2c_lr
        self.l2_scale = self.args.fedsd2c_l2_scale
        self.tv_l2 = self.args.fedsd2c_tv_l2
        self.r_bn = self.args.fedsd2c_r_bn
        self.r_c = self.args.fedsd2c_r_c
        self.first_bn_multiplier = 10.
        self.inputs_init = self.args.fedsd2c_inputs_init
        self.hard_label = self.args.fedsd2c_hard_label
        self.zero_init_scaler = self.args.fedsd2c_zero_init_scaler
        self.mask_ratio = self.args.fedsd2c_mask_ratio
        self.patch_size = self.args.fedsd2c_patch_size
        self.filling_methods = self.args.fedsd2c_filling_methods
        self.compress = self.args.fedsd2c_compress

        self.gm_iteration = self.args.fedsd2c_gm_iter
        self.gm_metric = self.args.fedsd2c_gm_metric

        self.sd_train_interval = self.args.fedsd2c_sd_trn_interval
        self.sd_alpha = self.args.fedsd2c_sd_alpha
        
        self.noise_type = self.args.fedsd2c_noise_type
        self.noise_s = self.args.fedsd2c_noise_s
        self.noise_p = self.args.fedsd2c_noise_p

        self._cls_record = None
        self._img_mean = None
        self._feat_mean = None

        self.sd_step = 0
        self.local_step = 0

        # Gan
        self.nz = 512
        if self.inputs_init == "gan":
            self.generator = LargeGenerator(nz=self.nz, ngf=64, img_size=self._image_width, nc=self._image_channel)
        else:
            # self.vae: AutoencoderKL = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
            # for p in self.vae.parameters():
            #     p.requires_grad = False
            pass
        self.normalizer = transforms.Normalize(means[self._dataset_id], stds[self._dataset_id])
        self.reset_gan = False
        self.r_adv = self.args.fedsd2c_r_adv

        # Training on GPU
        gpu = args.gpu_id
        self._device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() and gpu != -1 else "cpu")

    @property
    def img_mean(self):
        if self._img_mean is None:
            loader = DataLoader(self._train_data, batch_size=self._batch_size, shuffle=False, drop_last=False)
            self._img_mean = torch.zeros(self._image_channel, self._image_length, self._image_width)
            for step, (x, _) in enumerate(loader):
                self._img_mean += x.sum(dim=[0])
            self._img_mean /= len(self._train_data)
        return self._img_mean

    def feat_mean(self, model, pos="conv1"):
        if self._feat_mean is None:
            self._feat_mean = feat_mean_comp(model, self._train_data, self._device, pos).cpu()

        return self._feat_mean

    def load_train(self, data):
        """
        Client loads the decentralized dataset, it can be Non-IID across clients.
        :param data: Local dataset for training.
        """
        self._train_data = {}
        # self._train_data = deepcopy(data)
        self._train_data = data
        self.n_data = len(data)

    def load_test(self, data):
        """
        Client loads the test dataset.
        :param data: Dataset for testing.
        """
        self._test_data = {}
        self._test_data = deepcopy(data)

    def load_cls_record(self, cls_record):
        """
        Client loads the statistic of local label.
        :param cls_record: class number record
        """
        self._cls_record = {}
        self._cls_record = {int(k): v for k, v in cls_record.items()}

    def rec_distill_data(self, data):
        self._sd_data = data

    def train(self, model: nn.Module):
        """
        Client trains the model on local dataset
        :param model: model waited to be trained
        :return: Local updated model
        """
        model.train()
        if self.args.client_instance_freeze_bn:
            model.apply(set_bn_eval)
        elif self.args.client_instance_rbn:
            for name, param in model.named_parameters():
                if "fc" not in name:
                    param.requires_grad = False
        elif self.args.client_instance_rbn_fc:
            for name, param in model.named_parameters():
                if "fc" not in name and name != "conv1.weight":
                    param.requires_grad = False
        model.to(self._device)
        mixup_transforms = []
        collate_fn = None
        if self.args.client_instance_identity_aug:
            print("Identity aug added!")
            p = 2 / 3
        else:
            p = 1.0
        if self.mixup_alpha > 0.0:
            mixup_transforms.append(RandomMixup(self._n_class, p=p, alpha=self.mixup_alpha))
        if self.cutmix_alpha > 0.0:
            mixup_transforms.append(RandomCutmix(self._n_class, p=p, alpha=self.cutmix_alpha))
        if mixup_transforms:
            mixupcutmix = transforms.RandomChoice(mixup_transforms)

            def collate_fn(batch):
                return mixupcutmix(*default_collate(batch))
        train_loader = DataLoader(self._train_data, batch_size=self._batch_size, shuffle=True, drop_last=True,
                                  collate_fn=collate_fn)

        optimizer = torch.optim.SGD(model.parameters(), lr=self._lr, momentum=self._momentum, weight_decay=1e-4)
        # optimizer = torch.optim.Adam(self.model.parameters(), lr=self._lr, weight_decay=1e-4)
        lr_scheduler = lr_cosine_policy(self._lr, 0, self._epoch)
        loss_func = nn.CrossEntropyLoss()

        # Training process
        loss_accumulator = AverageMeter()
        pbar = tqdm(range(self._epoch))
        for epoch in pbar:
            epoch_loss = AverageMeter()
            lr_scheduler(optimizer, epoch, epoch)
            for step, (x, y) in enumerate(train_loader):
                with torch.no_grad():
                    b_x = x.to(self._device)  # Tensor on GPU
                    b_y = y.to(self._device)  # Tensor on GPU

                with torch.enable_grad():
                    output = model(b_x)
                    loss = loss_func(output, b_y)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                loss_accumulator.update(loss.data.cpu().item())
                epoch_loss.update(loss.data.cpu().item())
                if self.args.using_wandb:
                    wandb.log({
                        f"{self._id}C local_loss": loss.item(),
                        "iteration": self.local_step,
                    })
                    self.local_step += 1
            pbar.set_description('Epoch: %d' % epoch +
                                 '| Train loss: %.4f ' % epoch_loss.avg +
                                 '| lr: %.4f ' % optimizer.state_dict()['param_groups'][0]['lr'])

        return model, loss_accumulator.avg

    def test(self, model):
        """
        Server tests the model on test dataset.
        """
        test_loader = DataLoader(self._test_data, batch_size=self._batch_size, shuffle=False)
        model.to(self._device)
        accuracy_collector = 0
        for step, (x, y) in enumerate(test_loader):
            with torch.no_grad():
                b_x = x.to(self._device)  # Tensor on GPU
                b_y = y.to(self._device)  # Tensor on GPU

                test_output = model(b_x)
                pred_y = torch.max(test_output, 1)[1].to(self._device).data.squeeze()
                accuracy_collector = accuracy_collector + sum(pred_y == b_y)
        accuracy = accuracy_collector / len(self._test_data)

        return accuracy.cpu().numpy()

    def get_ipc(self, label):
        if self.use_ld:
            assert self._cls_record is not None
            return math.ceil(self.ipc * 200 / 10 * self._cls_record[label] / sum(self._cls_record.values()))
        else:
            return self.ipc

    def coreset_stage(self, model):
        model = deepcopy(model)
        model.eval()
        for p in model.parameters():
            p.requires_grad = False

        _dataset = deepcopy(self._train_data)
        _dataset.dataset = deepcopy(_dataset.dataset)
        _dataset.dataset.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize([self._image_length, self._image_width]),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # _dataset.transform = transforms.Compose([
        #     _dataset.transform,
        #     MultiRandomCrop(
        #         num_crop=self.num_crop, size=self.input_size, factor=self.factor
        #     ),
        # ])
        # self.args.fedsd2c_ipc_min = min(1.0, 40 / self.ipc * self.factor)
        # print(self.args.fedsd2c_ipc_min)
        _dataset = CategoryDataset(_dataset, mipc=self.mipc, ipc=self.ipc * self.factor, shuffle=True,
                                   seed=self.args.sys_i_seed, min_ipc_scale=self.args.fedsd2c_ipc_min)

        ret_x = []
        ret_y = []
        ret_score = {}

        mrc = MultiRandomCrop(self.num_crop, self.input_size, 1, 1)
        model.to(self._device)

        for c, (images, labels) in enumerate(_dataset):
            with torch.no_grad():
                images = mrc(images)
                ipc = self.get_ipc(labels[0].item())
                if self.args.fedmix_src == "other" or self.args.fedmix_src == "hard":
                    images, dists, rest_images, rest_dists, rest_preds = selector_coreset(
                        ipc * self.factor,
                        model,
                        images,
                        labels,
                        self.input_size,
                        device=self._device,
                        m=self.num_crop,
                        descending=self.args.descending_dist,
                        ret_all=True
                    )
                    self._rest_data['x'].extend([data.squeeze() for data in torch.split(rest_images.cpu(), 1)])
                    self._rest_data['y'].extend([labels[0].cpu().item() for _ in range(rest_images.shape[0])])
                    self._rest_data['dist'].extend([data.squeeze() for data in torch.split(rest_dists.cpu(), 1)])
                    self._rest_data['pred'].extend([data.squeeze() for data in torch.split(rest_preds.cpu(), 1)])
                else:
                    images, dists = selector_coreset(
                        ipc * self.factor,
                        model,
                        images,
                        labels,
                        self.input_size,
                        device=self._device,
                        m=self.num_crop,
                        descending=self.args.descending_dist,
                        ret_all=False
                    )
                images = mix_images(images, self.input_size, 1, images.shape[0]).cpu()
                # model.eval()
                # soft_mix_label = model(images)
                # soft_mix_label = F.softmax(soft_mix_label / self.temperature, dim=1)

            # (ipc, 3, H, W)
            # self._distill_data['x'].append(images)
            # self._distill_data['y'].append(soft_mix_label)
            ret_x.extend([data.squeeze() for data in torch.split(images.cpu(), 1)])
            ret_y.extend([labels[0].cpu().clone() for _ in range(images.shape[0])])
            # ret_y.extend([0] * int(images.shape[1]))
            ret_score[labels[0].item()] = torch.mean(dists).cpu().item()
        # ret_y = [0] * len(ret_x)
        self._distill_data['x'] = ret_x
        self._distill_data['y'] = ret_y

        return ret_x, ret_y, ret_score

    def random_stage(self, model):
        _dataset = deepcopy(self._train_data)
        _dataset.dataset = deepcopy(_dataset.dataset)
        _dataset.dataset.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize([self._image_length, self._image_width]),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        _dataset = CategoryDataset(_dataset, mipc=self.mipc, ipc=self.ipc, shuffle=True,
                                   seed=self.args.sys_i_seed, min_ipc_scale=self.args.fedsd2c_ipc_min)

        ret_x = []
        ret_y = []
        ret_score = {}

        for c, (images, labels) in enumerate(_dataset):
            with torch.no_grad():
                ipc = self.get_ipc(labels[0].item())
                indices = torch.randperm(len(images))[:ipc]
                images = images[indices]
                images = images.cpu()
                # model.eval()
                # soft_mix_label = model(images)
                # soft_mix_label = F.softmax(soft_mix_label / self.temperature, dim=1)

            # (ipc, 3, H, W)
            # self._distill_data['x'].append(images)
            # self._distill_data['y'].append(soft_mix_label)
            ret_x.extend([data.squeeze() for data in torch.split(images.cpu(), 1)])
            ret_y.extend([labels[0].cpu().clone() for _ in range(self.ipc)])
            # ret_y.extend([0] * int(images.shape[1]))
            ret_score[labels[0].item()] = 0
            self._rest_data['x'].extend([data.squeeze() for data in torch.split(images.cpu(), 1)])
            self._rest_data['y'].extend([labels[0].cpu().item() for _ in range(images.shape[0])])
            self._rest_data['dist'].extend([labels[0].cpu() for _ in range(images.shape[0])])
            self._rest_data['pred'].extend([labels[0].cpu() for _ in range(images.shape[0])])
        # ret_y = [0] * len(ret_x)
        self._distill_data['x'] = ret_x
        self._distill_data['y'] = ret_y

        return ret_x, ret_y, ret_score

    def mixup_stage(self, model):
        logger = Logger()
        logger = logger.get_logger()

        ret_x = []
        ret_y = []
        ori_img = None
        mixed_img = None
        loss_list = []
        for i in range(0, len(self._distill_data['x']), self.ipc * self.factor):
            idxs = np.random.permutation(self.ipc * self.factor).tolist()
            subset_x = torch.stack([self._distill_data['x'][i + idx] for idx in idxs])
            subset_y = torch.stack([self._distill_data['y'][i + idx] for idx in idxs])
            mixed_x = subset_x.clone()
            mixed_y = subset_y.clone()
            if self.args.fedmix_src == "other":
                print("Using other as combining src")
                corres_idxs = np.where(np.array(self._rest_data['y']) == subset_y[0].item())[0]
                subset_x = torch.stack([self._rest_data['x'][idx] for idx in corres_idxs])
                subset_y = torch.tensor([self._rest_data['y'][idx] for idx in corres_idxs])
                print(len(corres_idxs))
                print(torch.mean(subset_y.float()))
                assert "sim" in self.args.fedmix_method
            if self.args.fedmix_method == "random":
                for _ in range(self.args.fedmix_batch_size - 1):
                    idxs = torch.randperm(mixed_x.shape[0])
                    mixed_x += subset_x[idxs]
                    mixed_y += subset_y[idxs]

                mixed_x = clip(mixed_x / self.args.fedmix_batch_size, dataset=self._dataset_id)
                # mixed_y /= self.args.fedmix_batch_size
                mixed_y = subset_y
            elif self.args.fedmix_method == "mse_sim":
                score = mse_sim(model, mixed_x.to(self._device), subset_x.to(self._device)).cpu()
                sorted_idxs = torch.argsort(score, descending=False)
                for j in range(self.args.fedmix_batch_size - 1):
                    mixed_x += subset_x[sorted_idxs[:, j]]
                    mixed_y += subset_y[sorted_idxs[:, j]]
                mixed_x = clip(mixed_x / self.args.fedmix_batch_size, dataset=self._dataset_id)
                # mixed_y /= self.args.fedmix_batch_size
                mixed_y = subset_y
            elif self.args.fedmix_method == "im_cos_sim":
                score = cos_sim(subset_x, subset_x, self.img_mean).cpu()
                sorted_idxs = torch.argsort(score, descending=False)
                for j in range(self.args.fedmix_batch_size - 1):
                    mixed_x += subset_x[sorted_idxs[:, j]]
                    mixed_y += subset_y[sorted_idxs[:, j]]
                mixed_x = clip(mixed_x / self.args.fedmix_batch_size, dataset=self._dataset_id)
                # mixed_y /= self.args.fedmix_batch_size
                mixed_y = subset_y
            elif "feat_cos_sim" in self.args.fedmix_method:
                pos = self.args.fedmix_method.split("_")[0]
                feat_x = feat_extract(model, subset_x, self._device, pos).cpu()
                score = cos_sim(feat_x, feat_x, self.feat_mean(model, pos)).cpu()
                sorted_idxs = torch.argsort(score, descending=False)
                for j in range(self.args.fedmix_batch_size - 1):
                    mixed_x += subset_x[sorted_idxs[:, j]]
                    mixed_y += subset_y[sorted_idxs[:, j]]
                mixed_x = clip(mixed_x / self.args.fedmix_batch_size, dataset=self._dataset_id)
                # mixed_y /= self.args.fedmix_batch_size
                mixed_y = subset_y
            else:
                raise NotImplementedError(f"{self.args.fedmix_method}")
            comp_x, losses = feat_match_comp(model,
                                             x=mixed_x,
                                             factor=self.factor,
                                             iteration=self.iterations_per_layer,
                                             dataset=self._dataset_id,
                                             device=self._device)
            ret_x.extend([data.squeeze() for data in torch.split(comp_x, 1)])
            ret_y.extend([data.squeeze() for data in torch.split(mixed_y, 1)])
            if ori_img is None and mixed_img is None:
                ori_img = subset_x[0]
                mixed_img = comp_x[0]

            logger.info("------------idx {} / {}----------".format(i, len(self._distill_data['x'])))
            logger.info("Compressed loss avg: {}, final: {}".format(np.mean(losses), losses[-1]))
            loss_list.append(losses)
        # idxs = np.random.permutation(distilled_data_size).tolist()
        # shuffled_distill_data = {
        #     'x': [self._distill_data['x'][i] for i in idxs],
        #     'y': [self._distill_data['y'][i] for i in idxs]
        # }
        # x = torch.stack([self._distill_data['x'][i] for i in idxs])
        # y = torch.stack([self._distill_data['y'][i] for i in idxs])
        # x_rolled = x.roll(1, 0)
        # y_rolled = y.roll(1, 0)
        # ret_x = [data.squeeze() for data in torch.split((x + x_rolled) / 2, 1)]
        # ret_y = [data.squeeze() for data in torch.split((y + y_rolled) / 2, 1)]

        if self.args.using_wandb:
            wandb.log({
                f"C{self._id} original image": wandb.Image(
                    ori_img.cpu().numpy().transpose((1, 2, 0))),
                "iteration": 0,
            })
            wandb.log({
                f"C{self._id} mixup image": wandb.Image(mixed_img.cpu().numpy().transpose((1, 2, 0))),
                "iteration": 0,
            })

            loss_mean = np.array(loss_list).mean(axis=0).tolist()
            loss_std = np.array(loss_list).std(axis=0).tolist()
            for i, loss in enumerate(loss_mean):
                wandb.log({
                    f"C{self._id} comp loss avg": loss,
                    f"C{self._id} comp loss std": loss_std[i],
                    "iteration": i,
                })

        torch.cuda.empty_cache()
        return ret_x, ret_y

    def feat_mixup_stage(self, model):
        logger = Logger()
        logger = logger.get_logger()

        ret_x = []
        ori_x = []
        ret_y = []
        ret_z = []
        ori_imgs = None
        mixed_img = None
        loss_list = []
        loss_dict_list = {}
        model = deepcopy(model)
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        loss_r_feature_layers = []
        if isinstance(model, ResNet):
            loss_r_feature_layers.append(OutputHook(model.maxpool))
            for name, module in model.named_modules():
                if name in [f"layer{j}" for j in range(1, 5)]:
                    print(f"Adding hook to {name}")
                    loss_r_feature_layers.append(OutputHook(module))
            loss_r_feature_layers.append(OutputHook(model.avgpool))
        elif isinstance(model, ConvNet):
            for j in range(4):
                print(f"Adding hook to {j} pool")
                loss_r_feature_layers.append(OutputHook(model.layers["pool"][j]))
        else:
            raise NotImplementedError()
        loss_r_bn_layers = []
        if self.r_bn > 0:
            for module in model.modules():
                if isinstance(module, nn.BatchNorm2d):
                    loss_r_bn_layers.append(BNFeatureHook(module))

        synset, batch_size = self._build_synset()
        synloader = torch.utils.data.DataLoader(synset, batch_size=batch_size, shuffle=False)
        max_psnr = 0
        psnrs = 0
        cnts = 0
        bss = 0
        vae =  AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
        for p in vae.parameters():
            p.requires_grad = False
        for i, batch in enumerate(synloader):
            img1, img2, y = batch
            z = None
            if self.inputs_init == "easy":
                inputs = img1.clone().to(self._device)
            elif self.inputs_init == "hard":
                inputs = img2.clone().to(self._device)
            elif self.inputs_init == "random":
                inputs = torch.randn_like(img1).to(self._device)
            elif self.inputs_init == "mixup":
                inputs = clip(((img1.clone() + img2.clone()) / 2), dataset=self._dataset_id).to(self._device)
            elif self.inputs_init == "fourier":
                inputs = img2.clone().to(self._device)
            elif self.inputs_init == "gan":
                z = torch.randn(size=(img1.shape[0], self.nz)).to(self._device)
            elif "vae" in self.inputs_init:
                with torch.no_grad():
                    vae.to(self._device)
                    if "random" in self.inputs_init:
                        z = torch.randn_like(z).to(self._device)
                    elif "fourier" in self.inputs_init:
                        z = vae.encode(denormalize(img2).to(self._device)).latent_dist.mode().clone().detach()
                    else:
                        z = vae.encode(denormalize(img1).to(self._device)).latent_dist.mode().clone().detach()
            else:
                raise NotImplementedError()
            s = img1.shape
            targets = y.to(self._device)
            entropy_criterion = nn.CrossEntropyLoss()
            if self.inputs_init == "gan":
                if self.reset_gan:
                    reset_model(self.generator)
                self.generator.train()
                self.generator.to(self._device)
                optimizer = torch.optim.Adam(self.generator.parameters(), lr=1e-3, betas=(0.5, 0.999))
                lr_scheduler = lr_cosine_policy(1e-3, 0, self.iterations_per_layer)
            elif "vae" in self.inputs_init:
                z.requires_grad = True
                optimizer = torch.optim.AdamW([z], lr=self.sre2l_lr, betas=(0.5, 0.9), eps=1e-8)
                lr_scheduler = lr_cosine_policy(self.sre2l_lr, 0, self.iterations_per_layer)
            else:
                if self.compress:
                    downsampler = nn.AvgPool2d(int(math.sqrt(self.factor)), stride=int(math.sqrt(self.factor)))
                    inputs = downsampler(inputs).clone()
                inputs.requires_grad = True
                optimizer = torch.optim.AdamW([inputs], lr=self.sre2l_lr, betas=(0.5, 0.9), eps=1e-8)
                lr_scheduler = lr_cosine_policy(self.sre2l_lr, 0, self.iterations_per_layer)

            losses = []
            if "mixup" in self.inputs_init:
                best_inputs = inputs.data.cpu().clone()
            else:
                best_inputs = None
            best_z = None
            best_cost = 1e4
            with torch.no_grad():
                model = model.to(self._device)
                # easy_sample = img1.clone().to(self._device)
                # hard_sample = img2.clone().to(self._device)

                model(img1.clone().to(self._device))
                easy_feat_lists = [mod.r_feature.clone().detach() for mod in loss_r_feature_layers]
                model(img2.clone().to(self._device))
                hard_feat_lists = [mod.r_feature.clone().detach() for mod in loss_r_feature_layers]
            losses = []
            loss_dicts = {}
            bound_sample_num = 0
            for iteration in range(self.iterations_per_layer):
                lr_scheduler(optimizer, iteration, iteration)

                # inputs gen
                if self.inputs_init == "gan":
                    inputs = self.generator(z)
                    inputs = self.normalizer(inputs)
                elif "vae" in self.inputs_init:
                    inputs = vae.decode(z).sample
                    inputs = self.normalizer(inputs)

                aug_function = transforms.Compose([
                    transforms.RandomResizedCrop(self.input_size),
                    transforms.RandomHorizontalFlip(),
                ])

                im = img1.clone().to(self._device)
                if self.compress:
                    resized_inputs = F.interpolate(inputs, size=(s[2], s[3]), mode='bilinear')
                else:
                    resized_inputs = inputs
                # _inputs = aug_function(torch.cat([resized_inputs, im], dim=0))
                _inputs = torch.cat([resized_inputs, im], dim=0)
                im = _inputs[inputs.shape[0]:]
                with torch.no_grad():
                    model(im)
                    easy_feat_lists = [mod.r_feature.clone().detach() for mod in loss_r_feature_layers]

                # _inputs = aug_function(inputs)
                _inputs = _inputs[:inputs.shape[0]]

                outputs = model(_inputs)
                input_feat_lists = [mod.r_feature for mod in loss_r_feature_layers]
                key_words = self.args.fedmix_method.split("_")
                loss = 0
                loss_dict = {}
                for key_word in key_words:
                    cf = key_word.split("-")
                    if "gram" in cf:
                        loss_fn = gram_mse_loss
                    elif "factorization" in cf:
                        loss_fn = factorization_loss
                    else:
                        loss_fn = mse_loss
                    if "hard" in cf:
                        target_feat_lists = hard_feat_lists
                    elif "easy" in cf:
                        target_feat_lists = easy_feat_lists
                    elif "mean" in cf:
                        target_feat_lists = [(ef + hf) / 2 for ef, hf in zip(easy_feat_lists, hard_feat_lists)]
                    elif "cat" in cf:
                        target_feat_lists = [torch.cat([ef, hf], dim=-1) for ef, hf in
                                             zip(easy_feat_lists, hard_feat_lists)]
                    elif "bound" in cf:
                        with torch.no_grad():
                            hard_logits, easy_logits = model.fc(torch.flatten(hard_feat_lists[-1], 1)), model.fc(
                                torch.flatten(easy_feat_lists[-1], 1))
                            hard_pred, easy_pred = torch.argmax(hard_logits, dim=1), torch.argmax(easy_logits,
                                                                                                  dim=1)
                            bound_sample = (easy_pred == targets)
                            bound_sample_num = bound_sample.sum().item()
                            # logits - target_logits
                            alpha = (easy_logits - torch.gather(easy_logits, 1, targets[:, None])) / \
                                    (torch.gather(hard_logits, 1, targets[:, None]) - hard_logits + 1e-8)
                            alpha = torch.min(torch.where(alpha > 0, alpha, 1e8), dim=1)[0]
                            alpha = alpha * 0.95
                            avgpool_feat = (easy_feat_lists[-1] + alpha * hard_feat_lists[-1]) / (1 + alpha)
                            avgpool_feat = avgpool_feat * bound_sample + (~bound_sample) * easy_feat_lists[-1]
                            target_feat_lists = easy_feat_lists[:-1] + [avgpool_feat]
                    else:
                        raise NotImplementedError()
                    if "conv" in cf:
                        loss_feat_conv = loss_fn(input_feat_lists[0], target_feat_lists[0], reduction="mean")
                        loss += loss_feat_conv
                        loss_dict["feat_conv"] = loss_feat_conv.item()
                    if "avgpool" in cf:
                        loss_feat_pool = loss_fn(input_feat_lists[-1], target_feat_lists[-1], reduction="mean")
                        loss += loss_feat_pool
                        loss_dict["feat_pool"] = loss_feat_pool.item()
                    for j in range(1, 5):
                        if f"layer{j}" in cf:
                            loss_feat_layer = loss_fn(input_feat_lists[j], target_feat_lists[j], reduction="mean")
                            loss += loss_feat_layer
                            loss_dict[f"feat_layer{j}"] = loss_feat_layer.item()
                if self.r_bn > 0:
                    rescale = [self.first_bn_multiplier] + [1. for _ in range(len(loss_r_bn_layers) - 1)]
                    loss_r_bn = sum(
                        [mod.r_feature * rescale[idx] for (idx, mod) in enumerate(loss_r_bn_layers)])
                    loss += self.r_bn * loss_r_bn
                    loss_dict["r_bn"] = loss_r_bn.item()
                if self.r_c > 0:
                    loss_r_c = entropy_criterion(outputs, targets)
                    loss += self.r_c * loss_r_c
                    loss_dict["r_ce"] = loss_r_c.item()
                if self.r_adv > 0:
                    loss_r_adv = -mse_loss(resized_inputs, img1.clone().to(self._device), reduction="mean")
                    loss += self.r_adv * loss_r_adv
                    loss_dict["r_adv"] = loss_r_adv.item()
                assert loss != 0

                if best_cost > loss.item() or iteration >= 0:
                    best_inputs = resized_inputs.data.cpu().clone()
                    if z is not None:
                        best_z = z.data.detach().cpu().clone()
                    best_cost = loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                inputs.data = clip(inputs.data, dataset=self._dataset_id)
                losses.append(loss.item())
                for k, v in loss_dict.items():
                    if k not in loss_dicts:
                        loss_dicts[k] = [v]
                    else:
                        loss_dicts[k].append(v)
            if len(losses) == 0:
                losses = [0]

            mixed_y = y.clone()
            ret_x.extend([data.squeeze() for data in torch.split(best_inputs, 1)])
            ori_x.extend([data.squeeze() for data in torch.split(img1.data.cpu().clone(), 1)])
            ret_y.extend([data.squeeze() for data in torch.split(mixed_y, 1)])
            if "vae" in self.inputs_init:
                ret_z.extend([data.squeeze() for data in torch.split(best_z, 1)])
            if ori_imgs is None and mixed_img is None:
                ori_imgs = [img1[0], img2[0]]
                mixed_img = best_inputs[0]

            psnr = 10 * torch.log10(
                1 / (denormalize(best_inputs.cpu().clone()) - denormalize(img1.cpu().clone())).pow(2).mean(
                    dim=[-3, -2, -1]))
            psnrs += psnr.sum().item()
            cnts += psnr.numel()
            psnr = psnr.max()
            logger.info("------------idx {} / {}----------".format(i * batch_size, len(self._distill_data['x'])))
            logger.info("loss avg: {}, final: {}, ".format(np.mean(losses), losses[-1]) + ", ".join(
                [f"{k}: {v[-1]}" for k, v in loss_dicts.items()]))
            logger.info("max psnr: {} ".format(psnr.item()))
            bss += bound_sample_num
            logger.info("bound sample: {} sum: {}".format(bound_sample_num, bss))
            max_psnr = max(max_psnr, psnr.item())
            loss_list.append(losses)
            for k, v in loss_dicts.items():
                if k not in loss_dict_list:
                    loss_dict_list[k] = [v]
                else:
                    loss_dict_list[k].append(v)

        if self.args.using_wandb:
            for j in range(len(ori_imgs)):
                wandb.log({
                    f"C{self._id} original image {j}": wandb.Image(
                        ori_imgs[j].cpu().numpy().transpose((1, 2, 0))),
                    "iteration": 0,
                })
            wandb.log({
                f"C{self._id} mixup image": wandb.Image(mixed_img.cpu().numpy().transpose((1, 2, 0))),
                "iteration": 0,
            })

            loss_mean = np.array(loss_list).mean(axis=0).tolist()
            loss_std = np.array(loss_list).std(axis=0).tolist()
            for i, loss in enumerate(loss_mean):
                wandb.log({
                    f"C{self._id} comp loss avg": loss,
                    f"C{self._id} comp loss std": loss_std[i],
                    "iteration": i,
                })
            for k, v in loss_dict_list.items():
                lm = np.array(v).mean(axis=0).tolist()
                ls = np.array(v).std(axis=0).tolist()
                for i, loss in enumerate(lm):
                    wandb.log({
                        f"C{self._id} {k} avg": loss,
                        f"C{self._id} {k} std": ls[i],
                        "iteration": i,
                    })
            wandb.log({
                "max psnr": max_psnr,
                "mean psnr": psnrs / cnts,
                "client_id": self._id
            })
        if self.args.fedsd2c_store_images:
            dir_path = os.path.join(self.args.sys_res_root, self.args.save_name)
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)
            path = os.path.join(dir_path, f"client{self._id}_")
            torch.save(torch.stack(ret_x), path + "images.pt")
            torch.save(torch.stack(ret_y), path + "labels.pt")
            torch.save(torch.stack(ori_x), path + "ori_images.pt")
            if "vae" in self.inputs_init:
                torch.save(torch.stack(ret_z), path + "latents.pt")
        del vae
        torch.cuda.empty_cache()
        return ret_x, ret_y

    def decode_latents(self, latents):
        vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
        for p in vae.parameters():
            p.requires_grad = False
        vae.eval()
        vae.to(self._device)
        bs = self.ipc * self.factor
        samples = []
        rng = np.random.default_rng(self.args.sys_i_seed)
        with torch.no_grad():
            for kk in range(0, len(latents), bs):
                z = latents[kk:kk+bs].to(self._device)
                if self.noise_type == "gaussian":
                    noise = torch.tensor(rng.normal(size=z.numel()), dtype=z.dtype).reshape(z.shape).to(self._device) * self.noise_s
                    z = (1 - self.noise_p) * z + noise
                elif self.noise_type == "laplace":
                    noise = torch.tensor(rng.laplace(size=z.numel()), dtype=z.dtype).reshape(z.shape).to(self._device) * self.noise_s
                    z = (1 - self.noise_p) * z + noise
                elif self.noise_type == "None":
                    pass
                else:
                    raise NotImplementedError()
                sample = vae.decode(z).sample.detach().clone().cpu()
                sample = self.normalizer(sample)
                samples.extend([data.squeeze() for data in torch.split(sample, 1)])
        
        return samples

    @property
    def all_select(self):
        """
        The client uploads all of the original dataset
        :return: All of the original images
        """
        return self._train_data

    def save_distilled_dataset(self, exp_dir='client_models', res_root='results'):
        """
        The client saves the distilled images in corresponding directory
        :param exp_dir: Experiment directory name
        :param res_root: Result directory root for saving the result files
        """
        agent_name = 'clients'
        model_save_dir = os.path.join(res_root, exp_dir, agent_name)
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        torch.save(self._distill_data, os.path.join(model_save_dir, self._id + '_distilled_img.pt'))

    def _build_synset(self):
        assert self.args.fedmix_src in ["other", "hard"]
        dx1, dx2, dy = [], [], []
        for i in range(0, len(self._distill_data['x']), self.ipc * self.factor):
            idxs = np.random.permutation(self.ipc * self.factor).tolist()
            subset_x = torch.stack([self._distill_data['x'][i + idx] for idx in idxs])
            subset_y = torch.stack([self._distill_data['y'][i + idx] for idx in idxs])

            corres_idxs = np.where(np.array(self._rest_data['y']) == subset_y[0].item())[0]
            rest_x = torch.stack([self._rest_data['x'][idx] for idx in corres_idxs])
            rest_dists = torch.stack([self._rest_data['dist'][idx] for idx in corres_idxs])
            rest_preds = torch.stack([self._rest_data['pred'][idx] for idx in corres_idxs])

            indices = np.where(torch.argmax(rest_preds).numpy() == subset_y[0].item())[0]
            if indices.shape[0] != 0:
                rest_x, rest_dists = rest_x[indices], rest_dists[indices]
            indices = torch.argsort(rest_dists, descending=(not self.args.descending_dist))[:subset_x.shape[0]]
            if indices.shape[0] < subset_x.shape[0]:
                indices = indices.repeat((subset_x.shape[0] // indices.shape[0]) + 1)[:subset_x.shape[0]]
            rest_x = rest_x[indices]

            dx1.append(subset_x)
            dx2.append(rest_x)
            dy.append(subset_y)
        dx1 = torch.stack(dx1, dim=0)
        dx2 = torch.stack(dx2, dim=0)
        dy = torch.stack(dy, dim=0)

        if self.iter_mode == "random" or self.iter_mode == "label":
            bs = self.ipc * self.factor
        elif self.iter_mode == "ipc":
            bs = dx1.shape[0]
        
        print(self.args.fourier_lambda)

        return SynDataset(dx1, dx2, dy, self.iter_mode, fourier= "fourier" in self.inputs_init, fourier_lambda=self.args.fourier_lambda,
                          fourier_src=self.args.fourier_src, dataset=self._dataset_id), bs
