import os
from itertools import chain
import pickle

import torch
from torch import nn, softmax
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data.sampler as sp
import torch.distributions as d
import torch.nn.functional as F
import numpy as np
from advertorch.attacks import *
from torchvision import transforms
import classifier
import data
import loss
import sys
import kd_loss
import math
from simclr.modules import NT_Xent
from simclr.modules.identity import Identity
from simclr import SimCLR

sys.path.append('..')
import setup
from data import *
from utils import *
# import loss
import time


class SubstituteTrainerWM:
    def __init__(self, opt,
                 victim, substitute, data_gen,
                 sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                 eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                 source, x_list, y_list, labels,
                 labeled_bs=100, unlabeled_bs=200, cls_div_bs=200,#50,100
                 simclr_bs=300,
                 strategy='every', loop=None, n_epochs=10,
                 save=False, load=False):
        self.opt = opt
        self.victim = victim
        self.substitute = substitute
        # self.substitute_projected = substitute_projected
        self.data_gen = data_gen

        self.sub_dataset = sub_dataset
        self.new_sub_dataset = new_sub_dataset
        self.next_dataset = next_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.val_dataset = SubDataset()
        self.eval_dataset = eval_dataset
        self.surrogate_eval_dataset = surrogate_eval_dataset
        self.simclr_dataset = simclr_dataset
        self.aug_dataset = aug_dataset
        self.div_dataset = div_dataset

        self.source = source
        self.labeled_bs = labeled_bs
        self.unlabeled_bs = unlabeled_bs
        self.cls_div_bs = cls_div_bs
        self.simclr_bs = simclr_bs
        self.strategy = strategy
        self.loop = loop
        self.n_epochs = n_epochs
        self.gen_epoch = opt.gen_epoch
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.augmentation = transforms.Compose(
            [
                # transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                # transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        # split validation dataset
        val_size = int(self.opt.epoch_val_rate * len(self.new_sub_dataset))
        self.val_dataset.items = self.new_sub_dataset.items[:val_size]
        self.new_sub_dataset.items = self.new_sub_dataset.items[val_size:]
        # plot parameters
        self.x_list = x_list
        self.y_list = y_list
        self.labels = labels

        self.save = save
        # load substitute model
        if load:
            save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
                                    f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}',
                                    f'{self.opt.pre_train_sub}_{self.source}')
            save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.opt.sub_eval_loop}'
            save_path = os.path.join(save_dir, save_name)
            self.substitute.load_state_dict(torch.load(save_path))

        self.trigger_dataset = data.LabeledDataset()
        victim_model_name = f'victim_{self.opt.victim_model}_{self.opt.victim_dataset}'
        trigger_path = os.path.join(self.opt.data_dir, f'checkpoints/{victim_model_name}_trigger.pkl')
        with open(trigger_path, 'rb') as pf:
            self.trigger_dataset.items = pickle.load(pf)
        self.wm_target = 0

    def train(self):
        print(f'Size of query dataset: {len(self.sub_dataset)}')
        print(f'Size of new query dataset: {len(self.new_sub_dataset)}')
        print(f'Size of diversity dataset: {len(self.div_dataset)}')
        # preparation
        if self.opt.use_gpu:
            self.substitute.cuda()
            self.data_gen.cuda()
            # if self.source == 'fusiongan':
            #     self.substitute_projected.projector.cuda()
        if len(self.sub_dataset) > 0:
            weighted_sampler = self.get_weighted_sampler()
            dataloader = torch.utils.data.DataLoader(
                self.sub_dataset,
                batch_size=self.labeled_bs, #100
                # shuffle = True,
                num_workers=4,
                sampler=weighted_sampler
            )
            all_dataloader = torch.utils.data.DataLoader(
                self.sub_dataset,
                batch_size=50,
                shuffle=True,
                num_workers=4
            )
        new_dataloader = torch.utils.data.DataLoader(
            self.new_sub_dataset,
            batch_size=100,
            shuffle=True,
            num_workers=4
        )
        aug_dataloader = torch.utils.data.DataLoader(
            self.aug_dataset,
            batch_size=self.unlabeled_bs,
            shuffle=True,
            num_workers=4,
            drop_last=True,
        )
        if self.source == 'fusiongan':
            unlabeled_dataloader = torch.utils.data.DataLoader(
                self.unlabeled_dataset,
                batch_size=self.unlabeled_bs,
                shuffle=True,
                num_workers=4,
                drop_last=True
            )
            fuse_unlabeled_dataloader = torch.utils.data.DataLoader(
                self.unlabeled_dataset,
                batch_size=self.cls_div_bs * self.opt.n_fuse,
                shuffle=True,
                num_workers=4,
                drop_last=True
            )
            simclr_dataloader = torch.utils.data.DataLoader(
                self.simclr_dataset,
                batch_size=self.simclr_bs,
                shuffle=True,
                num_workers=4,
                drop_last=True,
            )
            div_dataloader = torch.utils.data.DataLoader(
                self.div_dataset,
                batch_size=50,
                shuffle=True,
                num_workers=4,
                drop_last=True,
            )
            if len(self.next_dataset) > 0:
                next_dataloader = torch.utils.data.DataLoader(
                    self.next_dataset,
                    batch_size=50,
                    shuffle=True,
                    num_workers=4
                )
        # if self.source == 'fusiongan':
        #     if self.opt.sub_optim == 'adam':
        #         substitute_optimizer = torch.optim.Adam(
        #             chain(self.substitute.parameters(),
        #                   self.substitute_projected.projector.parameters()),
        #             lr=self.opt.sub_lr
        #         )
        #     else:
        #         substitute_optimizer = torch.optim.SGD(
        #             chain(self.substitute.parameters(),
        #                   self.substitute_projected.projector.parameters()),
        #             lr=self.opt.sub_lr
        #         )
        # else:
        #     if self.opt.sub_optim == 'adam':
        #         substitute_optimizer = torch.optim.Adam(
        #             self.substitute.parameters(),
        #             lr=self.opt.sub_lr
        #         )
        #     else:
        #         substitute_optimizer = torch.optim.SGD(
        #             self.substitute.parameters(),
        #             lr=self.opt.sub_lr
        #         )

        if self.opt.sub_optim == 'adam':
            substitute_optimizer = torch.optim.Adam(
                self.substitute.parameters(),
                lr=self.opt.sub_lr
            )
        else:
            substitute_optimizer = torch.optim.SGD(
                self.substitute.parameters(),
                lr=self.opt.sub_lr
            )

        # simclr_optimizer = torch.optim.Adam(simclr_substitute.parameters(), lr=self.opt.online_simclr_lr)
        data_gen_optimizer = torch.optim.Adam(
            self.data_gen.parameters(),
            lr=self.opt.gen_lr
        )
        # data_gen_next_optimizer = torch.optim.Adam(
        #     self.data_gen.parameters(),
        #     lr=self.opt.gen_next_lr
        # )
        self.substitute.train()
        self.data_gen.train()
        # if self.source == 'fusiongan':
        #     self.substitute_projected.projector.train()

        ckpt_dir = f'{self.opt.data_dir}checkpoints/'
        writer = SummaryWriter(f'runs/{self.opt.caption}')
        sub_ckpt_path = f'{ckpt_dir}substitute_{self.strategy}'
        gen_ckpt_path = f'{ckpt_dir}data_gen_{self.strategy}'

        acc, fidelity, kd_loss = self.evaluate()
        wm_extract_success = self.watermark_evaluate()
        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
        print(f'[start] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise} | wm extract success {wm_extract_success}')

        train_epoch = self.n_epochs

        if self.opt.continue_train and self.loop == self.opt.n_loop - 1:
            train_epoch = self.opt.pseudo_epoch

        # define loss functions
        l1_loss = nn.L1Loss()
        mse_loss = nn.MSELoss()
        cos_loss = nn.CosineEmbeddingLoss()
        kldiv_loss = nn.KLDivLoss(reduction='mean')
        ce_loss = nn.CrossEntropyLoss()
        div_func = loss.DivLoss()
        cw_loss_func = loss.CWLoss()

        softmax = nn.Softmax(dim=1)

        min_loss = 1000.
        min_loss_epoch = 0
        min_div = 1000.
        min_div_epoch = 0

        unlabeled_data_iter = iter(unlabeled_dataloader)
        simclr_data_iter = iter(simclr_dataloader)
        aug_data_iter = iter(aug_dataloader)

        # train substitute with all data
        for epoch in range(train_epoch):
            substitute_loss_data = 0.
            self.substitute.train()
            for _, (_, perturbed_img, victim_prob, _) in enumerate(
                    dataloader):  # all_dataloader dataloader new_dataloader
                if self.opt.use_gpu:
                    perturbed_img = perturbed_img.cuda()
                    victim_prob = victim_prob.cuda()
                log_softmax = nn.LogSoftmax(dim=1)
                softmax = nn.Softmax(dim=1)

                # calculate labeled loss (substitute training in the s-Train process)
                perturbed_sub_output = nn.parallel.data_parallel(self.substitute,
                                                                 self.augmentation(perturbed_img))
                perturbed_sub_prob = log_softmax(perturbed_sub_output)
                sub_same_loss = kldiv_loss(perturbed_sub_prob, victim_prob)
                labeled_loss = sub_same_loss

                # # mixmatch style pseudo label loss
                # try:
                #     unlabeled_inputs, _, _, _ = unlabeled_data_iter.next()
                # except:
                #     unlabeled_data_iter = iter(unlabeled_dataloader)
                #     unlabeled_inputs, _, _, _ = unlabeled_data_iter.next()
                # if self.opt.use_gpu:
                #     unlabeled_inputs = unlabeled_inputs.cuda()
                # # pseudo label loss
                # unlabeled_targets, pseudo_labels, mask = self.get_pseudo_label(unlabeled_inputs)
                # # print(kldiv_loss(unlabeled_targets.mean(dim=0).log(),da))
                # # print(f'victim labels: {self.victim(unlabeled_inputs).max(1)[1]}')
                # # print(f'substitute labels: {self.substitute(unlabeled_inputs).max(1)[1]}')
                # # print(f'pseudo labels: {pseudo_labels}')
                # # victim_probs = softmax(self.victim(unlabeled_inputs))
                # # print(victim_probs.max(1)[1])
                # unlabeled_outputs = self.substitute(unlabeled_inputs)
                # unlabeled_probs = softmax(unlabeled_outputs)
                # # print(unlabeled_outputs,unlabeled_targets)
                # pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)
                # # pseudo_label_loss = kldiv_loss(unlabeled_probs, unlabeled_targets)

                # online unsupervised loss (the U-Train process)
                try:
                    (w_inputs, s_inputs), _ = aug_data_iter.next()
                except:
                    aug_data_iter = iter(aug_dataloader)
                    (w_inputs, s_inputs), _ = aug_data_iter.next()
                if self.opt.use_gpu:
                    w_inputs = w_inputs.cuda()
                    s_inputs = s_inputs.cuda()

                # projector version
                # with torch.no_grad():
                #     unlabeled_targets = self.substitute_projected(w_inputs)
                # # unlabeled_targets = self.substitute_projected(w_inputs)
                # unlabeled_outputs = self.substitute_projected(s_inputs)
                # # max_probs = unlabeled_probs.max(1)[0]
                # # mask = max_probs.ge(0.5).float().view([unlabeled_outputs.shape[0],1])
                # pseudo_label_loss = mse_loss(unlabeled_outputs, unlabeled_targets)

                # no projector version
                with torch.no_grad():
                    unlabeled_targets = softmax(self.substitute(w_inputs))
                unlabeled_outputs = self.substitute(s_inputs)
                unlabeled_probs = softmax(unlabeled_outputs)
                # max_probs = unlabeled_probs.max(1)[0]
                # mask = max_probs.ge(0.5).float().view([unlabeled_outputs.shape[0],1])
                pseudo_label_loss = mse_loss(unlabeled_probs, unlabeled_targets)

                # print(unlabeled_outputs.shape, unlabeled_targets.shape, mask.shape)
                # pseudo_label_loss = (F.mse_loss(unlabeled_outputs, unlabeled_targets, reduction='none')*mask).mean()
                # pseudo_label_loss = kldiv_loss(log_softmax(unlabeled_outputs), unlabeled_targets)

                # # get contrastive loss
                # try:
                #     (simclr_x_i, simclr_x_j), _ = simclr_data_iter.next()
                # except:
                #     simclr_data_iter = iter(simclr_dataloader)
                #     (simclr_x_i, simclr_x_j), _ = simclr_data_iter.next()
                # if self.opt.use_gpu:
                #     simclr_x_i = simclr_x_i.cuda()
                #     simclr_x_j = simclr_x_j.cuda()
                # substitute_fc = self.substitute.fc
                # self.substitute.fc = Identity()
                # simclr_h_i = self.substitute(simclr_x_i)
                # simclr_h_j = self.substitute(simclr_x_j)
                # simclr_criterion = NT_Xent(self.simclr_bs,
                #                     self.opt.temperature, self.opt.world_size)
                # simclr_loss = simclr_criterion(simclr_h_i, simclr_h_j)
                # self.substitute.fc = substitute_fc

                # optimize combined loss
                labeled_loss_value = labeled_loss.detach()
                pseudo_label_loss_value = pseudo_label_loss.detach()
                # simclr_loss_value = simclr_loss.detach()
                # substitute_loss = labeled_loss
                # substitute_loss = labeled_loss + \
                #                   simclr_loss*(labeled_loss_value/simclr_loss_value)*self.opt.simclr_weight
                substitute_loss = labeled_loss + \
                                  pseudo_label_loss * (
                                              labeled_loss_value / pseudo_label_loss_value) * self.opt.pseudo_label_weight

                substitute_loss_data += substitute_loss.item()
                self.substitute.zero_grad()
                substitute_loss.backward()
                substitute_optimizer.step()

            # # simclr optimization
            # for _, ((simclr_x_i, simclr_x_j), _) in enumerate(simclr_dataloader):
            #     simclr_optimizer.zero_grad()
            #     if self.opt.use_gpu:
            #         simclr_x_i = simclr_x_i.cuda()
            #         simclr_x_j = simclr_x_j.cuda()
            #     simclr_h_i = simclr_substitute(simclr_x_i)
            #     simclr_h_j = simclr_substitute(simclr_x_j)
            #     simclr_criterion = NT_Xent(self.simclr_bs,
            #                         self.opt.temperature, self.opt.world_size)
            #     simclr_loss = simclr_criterion(simclr_h_i, simclr_h_j)
            #     simclr_loss.backward()
            #     simclr_optimizer.step()

            substitute_loss_data /= len(self.sub_dataset)
            if substitute_loss_data < min_loss:
                min_loss = substitute_loss_data
                min_loss_epoch = epoch + 1
            elif self.opt.loss_rise_epoch:
                if epoch + 1 >= 10 and epoch + 1 - min_loss_epoch >= self.opt.loss_rise_epoch:
                    break
            elif self.opt.div_rise_epoch:
                if epoch + 1 >= 10 and epoch + 1 - min_div_epoch >= self.opt.div_rise_epoch:
                    break
            # print(f'dist. for unlabeled sub. prob. and labeled victim prob.: {kldiv_loss(unlabeled_sub_prob.log(), da)}')

            # # training with unlabeled data
            # for _, (unlabeled_inputs, _, _, _) in enumerate(unlabeled_dataloader):
            #     if self.opt.use_gpu:
            #         unlabeled_inputs = unlabeled_inputs.cuda()
            #     unlabeled_targets, pseudo_labels, mask = self.get_pseudo_label(unlabeled_inputs)
            #     unlabeled_outputs = self.substitute(unlabeled_inputs)
            #     unlabeled_loss = mse_loss(unlabeled_outputs,unlabeled_targets)
            #     self.substitute.zero_grad()
            #     unlabeled_loss.backward()
            #     substitute_optimizer.step()

            print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data}')
            # print(f'[substitute] epoch {epoch + 1} | loss {substitute_loss_data} | pseudo label loss {pseudo_label_loss}')

            if (epoch + 1) % self.opt.print_freq == 0:

                acc, fidelity, kd_loss = self.evaluate()
                wm_extract_success = self.watermark_evaluate()
                asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
                print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                      f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                      f'| L2 noise {avg_l2_noise} | wm extract success {wm_extract_success}')

        # get next batch data with different labels (generator input sampling based on pseudo labeling)
        next_diff_dataset = self.get_next_diff_dataset()
        next_diff_dataloader = torch.utils.data.DataLoader(
            next_diff_dataset,
            batch_size=50,
            shuffle=True,
            num_workers=4
        )

        # train gan with new data only (generator training in the S-Train process)
        if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
            for epoch in range(self.gen_epoch):
                self.data_gen.train()
                # current batch
                for _, (clean_imgs, perturbed_img, victim_prob, _) in enumerate(new_dataloader):
                    clean_imgs = tuple([img.cuda() for img in clean_imgs])
                    # next_img = tuple([img.cuda() for img in next_img])
                    if self.opt.use_gpu:
                        perturbed_img = perturbed_img.cuda()
                        victim_prob = victim_prob.cuda()

                    log_softmax = nn.LogSoftmax(dim=1)
                    new_perturbed_img = self.data_gen(clean_imgs)
                    new_perturbed_sub_output = nn.parallel.data_parallel(self.substitute, new_perturbed_img)
                    new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)
                    noise_loss = 0.0
                    for i in range(self.opt.n_fuse):
                        noise_loss += mse_loss(new_perturbed_img, clean_imgs[i])
                    adv_loss = torch.exp(-1 * kldiv_loss(new_perturbed_sub_prob, victim_prob))
                    gen_loss = self.opt.noise_weight * noise_loss + self.opt.adv_weight * adv_loss  # + self.opt.cw_weight * cw_loss
                    self.data_gen.zero_grad()
                    gen_loss.backward()
                    data_gen_optimizer.step()

                # train generator with next batch data (diversity loss)
                for _ in range(self.opt.div_epoch):
                    for _, (clean_imgs, _, _, _) in enumerate(next_diff_dataloader):  # next_dataloader
                        for _, (_, perturbed_img, _, _) in enumerate(div_dataloader):  # all_dataloader div_dataloader
                            clean_imgs = tuple([img.cuda() for img in clean_imgs])
                            if self.opt.use_gpu:
                                perturbed_img = perturbed_img.cuda()
                            # perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute, perturbed_img)
                            perturbed_img_output = nn.parallel.data_parallel(self.substitute, perturbed_img)
                            div_func = loss.DivLoss()
                            new_perturbed_img = self.data_gen(clean_imgs)
                            # new_perturbed_img_output, _ = nn.parallel.data_parallel(self.substitute,
                            #                                                         new_perturbed_img)
                            new_perturbed_img_output = nn.parallel.data_parallel(self.substitute,
                                                                                 new_perturbed_img)
                            div_loss = div_func(new_perturbed_img_output, perturbed_img_output, 256)
                            div_loss = div_loss / new_perturbed_img_output.shape[0]
                            gen_loss_new = div_loss
                            self.data_gen.zero_grad()
                            gen_loss_new.backward()
                            nn.utils.clip_grad_norm_(self.data_gen.parameters(),
                                                     max_norm=self.opt.div_weight, norm_type=2)
                            data_gen_optimizer.step()
                try:
                    print(
                        f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss},{div_loss}')
                except:
                    print(
                        f'[data generator] epoch {epoch + 1} | loss {gen_loss} | {noise_loss},{adv_loss}')

        acc, fidelity, kd_loss = self.evaluate()
        wm_extract_success = self.watermark_evaluate()
        asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = self.adv_evaluate(200)
        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise} | wm extract success {wm_extract_success}')

        # save substitute model
        if self.save:
            save_dir = os.path.join(self.opt.data_dir, 'checkpoints',
                                    f'{self.opt.victim_dataset}_{self.opt.surrogate_dataset}',
                                    f'{self.opt.pre_train_sub}_{self.source}')
            os.makedirs(save_dir, exist_ok=True)
            save_name = f'{self.opt.sub_model}_{self.opt.seed}_{self.loop + 1}'
            save_path = os.path.join(save_dir, save_name)
            torch.save(self.substitute.state_dict(), save_path)
            print(f'[substitute] Saved substitute model in loop {self.loop + 1}')

        return self.substitute, self.data_gen, self.unlabeled_dataset, next_diff_dataset

    def evaluate(self):
        n_outputs = self.opt.victim_n_classes
        if self.opt.eval_model.startswith('efficientnet'):
            eval_model = setup.EfficientNet.from_pretrained(
                self.opt.eval_model,
                num_classes=n_outputs)
        elif self.opt.eval_model.startswith('wrn'):
            depth = int(self.opt.eval_model.split('-')[-1])
            eval_model = setup.classifier.WideResNet(
                n_outputs=n_outputs,
                depth=depth
            )
        else:
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )

        substitute = load_model_weights(self.substitute, eval_model)

        victim = self.victim
        substitute.eval()
        victim.eval()
        kd_loss_func = kd_loss.FSPLoss()
        dataloader = self.eval_dataset.test_dataloader()
        accs = 0.0
        fidelity = 0.0
        kd_loss_result = 0.0
        n_samples = 0
        n_batch = 0
        for _, data in enumerate(dataloader):
            imgs = data[0]
            targets = data[1]
            if self.opt.use_gpu:
                imgs = imgs.cuda()
                targets = targets.cuda()
                victim.cuda()
                substitute.cuda()
            n_samples += targets.shape[0]
            n_batch += 1
            with torch.no_grad():
                outputs = substitute(imgs)
                victim_outputs = victim(imgs)[-1]
                # sub_features = substitute.features
                # victim_features = victim.features
                acc = outputs.max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
                same = victim_outputs.max(1)[1].eq(outputs.max(1)[1]).float().sum()
                same = same.detach().cpu()
                # kd_loss_batch = kd_loss_func(sub_features, victim_features).detach().cpu()
            accs += acc
            fidelity += same
            # kd_loss_result += kd_loss_batch
        accs /= n_samples
        fidelity /= n_samples
        kd_loss_result /= n_batch
        max_kd = 280000
        scale_kd = 100000
        final_kd = (max_kd - kd_loss_result) / scale_kd
        return accs, fidelity, final_kd

    def adv_evaluate(self, sample_size):
        if self.opt.use_gpu:
            self.victim.cuda()
            self.substitute.cuda()
        self.victim.eval()
        self.substitute.eval()

        n_outputs = self.opt.victim_n_classes
        if self.opt.eval_model.startswith('efficientnet'):
            eval_model = setup.EfficientNet.from_pretrained(
                self.opt.eval_model,
                num_classes=n_outputs)
        elif self.opt.eval_model.startswith('wrn'):
            depth = int(self.opt.eval_model.split('-')[-1])
            eval_model = setup.classifier.WideResNet(
                n_outputs=n_outputs,
                depth=depth
            )
        else:
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs=n_outputs
            )

        substitute = load_model_weights(self.substitute, eval_model)

        data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        batch_size = 50
        dataloader = torch.utils.data.DataLoader(
            self.eval_dataset.test_dataset, batch_size=batch_size,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        # dataloader = self.eval_dataset.test_dataloader()
        # adversary_ghost = LinfBasicIterativeAttack(
        #     substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
        #     nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
        #     clip_min=-1.0, clip_max=1.0, targeted=False)
        adversary_ghost = LinfMomentumIterativeAttack(
            substitute, loss_fn=nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
            nb_iter=self.opt.noise_step, eps_iter=self.opt.eps_multiple * self.opt.noise_eps / self.opt.noise_step,
            clip_min=-1.0, clip_max=1.0, targeted=False)
        attack_success = 0.0
        sub_attack_success = 0.0
        total = 0.0
        success_total = 0.0
        l2_noise = 0.0
        success_l2_noise = 0.0
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = self.victim(inputs)[-1]
                _, predicted = torch.max(outputs.data, 1)
                sub_outputs = substitute(inputs)
                _, sub_labels = torch.max(sub_outputs.data, 1)
            correct_predict = predicted == labels
            if self.opt.targeted:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels + 1) % 10)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)[-1]
                    _, predicted = torch.max(outputs.data, 1)
                attack_success += (correct_predict * (predicted == (labels + 1) % 10)).sum()
                l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
                                                    inputs.view(len(inputs), -1))
                for i, correct in enumerate(correct_predict):
                    if correct:
                        l2_noise += l2_noise_list[i]
                        total += 1
                        if predicted[i] == (labels[i] + 1) % 10:
                            success_l2_noise += l2_noise_list[i]
                            success_total += 1
            else:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)[-1]
                    _, predicted = torch.max(outputs.data, 1)
                    sub_outputs = self.victim(adv_inputs_ghost)[-1]
                    _, sub_predicted = torch.max(sub_outputs.data, 1)
                attack_success += (correct_predict * (predicted != labels)).sum()
                sub_attack_success += (correct_predict * (sub_predicted != sub_labels)).sum()
                l2_noise_list = F.pairwise_distance(adv_inputs_ghost.view(len(inputs), -1),
                                                    inputs.view(len(inputs), -1))
                for i, correct in enumerate(correct_predict):
                    if correct:
                        l2_noise += l2_noise_list[i]
                        total += 1
                        if predicted[i] != labels[i]:
                            success_l2_noise += l2_noise_list[i]
                            success_total += 1

        asr = (100. * attack_success.float() / total).cpu()
        sub_asr = (100. * sub_attack_success.float() / total).cpu()
        success_avg_l2_noise = (success_l2_noise / success_total).cpu()
        success_l2_noise_per_pixel = success_avg_l2_noise / (data[0].shape[1] * data[0].shape[2] * data[0].shape[3])
        avg_l2_noise = (l2_noise / total).cpu()
        # l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
        return asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise

    def get_weighted_sampler(self):
        alpha = self.opt.dataset_decay
        n_total = len(self.sub_dataset)
        n_new = len(self.new_sub_dataset)
        n_loop = n_total // n_new
        weights = []
        for i in range(n_total):
            loop = i // n_new
            weights.append(alpha ** (n_loop - loop))
        weights = torch.DoubleTensor(weights)
        weighted_sampler = sp.WeightedRandomSampler(weights, n_total)
        return weighted_sampler

    def get_pseudo_label(self, unlabeled_items, k_avg=3, T=0.5, da=None, threshhold=0.7):
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        softmax = nn.Softmax(dim=1)
        self.substitute.eval()
        # average probability predictions after k augmentations
        for i_aug in range(k_avg):
            unlabeled_items_aug = augmentation(unlabeled_items)
            with torch.no_grad():
                unlabeled_outputs = self.substitute(unlabeled_items_aug)
            # print(self.substitute(unlabeled_items_aug).max(1)[1])
            if i_aug == 0:
                unlabeled_probs = softmax(unlabeled_outputs)
            else:
                unlabeled_probs += softmax(unlabeled_outputs)
        unlabeled_probs = unlabeled_probs.detach()
        unlabeled_probs /= k_avg
        # sharpening
        if T is not None:
            unlabeled_pt = unlabeled_probs ** (1/T)
            unlabeled_probs = unlabeled_pt / unlabeled_pt.sum(dim=1, keepdim=True)

        # if distribution alignment is enabled
        if da is not None:
            unlabeled_probs_da = unlabeled_probs * da
            unlabeled_probs = unlabeled_probs_da / unlabeled_probs_da.sum(dim=1, keepdim=True)

        pseudo_labels = unlabeled_probs.max(1)[1]   # e.g. 5
        # pseudo_labels = nn.functional.one_hot(unlabeled_probs.max(1)[1], num_classes=10)
        # print(pseudo_labels)

        mask = (unlabeled_probs.max(1)[0]>threshhold).float()
        # print(unlabeled_probs)
        # print(mask)

        return unlabeled_probs, pseudo_labels, mask

    def get_unlabeled_sub_prob(self, sample_size):
        self.substitute.eval()
        softmax = nn.Softmax(dim=1)
        data_list = np.random.choice(range(len(self.unlabeled_dataset)), sample_size,
                                     replace=False)  # [i for i in range(7800,8000)]
        dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset, batch_size=50,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        for i, (unlabeled_inputs, _, _, _) in enumerate(dataloader):
            if self.opt.use_gpu:
                unlabeled_inputs = unlabeled_inputs.cuda()
            with torch.no_grad():
                unlabeled_outputs = self.substitute(unlabeled_inputs).detach()
            unlabeled_probs = softmax(unlabeled_outputs)
            if i == 0:
                probs_sum = unlabeled_probs.sum(dim=0)
            else:
                probs_sum += unlabeled_probs.sum(dim=0)
        return probs_sum / sample_size

    def get_next_diff_dataset(self, k_avg=3):
        loop_size = len(self.new_sub_dataset)
        next_diff_dataset = data.SubDataset()
        softmax = nn.Softmax(dim=1)
        batch_size = 100
        unlabeled_dataloader = torch.utils.data.DataLoader(
            self.unlabeled_dataset,
            batch_size=batch_size,
            num_workers=4
        )
        # stores the indexs of unused public data for each pseudo label class
        label_idx_list = [[] for _ in range(self.opt.victim_n_classes)]
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )
        for i, (unlabeled_items, _, _, _) in enumerate(unlabeled_dataloader):
            if self.opt.use_gpu:
                unlabeled_items = unlabeled_items.cuda()
            for i_aug in range(k_avg):
                unlabeled_items_aug = augmentation(unlabeled_items)
                with torch.no_grad():
                    unlabeled_outputs = self.substitute(unlabeled_items_aug)
                if i_aug == 0:
                    unlabeled_probs = softmax(unlabeled_outputs)
                else:
                    unlabeled_probs += softmax(unlabeled_outputs)
            unlabeled_probs = unlabeled_probs.detach()
            unlabeled_probs /= k_avg
            pseudo_labels = unlabeled_probs.max(1)[1]
            for j in range(len(pseudo_labels)):
                self.unlabeled_dataset.items[i*batch_size+j][1] = unlabeled_probs[j].cpu()
                self.unlabeled_dataset.items[i*batch_size+j][2] = int(pseudo_labels[j])
                # only involve unused items for next batch
                if self.unlabeled_dataset.items[i*batch_size+j][3] == 0:
                    label_idx_list[pseudo_labels[j]].append(i*batch_size+j)

        # # sample in order
        # current_label = 0
        # for _ in range(loop_size):
        #     items = []
        #     for _ in range(self.opt.n_fuse):
        #         while len(label_idx_list[current_label]) == 0:
        #             current_label = (current_label+1)%self.opt.victim_n_classes
        #         idx = label_idx_list[current_label].pop(0)
        #         items.append(self.unlabeled_dataset.items[idx][0])
        #         self.unlabeled_dataset.items[idx][3] = 1 # mark as used
        #         current_label = (current_label + 1) % self.opt.victim_n_classes
        #     # items = tuple(items)
        #     next_diff_dataset.items.append((items, -1, -1, -1))

        # sample by frequency
        label_nums = [len(idxs) for idxs in label_idx_list]
        print(label_nums)
        # weight mapping function
        # max_num = max(label_nums) # func1
        # label_weights = [2*max_num-label_num if label_num != 0 else 0 for label_num in label_nums]
        label_num_sum = sum(label_nums) # func2
        label_weights = [-math.log(label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        # label_num_sum = sum(label_nums) # func3
        # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
        label_probs = [label_weight/sum(label_weights) for label_weight in label_weights]

        for _ in range(loop_size):
            # if the number of classes where p!=0 is less than n_fuse, set replace = True
            n_valid_class = 0
            for p in label_probs:
                if p > 0:
                    n_valid_class += 1
            if n_valid_class < self.opt.n_fuse:
                replace = True
            else:
                replace = False
            items = []
            current_labels = np.random.choice(range(self.opt.victim_n_classes), size=self.opt.n_fuse, p=label_probs, replace=replace)
            for current_label in current_labels:
                idx = label_idx_list[current_label].pop(0)
                items.append(self.unlabeled_dataset.items[idx][0])
                self.unlabeled_dataset.items[idx][3] = 1  # mark as used
            next_diff_dataset.items.append((items, -1, -1, -1))
            # update weights
            label_nums = [len(idxs) for idxs in label_idx_list]
            # weight mapping function
            # max_num = max(label_nums)
            # label_weights = [2 * max_num - label_num if label_num != 0 else 0 for label_num in label_nums]
            label_num_sum = sum(label_nums) # func2
            label_weights = [-math.log(label_num/label_num_sum)+1e-6 if label_num != 0 else 0 for label_num in label_nums]
            # label_num_sum = sum(label_nums) # func3
            # label_weights = [math.exp(-label_num/label_num_sum) if label_num != 0 else 0 for label_num in label_nums]
            label_probs = [label_weight / sum(label_weights) for label_weight in label_weights]

        return next_diff_dataset

    def batch_watermark_evaluate(self, trigger_set):
        self.substitute.eval()
        if self.opt.use_gpu:
            trigger_set = trigger_set.cuda()
            self.substitute.cuda()
        with torch.no_grad():
            predictions = self.substitute(trigger_set)
            success = (predictions.max(1)[1] == self.wm_target).float().sum().detach().cpu()
        return success / trigger_set.shape[0]

    def watermark_evaluate(self, sample_size=200):
        data_list = [i for i in range(0, sample_size)]
        batch_size = 50
        dataloader = torch.utils.data.DataLoader(
            self.trigger_dataset, batch_size=batch_size,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        success = 0.
        total = 0.
        for _, (trigger_data, _) in enumerate(dataloader):
            total += trigger_data.shape[0]
            success += self.batch_watermark_evaluate(trigger_data) * trigger_data.shape[0]
        return success / total
