import os
from itertools import chain
import random
import sys
import math

import torch
from torch import nn, softmax
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data.sampler as sp
import torch.distributions as d
import torch.nn.functional as F
import numpy as np
from advertorch.attacks import *
from torchvision import transforms

import classifier
import data
import loss
import kd_loss
from simclr.modules import NT_Xent
from simclr.modules.identity import Identity
from simclr import SimCLR

sys.path.append('..')
import setup
from data import *
from utils import *
# import loss
import time


class SubstituteTrainerSeeker:
    def __init__(self, opt,
                 victim, substitute, data_gen,
                 sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                 eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                 source, x_list, y_list, labels,
                 labeled_bs=200, #100
                 #unlabeled_bs=200, cls_div_bs=200,#50,100
                 # simclr_bs=300,
                 strategy='every', loop=None, n_epochs=10,
                 save=False, load=False):
        self.opt = opt
        self.victim = victim
        self.substitute = substitute
        # self.substitute_projected = substitute_projected
        self.data_gen = data_gen

        self.sub_dataset = sub_dataset
        self.new_sub_dataset = new_sub_dataset
        self.next_dataset = next_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.val_dataset = SubDataset()
        self.eval_dataset = eval_dataset
        self.surrogate_eval_dataset = surrogate_eval_dataset
        self.simclr_dataset = simclr_dataset
        self.aug_dataset = aug_dataset
        self.div_dataset = div_dataset

        self.source = source
        self.labeled_bs = labeled_bs
        self.unlabeled_bs = labeled_bs*2 #unlabeled_bs
        self.cls_div_bs = labeled_bs*2 #cls_div_bs
        self.simclr_bs = labeled_bs*2 #simclr_bs
        self.strategy = strategy
        self.loop = loop
        self.n_epochs = n_epochs
        self.gen_epoch = opt.gen_epoch
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.augmentation = transforms.Compose(
            [
                # transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                # transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        # split validation dataset
        val_size = int(self.opt.epoch_val_rate * len(self.new_sub_dataset))
        self.val_dataset.items = self.new_sub_dataset.items[:val_size]
        self.new_sub_dataset.items = self.new_sub_dataset.items[val_size:]
        # plot parameters
        self.x_list = x_list
        self.y_list = y_list
        self.labels = labels

        self.save = save
        # load substitute model
        if load:
            save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
                                    f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}',
                                    f'{self.opt.pre_train_sub}_{self.source}')
            save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.opt.sub_eval_loop}'
            save_path = os.path.join(save_dir, save_name)
            self.substitute.load_state_dict(torch.load(save_path))


    def train(self):
        print(f'Size of query dataset: {len(self.sub_dataset)}')
        print(f'Size of new query dataset: {len(self.new_sub_dataset)}')
        print(f'Size of diversity dataset: {len(self.div_dataset)}')
        # preparation
        if self.opt.use_gpu:
            self.substitute.cuda()
            self.data_gen.cuda()
            # if self.source == 'fusiongan':
            #     self.substitute_projected.projector.cuda()
        if len(self.sub_dataset) > 0:
            weighted_sampler = self.get_weighted_sampler()
            dataloader = torch.utils.data.DataLoader(
                self.sub_dataset,
                batch_size=self.labeled_bs, #100
                # shuffle = True,
                num_workers=4,
                sampler=weighted_sampler
            )
            all_dataloader = torch.utils.data.DataLoader(
                self.sub_dataset,
                batch_size=self.labeled_bs, #50
                shuffle=True,
                num_workers=4
            )
        new_dataloader = torch.utils.data.DataLoader(
            self.new_sub_dataset,
            batch_size=300, #100
            shuffle=True,
            num_workers=4
        )
        aug_dataloader = torch.utils.data.DataLoader(
            self.aug_dataset,
            batch_size=self.unlabeled_bs,
            shuffle=True,
            num_workers=4,
            drop_last=True,
        )
        unlabeled_dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset,
            batch_size=self.unlabeled_bs,
            shuffle=True,
            num_workers=4,
            drop_last=True
        )
        fuse_unlabeled_dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset,
            batch_size=self.cls_div_bs * self.opt.n_fuse,
            shuffle=True,
            num_workers=4,
            drop_last=True
        )
        simclr_dataloader = torch.utils.data.DataLoader(
            self.simclr_dataset,
            batch_size=self.simclr_bs,
            shuffle=True,
            num_workers=4,
            drop_last=True,
        )
        if self.opt.div_epoch > 0:
            div_dataloader = torch.utils.data.DataLoader(
                self.div_dataset,
                batch_size=self.labeled_bs,
                shuffle=True,
                num_workers=4,
                drop_last=True,
            )
        if len(self.next_dataset) > 0:
            next_dataloader = torch.utils.data.DataLoader(
                self.next_dataset,
                batch_size=50,
                shuffle=True,
                num_workers=4
            )

        if self.opt.sub_optim == 'adam':
            substitute_optimizer = torch.optim.Adam(
                self.substitute.parameters(),
                lr=self.opt.sub_lr,
                # weight_decay=self.opt.sub_weight_decay
            )
        else:
            substitute_optimizer = torch.optim.SGD(
                self.substitute.parameters(),
                lr=self.opt.sub_lr,
                momentum=0.9,
                weight_decay=self.opt.sub_weight_decay
            )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(substitute_optimizer, T_max=50)
        # scheduler = torch.optim.lr_scheduler.StepLR(substitute_optimizer, 40, gamma=0.1)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(substitute_optimizer, milestones=[10, 30, 40], gamma=0.2)

        # simclr_optimizer = torch.optim.Adam(simclr_substitute.parameters(), lr=self.opt.online_simclr_lr)
        data_gen_optimizer = torch.optim.Adam(
            self.data_gen.parameters(),
            lr=self.opt.gen_lr
        )
        # data_gen_next_optimizer = torch.optim.Adam(
        #     self.data_gen.parameters(),
        #     lr=self.opt.gen_next_lr
        # )
        self.substitute.train()
        self.data_gen.train()
        # if self.source == 'fusiongan':
        #     self.substitute_projected.projector.train()

        ckpt_dir = f'{self.opt.data_dir}checkpoints/'
        writer = SummaryWriter(f'runs/{self.opt.caption}')
        sub_ckpt_path = f'{ckpt_dir}substitute_{self.strategy}'
        gen_ckpt_path = f'{ckpt_dir}data_gen_{self.strategy}'

        acc, fidelity, kd_loss = self.evaluate()
        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
        print(f'[start] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise}')

        train_epoch = self.n_epochs


        # update D and G in same loop

        if self.source == 'fusiongan':
            if self.opt.continue_train and self.loop == self.opt.n_loop - 1:
                train_epoch = self.opt.pseudo_epoch

            # define loss functions
            l1_loss = nn.L1Loss()
            mse_loss = nn.MSELoss()
            cos_loss = nn.CosineEmbeddingLoss()
            kldiv_loss = nn.KLDivLoss(reduction='mean')
            ce_loss = nn.CrossEntropyLoss()
            div_func = loss.DivLoss()
            cw_loss_func = loss.CWLoss()

            softmax = nn.Softmax(dim=1)

            min_loss = 1000.
            min_loss_epoch = 0
            min_div = 1000.
            min_div_epoch = 0

            unlabeled_data_iter = iter(unlabeled_dataloader)
            simclr_data_iter = iter(simclr_dataloader)
            aug_data_iter = iter(aug_dataloader)

            # query data - victim prob
            # for i, (_, _, victim_prob, _) in enumerate(all_dataloader):
            #     if self.opt.use_gpu:
            #         victim_prob = victim_prob.cuda()
            #     if i == 0:
            #         query_vic_prob = victim_prob
            #         label_dist = nn.functional.one_hot(victim_prob.max(1)[1],
            #                                            num_classes=self.opt.victim_n_classes).float()
            #     else:
            #         query_vic_prob += victim_prob
            #         label_dist += nn.functional.one_hot(victim_prob.max(1)[1],
            #                                            num_classes=self.opt.victim_n_classes).float()
            # # da /= len(self.sub_dataset)
            # query_vic_prob /= (len(self.sub_dataset) // self.labeled_bs)
            # query_vic_prob = query_vic_prob.mean(dim=0)
            # label_dist /= (len(self.sub_dataset) // self.labeled_bs)
            # label_dist = label_dist.mean(dim=0)
            # print(f'query data-victim prob-avg:{query_vic_prob}')
            # print(f'query data-victim label dist-avg:{label_dist}')


            # calculate avg. prob. for public dataset with victim
            # for i, (unlabeled_inputs, _, _, _) in enumerate(unlabeled_dataloader):
            #     if self.opt.use_gpu:
            #         unlabeled_inputs = unlabeled_inputs.cuda()
            #     with torch.no_grad():
            #         unlabeled_probs = softmax(self.victim(unlabeled_inputs))
            #     if i == 0:
            #         da = unlabeled_probs
            #         label_dist = nn.functional.one_hot(unlabeled_probs.max(1)[1],
            #                                            num_classes=self.opt.victim_n_classes).float()
            #     else:
            #         da += unlabeled_probs
            #         label_dist += nn.functional.one_hot(unlabeled_probs.max(1)[1],
            #                                            num_classes=self.opt.victim_n_classes).float()
            #     # print(da, label_dist)
            # # da /= len(self.sub_dataset)
            # da /= (len(self.unlabeled_dataset) // self.unlabeled_bs)
            # da = da.mean(dim=0)
            # label_dist /= (len(self.unlabeled_dataset) // self.unlabeled_bs)
            # label_dist = label_dist.mean(dim=0)
            # print(f'public data-victim prob-avg:{da}')
            # print(f'public data-victim label dist-avg:{label_dist}')

            # public data - substitute prob.
            # for i, (unlabeled_inputs, _, _, _) in enumerate(unlabeled_dataloader):
            #     if self.opt.use_gpu:
            #         unlabeled_inputs = unlabeled_inputs.cuda()
            #     with torch.no_grad():
            #         unlabeled_probs = softmax(self.substitute(unlabeled_inputs))
            #     if i == 0:
            #         da = unlabeled_probs
            #         label_dist = nn.functional.one_hot(unlabeled_probs.max(1)[1],
            #                                            num_classes=self.opt.victim_n_classes).float()
            #     else:
            #         da += unlabeled_probs
            #         label_dist += nn.functional.one_hot(unlabeled_probs.max(1)[1],
            #                                             num_classes=self.opt.victim_n_classes).float()
            #
            # # da /= len(self.sub_dataset)
            # da /= (len(self.unlabeled_dataset) // self.unlabeled_bs)
            # da = da.mean(dim=0)
            # label_dist /= (len(self.unlabeled_dataset) // self.unlabeled_bs)
            # label_dist = label_dist.mean(dim=0)
            # print(f'public data-sub prob-avg:{da}')
            # print(f'public data-sub label dist-avg:{label_dist}')


            '''
                new version
            '''
            # train substitute with all data
            for epoch in range(train_epoch):
                substitute_loss_data = 0.
                self.substitute.train()
                for _, (_, perturbed_img, victim_prob, _) in enumerate(dataloader):  # all_dataloader dataloader new_dataloader
                    if self.opt.use_gpu:
                        perturbed_img = perturbed_img.cuda()
                        victim_prob = victim_prob.cuda()
                    log_softmax = nn.LogSoftmax(dim=1)
                    softmax = nn.Softmax(dim=1)

                    # calculate labeled loss (substitute training in the s-Train process)
                    perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                        self.augmentation(perturbed_img))
                    perturbed_sub_prob = log_softmax(perturbed_sub_output)
                    if self.opt.victim_return_type == 'label':
                        sub_same_loss = ce_loss(perturbed_sub_output, victim_prob)
                    else:
                        sub_same_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
                    
                    labeled_loss = sub_same_loss

                    # online unsupervised loss (the U-Train process)
                    if self.opt.pseudo_label_weight == 0:
                        substitute_loss = labeled_loss
                    else:
                        try:
                            (w_inputs, s_inputs), _ = aug_data_iter.next()
                        except:
                            aug_data_iter = iter(aug_dataloader)
                            (w_inputs, s_inputs), _ = aug_data_iter.next()
                        if self.opt.use_gpu:
                            w_inputs = w_inputs.cuda()
                            s_inputs = s_inputs.cuda()

                        # projector version
                        # with torch.no_grad():
                        #     unlabeled_targets = self.substitute_projected(w_inputs)
                        # # unlabeled_targets = self.substitute_projected(w_inputs)
                        # unlabeled_outputs = self.substitute_projected(s_inputs)
                        # # max_probs = unlabeled_probs.max(1)[0]
                        # # mask = max_probs.ge(0.5).float().view([unlabeled_outputs.shape[0],1])
                        # pseudo_label_loss = mse_loss(unlabeled_outputs, unlabeled_targets)

                        # no projector version
                        with torch.no_grad():
                            unlabeled_targets = softmax(self.substitute(w_inputs))
                        unlabeled_outputs = self.substitute(s_inputs)
                        unlabeled_probs = softmax(unlabeled_outputs)
                        # max_probs = unlabeled_probs.max(1)[0]
                        # mask = max_probs.ge(0.5).float().view([unlabeled_outputs.shape[0],1])
                        pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)

                        # print(unlabeled_outputs.shape, unlabeled_targets.shape, mask.shape)
                        # pseudo_label_loss = (F.mse_loss(unlabeled_outputs, unlabeled_targets, reduction='none')*mask).mean()
                        # pseudo_label_loss = kldiv_loss(log_softmax(unlabeled_outputs), unlabeled_targets)

                        # # get contrastive loss
                        # try:
                        #     (simclr_x_i, simclr_x_j), _ = simclr_data_iter.next()
                        # except:
                        #     simclr_data_iter = iter(simclr_dataloader)
                        #     (simclr_x_i, simclr_x_j), _ = simclr_data_iter.next()
                        # if self.opt.use_gpu:
                        #     simclr_x_i = simclr_x_i.cuda()
                        #     simclr_x_j = simclr_x_j.cuda()
                        # substitute_fc = self.substitute.fc
                        # self.substitute.fc = Identity()
                        # simclr_h_i = self.substitute(simclr_x_i)
                        # simclr_h_j = self.substitute(simclr_x_j)
                        # simclr_criterion = NT_Xent(self.simclr_bs,
                        #                     self.opt.temperature, self.opt.world_size)
                        # simclr_loss = simclr_criterion(simclr_h_i, simclr_h_j)
                        # self.substitute.fc = substitute_fc

                        # optimize combined loss
                        labeled_loss_value = labeled_loss.detach()
                        pseudo_label_loss_value = pseudo_label_loss.detach()
                        # simclr_loss_value = simclr_loss.detach()
                        # substitute_loss = labeled_loss
                        # substitute_loss = labeled_loss + \
                        #                   simclr_loss*(labeled_loss_value/simclr_loss_value)*self.opt.simclr_weight
                        substitute_loss = labeled_loss + \
                                            pseudo_label_loss*(labeled_loss_value/pseudo_label_loss_value)*self.opt.pseudo_label_weight

                    substitute_loss_data += substitute_loss.item()
                    self.substitute.zero_grad()
                    substitute_loss.backward()
                    substitute_optimizer.step()

                scheduler.step() #fusiongan

                # # simclr optimization
                # for _, ((simclr_x_i, simclr_x_j), _) in enumerate(simclr_dataloader):
                #     simclr_optimizer.zero_grad()
                #     if self.opt.use_gpu:
                #         simclr_x_i = simclr_x_i.cuda()
                #         simclr_x_j = simclr_x_j.cuda()
                #     simclr_h_i = simclr_substitute(simclr_x_i)
                #     simclr_h_j = simclr_substitute(simclr_x_j)
                #     simclr_criterion = NT_Xent(self.simclr_bs,
                #                         self.opt.temperature, self.opt.world_size)
                #     simclr_loss = simclr_criterion(simclr_h_i, simclr_h_j)
                #     simclr_loss.backward()
                #     simclr_optimizer.step()

                substitute_loss_data /= len(self.sub_dataset)
                if substitute_loss_data < min_loss:
                    min_loss = substitute_loss_data
                    min_loss_epoch = epoch + 1
                elif self.opt.loss_rise_epoch:
                    if epoch+1-min_loss_epoch >= self.opt.loss_rise_epoch:
                        break

                # unlabeled_sub_prob = self.get_unlabeled_sub_prob(2000)
                # query_vic_public_sub_div = kldiv_loss(unlabeled_sub_prob.log(), query_vic_prob)
                # if query_vic_public_sub_div < min_div:
                #     min_div = query_vic_public_sub_div
                #     min_div_epoch = epoch + 1
                elif self.opt.div_rise_epoch:
                    if epoch+1 >= 10 and epoch+1-min_div_epoch >= self.opt.div_rise_epoch:
                        break
                # change
                if epoch == train_epoch-1:
                    break
                # print(f'dist. for unlabeled sub. prob. and labeled victim prob.: {kldiv_loss(unlabeled_sub_prob.log(), da)}')

                # # training with unlabeled data
                # for _, (unlabeled_inputs, _, _, _) in enumerate(unlabeled_dataloader):
                #     if self.opt.use_gpu:
                #         unlabeled_inputs = unlabeled_inputs.cuda()
                #     unlabeled_targets, pseudo_labels, mask = self.get_pseudo_label(unlabeled_inputs)
                #     unlabeled_outputs = self.substitute(unlabeled_inputs)
                #     unlabeled_loss = mse_loss(unlabeled_outputs,unlabeled_targets)
                #     self.substitute.zero_grad()
                #     unlabeled_loss.backward()
                #     substitute_optimizer.step()

                print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data}')
                # print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data} | pseudo label loss {pseudo_label_loss}')

                if (epoch + 1) % self.opt.print_freq == 0:
                    # query data - substitute prob.
                    # for i, (_, query_img, _, _) in enumerate(all_dataloader):
                    #     if self.opt.use_gpu:
                    #         query_img = query_img.cuda()
                    #     with torch.no_grad():
                    #         sub_prob = softmax(self.substitute(query_img))
                    #     if i == 0:
                    #         da = sub_prob
                    #     else:
                    #         da += sub_prob
                    # # da /= len(self.sub_dataset)
                    # da /= (len(self.sub_dataset) // self.labeled_bs)
                    # da = da.mean(dim=0)
                    # print(f'query data-sub prob-avg:{da}')
                    #
                    # # public data - substitute prob.
                    # for i, (unlabeled_inputs, _, _, _) in enumerate(unlabeled_dataloader):
                    #     if self.opt.use_gpu:
                    #         unlabeled_inputs = unlabeled_inputs.cuda()
                    #     with torch.no_grad():
                    #         unlabeled_probs = softmax(self.substitute(unlabeled_inputs))
                    #     if i == 0:
                    #         da = unlabeled_probs
                    #         label_dist = nn.functional.one_hot(unlabeled_probs.max(1)[1],
                    #                                    num_classes=self.opt.victim_n_classes).float()
                    #     else:
                    #         da += unlabeled_probs
                    #         label_dist += nn.functional.one_hot(unlabeled_probs.max(1)[1],
                    #                                    num_classes=self.opt.victim_n_classes).float()
                    #
                    # # da /= len(self.sub_dataset)
                    # da /= (len(self.unlabeled_dataset) // self.unlabeled_bs)
                    # da = da.mean(dim=0)
                    # label_dist /= (len(self.unlabeled_dataset) // self.unlabeled_bs)
                    # label_dist = label_dist.mean(dim=0)
                    # print(f'public data-sub prob-avg:{da}')
                    # print(f'public data-sub label dist-avg:{label_dist}')

                    acc, fidelity, kd_loss = self.evaluate()
                    asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                    print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                            f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                            f'| L2 noise {avg_l2_noise}')
                    self.x_list[0].append(self.loop * self.n_epochs + epoch)
                    self.x_list[1].append(self.loop * self.n_epochs + epoch)
                    self.x_list[2].append(self.loop * self.n_epochs + epoch)
                    self.x_list[3].append(self.loop * self.n_epochs + epoch)
                    self.x_list[4].append(self.loop * self.n_epochs + epoch)
                    self.y_list[0].append(avg_l2_noise / 4)
                    self.y_list[1].append(asr / 100)
                    self.y_list[2].append(acc)
                    self.y_list[3].append(fidelity)
                    self.y_list[4].append(kd_loss)

            # acc, fidelity, kd_loss = self.evaluate()
            # asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
            # print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
            #       f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
            #       f'| L2 noise {avg_l2_noise}')
            # # save substitute model
            # if self.save:
            #     save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
            #                             f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}')
            #     os.makedirs(save_dir, exist_ok=True)
            #     save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.loop+1}'
            #     save_path = os.path.join(save_dir, save_name)
            #     torch.save(self.substitute.state_dict(), save_path)
            #     print(f'[substitute] Saved substitute model in loop {self.loop+1}')

            # get next batch data with different labels (generator input sampling based on pseudo labeling)
            next_diff_dataset = self.get_next_diff_dataset()
            next_diff_dataloader = torch.utils.data.DataLoader(
                next_diff_dataset,
                batch_size=self.labeled_bs,
                shuffle=True,
                num_workers=4
            )

            torch.cuda.empty_cache()

            # train gan with new data only (generator training in the S-Train process)
            if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                for epoch in range(self.gen_epoch):
                    self.data_gen.train()
                    # current batch
                    # unlabeled_data_iter = iter(fuse_unlabeled_dataloader)
                    for _, (clean_imgs, perturbed_img, victim_prob, _) in enumerate(new_dataloader):
                        clean_imgs = tuple([img.cuda() for img in clean_imgs])
                        # next_img = tuple([img.cuda() for img in next_img])
                        if self.opt.use_gpu:
                            perturbed_img = perturbed_img.cuda()
                            victim_prob = victim_prob.cuda()

                        log_softmax = nn.LogSoftmax(dim=1)
                        softmax = nn.Softmax(dim=1)
                        # perturbed_sub_output, perturbed_fake_output = nn.parallel.data_parallel(self.substitute,
                        #                                                                         self.augmentation(
                        #                                                                             perturbed_img))
                        perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                            self.augmentation(perturbed_img))
                        perturbed_sub_prob = log_softmax(perturbed_sub_output)
                        new_perturbed_img = self.data_gen(clean_imgs)
                        # new_perturbed_img = nn.parallel.data_parallel(self.data_gen, clean_imgs)
                        # new_perturbed_sub_output, new_perturbed_sub_fake_output = nn.parallel.data_parallel(
                        #     self.substitute, new_perturbed_img)
                        new_perturbed_sub_output = nn.parallel.data_parallel(self.substitute, new_perturbed_img)
                        new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)
                        # print(new_perturbed_sub_prob.size())

                        # train generator with current batch data
                        noise_loss = 0.0
                        for i in range(self.opt.n_fuse):
                            noise_loss += mse_loss(new_perturbed_img, clean_imgs[i])
                        # adv_loss = torch.exp(-1 * mse_loss(new_perturbed_sub_prob, victim_prob)) # not log softmax
                        if self.opt.victim_return_type == 'label':
                            adv_loss = torch.exp(-1 * ce_loss(new_perturbed_sub_output, victim_prob))
                        else:
                            adv_loss = torch.exp(-1 * kldiv_loss(new_perturbed_sub_prob, victim_prob))
                        # new_perturbed_argmax = new_perturbed_sub_prob.argmax(1).cuda()
                        # uncertain_loss = torch.exp(-1 * ce_loss(new_perturbed_sub_prob, new_perturbed_argmax))
                        # cw_loss = cw_loss_func(new_perturbed_sub_output, new_perturbed_sub_output.max(1)[1])

                        # # class diversity loss with unlabeled data
                        # try:
                        #     unlabeled_inputs, _, _, _ = unlabeled_data_iter.next()
                        # except:
                        #     unlabeled_data_iter = iter(fuse_unlabeled_dataloader)
                        #     unlabeled_inputs, _, _, _ = unlabeled_data_iter.next()
                        # if self.opt.use_gpu:
                        #     unlabeled_inputs = unlabeled_inputs.cuda()
                        # for i in range(self.cls_div_bs):
                        #     cls_div_input = tuple(
                        #         [unlabeled_inputs[i*self.opt.n_fuse+j].view(1, 3, self.opt.public_img_size, self.opt.public_img_size)
                        #                             for j in range(self.opt.n_fuse)])
                        #     cls_div_gen_output = self.data_gen(cls_div_input)
                        #     cls_div_output = self.substitute(cls_div_gen_output)
                        #     cls_div_prob = softmax(cls_div_output)
                        #     # print(cls_div_output,cls_div_prob)
                        #     if i == 0:
                        #         cls_div_prob_sum = cls_div_prob
                        #     else:
                        #         cls_div_prob_sum = torch.cat([cls_div_prob_sum, cls_div_prob], 0)
                        #     # print(cls_div_prob_sum)
                        # avg_cls_div_prob = cls_div_prob_sum.mean(dim=0)
                        # print(avg_cls_div_prob)
                        # cls_div_loss = (-avg_cls_div_prob * torch.log(avg_cls_div_prob)).sum()

                        # cls_div_loss = 1 - torch.distributions.Categorical(avg_cls_div_prob).entropy().mean()
                        # uncertain_loss = torch.mean(new_perturbed_sub_prob.max(1)[0])
                        # div_loss = 0.0
                        # perturbed_next_img = self.data_gen(next_img)
                        # perturbed_next_img_output, _ = nn.parallel.data_parallel(self.substitute, perturbed_next_img)
                        # for _, (_, all_perturbed_img, _, _) in enumerate(all_dataloader):
                        #     if self.opt.use_gpu:
                        #         all_perturbed_img = all_perturbed_img.cuda()
                        #     all_perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, all_perturbed_img)
                        #     # div_loss += div_func(perturbed_next_img, all_perturbed_img, 2048)
                        #     div_loss += div_func(perturbed_next_img_output, all_perturbed_img_output, 4096)
                        # div_loss = div_loss / (len(self.sub_dataset)*perturbed_next_img_output.shape[0])
                        # gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss + \
                        #            self.opt.uncertain_weight * uncertain_loss# + self.opt.cls_div_weight * cls_div_loss
                        # + self.opt.div_weight * div_loss  # +fake_loss
                        gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss# + self.opt.cw_weight * cw_loss
                        self.data_gen.zero_grad()
                        gen_loss.backward()
                        data_gen_optimizer.step()

                    # train generator with next batch data (diversity loss)
                    for _ in range(self.opt.div_epoch):
                        for _, (clean_imgs, _, _, _) in enumerate(next_diff_dataloader): #next_dataloader
                            for _, (_, perturbed_img, _, _) in enumerate(div_dataloader): #all_dataloader div_dataloader
                                clean_imgs = tuple([img.cuda() for img in clean_imgs])
                                if self.opt.use_gpu:
                                    perturbed_img = perturbed_img.cuda()
                                # perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, perturbed_img)
                                perturbed_img_output = nn.parallel.data_parallel(self.substitute, perturbed_img)
                                div_func = loss.DivLoss()
                                new_perturbed_img = self.data_gen(clean_imgs)
                                # new_perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute,
                                #                                                         new_perturbed_img)
                                new_perturbed_img_output = nn.parallel.data_parallel(self.substitute,
                                                                                        new_perturbed_img)
                                div_loss = div_func(new_perturbed_img_output, perturbed_img_output, 256)
                                div_loss = div_loss / new_perturbed_img_output.shape[0]
                                gen_loss_new = div_loss
                                self.data_gen.zero_grad()
                                gen_loss_new.backward()
                                nn.utils.clip_grad_norm_(self.data_gen.parameters(),
                                                            max_norm=self.opt.div_weight, norm_type=2)
                                data_gen_optimizer.step()
                                # data_gen_next_optimizer.step()
                    # print(
                    #     f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{uncertain_loss},{cls_div_loss}')
                    # print(
                    #     f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{uncertain_loss},{div_loss},{cls_div_loss}')
                    try:
                        print(
                            f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{div_loss}')
                    except:
                        print(
                            f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss}')
            torch.cuda.empty_cache()

        elif self.source == 'fusiongan-label':
            if self.opt.continue_train and self.loop == self.opt.n_loop - 1:
                train_epoch = self.opt.pseudo_epoch

            # define loss functions
            l1_loss = nn.L1Loss()
            mse_loss = nn.MSELoss()
            cos_loss = nn.CosineEmbeddingLoss()
            kldiv_loss = nn.KLDivLoss(reduction='mean')
            ce_loss = nn.CrossEntropyLoss()
            div_func = loss.DivLoss()
            cw_loss_func = loss.CWLoss()

            softmax = nn.Softmax(dim=1)

            min_loss = 1000.
            min_loss_epoch = 0
            min_div = 1000.
            min_div_epoch = 0

            unlabeled_data_iter = iter(unlabeled_dataloader)
            simclr_data_iter = iter(simclr_dataloader)
            aug_data_iter = iter(aug_dataloader)

            '''
                new version
            '''
            # train substitute with all data
            for epoch in range(train_epoch):
                substitute_loss_data = 0.
                self.substitute.train()
                for _, (_, perturbed_img, victim_prob, _) in enumerate(dataloader):  # all_dataloader dataloader new_dataloader
                    if self.opt.use_gpu:
                        perturbed_img = perturbed_img.cuda()
                        victim_prob = victim_prob.cuda()
                    log_softmax = nn.LogSoftmax(dim=1)
                    softmax = nn.Softmax(dim=1)
                    victim_label = victim_prob.max(1)[1]

                    # calculate labeled loss (substitute training in the s-Train process)
                    perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                        self.augmentation(perturbed_img))
                    # perturbed_sub_prob = log_softmax(perturbed_sub_output)
                    sub_same_loss = ce_loss(perturbed_sub_output, victim_label)
                    labeled_loss = sub_same_loss

                    # online unsupervised loss (the U-Train process)
                    if self.opt.pseudo_label_weight == 0:
                        substitute_loss = labeled_loss
                    else:
                        try:
                            (w_inputs, s_inputs), _ = aug_data_iter.next()
                        except:
                            aug_data_iter = iter(aug_dataloader)
                            (w_inputs, s_inputs), _ = aug_data_iter.next()
                        if self.opt.use_gpu:
                            w_inputs = w_inputs.cuda()
                            s_inputs = s_inputs.cuda()


                        # no projector version
                        with torch.no_grad():
                            unlabeled_targets = softmax(self.substitute(w_inputs))
                        unlabeled_outputs = self.substitute(s_inputs)
                        unlabeled_probs = softmax(unlabeled_outputs)
                        pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)

                        # optimize combined loss
                        labeled_loss_value = labeled_loss.detach()
                        pseudo_label_loss_value = pseudo_label_loss.detach()
                        # simclr_loss_value = simclr_loss.detach()
                        # substitute_loss = labeled_loss
                        # substitute_loss = labeled_loss + \
                        #                   simclr_loss*(labeled_loss_value/simclr_loss_value)*self.opt.simclr_weight
                        substitute_loss = labeled_loss + \
                                            pseudo_label_loss*(labeled_loss_value/pseudo_label_loss_value)*self.opt.pseudo_label_weight

                    substitute_loss_data += substitute_loss.item()
                    self.substitute.zero_grad()
                    substitute_loss.backward()
                    substitute_optimizer.step()

                scheduler.step() #fusiongan

                substitute_loss_data /= len(self.sub_dataset)
                if substitute_loss_data < min_loss:
                    min_loss = substitute_loss_data
                    min_loss_epoch = epoch + 1
                elif self.opt.loss_rise_epoch:
                    if epoch+1-min_loss_epoch >= self.opt.loss_rise_epoch:
                        break
                elif self.opt.div_rise_epoch:
                    if epoch+1 >= 10 and epoch+1-min_div_epoch >= self.opt.div_rise_epoch:
                        break
                # change
                if epoch == train_epoch-1:
                    break
                # print(f'dist. for unlabeled sub. prob. and labeled victim prob.: {kldiv_loss(unlabeled_sub_prob.log(), da)}')

                print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data}')
                # print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data} | pseudo label loss {pseudo_label_loss}')

                if (epoch + 1) % self.opt.print_freq == 0:

                    acc, fidelity, kd_loss = self.evaluate()
                    asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                    print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                            f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                            f'| L2 noise {avg_l2_noise}')
                    self.x_list[0].append(self.loop * self.n_epochs + epoch)
                    self.x_list[1].append(self.loop * self.n_epochs + epoch)
                    self.x_list[2].append(self.loop * self.n_epochs + epoch)
                    self.x_list[3].append(self.loop * self.n_epochs + epoch)
                    self.x_list[4].append(self.loop * self.n_epochs + epoch)
                    self.y_list[0].append(avg_l2_noise / 4)
                    self.y_list[1].append(asr / 100)
                    self.y_list[2].append(acc)
                    self.y_list[3].append(fidelity)
            next_diff_dataset = self.get_next_diff_dataset()
            next_diff_dataloader = torch.utils.data.DataLoader(
                next_diff_dataset,
                batch_size=self.labeled_bs,
                shuffle=True,
                num_workers=4
            )

            torch.cuda.empty_cache()

            # train gan with new data only (generator training in the S-Train process)
            if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                for epoch in range(self.gen_epoch):
                    self.data_gen.train()
                    # current batch
                    # unlabeled_data_iter = iter(fuse_unlabeled_dataloader)
                    for _, (clean_imgs, perturbed_img, victim_prob, _) in enumerate(new_dataloader):
                        clean_imgs = tuple([img.cuda() for img in clean_imgs])
                        # next_img = tuple([img.cuda() for img in next_img])
                        if self.opt.use_gpu:
                            perturbed_img = perturbed_img.cuda()
                            victim_prob = victim_prob.cuda()
                        
                        victim_label = victim_prob.max(1)[1]

                        log_softmax = nn.LogSoftmax(dim=1)
                        softmax = nn.Softmax(dim=1)
                        # perturbed_sub_output, perturbed_fake_output = nn.parallel.data_parallel(self.substitute,
                        #                                                                         self.augmentation(
                        #                                                                             perturbed_img))
                        perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                            self.augmentation(perturbed_img))
                        perturbed_sub_prob = log_softmax(perturbed_sub_output)
                        new_perturbed_img = self.data_gen(clean_imgs)
                        new_perturbed_sub_output = nn.parallel.data_parallel(self.substitute, new_perturbed_img)
                        # new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)
                        # print(new_perturbed_sub_prob.size())

                        # train generator with current batch data
                        noise_loss = 0.0
                        for i in range(self.opt.n_fuse):
                            noise_loss += mse_loss(new_perturbed_img, clean_imgs[i])
                        # adv_loss = torch.exp(-1 * mse_loss(new_perturbed_sub_prob, victim_prob)) # not log softmax
                        adv_loss = torch.exp(-1 * ce_loss(new_perturbed_sub_output, victim_label))
                        gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss# + self.opt.cw_weight * cw_loss
                        self.data_gen.zero_grad()
                        gen_loss.backward()
                        data_gen_optimizer.step()

                    # train generator with next batch data (diversity loss)
                    for _ in range(self.opt.div_epoch):
                        for _, (clean_imgs, _, _, _) in enumerate(next_diff_dataloader): #next_dataloader
                            for _, (_, perturbed_img, _, _) in enumerate(div_dataloader): #all_dataloader div_dataloader
                                clean_imgs = tuple([img.cuda() for img in clean_imgs])
                                if self.opt.use_gpu:
                                    perturbed_img = perturbed_img.cuda()
                                # perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, perturbed_img)
                                perturbed_img_output = nn.parallel.data_parallel(self.substitute, perturbed_img)
                                div_func = loss.DivLoss()
                                new_perturbed_img = self.data_gen(clean_imgs)
                                # new_perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute,
                                #                                                         new_perturbed_img)
                                new_perturbed_img_output = nn.parallel.data_parallel(self.substitute,
                                                                                        new_perturbed_img)
                                div_loss = div_func(new_perturbed_img_output, perturbed_img_output, 256)
                                div_loss = div_loss / new_perturbed_img_output.shape[0]
                                gen_loss_new = div_loss
                                self.data_gen.zero_grad()
                                gen_loss_new.backward()
                                nn.utils.clip_grad_norm_(self.data_gen.parameters(),
                                                            max_norm=self.opt.div_weight, norm_type=2)
                                data_gen_optimizer.step()
                                # data_gen_next_optimizer.step()
                    # print(
                    #     f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{uncertain_loss},{cls_div_loss}')
                    # print(
                    #     f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{uncertain_loss},{div_loss},{cls_div_loss}')
                    try:
                        print(
                            f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{div_loss}')
                    except:
                        print(
                            f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss}')
            torch.cuda.empty_cache()

        elif self.source == 'random_aggr_gen':
            if self.opt.continue_train and self.loop == self.opt.n_loop - 1:
                train_epoch = self.opt.pseudo_epoch

            # define loss functions
            l1_loss = nn.L1Loss()
            mse_loss = nn.MSELoss()
            cos_loss = nn.CosineEmbeddingLoss()
            kldiv_loss = nn.KLDivLoss(reduction='mean')
            ce_loss = nn.CrossEntropyLoss()
            div_func = loss.DivLoss()
            cw_loss_func = loss.CWLoss()

            softmax = nn.Softmax(dim=1)

            min_loss = 1000.
            min_loss_epoch = 0
            min_div = 1000.
            min_div_epoch = 0

            unlabeled_data_iter = iter(unlabeled_dataloader)
            simclr_data_iter = iter(simclr_dataloader)
            aug_data_iter = iter(aug_dataloader)


            # train substitute with all data
            for epoch in range(train_epoch):
                substitute_loss_data = 0.
                self.substitute.train()
                for _, (_, perturbed_img, victim_prob, _) in enumerate(dataloader):  # all_dataloader dataloader new_dataloader
                    if self.opt.use_gpu:
                        perturbed_img = perturbed_img.cuda()
                        victim_prob = victim_prob.cuda()
                    log_softmax = nn.LogSoftmax(dim=1)
                    softmax = nn.Softmax(dim=1)

                    # calculate labeled loss (substitute training in the s-Train process)
                    perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                        self.augmentation(perturbed_img))
                    perturbed_sub_prob = log_softmax(perturbed_sub_output)
                    sub_same_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
                    labeled_loss = sub_same_loss

                    # online unsupervised loss (the U-Train process)
                    if self.opt.pseudo_label_weight == 0:
                        substitute_loss = labeled_loss
                    else:
                        try:
                            (w_inputs, s_inputs), _ = aug_data_iter.next()
                        except:
                            aug_data_iter = iter(aug_dataloader)
                            (w_inputs, s_inputs), _ = aug_data_iter.next()
                        if self.opt.use_gpu:
                            w_inputs = w_inputs.cuda()
                            s_inputs = s_inputs.cuda()

                        with torch.no_grad():
                            unlabeled_targets = softmax(self.substitute(w_inputs))
                        unlabeled_outputs = self.substitute(s_inputs)
                        unlabeled_probs = softmax(unlabeled_outputs)
                        # max_probs = unlabeled_probs.max(1)[0]
                        # mask = max_probs.ge(0.5).float().view([unlabeled_outputs.shape[0],1])
                        pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)

                        # optimize combined loss
                        labeled_loss_value = labeled_loss.detach()
                        pseudo_label_loss_value = pseudo_label_loss.detach()
                        substitute_loss = labeled_loss + \
                                            pseudo_label_loss*(labeled_loss_value/pseudo_label_loss_value)*self.opt.pseudo_label_weight

                    substitute_loss_data += substitute_loss.item()
                    self.substitute.zero_grad()
                    substitute_loss.backward()
                    substitute_optimizer.step()

                substitute_loss_data /= len(self.sub_dataset)
                if substitute_loss_data < min_loss:
                    min_loss = substitute_loss_data
                    min_loss_epoch = epoch + 1
                elif self.opt.loss_rise_epoch:
                    if epoch+1 >= 10 and epoch+1-min_loss_epoch >= self.opt.loss_rise_epoch:
                        break

                elif self.opt.div_rise_epoch:
                    if epoch+1 >= 10 and epoch+1-min_div_epoch >= self.opt.div_rise_epoch:
                        break

                print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data}')
                # print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data} | pseudo label loss {pseudo_label_loss}')

                if (epoch + 1) % self.opt.print_freq == 0:
                    acc, fidelity, kd_loss = self.evaluate()
                    asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                    print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                            f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                            f'| L2 noise {avg_l2_noise}')
                    self.x_list[0].append(self.loop * self.n_epochs + epoch)
                    self.x_list[1].append(self.loop * self.n_epochs + epoch)
                    self.x_list[2].append(self.loop * self.n_epochs + epoch)
                    self.x_list[3].append(self.loop * self.n_epochs + epoch)
                    self.x_list[4].append(self.loop * self.n_epochs + epoch)
                    self.y_list[0].append(avg_l2_noise / 4)
                    self.y_list[1].append(asr / 100)
                    self.y_list[2].append(acc)
                    self.y_list[3].append(fidelity)
                    self.y_list[4].append(kd_loss)


            # get next batch data with different labels (generator input sampling based on pseudo labeling)
            next_diff_dataset = self.get_next_diff_dataset()
            next_diff_dataloader = torch.utils.data.DataLoader(
                next_diff_dataset,
                batch_size=self.labeled_bs,
                shuffle=True,
                num_workers=4
            )

            torch.cuda.empty_cache()

            # train gan with new data only (generator training in the S-Train process)
            if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                for epoch in range(self.gen_epoch):
                    self.data_gen.train()
                    # current batch
                    # unlabeled_data_iter = iter(fuse_unlabeled_dataloader)
                    for _, (clean_imgs, perturbed_img, victim_prob, _) in enumerate(new_dataloader):
                        clean_imgs = tuple([img.cuda() for img in clean_imgs])
                        noise = torch.randn(perturbed_img.shape[0], self.opt.noise_dim)
                        # next_img = tuple([img.cuda() for img in next_img])
                        if self.opt.use_gpu:
                            perturbed_img = perturbed_img.cuda()
                            victim_prob = victim_prob.cuda()
                            noise = noise.cuda()

                        log_softmax = nn.LogSoftmax(dim=1)
                        softmax = nn.Softmax(dim=1)
                        # perturbed_sub_output, perturbed_fake_output = nn.parallel.data_parallel(self.substitute,
                        #                                                                         self.augmentation(
                        #                                                                             perturbed_img))
                        perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                            self.augmentation(perturbed_img))
                        perturbed_sub_prob = log_softmax(perturbed_sub_output)
                        new_perturbed_img = self.data_gen(clean_imgs, noise)
                        # new_perturbed_img = nn.parallel.data_parallel(self.data_gen, clean_imgs)
                        # new_perturbed_sub_output, new_perturbed_sub_fake_output = nn.parallel.data_parallel(
                        #     self.substitute, new_perturbed_img)
                        new_perturbed_sub_output = nn.parallel.data_parallel(self.substitute, new_perturbed_img)
                        new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)
                        # print(new_perturbed_sub_prob.size())

                        # train generator with current batch data
                        noise_loss = 0.0
                        for i in range(self.opt.n_fuse):
                            noise_loss += mse_loss(new_perturbed_img, clean_imgs[i])
                        # adv_loss = torch.exp(-1 * mse_loss(new_perturbed_sub_prob, victim_prob)) # not log softmax
                        adv_loss = torch.exp(-1 * kldiv_loss(new_perturbed_sub_prob, victim_prob))
                        gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss
                        self.data_gen.zero_grad()
                        gen_loss.backward()
                        data_gen_optimizer.step()

                    # train generator with next batch data (diversity loss)
                    for _ in range(self.opt.div_epoch):
                        for _, (clean_imgs, _, _, _) in enumerate(next_diff_dataloader): #next_dataloader
                            for _, (_, perturbed_img, _, _) in enumerate(div_dataloader): #all_dataloader div_dataloader
                                noise = torch.randn(perturbed_img.shape[0], self.opt.noise_dim)
                                clean_imgs = tuple([img.cuda() for img in clean_imgs])
                                if self.opt.use_gpu:
                                    perturbed_img = perturbed_img.cuda()
                                    noise = noise.cuda()
                                # perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, perturbed_img)
                                perturbed_img_output = nn.parallel.data_parallel(self.substitute, perturbed_img)
                                div_func = loss.DivLoss()
                                new_perturbed_img = self.data_gen(clean_imgs, noise)
                                # new_perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute,
                                #                                                         new_perturbed_img)
                                new_perturbed_img_output = nn.parallel.data_parallel(self.substitute,
                                                                                        new_perturbed_img)
                                div_loss = div_func(new_perturbed_img_output, perturbed_img_output, 256)
                                div_loss = div_loss / new_perturbed_img_output.shape[0]
                                gen_loss_new = div_loss
                                self.data_gen.zero_grad()
                                gen_loss_new.backward()
                                nn.utils.clip_grad_norm_(self.data_gen.parameters(),
                                                            max_norm=self.opt.div_weight, norm_type=2)
                                data_gen_optimizer.step()
                                # data_gen_next_optimizer.step()
                    # print(
                    #     f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{uncertain_loss},{cls_div_loss}')
                    # print(
                    #     f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{uncertain_loss},{div_loss},{cls_div_loss}')
                    try:
                        print(
                            f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{div_loss}')
                    except:
                        print(
                            f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss}')


        writer.flush()
        writer.close()
        acc, fidelity, kd_loss = self.evaluate()
        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise}')

        # save substitute model
        if self.save:
            save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
                                    f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}',
                                    f'{self.opt.pre_train_sub}_{self.source}')
            os.makedirs(save_dir, exist_ok=True)
            save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.loop + 1}'
            save_path = os.path.join(save_dir, save_name)
            torch.save(self.substitute.state_dict(), save_path)
            print(f'[substitute] Saved substitute model in loop {self.loop + 1}')

        # # train generator

        # # train with new queries
        # for epoch in range(50):
        #     for _, (seed, _, prob) in enumerate(new_dataloader):
        #         if self.opt.use_gpu:
        #             seed = seed.cuda()
        #             prob = prob.cuda()
        #         self.data_gen.zero_grad()
        #         sub_output = self.substitute(self.data_gen(seed))
        #         # KL divergence loss
        #         # softmax = nn.LogSoftmax(dim=1)
        #         # sub_prob = softmax(sub_output)
        #         # loss_func = torch.nn.KLDivLoss(reduction='mean')
        #         # eps = 1e-7
        #         # loss = 1/(loss_func(sub_prob, prob)+eps)

        #         # L2 loss
        #         softmax = nn.Softmax(dim=1)
        #         sub_prob = softmax(sub_output)
        #         loss_func = nn.MSELoss()
        #         loss = torch.exp(-1*loss_func(sub_prob,prob))
        #         loss.backward()
        #         # nn.utils.clip_grad_value_(self.data_gen.parameters(),clip_value=0.7)
        #         data_gen_optimizer.step()
        #     print(f'[data generator] epoch {epoch} | loss {loss}')
        #     torch.save(self.data_gen.state_dict(),gen_ckpt_path)

        # train with random input
        # for epoch in range(40):
        #     for _ in range(len(dataloader)):
        #         self.data_gen.zero_grad()
        #         noise = torch.Tensor(
        #             np.random.uniform(-5.,5.,size=(16,256))
        #         )
        #         if self.opt.use_gpu:
        #             noise = noise.cuda()
        #         sub_output = self.substitute(self.data_gen(noise))
        #         softmax = nn.Softmax(dim=1)
        #         sub_prob = softmax(sub_output)
        #         # print(torch.max(sub_prob,1))
        #         data_gen_loss = torch.mean(torch.max(sub_prob,1)[0])
        #         # print(data_gen_loss)
        #         data_gen_loss.backward()
        #         data_gen_optimizer.step()
        #     print(f'[data generator] epoch {epoch} | loss {data_gen_loss}')
        #     torch.save(self.data_gen.state_dict(),gen_ckpt_path)

        if 'fusiongan' in self.source or self.source == 'random_aggr_gen':
            # return self.substitute, self.data_gen, self.substitute_projected, self.unlabeled_dataset, next_diff_dataset
            return self.substitute, self.data_gen, self.unlabeled_dataset, next_diff_dataset
        return self.substitute, self.data_gen, self.unlabeled_dataset

    def evaluate(self):
        n_outputs = self.opt.victim_n_classes
        if self.opt.eval_model.startswith('efficientnet'):
            eval_model = setup.EfficientNet.from_pretrained(
                self.opt.eval_model,
                num_classes=n_outputs)
        elif self.opt.eval_model.startswith('wrn'):
            depth = int(self.opt.eval_model.split('-')[-1])
            eval_model = setup.classifier.WideResNet(
                n_outputs=n_outputs,
                depth=depth
            )
        else:
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )

        substitute = load_model_weights(self.substitute, eval_model)

        victim = self.victim
        substitute.eval()
        victim.eval()
        kd_loss_func = kd_loss.FSPLoss()
        dataloader = self.eval_dataset.test_dataloader()
        accs = 0.0
        fidelity = 0.0
        kd_loss_result = 0.0
        n_samples = 0
        n_batch = 0
        for _, data in enumerate(dataloader):
            imgs = data[0]
            targets = data[1]
            if self.opt.use_gpu:
                imgs = imgs.cuda()
                targets = targets.cuda()
                victim.cuda()
                substitute.cuda()
            n_samples += targets.shape[0]
            n_batch += 1
            with torch.no_grad():
                outputs = substitute(imgs)
                victim_outputs = victim(imgs)
                # sub_features = substitute.features
                # victim_features = victim.features
                acc = outputs.max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
                same = victim_outputs.max(1)[1].eq(outputs.max(1)[1]).float().sum()
                same = same.detach().cpu()
                # kd_loss_batch = kd_loss_func(sub_features, victim_features).detach().cpu()
            accs += acc
            fidelity += same
            # kd_loss_result += kd_loss_batch
        accs /= n_samples
        fidelity /= n_samples
        kd_loss_result /= n_batch
        max_kd = 280000
        scale_kd = 100000
        final_kd = (max_kd - kd_loss_result) / scale_kd
        return accs, fidelity, final_kd

    # def adv_evaluate(self, sample_size):
    #     n_outputs = self.opt.victim_n_classes
    #     eval_model = classifier_dict[self.opt.eval_model](
    #         n_outputs=n_outputs
    #     )
    #     substitute = load_model_weights(self.substitute, eval_model)
    #     self.victim.eval()
    #     substitute.eval()
    #     data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)),sample_size,replace=False) #[i for i in range(7800,8000)]
    #     dataloader = torch.utils.data.DataLoader(
    #         self.eval_dataset.test_dataset, batch_size=50,
    #         sampler=sp.SubsetRandomSampler(data_list), num_workers=4
    #     )
    #     # dataloader = self.eval_dataset.test_dataloader()
    #     adversary_ghost = LinfBasicIterativeAttack(
    #         substitute, loss_fn = nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
    #         nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple*self.opt.noise_eps/self.opt.noise_step,
    #         clip_min=-1.0, clip_max=1.0, targeted=False)
    #     attack_success = 0.0
    #     total = 0.0
    #     l2_noise = 0.0
    #     self.victim.eval()
    #     for data in dataloader:
    #         inputs, labels = data
    #         inputs = inputs.cuda()
    #         labels = labels.cuda()
    #         with torch.no_grad():
    #             outputs = self.victim(inputs)
    #             _, predicted = torch.max(outputs.data, 1)
    #         correct_predict = predicted == labels
    #         total += correct_predict.float().sum()
    #         if self.opt.targeted:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels+1) % 10)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #             attack_success += (correct_predict*(predicted == (labels+1) % 10)).sum()
    #             for i, correct in enumerate(correct_predict):
    #                 if correct and predicted[i] == (labels[i]+1) % 10:
    #                     l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
    #         else:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #             attack_success += (correct_predict*(predicted != labels)).sum()
    #             for i, correct in enumerate(correct_predict):
    #                 if correct and predicted[i] != labels[i]:
    #                     l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
    #
    #         # with torch.no_grad():
    #         #     victim_labels = self.victim(inputs).max(1)[1]
    #         #     sub_labels = self.substitute(inputs).max(1)[1]
    #
    #         # adv_inputs_ghost = adversary_ghost.perturb(inputs, sub_labels)
    #         # with torch.no_grad():
    #         #     outputs = self.victim(adv_inputs_ghost)
    #         #     _, predicted = torch.max(outputs.data,1)
    #         # # print('victim_labels',victim_labels)
    #         # # print('sub_labels',sub_labels)
    #         # # print('predicted',predicted)
    #         # total += labels.size(0)
    #         # correct_ghost += (predicted == victim_labels).sum()
    #     asr = (100. * attack_success.float() / total).cpu()
    #     avg_l2_noise = (l2_noise / total).cpu()
    #     l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
    #     return asr, avg_l2_noise, l2_noise_per_pixel

    # def adv_evaluate(self, sample_size):
    #     if self.opt.use_gpu:
    #         self.victim.cuda()
    #         self.substitute.cuda()
    #     self.victim.eval()
    #     self.substitute.eval()
    #     data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)), sample_size,
    #                                  replace=False)  # [i for i in range(7800,8000)]
    #     batch_size = 50
    #     dataloader = torch.utils.data.DataLoader(
    #         self.eval_dataset.test_dataset, batch_size=batch_size,
    #         sampler=sp.SubsetRandomSampler(data_list), num_workers=4
    #     )
    #     # dataloader = self.eval_dataset.test_dataloader()
    #     # adversary_ghost = LinfBasicIterativeAttack(
    #     #     self.substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
    #     #     nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
    #     #     clip_min=-1.0, clip_max=1.0, targeted=False)
    #     adversary_ghost = LinfMomentumIterativeAttack(
    #         self.substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
    #         nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
    #         clip_min=-1.0, clip_max=1.0, targeted=False)
    #     attack_success = 0.0
    #     sub_attack_success = 0.0
    #     total = 0.0
    #     success_total = 0.0
    #     l2_noise = 0.0
    #     success_l2_noise = 0.0
    #     for data in dataloader:
    #         inputs, labels = data
    #         inputs = inputs.cuda()
    #         labels = labels.cuda()
    #         with torch.no_grad():
    #             outputs = self.victim(inputs)
    #             _, predicted = torch.max(outputs.data, 1)
    #             sub_outputs = self.substitute(inputs)
    #             _, sub_labels = torch.max(sub_outputs.data, 1)
    #         correct_predict = predicted == labels
    #         if self.opt.targeted:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #             attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
    #             l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
    #                                                 inputs.view(len(inputs), -1))
    #             for i, correct in enumerate(correct_predict):
    #                 if correct:
    #                     l2_noise += l2_noise_list[i]
    #                     total += 1
    #                     if predicted[i] == (labels[i] + 1) % 10:
    #                         success_l2_noise += l2_noise_list[i]
    #                         success_total += 1
    #         else:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #                 sub_outputs = self.victim(adv_inputs_ghost)
    #                 _, sub_predicted = torch.max(sub_outputs.data, 1)
    #             attack_success += (correct_predict * (predicted != labels)).sum()
    #             sub_attack_success += (correct_predict * (sub_predicted != sub_labels)).sum()
    #             l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
    #                                                 inputs.view(len(inputs), -1))
    #             for i, correct in enumerate(correct_predict):
    #                 if correct:
    #                     l2_noise += l2_noise_list[i]
    #                     total += 1
    #                     if predicted[i] != labels[i]:
    #                         success_l2_noise += l2_noise_list[i]
    #                         success_total += 1
    #
    #     asr = (100. * attack_success.float() / total).cpu()
    #     sub_asr = (100. * sub_attack_success.float() / total).cpu()
    #     success_avg_l2_noise = (success_l2_noise / success_total).cpu()
    #     success_l2_noise_per_pixel = success_avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
    #     avg_l2_noise = (l2_noise / total).cpu()
    #     # l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
    #     return asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise

    def adv_evaluate(self, sample_size):
        if self.opt.use_gpu:
            self.victim.cuda()
            self.substitute.cuda()
        self.victim.eval()
        self.substitute.eval()

        n_outputs = self.opt.victim_n_classes
        if self.opt.eval_model.startswith('efficientnet'):
            eval_model = setup.EfficientNet.from_pretrained(
                self.opt.eval_model,
                num_classes=n_outputs)
        elif self.opt.eval_model.startswith('wrn'):
            depth = int(self.opt.eval_model.split('-')[-1])
            eval_model = setup.classifier.WideResNet(
                n_outputs=n_outputs,
                depth=depth
            )
        else:
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )

        substitute = load_model_weights(self.substitute, eval_model)

        data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        batch_size = 50
        dataloader = torch.utils.data.DataLoader(
            self.eval_dataset.test_dataset, batch_size=batch_size,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        # dataloader = self.eval_dataset.test_dataloader()
        # adversary_ghost = LinfBasicIterativeAttack(
        #     substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
        #     nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
        #     clip_min=-1.0, clip_max=1.0, targeted=False)
        adversary_ghost = LinfMomentumIterativeAttack(
            substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
            nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
            clip_min=-1.0, clip_max=1.0, targeted=False)
        attack_success = 0.0
        sub_attack_success = 0.0
        total = 0.0
        success_total = 0.0
        l2_noise = 0.0
        success_l2_noise = 0.0
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = self.victim(inputs)
                _, predicted = torch.max(outputs.data, 1)
                sub_outputs = substitute(inputs)
                _, sub_labels = torch.max(sub_outputs.data, 1)
            correct_predict = predicted == labels
            if self.opt.targeted:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
                l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
                                                    inputs.view(len(inputs), -1))
                for i, correct in enumerate(correct_predict):
                    if correct:
                        l2_noise += l2_noise_list[i]
                        total += 1
                        if predicted[i] == (labels[i] + 1) % 10:
                            success_l2_noise += l2_noise_list[i]
                            success_total += 1
            else:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                    sub_outputs = self.victim(adv_inputs_ghost)
                    _, sub_predicted = torch.max(sub_outputs.data, 1)
                attack_success += (correct_predict * (predicted != labels)).sum()
                sub_attack_success += (correct_predict * (sub_predicted != sub_labels)).sum()
                l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
                                                    inputs.view(len(inputs), -1))
                for i, correct in enumerate(correct_predict):
                    if correct:
                        l2_noise += l2_noise_list[i]
                        total += 1
                        if predicted[i] != labels[i]:
                            success_l2_noise += l2_noise_list[i]
                            success_total += 1

        asr = (100. * attack_success.float() / total).cpu()
        sub_asr = (100. * sub_attack_success.float() / total).cpu()
        success_avg_l2_noise = (success_l2_noise / success_total).cpu()
        success_l2_noise_per_pixel = success_avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
        avg_l2_noise = (l2_noise / total).cpu()
        # l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
        return asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise

    def surrogate_adv_evaluate(self, sample_size):
        substitute = self.substitute
        substitute.eval()
        data_list = np.random.choice(range(len(self.surrogate_eval_dataset.test_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        dataloader = torch.utils.data.DataLoader(
            self.surrogate_eval_dataset.test_dataset, batch_size=50,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        adversary_ghost = LinfBasicIterativeAttack(
            substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
            nb_iter=self.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.noise_step,
            clip_min=-1.0, clip_max=1.0, targeted=False)
        attack_success = 0.0
        total = 0.0
        l2_noise = 0.0
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = substitute(inputs)
                _, predicted = torch.max(outputs.data, 1)
            correct_predict = predicted == labels
            total += correct_predict.float().sum()
            if self.opt.targeted:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
                with torch.no_grad():
                    outputs = substitute(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
                for i, correct in enumerate(correct_predict):
                    if correct and predicted[i] == (labels[i] + 1) % 10:
                        l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
            else:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
                with torch.no_grad():
                    outputs = substitute(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted != labels)).sum()
                for i, correct in enumerate(correct_predict):
                    if correct and predicted[i] != labels[i]:
                        l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
        asr = (100. * attack_success.float() / total).cpu()
        avg_l2_noise = (l2_noise / total).cpu()
        l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
        return asr, avg_l2_noise, l2_noise_per_pixel

    def epoch_val(self):
        if self.source == 'attackgan':
            n_outputs = self.opt.victim_n_classes
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )
            substitute = load_model_weights(self.substitute, eval_model)
        else:
            substitute = self.substitute
        substitute.eval()
        val_dataloader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=50,
            shuffle=True,
            num_workers=4
        )
        total_loss = 0.0
        num_batch = 0.0
        for _, (_, perturbed_img, victim_prob) in enumerate(val_dataloader):
            num_batch += 1
            if self.opt.use_gpu:
                perturbed_img = perturbed_img.cuda()
                victim_prob = victim_prob.cuda()

            log_softmax = nn.LogSoftmax(dim=1)
            perturbed_sub_output, _ = self.substitute(perturbed_img)
            perturbed_sub_prob = log_softmax(perturbed_sub_output)

            # define loss functions
            l1_loss = nn.L1Loss()
            mse_loss = nn.MSELoss()
            kldiv_loss = nn.KLDivLoss(reduction='mean')

            # sub_same_loss = l1_loss(softmax(new_perturbed_sub_output), victim_prob)
            # sub_same_loss = mse_loss(softmax(new_perturbed_sub_output), victim_prob)
            sub_same_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
            total_loss += sub_same_loss.detach().cpu()
        return total_loss / num_batch

    def get_weighted_sampler(self):
        alpha = self.opt.dataset_decay
        n_total = len(self.sub_dataset)
        n_new = len(self.new_sub_dataset)
        n_loop = n_total // n_new
        weights = []
        for i in range(n_total):
            loop = i // n_new
            weights.append(alpha ** (n_loop - loop))
        weights = torch.DoubleTensor(weights)
        weighted_sampler = sp.WeightedRandomSampler(weights, n_total)
        return weighted_sampler

    def get_pseudo_label(self, unlabeled_items, k_avg=3, T=0.5, da=None, threshhold=0.7):
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        softmax = nn.Softmax(dim=1)
        self.substitute.eval()
        # average probability predictions after k augmentations
        for i_aug in range(k_avg):
            unlabeled_items_aug = augmentation(unlabeled_items)
            with torch.no_grad():
                unlabeled_outputs = self.substitute(unlabeled_items_aug)
            # print(self.substitute(unlabeled_items_aug).max(1)[1])
            if i_aug == 0:
                unlabeled_probs = softmax(unlabeled_outputs)
            else:
                unlabeled_probs += softmax(unlabeled_outputs)
        unlabeled_probs = unlabeled_probs.detach()
        unlabeled_probs /= k_avg
        # sharpening
        if T is not None:
            unlabeled_pt = unlabeled_probs ** (1/T)
            unlabeled_probs = unlabeled_pt / unlabeled_pt.sum(dim=1, keepdim=True)

        # if distribution alignment is enabled
        if da is not None:
            unlabeled_probs_da = unlabeled_probs * da
            unlabeled_probs = unlabeled_probs_da / unlabeled_probs_da.sum(dim=1, keepdim=True)

        pseudo_labels = unlabeled_probs.max(1)[1]   # e.g. 5
        # pseudo_labels = nn.functional.one_hot(unlabeled_probs.max(1)[1], num_classes=10)
        # print(pseudo_labels)

        mask = (unlabeled_probs.max(1)[0]>threshhold).float()
        # print(unlabeled_probs)
        # print(mask)

        return unlabeled_probs, pseudo_labels, mask

    def get_unlabeled_sub_prob(self, sample_size):
        self.substitute.eval()
        softmax = nn.Softmax(dim=1)
        data_list = np.random.choice(range(len(self.unlabeled_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset, batch_size=50,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        for i, (unlabeled_inputs, _, _, _) in enumerate(dataloader):
            if self.opt.use_gpu:
                unlabeled_inputs = unlabeled_inputs.cuda()
            with torch.no_grad():
                unlabeled_outputs = self.substitute(unlabeled_inputs).detach()
            unlabeled_probs = softmax(unlabeled_outputs)
            if i == 0:
                probs_sum = unlabeled_probs.sum(dim=0)
            else:
                probs_sum += unlabeled_probs.sum(dim=0)
        return probs_sum / sample_size

    def get_next_diff_dataset(self, k_avg=3):
        next_diff_dataset = data.SubDataset()

        used = 0
        for item in self.unlabeled_dataset.items:
            used += item[3]
        unused = len(self.unlabeled_dataset) - used
        query_per_loop = int(self.opt.query / self.opt.n_loop)
        if unused <= self.opt.n_fuse * query_per_loop and used == 0:
            print(f'Full-scale public dataset needed. Skip sampling process.')
            times = math.ceil(query_per_loop / (unused//self.opt.n_fuse))
            for i in range(times):
                random_idx_list = random.sample([i for i in range(unused)], unused)
                for query in range(unused//self.opt.n_fuse):
                    items = []
                    for i in range(self.opt.n_fuse):
                        items.append(self.unlabeled_dataset.items[random_idx_list[query * self.opt.n_fuse + i]][0].view(
                            3, self.opt.public_img_size, self.opt.public_img_size))
                        self.unlabeled_dataset.items[random_idx_list[query * self.opt.n_fuse + i]][3] = 1
                    next_diff_dataset.items.append((items, -1, -1, -1))
                # for idx in range(len(self.unlabeled_dataset)):
                #     self.unlabeled_dataset.items[idx][3] = 1
                #     next_diff_dataset.items.append((self.unlabeled_dataset.items[idx][0], -1, -1, -1))
            next_diff_dataset.items = next_diff_dataset.items[:query_per_loop]
            return next_diff_dataset


        loop_size = len(self.new_sub_dataset)
        softmax = nn.Softmax(dim=1)
        batch_size = 100
        unlabeled_dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        # stores the indexs of unused public data for each pseudo label class
        label_idx_list = [[] for _ in range(self.opt.victim_n_classes)]
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        for i, (unlabeled_items, _, _, _) in enumerate(unlabeled_dataloader):
            if self.opt.use_gpu:
                unlabeled_items = unlabeled_items.cuda()
            for i_aug in range(k_avg):
                unlabeled_items_aug = augmentation(unlabeled_items)
                with torch.no_grad():
                    unlabeled_outputs = self.substitute(unlabeled_items_aug)
                if i_aug == 0:
                    unlabeled_probs = softmax(unlabeled_outputs)
                else:
                    unlabeled_probs += softmax(unlabeled_outputs)
            unlabeled_probs = unlabeled_probs.detach()
            unlabeled_probs /= k_avg
            pseudo_labels = unlabeled_probs.max(1)[1]
            for j in range(len(pseudo_labels)):
                self.unlabeled_dataset.items[i*batch_size+j][1] = unlabeled_probs[j].cpu()
                self.unlabeled_dataset.items[i*batch_size+j][2] = int(pseudo_labels[j])
                # only involve unused items for next batch
                if self.unlabeled_dataset.items[i*batch_size+j][3] == 0:
                    label_idx_list[pseudo_labels[j]].append(i*batch_size+j)

        # # sample in order
        # current_label = 0
        # for _ in range(loop_size):
        #     items = []
        #     for _ in range(self.opt.n_fuse):
        #         while len(label_idx_list[current_label]) == 0:
        #             current_label = (current_label+1)%self.opt.victim_n_classes
        #         idx = label_idx_list[current_label].pop(0)
        #         items.append(self.unlabeled_dataset.items[idx][0])
        #         self.unlabeled_dataset.items[idx][3] = 1 # mark as used
        #         current_label = (current_label + 1) % self.opt.victim_n_classes
        #     # items = tuple(items)
        #     next_diff_dataset.items.append((items, -1, -1, -1))

        # sample by frequency
        label_nums = [len(idxs) for idxs in label_idx_list]
        print(label_nums)
        # weight mapping function
        # max_num = max(label_nums) # func1
        # label_weights = [2*max_num-label_num if label_num != 0 else 0 for label_num in label_nums]
        label_num_sum = sum(label_nums) # func2
        label_weights = [-math.log(label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        # label_num_sum = sum(label_nums) # func3
        # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        label_probs = [label_weight/sum(label_weights) for label_weight in label_weights]

        for _ in range(loop_size):
            items = []
            num_unsampled = self.opt.n_fuse
            while num_unsampled > 0:
                n_valid_class = 0
                for p in label_probs:
                    if p > 0:
                        n_valid_class += 1
                num_current_sample = min(n_valid_class, num_unsampled)
                num_unsampled -= num_current_sample
                current_labels = np.random.choice(range(self.opt.victim_n_classes), size=num_current_sample, p=label_probs,
                                                  replace=False)
                for current_label in current_labels:
                    idx = label_idx_list[current_label].pop(0)
                    items.append(self.unlabeled_dataset.items[idx][0])
                    self.unlabeled_dataset.items[idx][3] = 1  # mark as used
                # update weights
                label_nums = [len(idxs) for idxs in label_idx_list]
                # weight mapping function
                # max_num = max(label_nums)
                # label_weights = [2 * max_num - label_num if label_num != 0 else 0 for label_num in label_nums]
                label_num_sum = sum(label_nums)  # func2
                label_weights = [-math.log(label_num / label_num_sum) + 1e-6 if label_num != 0 else 0 for label_num in
                                 label_nums]
                # label_num_sum = sum(label_nums) # func3
                # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
                if sum(label_weights) == 0:
                    label_probs = [0. for _ in label_weights]
                else:
                    label_probs = [label_weight / sum(label_weights) for label_weight in label_weights]
            next_diff_dataset.items.append((items, -1, -1, -1))


        # for _ in range(loop_size):
        #     # if the number of classes where p!=0 is less than n_fuse, set replace = True
        #     n_valid_class = 0
        #     for p in label_probs:
        #         if p > 0:
        #             n_valid_class += 1
        #     if n_valid_class < self.opt.n_fuse:
        #         replace = True
        #     else:
        #         replace = False
        #     items = []
        #     current_labels = np.random.choice(range(self.opt.victim_n_classes), size=self.opt.n_fuse, p=label_probs, replace=replace)
        #     for current_label in current_labels:
        #         idx = label_idx_list[current_label].pop(0)
        #         items.append(self.unlabeled_dataset.items[idx][0])
        #         self.unlabeled_dataset.items[idx][3] = 1  # mark as used
        #     next_diff_dataset.items.append((items, -1, -1, -1))
        #     # update weights
        #     label_nums = [len(idxs) for idxs in label_idx_list]
        #     # weight mapping function
        #     # max_num = max(label_nums)
        #     # label_weights = [2 * max_num - label_num if label_num != 0 else 0 for label_num in label_nums]
        #     label_num_sum = sum(label_nums) # func2
        #     label_weights = [-math.log(label_num/label_num_sum)+1e-6 if label_num != 0 else 0 for label_num in label_nums]
        #     # label_num_sum = sum(label_nums) # func3
        #     # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        #     if sum(label_weights) == 0:
        #         label_probs = [0. for _ in label_weights]
        #     else:
        #         label_probs = [label_weight / sum(label_weights) for label_weight in label_weights]

        return next_diff_dataset
