import os
from itertools import chain
import random
import sys
import math

import torch
from torch import nn, softmax
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data.sampler as sp
import torch.distributions as d
import torch.nn.functional as F
import numpy as np
from advertorch.attacks import *
from torchvision import transforms

import data
import loss
import kd_loss

sys.path.append('..')
import setup
from data import *
from utils import *
# import loss
import time


class SubstituteTrainer:
    def __init__(self, opt,
                 victim, substitute, data_gen,
                 sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                 eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                 source, x_list, y_list, labels,
                 labeled_bs=200, #100
                 #unlabeled_bs=200, cls_div_bs=200,#50,100
                 # simclr_bs=300,
                 strategy='every', loop=None, n_epochs=10,
                 save=False, load=False):
        self.opt = opt
        self.victim = victim
        self.substitute = substitute
        # self.substitute_projected = substitute_projected
        self.data_gen = data_gen

        self.sub_dataset = sub_dataset
        self.new_sub_dataset = new_sub_dataset
        self.next_dataset = next_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.val_dataset = SubDataset()
        self.eval_dataset = eval_dataset
        self.surrogate_eval_dataset = surrogate_eval_dataset
        self.simclr_dataset = simclr_dataset
        self.aug_dataset = aug_dataset
        self.div_dataset = div_dataset

        self.source = source
        self.labeled_bs = labeled_bs
        self.unlabeled_bs = labeled_bs*2 #unlabeled_bs
        self.cls_div_bs = labeled_bs*2 #cls_div_bs
        self.simclr_bs = labeled_bs*2 #simclr_bs
        self.strategy = strategy
        self.loop = loop
        self.n_epochs = n_epochs
        self.gen_epoch = opt.gen_epoch
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.augmentation = transforms.Compose(
            [
                # transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                # transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        # split validation dataset
        val_size = int(self.opt.epoch_val_rate * len(self.new_sub_dataset))
        self.val_dataset.items = self.new_sub_dataset.items[:val_size]
        self.new_sub_dataset.items = self.new_sub_dataset.items[val_size:]
        # plot parameters
        self.x_list = x_list
        self.y_list = y_list
        self.labels = labels

        self.save = save
        # load substitute model
        if load:
            save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
                                    f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}',
                                    f'{self.opt.pre_train_sub}_{self.source}')
            save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.opt.sub_eval_loop}'
            save_path = os.path.join(save_dir, save_name)
            self.substitute.load_state_dict(torch.load(save_path))


    def train(self):
        print(f'Size of query dataset: {len(self.sub_dataset)}')
        print(f'Size of new query dataset: {len(self.new_sub_dataset)}')
        print(f'Size of diversity dataset: {len(self.div_dataset)}')
        # preparation
        if self.opt.use_gpu:
            self.substitute.cuda()
            self.data_gen.cuda()
            # if self.source == 'fusiongan':
            #     self.substitute_projected.projector.cuda()
        if len(self.sub_dataset) > 0:
            weighted_sampler = self.get_weighted_sampler()
            dataloader = torch.utils.data.DataLoader(
                self.sub_dataset,
                batch_size=self.labeled_bs, #100
                # shuffle = True,
                num_workers=4,
                sampler=weighted_sampler
            )
            all_dataloader = torch.utils.data.DataLoader(
                self.sub_dataset,
                batch_size=self.labeled_bs, #50
                shuffle=True,
                num_workers=4
            )
        new_dataloader = torch.utils.data.DataLoader(
            self.new_sub_dataset,
            batch_size=300, #100
            shuffle=True,
            num_workers=4
        )
        aug_dataloader = torch.utils.data.DataLoader(
            self.aug_dataset,
            batch_size=self.unlabeled_bs,
            shuffle=True,
            num_workers=4,
            drop_last=True,
        )
        if 'fusiongan' in self.source or self.source == 'random_aggr_gen':
            unlabeled_dataloader = torch.utils.data.DataLoader(
                self.unlabeled_dataset,
                batch_size=self.unlabeled_bs,
                shuffle=True,
                num_workers=4,
                drop_last=True
            )
            fuse_unlabeled_dataloader = torch.utils.data.DataLoader(
                self.unlabeled_dataset,
                batch_size=self.cls_div_bs * self.opt.n_fuse,
                shuffle=True,
                num_workers=4,
                drop_last=True
            )
            simclr_dataloader = torch.utils.data.DataLoader(
                self.simclr_dataset,
                batch_size=self.simclr_bs,
                shuffle=True,
                num_workers=4,
                drop_last=True,
            )
            if self.opt.div_epoch > 0:
                div_dataloader = torch.utils.data.DataLoader(
                    self.div_dataset,
                    batch_size=self.labeled_bs,
                    shuffle=True,
                    num_workers=4,
                    drop_last=True,
                )
            if len(self.next_dataset) > 0:
                next_dataloader = torch.utils.data.DataLoader(
                    self.next_dataset,
                    batch_size=50,
                    shuffle=True,
                    num_workers=4
                )
        # if self.source == 'fusiongan':
        #     if self.opt.sub_optim == 'adam':
        #         substitute_optimizer = torch.optim.Adam(
        #             chain(self.substitute.parameters(),
        #                   self.substitute_projected.projector.parameters()),
        #             lr=self.opt.sub_lr
        #         )
        #     else:
        #         substitute_optimizer = torch.optim.SGD(
        #             chain(self.substitute.parameters(),
        #                   self.substitute_projected.projector.parameters()),
        #             lr=self.opt.sub_lr
        #         )
        # else:
        #     if self.opt.sub_optim == 'adam':
        #         substitute_optimizer = torch.optim.Adam(
        #             self.substitute.parameters(),
        #             lr=self.opt.sub_lr
        #         )
        #     else:
        #         substitute_optimizer = torch.optim.SGD(
        #             self.substitute.parameters(),
        #             lr=self.opt.sub_lr
        #         )

        if self.opt.sub_optim == 'adam':
            substitute_optimizer = torch.optim.Adam(
                self.substitute.parameters(),
                lr=self.opt.sub_lr,
                # weight_decay=self.opt.sub_weight_decay
            )
        else:
            substitute_optimizer = torch.optim.SGD(
                self.substitute.parameters(),
                lr=self.opt.sub_lr,
                momentum=0.9,
                weight_decay=self.opt.sub_weight_decay
            )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(substitute_optimizer, T_max=50)
        # scheduler = torch.optim.lr_scheduler.StepLR(substitute_optimizer, 40, gamma=0.1)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(substitute_optimizer, milestones=[10, 30, 40], gamma=0.2)

        # simclr_optimizer = torch.optim.Adam(simclr_substitute.parameters(), lr=self.opt.online_simclr_lr)
        data_gen_optimizer = torch.optim.Adam(
            self.data_gen.parameters(),
            lr=self.opt.gen_lr
        )
        # data_gen_next_optimizer = torch.optim.Adam(
        #     self.data_gen.parameters(),
        #     lr=self.opt.gen_next_lr
        # )
        self.substitute.train()
        self.data_gen.train()
        # if self.source == 'fusiongan':
        #     self.substitute_projected.projector.train()

        ckpt_dir = f'{self.opt.data_dir}checkpoints/'
        writer = SummaryWriter(f'runs/{self.opt.caption}')
        sub_ckpt_path = f'{ckpt_dir}substitute_{self.strategy}'
        gen_ckpt_path = f'{ckpt_dir}data_gen_{self.strategy}'

        acc, fidelity, kd_loss = self.evaluate()
        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
        print(f'[start] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise}')

        train_epoch = self.n_epochs

        if self.strategy == 'static':
            for epoch in range(train_epoch):
                for _, (_, data, prob) in enumerate(dataloader):
                    if self.opt.use_gpu:
                        data = data.cuda()
                        prob = prob.cuda()
                    self.substitute.zero_grad()
                    sub_output = self.substitute(self.augmentation(data))
                    softmax = nn.LogSoftmax(dim=1)
                    sub_prob = softmax(sub_output)
                    substitute_loss_function = torch.nn.KLDivLoss(reduction='mean')
                    substitute_loss = substitute_loss_function(sub_prob, prob)
                    substitute_loss.backward()
                    substitute_optimizer.step()

                print(f'[substitute] epoch {epoch} | loss {substitute_loss}')

                if epoch % self.opt.print_freq == 0:
                    acc, fidelity, kd_loss = self.evaluate()
                    asr, l2_noise, noise_per_pixel = self.adv_evaluate(200)
                    print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} ' +
                          f'| KD loss {kd_loss} | L2 noise {l2_noise}({noise_per_pixel})')
                    # writer.add_scalar("loss",substitute_loss,epoch)
                    # writer.add_scalar("accuracy",acc,epoch)
                    # writer.add_scalar("fidelity",fidelity,epoch)
                    # writer.add_scalar("ASR",asr,epoch)

                torch.save(self.substitute.state_dict(), sub_ckpt_path)
        else:
            # update D and G in same loop
            if self.source == 'attackgan':
                if self.opt.continue_train and self.loop == self.opt.n_loop - 1:
                    train_epoch = self.opt.pseudo_epoch
                max_certainty = 0.0
                max_certainty_epoch = 0
                max_noise = 0.0
                max_noise_epoch = 0
                for epoch in range(train_epoch):
                    certainty = 0.0
                    num_batch = 0.0
                    for _, (clean_img, perturbed_img, victim_prob, next_img) in enumerate(new_dataloader):
                        num_batch += 1
                        if self.opt.use_gpu:
                            clean_img = clean_img.cuda()
                            perturbed_img = perturbed_img.cuda()
                            victim_prob = victim_prob.cuda()
                            next_img = next_img.cuda()

                        log_softmax = nn.LogSoftmax(dim=1)
                        softmax = nn.Softmax(dim=1)
                        # clean_sub_output, clean_fake_output = self.substitute(self.augmentation(clean_img))
                        clean_sub_output, clean_fake_output = nn.parallel.data_parallel(self.substitute,
                                                                                        self.augmentation(clean_img))
                        clean_sub_prob = log_softmax(clean_sub_output)
                        clean_fake_prob = softmax(clean_fake_output)
                        # perturbed_sub_output, perturbed_fake_output = self.substitute(self.augmentation(perturbed_img))
                        perturbed_sub_output, perturbed_fake_output = nn.parallel.data_parallel(self.substitute,
                                                                                                self.augmentation(
                                                                                                    perturbed_img))
                        perturbed_sub_prob = log_softmax(perturbed_sub_output)
                        perturbed_fake_prob = softmax(perturbed_fake_output)
                        # new_perturbed_img = self.data_gen(clean_img)
                        new_perturbed_img = nn.parallel.data_parallel(self.data_gen, clean_img)
                        new_perturbed_sub_output, new_perturbed_sub_fake_output = nn.parallel.data_parallel(
                            self.substitute, new_perturbed_img)
                        new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)
                        new_perturbed_fake_prob = softmax(new_perturbed_sub_fake_output)
                        # print(new_perturbed_sub_prob.size())

                        # define loss functions
                        l1_loss = nn.L1Loss()
                        mse_loss = nn.MSELoss()
                        cos_loss = nn.CosineEmbeddingLoss()
                        kldiv_loss = nn.KLDivLoss(reduction='mean')
                        ce_loss = nn.CrossEntropyLoss()
                        div_func = loss.DivLoss()

                        # train generator
                        self.data_gen.zero_grad()
                        if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                            noise_loss = mse_loss(new_perturbed_img, clean_img)
                            # diff_loss = torch.exp(-1*loss_func(sub_prob,prob))
                            adv_loss = torch.exp(-1 * mse_loss(new_perturbed_sub_prob, victim_prob))
                            new_perturbed_argmax = new_perturbed_sub_prob.argmax(1).cuda()
                            # uncertain_loss = torch.exp(-1 * ce_loss(new_perturbed_sub_prob, new_perturbed_argmax))
                            # uncertain_loss = torch.mean(new_perturbed_sub_prob.max(1)[0])
                            target = torch.full([clean_img.size()[0]], 1, dtype=torch.int64).cuda()
                            fake_loss = ce_loss(new_perturbed_fake_prob, target)
                            # perturbed_next_img = nn.parallel.data_parallel(self.data_gen, next_img)
                            # perturbed_next_img_output = nn.parallel.data_parallel(self.substitute, perturbed_next_img)
                            # div_loss = 0.0
                            # for _, (_, all_perturbed_img, _, _) in enumerate(all_dataloader):
                            #     if self.opt.use_gpu:
                            #         all_perturbed_img = all_perturbed_img.cuda()
                            #     all_perturbed_img_output = nn.parallel.data_parallel(self.substitute, all_perturbed_img)
                            #     # div_loss += div_func(perturbed_next_img, all_perturbed_img, 2048)
                            #     div_loss += div_func(perturbed_next_img_output, all_perturbed_img_output, 256)
                            # gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss +\
                            #            self.opt.div_weight * div_loss  # +fake_loss
                            gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss  # + self.opt.uncertain_weight * uncertain_loss
                            gen_loss.backward(retain_graph=True)
                            data_gen_optimizer.step()

                        certain_probs = softmax(new_perturbed_sub_output)
                        max_certain_prob = torch.mean(certain_probs.max(1)[0])
                        certainty += max_certain_prob.detach().cpu()

                        # train substitute
                        self.substitute.zero_grad()
                        # new_perturbed_img = self.data_gen(clean_img)
                        # new_perturbed_sub_output, new_perturbed_sub_fake_output = self.substitute(new_perturbed_img)
                        # new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)
                        # sub_same_loss = l1_loss(softmax(perturbed_sub_output), victim_prob)
                        # sub_same_loss = mse_loss(softmax(perturbed_sub_output), victim_prob)
                        # sub_same_loss = cos_loss(softmax(perturbed_sub_output), victim_prob, torch.Tensor([1]).cuda())
                        sub_same_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
                        target = torch.full([clean_img.size()[0]], 1, dtype=torch.int64).cuda()
                        sub_fake_prob1 = ce_loss(clean_fake_prob, target)
                        target = torch.full([clean_img.size()[0]], 0, dtype=torch.int64).cuda()
                        sub_fake_prob2 = ce_loss(new_perturbed_fake_prob, target)
                        substitute_loss = sub_same_loss  # +sub_fake_prob1+sub_fake_prob2
                        # substitute_loss.backward(retain_graph=True)
                        substitute_loss.backward()
                        substitute_optimizer.step()

                    # # train generator with next batch data
                    # self.data_gen.zero_grad()
                    # if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                    #     for _, (img, _, _, _) in enumerate(next_dataloader):
                    #         for _, (_, perturbed_img, _, _) in enumerate(all_dataloader):
                    #             if self.opt.use_gpu:
                    #                 img = img.cuda()
                    #                 perturbed_img = perturbed_img.cuda()
                    #             perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, perturbed_img)
                    #             div_func = loss.DivLoss()
                    #             new_perturbed_img = nn.parallel.data_parallel(self.data_gen, img)
                    #             new_perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, new_perturbed_img)
                    #             div_loss = div_func(new_perturbed_img_output, perturbed_img_output, 256)
                    #             gen_loss_new = div_loss
                    #             gen_loss_new.backward()
                    #             nn.utils.clip_grad_norm_(self.data_gen.parameters(), max_norm=1e-6, norm_type=2)
                    #             data_gen_optimizer.step()

                    # gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss + self.opt.div_weight * div_loss#+fake_loss
                    # gen_loss.backward()
                    # data_gen_optimizer.step()

                    print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss}')
                    if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                        # print(f'[data generator] epoch {epoch} | loss {gen_loss} | {noise_loss},{adv_loss},{div_loss}')
                        print(f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss}')

                    if (epoch + 1) % self.opt.print_freq == 0:
                        acc, fidelity, kd_loss = self.evaluate()
                        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                        if noise_per_pixel > max_noise:
                            max_noise = noise_per_pixel
                            max_noise_epoch = epoch + 1
                        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                              f'| L2 noise {avg_l2_noise}')
                        if self.opt.epoch_val_rate:
                            val_loss = self.epoch_val()
                            print(f'validation loss: {val_loss}', val_loss)
                        self.x_list[0].append(self.loop * self.n_epochs + epoch)
                        self.x_list[1].append(self.loop * self.n_epochs + epoch)
                        self.x_list[2].append(self.loop * self.n_epochs + epoch)
                        self.x_list[3].append(self.loop * self.n_epochs + epoch)
                        self.x_list[4].append(self.loop * self.n_epochs + epoch)
                        self.y_list[0].append(l2_noise / 3)
                        self.y_list[1].append(asr / 100)
                        self.y_list[2].append(acc)
                        self.y_list[3].append(fidelity)
                        self.y_list[4].append(kd_loss)
                        # plot_line('noise-asr',self.x_list,self.y_list,self.labels)
                        # writer.add_scalar("loss",substitute_loss,self.loop*self.n_epochs+epoch)
                        # writer.add_scalar("accuracy",acc,self.loop*self.n_epochs+epoch)
                        # writer.add_scalar("fidelity",fidelity,self.loop*self.n_epochs+epoch)
                        # writer.add_scalar("ASR",asr,self.loop*self.n_epochs+epoch)

                    torch.save(self.substitute.state_dict(), sub_ckpt_path)
                    if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                        torch.save(self.data_gen.state_dict(), gen_ckpt_path)

                    if self.opt.loss_threshhold:
                        if substitute_loss < self.opt.loss_threshhold:
                            break
                    avg_certainty = certainty / num_batch
                    if avg_certainty > max_certainty:
                        max_certainty = avg_certainty
                        max_certainty_epoch = epoch + 1
                    if self.opt.same_certainty_epoch:
                        if epoch + 1 - max_certainty_epoch >= self.opt.same_certainty_epoch:
                            break
                    if self.opt.noise_fall_epoch:
                        if epoch + 1 - max_noise_epoch >= self.opt.noise_fall_epoch:
                            break
                    if self.opt.certainty_threshhold:
                        if avg_certainty > self.opt.certainty_threshhold:
                            break

            elif 'active' in self.source or self.source == 'random' or self.source == 'papernot':
                # prepare data for online unsupervised training
                aug_data_iter = iter(aug_dataloader)
                # define functions
                mse_loss = nn.MSELoss()
                kldiv_loss = nn.KLDivLoss(reduction='mean')
                ce_loss = nn.CrossEntropyLoss()
                softmax = nn.Softmax(dim=1)

                for epoch in range(train_epoch):
                    # train substitute with all data
                    for _, (_, perturbed_img, victim_prob, _) in enumerate(all_dataloader):
                        if self.opt.use_gpu:
                            perturbed_img = perturbed_img.cuda()
                            victim_prob = victim_prob.cuda()

                        log_softmax = nn.LogSoftmax(dim=1)
                        # perturbed_sub_output, perturbed_fake_output = nn.parallel.data_parallel(self.substitute,
                        #                                                                         self.augmentation(
                        #                                                                             perturbed_img))
                        perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                         self.augmentation(perturbed_img))
                        perturbed_sub_prob = log_softmax(perturbed_sub_output)

                        # train substitute
                        self.substitute.zero_grad()

                        # online supervised loss
                        labeled_loss = kldiv_loss(perturbed_sub_prob, victim_prob)

                        # online unsupervised loss (the U-Train process)
                        if self.opt.pseudo_label_weight == 0:
                            substitute_loss = labeled_loss
                        else:
                            try:
                                (w_inputs, s_inputs), _ = aug_data_iter.next()
                            except:
                                aug_data_iter = iter(aug_dataloader)
                                (w_inputs, s_inputs), _ = aug_data_iter.next()
                            if self.opt.use_gpu:
                                w_inputs = w_inputs.cuda()
                                s_inputs = s_inputs.cuda()

                            # no projector version
                            with torch.no_grad():
                                unlabeled_targets = softmax(self.substitute(w_inputs))
                            unlabeled_outputs = self.substitute(s_inputs)
                            unlabeled_probs = softmax(unlabeled_outputs)
                            pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)
                            # optimize combined loss
                            labeled_loss_value = labeled_loss.detach()
                            pseudo_label_loss_value = pseudo_label_loss.detach()
                            # simclr_loss_value = simclr_loss.detach()
                            # substitute_loss = labeled_loss
                            # substitute_loss = labeled_loss + \
                            #                   simclr_loss*(labeled_loss_value/simclr_loss_value)*self.opt.simclr_weight
                            substitute_loss = labeled_loss + \
                                              pseudo_label_loss * (
                                                          labeled_loss_value / pseudo_label_loss_value) * self.opt.pseudo_label_weight



                        # online unsupervised loss
                        # try:
                        #     (w_inputs, s_inputs), _ = aug_data_iter.next()
                        # except:
                        #     aug_data_iter = iter(aug_dataloader)
                        #     (w_inputs, s_inputs), _ = aug_data_iter.next()
                        # if self.opt.use_gpu:
                        #     w_inputs = w_inputs.cuda()
                        #     s_inputs = s_inputs.cuda()
                        # with torch.no_grad():
                        #     unlabeled_targets = softmax(self.substitute(w_inputs))
                        # unlabeled_outputs = self.substitute(s_inputs)
                        # unlabeled_probs = softmax(unlabeled_outputs)
                        # # max_probs = unlabeled_probs.max(1)[0]
                        # # mask = max_probs.ge(0.5).float().view([unlabeled_outputs.shape[0],1])
                        # pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)

                        # labeled_loss_value = labeled_loss.detach()
                        # pseudo_label_loss_value = pseudo_label_loss.detach()
                        # substitute_loss = labeled_loss
                        # substitute_loss = labeled_loss + \
                        #                   pseudo_label_loss * (
                        #                               labeled_loss_value / pseudo_label_loss_value) * self.opt.pseudo_label_weight
                        substitute_loss.backward()
                        substitute_optimizer.step()
                    # scheduler.step() # random
                    print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss}')
                    if (epoch + 1) % self.opt.print_freq == 0:
                        acc, fidelity, kd_loss = self.evaluate()
                        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                              f'| L2 noise {avg_l2_noise}')
                        self.x_list[0].append(self.loop * self.n_epochs + epoch)
                        self.x_list[1].append(self.loop * self.n_epochs + epoch)
                        self.x_list[2].append(self.loop * self.n_epochs + epoch)
                        self.x_list[3].append(self.loop * self.n_epochs + epoch)
                        self.x_list[4].append(self.loop * self.n_epochs + epoch)
                        self.y_list[0].append(avg_l2_noise / 4)
                        self.y_list[1].append(asr / 100)
                        self.y_list[2].append(acc)
                        self.y_list[3].append(fidelity)
                        self.y_list[4].append(kd_loss)

            elif self.source == 'mosafi' or self.source == 'avg':
                kldiv_loss = nn.KLDivLoss(reduction='mean')

                for epoch in range(train_epoch):
                    # train substitute with all data
                    for _, (_, perturbed_img, victim_prob, _) in enumerate(new_dataloader):
                        if self.opt.use_gpu:
                            perturbed_img = perturbed_img.cuda()
                            victim_prob = victim_prob.cuda()

                        log_softmax = nn.LogSoftmax(dim=1)
                        # perturbed_sub_output, perturbed_fake_output = nn.parallel.data_parallel(self.substitute,
                        #                                                                         self.augmentation(
                        #                                                                             perturbed_img))
                        perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                         self.augmentation(perturbed_img))
                        perturbed_sub_prob = log_softmax(perturbed_sub_output)

                        # train substitute
                        self.substitute.zero_grad()

                        # online supervised loss
                        labeled_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
                        substitute_loss = labeled_loss
                        substitute_loss.backward()
                        substitute_optimizer.step()
                    # scheduler.step()
                    print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss}')
                    if (epoch + 1) % self.opt.print_freq == 0:
                        acc, fidelity, kd_loss = self.evaluate()
                        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                              f'| L2 noise {avg_l2_noise}')
                        self.x_list[0].append(self.loop * self.n_epochs + epoch)
                        self.x_list[1].append(self.loop * self.n_epochs + epoch)
                        self.x_list[2].append(self.loop * self.n_epochs + epoch)
                        self.x_list[3].append(self.loop * self.n_epochs + epoch)
                        self.x_list[4].append(self.loop * self.n_epochs + epoch)
                        self.y_list[0].append(avg_l2_noise / 4)
                        self.y_list[1].append(asr / 100)
                        self.y_list[2].append(acc)
                        self.y_list[3].append(fidelity)
                        self.y_list[4].append(kd_loss)

                # acc, fidelity, kd_loss = self.evaluate()
                # asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                # print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                #       f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                #       f'| L2 noise {avg_l2_noise}')

            else:
                for epoch in range(train_epoch):
                    for _, (seed, data, prob) in enumerate(new_dataloader):
                        if self.opt.use_gpu:
                            seed = seed.cuda()
                            data = data.cuda()
                            prob = prob.cuda()

                        # define loss functions
                        softmax = nn.Softmax(dim=1)
                        mse_loss = nn.MSELoss()
                        log_softmax = nn.LogSoftmax(dim=1)
                        substitute_loss_function = torch.nn.KLDivLoss(reduction='mean')

                        # train generator with noise and new query
                        # train generator with new query
                        self.data_gen.zero_grad()
                        sub_output = self.substitute(self.data_gen(seed))
                        # L2 loss (MSE)
                        sub_prob = softmax(sub_output)
                        # diff_loss = torch.exp(-1*loss_func(sub_prob,prob))
                        same_loss = loss_func(sub_prob, prob)

                        # train generator with noise
                        # noise_batch = 128
                        # noise = torch.Tensor(
                        #     np.random.uniform(-5.,5.,size=(noise_batch,256))
                        # )
                        # if self.opt.use_gpu:
                        #     noise = noise.cuda()
                        # noise_data = self.data_gen(noise)
                        # noise_output = self.substitute(noise_data)
                        # noise_prob = softmax(noise_output)
                        # # div_loss = torch.exp(-1*loss_func(noise_data,torch.mean(noise_data,0)))
                        # item_entropy = d.Categorical(probs=noise_prob).entropy()
                        # batch_prob = softmax(item_entropy)
                        # entropy = d.Categorical(probs=batch_prob).entropy()
                        # div_loss = torch.exp(-1*entropy)

                        # generator backprop
                        gen_loss = same_loss  # + 2*div_loss
                        gen_loss.backward()
                        data_gen_optimizer.step()

                        # train substitute
                        self.substitute.zero_grad()
                        sub_output = self.substitute(self.augmentation(data))
                        sub_prob = softmax(sub_output)
                        substitute_loss = substitute_loss_function(sub_prob, prob)
                        substitute_loss.backward()
                        substitute_optimizer.step()

                    print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss}')
                    print(f'[data generator] epoch {epoch + 1} | loss {gen_loss}')

                    if (epoch + 1) % self.opt.print_freq == 0:
                        acc, fidelity, kd_loss = self.evaluate()
                        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                              f'| L2 noise {avg_l2_noise}')
                        # writer.add_scalar("loss",substitute_loss,self.loop*self.n_epochs+epoch)
                        # writer.add_scalar("accuracy",acc,self.loop*self.n_epochs+epoch)
                        # writer.add_scalar("fidelity",fidelity,self.loop*self.n_epochs+epoch)
                        # writer.add_scalar("ASR",asr,self.loop*self.n_epochs+epoch)

                    torch.save(self.substitute.state_dict(), sub_ckpt_path)
                    torch.save(self.data_gen.state_dict(), gen_ckpt_path)

        writer.flush()
        writer.close()
        acc, fidelity, kd_loss = self.evaluate()
        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise}')

        # save substitute model
        if self.save:
            save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
                                    f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}',
                                    f'{self.opt.pre_train_sub}_{self.source}')
            os.makedirs(save_dir, exist_ok=True)
            save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.loop + 1}'
            save_path = os.path.join(save_dir, save_name)
            torch.save(self.substitute.state_dict(), save_path)
            print(f'[substitute] Saved substitute model in loop {self.loop + 1}')

        # # train generator

        # # train with new queries
        # for epoch in range(50):
        #     for _, (seed, _, prob) in enumerate(new_dataloader):
        #         if self.opt.use_gpu:
        #             seed = seed.cuda()
        #             prob = prob.cuda()
        #         self.data_gen.zero_grad()
        #         sub_output = self.substitute(self.data_gen(seed))
        #         # KL divergence loss
        #         # softmax = nn.LogSoftmax(dim=1)
        #         # sub_prob = softmax(sub_output)
        #         # loss_func = torch.nn.KLDivLoss(reduction='mean')
        #         # eps = 1e-7
        #         # loss = 1/(loss_func(sub_prob, prob)+eps)

        #         # L2 loss
        #         softmax = nn.Softmax(dim=1)
        #         sub_prob = softmax(sub_output)
        #         loss_func = nn.MSELoss()
        #         loss = torch.exp(-1*loss_func(sub_prob,prob))
        #         loss.backward()
        #         # nn.utils.clip_grad_value_(self.data_gen.parameters(),clip_value=0.7)
        #         data_gen_optimizer.step()
        #     print(f'[data generator] epoch {epoch} | loss {loss}')
        #     torch.save(self.data_gen.state_dict(),gen_ckpt_path)

        # train with random input
        # for epoch in range(40):
        #     for _ in range(len(dataloader)):
        #         self.data_gen.zero_grad()
        #         noise = torch.Tensor(
        #             np.random.uniform(-5.,5.,size=(16,256))
        #         )
        #         if self.opt.use_gpu:
        #             noise = noise.cuda()
        #         sub_output = self.substitute(self.data_gen(noise))
        #         softmax = nn.Softmax(dim=1)
        #         sub_prob = softmax(sub_output)
        #         # print(torch.max(sub_prob,1))
        #         data_gen_loss = torch.mean(torch.max(sub_prob,1)[0])
        #         # print(data_gen_loss)
        #         data_gen_loss.backward()
        #         data_gen_optimizer.step()
        #     print(f'[data generator] epoch {epoch} | loss {data_gen_loss}')
        #     torch.save(self.data_gen.state_dict(),gen_ckpt_path)

        return self.substitute, self.data_gen, self.unlabeled_dataset

    def evaluate(self):
        n_outputs = self.opt.victim_n_classes
        if self.opt.eval_model.startswith('efficientnet'):
            eval_model = setup.EfficientNet.from_pretrained(
                self.opt.eval_model,
                num_classes=n_outputs)
        elif self.opt.eval_model.startswith('wrn'):
            depth = int(self.opt.eval_model.split('-')[-1])
            eval_model = setup.classifier.WideResNet(
                n_outputs=n_outputs,
                depth=depth
            )
        else:
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )
        substitute = load_model_weights(self.substitute, eval_model)

        victim = self.victim
        substitute.eval()
        victim.eval()
        kd_loss_func = kd_loss.FSPLoss()
        dataloader = self.eval_dataset.test_dataloader()
        accs = 0.0
        fidelity = 0.0
        kd_loss_result = 0.0
        n_samples = 0
        n_batch = 0
        for _, data in enumerate(dataloader):
            imgs = data[0]
            targets = data[1]
            if self.opt.use_gpu:
                imgs = imgs.cuda()
                targets = targets.cuda()
                victim.cuda()
                substitute.cuda()
            n_samples += targets.shape[0]
            n_batch += 1
            with torch.no_grad():
                outputs = substitute(imgs)
                victim_outputs = victim(imgs)
                # sub_features = substitute.features
                # victim_features = victim.features
                acc = outputs.max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
                same = victim_outputs.max(1)[1].eq(outputs.max(1)[1]).float().sum()
                same = same.detach().cpu()
                # kd_loss_batch = kd_loss_func(sub_features, victim_features).detach().cpu()
            accs += acc
            fidelity += same
            # kd_loss_result += kd_loss_batch
        accs /= n_samples
        fidelity /= n_samples
        kd_loss_result /= n_batch
        max_kd = 280000
        scale_kd = 100000
        final_kd = (max_kd - kd_loss_result) / scale_kd
        return accs, fidelity, final_kd

    # def adv_evaluate(self, sample_size):
    #     n_outputs = self.opt.victim_n_classes
    #     eval_model = classifier_dict[self.opt.eval_model](
    #         n_outputs=n_outputs
    #     )
    #     substitute = load_model_weights(self.substitute, eval_model)
    #     self.victim.eval()
    #     substitute.eval()
    #     data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)),sample_size,replace=False) #[i for i in range(7800,8000)]
    #     dataloader = torch.utils.data.DataLoader(
    #         self.eval_dataset.test_dataset, batch_size=50,
    #         sampler=sp.SubsetRandomSampler(data_list), num_workers=4
    #     )
    #     # dataloader = self.eval_dataset.test_dataloader()
    #     adversary_ghost = LinfBasicIterativeAttack(
    #         substitute, loss_fn = nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
    #         nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple*self.opt.noise_eps/self.opt.noise_step,
    #         clip_min=-1.0, clip_max=1.0, targeted=False)
    #     attack_success = 0.0
    #     total = 0.0
    #     l2_noise = 0.0
    #     self.victim.eval()
    #     for data in dataloader:
    #         inputs, labels = data
    #         inputs = inputs.cuda()
    #         labels = labels.cuda()
    #         with torch.no_grad():
    #             outputs = self.victim(inputs)
    #             _, predicted = torch.max(outputs.data, 1)
    #         correct_predict = predicted == labels
    #         total += correct_predict.float().sum()
    #         if self.opt.targeted:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels+1) % 10)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #             attack_success += (correct_predict*(predicted == (labels+1) % 10)).sum()
    #             for i, correct in enumerate(correct_predict):
    #                 if correct and predicted[i] == (labels[i]+1) % 10:
    #                     l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
    #         else:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #             attack_success += (correct_predict*(predicted != labels)).sum()
    #             for i, correct in enumerate(correct_predict):
    #                 if correct and predicted[i] != labels[i]:
    #                     l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
    #
    #         # with torch.no_grad():
    #         #     victim_labels = self.victim(inputs).max(1)[1]
    #         #     sub_labels = self.substitute(inputs).max(1)[1]
    #
    #         # adv_inputs_ghost = adversary_ghost.perturb(inputs, sub_labels)
    #         # with torch.no_grad():
    #         #     outputs = self.victim(adv_inputs_ghost)
    #         #     _, predicted = torch.max(outputs.data,1)
    #         # # print('victim_labels',victim_labels)
    #         # # print('sub_labels',sub_labels)
    #         # # print('predicted',predicted)
    #         # total += labels.size(0)
    #         # correct_ghost += (predicted == victim_labels).sum()
    #     asr = (100. * attack_success.float() / total).cpu()
    #     avg_l2_noise = (l2_noise / total).cpu()
    #     l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
    #     return asr, avg_l2_noise, l2_noise_per_pixel

    # def adv_evaluate(self, sample_size):
    #     if self.opt.use_gpu:
    #         self.victim.cuda()
    #         self.substitute.cuda()
    #     self.victim.eval()
    #     self.substitute.eval()
    #     data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)), sample_size,
    #                                  replace=False)  # [i for i in range(7800,8000)]
    #     batch_size = 50
    #     dataloader = torch.utils.data.DataLoader(
    #         self.eval_dataset.test_dataset, batch_size=batch_size,
    #         sampler=sp.SubsetRandomSampler(data_list), num_workers=4
    #     )
    #     # dataloader = self.eval_dataset.test_dataloader()
    #     # adversary_ghost = LinfBasicIterativeAttack(
    #     #     self.substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
    #     #     nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
    #     #     clip_min=-1.0, clip_max=1.0, targeted=False)
    #     adversary_ghost = LinfMomentumIterativeAttack(
    #         self.substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
    #         nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
    #         clip_min=-1.0, clip_max=1.0, targeted=False)
    #     attack_success = 0.0
    #     sub_attack_success = 0.0
    #     total = 0.0
    #     success_total = 0.0
    #     l2_noise = 0.0
    #     success_l2_noise = 0.0
    #     for data in dataloader:
    #         inputs, labels = data
    #         inputs = inputs.cuda()
    #         labels = labels.cuda()
    #         with torch.no_grad():
    #             outputs = self.victim(inputs)
    #             _, predicted = torch.max(outputs.data, 1)
    #             sub_outputs = self.substitute(inputs)
    #             _, sub_labels = torch.max(sub_outputs.data, 1)
    #         correct_predict = predicted == labels
    #         if self.opt.targeted:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #             attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
    #             l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
    #                                                 inputs.view(len(inputs), -1))
    #             for i, correct in enumerate(correct_predict):
    #                 if correct:
    #                     l2_noise += l2_noise_list[i]
    #                     total += 1
    #                     if predicted[i] == (labels[i] + 1) % 10:
    #                         success_l2_noise += l2_noise_list[i]
    #                         success_total += 1
    #         else:
    #             adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
    #             with torch.no_grad():
    #                 outputs = self.victim(adv_inputs_ghost)
    #                 _, predicted = torch.max(outputs.data, 1)
    #                 sub_outputs = self.victim(adv_inputs_ghost)
    #                 _, sub_predicted = torch.max(sub_outputs.data, 1)
    #             attack_success += (correct_predict * (predicted != labels)).sum()
    #             sub_attack_success += (correct_predict * (sub_predicted != sub_labels)).sum()
    #             l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
    #                                                 inputs.view(len(inputs), -1))
    #             for i, correct in enumerate(correct_predict):
    #                 if correct:
    #                     l2_noise += l2_noise_list[i]
    #                     total += 1
    #                     if predicted[i] != labels[i]:
    #                         success_l2_noise += l2_noise_list[i]
    #                         success_total += 1
    #
    #     asr = (100. * attack_success.float() / total).cpu()
    #     sub_asr = (100. * sub_attack_success.float() / total).cpu()
    #     success_avg_l2_noise = (success_l2_noise / success_total).cpu()
    #     success_l2_noise_per_pixel = success_avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
    #     avg_l2_noise = (l2_noise / total).cpu()
    #     # l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
    #     return asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise

    def adv_evaluate(self, sample_size):
        if self.opt.use_gpu:
            self.victim.cuda()
            self.substitute.cuda()
        self.victim.eval()
        self.substitute.eval()

        n_outputs = self.opt.victim_n_classes
        if self.opt.eval_model.startswith('efficientnet'):
            eval_model = setup.EfficientNet.from_pretrained(
                self.opt.eval_model,
                num_classes=n_outputs)
        elif self.opt.eval_model.startswith('wrn'):
            depth = int(self.opt.eval_model.split('-')[-1])
            eval_model = setup.classifier.WideResNet(
                n_outputs=n_outputs,
                depth=depth
            )
        else:
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )

        substitute = load_model_weights(self.substitute, eval_model)

        data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        batch_size = 50
        dataloader = torch.utils.data.DataLoader(
            self.eval_dataset.test_dataset, batch_size=batch_size,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        # dataloader = self.eval_dataset.test_dataloader()
        # adversary_ghost = LinfBasicIterativeAttack(
        #     substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
        #     nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
        #     clip_min=-1.0, clip_max=1.0, targeted=False)
        adversary_ghost = LinfMomentumIterativeAttack(
            substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
            nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
            clip_min=-1.0, clip_max=1.0, targeted=False)
        attack_success = 0.0
        sub_attack_success = 0.0
        total = 0.0
        success_total = 0.0
        l2_noise = 0.0
        success_l2_noise = 0.0
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = self.victim(inputs)
                _, predicted = torch.max(outputs.data, 1)
                sub_outputs = substitute(inputs)
                _, sub_labels = torch.max(sub_outputs.data, 1)
            correct_predict = predicted == labels
            if self.opt.targeted:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
                l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
                                                    inputs.view(len(inputs), -1))
                for i, correct in enumerate(correct_predict):
                    if correct:
                        l2_noise += l2_noise_list[i]
                        total += 1
                        if predicted[i] == (labels[i] + 1) % 10:
                            success_l2_noise += l2_noise_list[i]
                            success_total += 1
            else:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                    sub_outputs = self.victim(adv_inputs_ghost)
                    _, sub_predicted = torch.max(sub_outputs.data, 1)
                attack_success += (correct_predict * (predicted != labels)).sum()
                sub_attack_success += (correct_predict * (sub_predicted != sub_labels)).sum()
                l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
                                                    inputs.view(len(inputs), -1))
                for i, correct in enumerate(correct_predict):
                    if correct:
                        l2_noise += l2_noise_list[i]
                        total += 1
                        if predicted[i] != labels[i]:
                            success_l2_noise += l2_noise_list[i]
                            success_total += 1

        asr = (100. * attack_success.float() / total).cpu()
        sub_asr = (100. * sub_attack_success.float() / total).cpu()
        success_avg_l2_noise = (success_l2_noise / success_total).cpu()
        success_l2_noise_per_pixel = success_avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
        avg_l2_noise = (l2_noise / total).cpu()
        # l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
        return asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise

    def surrogate_adv_evaluate(self, sample_size):
        if self.source == 'attackgan':
            n_outputs = self.opt.victim_n_classes
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )
            substitute = load_model_weights(self.substitute, eval_model)
        else:
            substitute = self.substitute
        substitute.eval()
        data_list = np.random.choice(range(len(self.surrogate_eval_dataset.test_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        dataloader = torch.utils.data.DataLoader(
            self.surrogate_eval_dataset.test_dataset, batch_size=50,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        adversary_ghost = LinfBasicIterativeAttack(
            substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
            nb_iter=self.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.noise_step,
            clip_min=-1.0, clip_max=1.0, targeted=False)
        attack_success = 0.0
        total = 0.0
        l2_noise = 0.0
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = substitute(inputs)
                _, predicted = torch.max(outputs.data, 1)
            correct_predict = predicted == labels
            total += correct_predict.float().sum()
            if self.opt.targeted:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
                with torch.no_grad():
                    outputs = substitute(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
                for i, correct in enumerate(correct_predict):
                    if correct and predicted[i] == (labels[i] + 1) % 10:
                        l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
            else:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
                with torch.no_grad():
                    outputs = substitute(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted != labels)).sum()
                for i, correct in enumerate(correct_predict):
                    if correct and predicted[i] != labels[i]:
                        l2_noise += torch.dist(adv_inputs_ghost[i], inputs[i], 2)
        asr = (100. * attack_success.float() / total).cpu()
        avg_l2_noise = (l2_noise / total).cpu()
        l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
        return asr, avg_l2_noise, l2_noise_per_pixel

    def epoch_val(self):
        if self.source == 'attackgan':
            n_outputs = self.opt.victim_n_classes
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )
            substitute = load_model_weights(self.substitute, eval_model)
        else:
            substitute = self.substitute
        substitute.eval()
        val_dataloader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=50,
            shuffle=True,
            num_workers=4
        )
        total_loss = 0.0
        num_batch = 0.0
        for _, (_, perturbed_img, victim_prob) in enumerate(val_dataloader):
            num_batch += 1
            if self.opt.use_gpu:
                perturbed_img = perturbed_img.cuda()
                victim_prob = victim_prob.cuda()

            log_softmax = nn.LogSoftmax(dim=1)
            perturbed_sub_output, _ = self.substitute(perturbed_img)
            perturbed_sub_prob = log_softmax(perturbed_sub_output)

            # define loss functions
            l1_loss = nn.L1Loss()
            mse_loss = nn.MSELoss()
            kldiv_loss = nn.KLDivLoss(reduction='mean')

            # sub_same_loss = l1_loss(softmax(new_perturbed_sub_output), victim_prob)
            # sub_same_loss = mse_loss(softmax(new_perturbed_sub_output), victim_prob)
            sub_same_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
            total_loss += sub_same_loss.detach().cpu()
        return total_loss / num_batch

    def get_weighted_sampler(self):
        alpha = self.opt.dataset_decay
        n_total = len(self.sub_dataset)
        n_new = len(self.new_sub_dataset)
        n_loop = n_total // n_new
        weights = []
        for i in range(n_total):
            loop = i // n_new
            weights.append(alpha ** (n_loop - loop))
        weights = torch.DoubleTensor(weights)
        weighted_sampler = sp.WeightedRandomSampler(weights, n_total)
        return weighted_sampler

    def get_pseudo_label(self, unlabeled_items, k_avg=3, T=0.5, da=None, threshhold=0.7):
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        softmax = nn.Softmax(dim=1)
        self.substitute.eval()
        # average probability predictions after k augmentations
        for i_aug in range(k_avg):
            unlabeled_items_aug = augmentation(unlabeled_items)
            with torch.no_grad():
                unlabeled_outputs = self.substitute(unlabeled_items_aug)
            # print(self.substitute(unlabeled_items_aug).max(1)[1])
            if i_aug == 0:
                unlabeled_probs = softmax(unlabeled_outputs)
            else:
                unlabeled_probs += softmax(unlabeled_outputs)
        unlabeled_probs = unlabeled_probs.detach()
        unlabeled_probs /= k_avg
        # sharpening
        if T is not None:
            unlabeled_pt = unlabeled_probs ** (1/T)
            unlabeled_probs = unlabeled_pt / unlabeled_pt.sum(dim=1, keepdim=True)

        # if distribution alignment is enabled
        if da is not None:
            unlabeled_probs_da = unlabeled_probs * da
            unlabeled_probs = unlabeled_probs_da / unlabeled_probs_da.sum(dim=1, keepdim=True)

        pseudo_labels = unlabeled_probs.max(1)[1]   # e.g. 5
        # pseudo_labels = nn.functional.one_hot(unlabeled_probs.max(1)[1], num_classes=10)
        # print(pseudo_labels)

        mask = (unlabeled_probs.max(1)[0]>threshhold).float()
        # print(unlabeled_probs)
        # print(mask)

        return unlabeled_probs, pseudo_labels, mask

    def get_unlabeled_sub_prob(self, sample_size):
        self.substitute.eval()
        softmax = nn.Softmax(dim=1)
        data_list = np.random.choice(range(len(self.unlabeled_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset, batch_size=50,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        for i, (unlabeled_inputs, _, _, _) in enumerate(dataloader):
            if self.opt.use_gpu:
                unlabeled_inputs = unlabeled_inputs.cuda()
            with torch.no_grad():
                unlabeled_outputs = self.substitute(unlabeled_inputs).detach()
            unlabeled_probs = softmax(unlabeled_outputs)
            if i == 0:
                probs_sum = unlabeled_probs.sum(dim=0)
            else:
                probs_sum += unlabeled_probs.sum(dim=0)
        return probs_sum / sample_size

    def get_next_diff_dataset(self, k_avg=3):
        next_diff_dataset = data.SubDataset()

        used = 0
        for item in self.unlabeled_dataset.items:
            used += item[3]
        unused = len(self.unlabeled_dataset) - used
        query_per_loop = int(self.opt.query / self.opt.n_loop)
        if unused <= self.opt.n_fuse * query_per_loop and used == 0:
            print(f'Full-scale public dataset needed. Skip sampling process.')
            times = math.ceil(query_per_loop / (unused//self.opt.n_fuse))
            for i in range(times):
                random_idx_list = random.sample([i for i in range(unused)], unused)
                for query in range(unused//self.opt.n_fuse):
                    items = []
                    for i in range(self.opt.n_fuse):
                        items.append(self.unlabeled_dataset.items[random_idx_list[query * self.opt.n_fuse + i]][0].view(
                            3, self.opt.public_img_size, self.opt.public_img_size))
                        self.unlabeled_dataset.items[random_idx_list[query * self.opt.n_fuse + i]][3] = 1
                    next_diff_dataset.items.append((items, -1, -1, -1))
                # for idx in range(len(self.unlabeled_dataset)):
                #     self.unlabeled_dataset.items[idx][3] = 1
                #     next_diff_dataset.items.append((self.unlabeled_dataset.items[idx][0], -1, -1, -1))
            next_diff_dataset.items = next_diff_dataset.items[:query_per_loop]
            return next_diff_dataset


        loop_size = len(self.new_sub_dataset)
        softmax = nn.Softmax(dim=1)
        batch_size = 100
        unlabeled_dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        # stores the indexs of unused public data for each pseudo label class
        label_idx_list = [[] for _ in range(self.opt.victim_n_classes)]
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        for i, (unlabeled_items, _, _, _) in enumerate(unlabeled_dataloader):
            if self.opt.use_gpu:
                unlabeled_items = unlabeled_items.cuda()
            for i_aug in range(k_avg):
                unlabeled_items_aug = augmentation(unlabeled_items)
                with torch.no_grad():
                    unlabeled_outputs = self.substitute(unlabeled_items_aug)
                if i_aug == 0:
                    unlabeled_probs = softmax(unlabeled_outputs)
                else:
                    unlabeled_probs += softmax(unlabeled_outputs)
            unlabeled_probs = unlabeled_probs.detach()
            unlabeled_probs /= k_avg
            pseudo_labels = unlabeled_probs.max(1)[1]
            for j in range(len(pseudo_labels)):
                self.unlabeled_dataset.items[i*batch_size+j][1] = unlabeled_probs[j].cpu()
                self.unlabeled_dataset.items[i*batch_size+j][2] = int(pseudo_labels[j])
                # only involve unused items for next batch
                if self.unlabeled_dataset.items[i*batch_size+j][3] == 0:
                    label_idx_list[pseudo_labels[j]].append(i*batch_size+j)

        # # sample in order
        # current_label = 0
        # for _ in range(loop_size):
        #     items = []
        #     for _ in range(self.opt.n_fuse):
        #         while len(label_idx_list[current_label]) == 0:
        #             current_label = (current_label+1)%self.opt.victim_n_classes
        #         idx = label_idx_list[current_label].pop(0)
        #         items.append(self.unlabeled_dataset.items[idx][0])
        #         self.unlabeled_dataset.items[idx][3] = 1 # mark as used
        #         current_label = (current_label + 1) % self.opt.victim_n_classes
        #     # items = tuple(items)
        #     next_diff_dataset.items.append((items, -1, -1, -1))

        # sample by frequency
        label_nums = [len(idxs) for idxs in label_idx_list]
        print(label_nums)
        # weight mapping function
        # max_num = max(label_nums) # func1
        # label_weights = [2*max_num-label_num if label_num != 0 else 0 for label_num in label_nums]
        label_num_sum = sum(label_nums) # func2
        label_weights = [-math.log(label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        # label_num_sum = sum(label_nums) # func3
        # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        label_probs = [label_weight/sum(label_weights) for label_weight in label_weights]

        for _ in range(loop_size):
            items = []
            num_unsampled = self.opt.n_fuse
            while num_unsampled > 0:
                n_valid_class = 0
                for p in label_probs:
                    if p > 0:
                        n_valid_class += 1
                num_current_sample = min(n_valid_class, num_unsampled)
                num_unsampled -= num_current_sample
                current_labels = np.random.choice(range(self.opt.victim_n_classes), size=num_current_sample, p=label_probs,
                                                  replace=False)
                for current_label in current_labels:
                    idx = label_idx_list[current_label].pop(0)
                    items.append(self.unlabeled_dataset.items[idx][0])
                    self.unlabeled_dataset.items[idx][3] = 1  # mark as used
                # update weights
                label_nums = [len(idxs) for idxs in label_idx_list]
                # weight mapping function
                # max_num = max(label_nums)
                # label_weights = [2 * max_num - label_num if label_num != 0 else 0 for label_num in label_nums]
                label_num_sum = sum(label_nums)  # func2
                label_weights = [-math.log(label_num / label_num_sum) + 1e-6 if label_num != 0 else 0 for label_num in
                                 label_nums]
                # label_num_sum = sum(label_nums) # func3
                # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
                if sum(label_weights) == 0:
                    label_probs = [0. for _ in label_weights]
                else:
                    label_probs = [label_weight / sum(label_weights) for label_weight in label_weights]
            next_diff_dataset.items.append((items, -1, -1, -1))


        # for _ in range(loop_size):
        #     # if the number of classes where p!=0 is less than n_fuse, set replace = True
        #     n_valid_class = 0
        #     for p in label_probs:
        #         if p > 0:
        #             n_valid_class += 1
        #     if n_valid_class < self.opt.n_fuse:
        #         replace = True
        #     else:
        #         replace = False
        #     items = []
        #     current_labels = np.random.choice(range(self.opt.victim_n_classes), size=self.opt.n_fuse, p=label_probs, replace=replace)
        #     for current_label in current_labels:
        #         idx = label_idx_list[current_label].pop(0)
        #         items.append(self.unlabeled_dataset.items[idx][0])
        #         self.unlabeled_dataset.items[idx][3] = 1  # mark as used
        #     next_diff_dataset.items.append((items, -1, -1, -1))
        #     # update weights
        #     label_nums = [len(idxs) for idxs in label_idx_list]
        #     # weight mapping function
        #     # max_num = max(label_nums)
        #     # label_weights = [2 * max_num - label_num if label_num != 0 else 0 for label_num in label_nums]
        #     label_num_sum = sum(label_nums) # func2
        #     label_weights = [-math.log(label_num/label_num_sum)+1e-6 if label_num != 0 else 0 for label_num in label_nums]
        #     # label_num_sum = sum(label_nums) # func3
        #     # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        #     if sum(label_weights) == 0:
        #         label_probs = [0. for _ in label_weights]
        #     else:
        #         label_probs = [label_weight / sum(label_weights) for label_weight in label_weights]

        return next_diff_dataset
