from .base import BaseSynthesis
from .hooks import DeepInversionHook
from .utils import ImagePool, DataIter
from utils import *
from utils_kd import *
from tqdm import tqdm


class FEDFTGSynthesizer(BaseSynthesis):
    def __init__(self, teacher, mdl_list, student, generator, nz, num_classes, img_size, save_dir, iterations=1,
                 lr_g=1e-3, synthesis_batch_size=128, sample_batch_size=128,
                 adv=0, bn=0, oh=0, balance=0, criterion=None, transform=None,
                 normalizer=None,
                 # TODO: FP16 and distributed training
                 autocast=None, use_fp16=False, distributed=False, args=None):
        super(FEDFTGSynthesizer, self).__init__(teacher, student)

        self.mdl_list = mdl_list
        self.args = args
        assert len(img_size) == 3, "image size should be a 3-dimension tuple"
        self.img_size = img_size
        self.iterations = iterations
        self.save_dir = save_dir
        self.transform = transform

        self.nz = nz
        self.num_classes = num_classes
        if criterion is None:
            criterion = kldiv
        self.criterion = criterion
        self.normalizer = normalizer
        self.synthesis_batch_size = synthesis_batch_size
        self.sample_batch_size = sample_batch_size

        self.data_pool = ImagePool(root=self.save_dir)
        self.data_iter = None
        # scaling factors
        self.lr_g = lr_g
        self.adv = adv
        self.bn = bn
        self.oh = oh
        self.balance = balance
        # generator
        self.generator = generator.cuda().train()
        self.distributed = distributed
        self.use_fp16 = use_fp16
        self.autocast = autocast  # for FP16
        self.data_loader = None
        self.diversity_loss = DiversityLoss(metric='l2').cuda()
        self.KL = KLDiv(T=args.T)

        for m_list in self.mdl_list:
            for m in m_list.modules():
                if isinstance(m, nn.BatchNorm2d):
                    self.hooks.append(DeepInversionHook(m))

    def synthesize(self, cur_ep=None):
        ###########
        # 设置eval模式
        ###########
        self.student.eval()
        self.generator.train()
        self.teacher.eval()
        for m in self.mdl_list:
            m.eval()

        best_cost = 1e6
        best_inputs = None
        z = torch.randn(size=(self.synthesis_batch_size, self.nz)).cuda()
        z.requires_grad = True
        targets = torch.randint(low=0, high=self.num_classes, size=(self.synthesis_batch_size,))
        targets = targets.sort()[0]
        targets = targets.cuda()
        reset_model(self.generator)

        optimizer = torch.optim.Adam([{'params': self.generator.parameters()}, {'params': [z]}], self.lr_g,
                                     betas=(0.5, 0.999))
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.iterations)
        bar_iter = tqdm(range(self.iterations), file=sys.stdout)
        for it in bar_iter:
            optimizer.zero_grad()
            inputs = self.generator(z)
            inputs = self.normalizer(inputs)
            t_out, _, _ = self.teacher(inputs,
                                       torch.arange(197).repeat(inputs.shape[0], 1).to(inputs.device),
                                       torch.arange(197).repeat(inputs.shape[0], 1).to(inputs.device))

            s_out, _, _ = self.student(inputs,
                                       torch.arange(197).repeat(inputs.shape[0], 1).to(inputs.device),
                                       torch.arange(197).repeat(inputs.shape[0], 1).to(inputs.device))
            loss_div = self.diversity_loss(noises=z, layer=inputs)
            loss_oh = F.cross_entropy(t_out, targets)
            loss_md = self.KL(t_out, s_out)
            loss = self.oh * loss_oh + self.args.beta_div * loss_div - self.args.beta_md * loss_md
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters=self.generator.parameters(), max_norm=10)

            for m in self.mdl_list:
                m.zero_grad()
            optimizer.step()
            # scheduler.step()

            torch.cuda.empty_cache()

            if best_cost > loss.item():
                best_cost = loss.item()
                best_inputs = inputs.data

        self.student.train()

        if self.normalizer:
            best_inputs = self.normalizer(best_inputs, True)

        self.data_pool.add(best_inputs, batch_id=cur_ep, targets=targets, his=True)
        dst = self.data_pool.get_dataset(transform=self.transform, labeled=True)

        if self.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dst) if self.distributed else None
        else:
            train_sampler = None

        loader = torch.utils.data.DataLoader(
            dst, batch_size=self.sample_batch_size, shuffle=(train_sampler is None),
            num_workers=8, pin_memory=True, sampler=train_sampler)
        self.data_iter = DataIter(loader)
        del z, targets

    def sample(self):
        if self.args.batchonly == True and self.args.batchused == False:
            self.generator.eval()
            z = torch.randn(size=(self.sample_batch_size, self.nz)).cuda()
            images = self.normalizer(self.generator(z))
            return images
        else:
            images, labels = self.data_iter.next()
        return images, labels

    def get_data(self, labeled=True):
        datasets = self.data_pool.get_dataset(transform=self.transform, labeled=labeled)  # 获取程序运行到现在所有的图片
        self.data_loader = torch.utils.data.DataLoader(
            datasets, batch_size=self.sample_batch_size, shuffle=True,
            num_workers=8, pin_memory=True, )
        return self.data_loader


def reset_model(model):
    for m in model.modules():
        if isinstance(m, (nn.ConvTranspose2d, nn.Linear, nn.Conv2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        if isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)


class DiversityLoss(nn.Module):
    """
    Diversity loss for improving the performance.
    """
    def __init__(self, metric):
        """
        Class initializer.
        """
        super().__init__()
        self.metric = metric
        self.cosine = nn.CosineSimilarity(dim=2)

    def compute_distance(self, tensor1, tensor2, metric):
        """
        Compute the distance between two tensors.
        """
        if metric == 'l1':
            # lll = torch.abs(tensor1 - tensor2).mean(dim=(2,))
            return torch.abs(tensor1 - tensor2).mean(dim=(2,))
        elif metric == 'l2':
            return torch.pow(tensor1 - tensor2, 2).mean(dim=(2,))
        elif metric == 'cosine':
            return 1 - self.cosine(tensor1, tensor2)
        else:
            raise ValueError(metric)

    def pairwise_distance(self, tensor, how):
        """
        Compute the pairwise distances between a Tensor's rows.
        """
        n_data = tensor.size(0)
        tensor1 = tensor.expand((n_data, n_data, tensor.size(1)))
        tensor2 = tensor.unsqueeze(dim=1)
        return self.compute_distance(tensor1, tensor2, how)

    def forward(self, noises, layer, y_input=None, diversity_loss_type=None):
        """
        Forward propagation.
        """
        if len(layer.shape) > 2:
            layer = layer.view((layer.size(0), -1))
        if diversity_loss_type == 'div2':
            y_input_dist = self.pairwise_distance(y_input, how='l1')
        layer_dist = self.pairwise_distance(layer, how=self.metric)
        noise_dist = self.pairwise_distance(noises, how='l2')
        if diversity_loss_type == 'div2':
            return torch.exp(-torch.mean(noise_dist * layer_dist * torch.exp(y_input_dist)))
        else:
            return torch.exp(-torch.mean(noise_dist * layer_dist))