import random
import os
import pickle

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import numpy as np
from efficientnet_pytorch import EfficientNet
from advertorch.attacks import JacobianSaliencyMapAttack

import defense
import data
import classifier
import generator
import trainer
from simclr import SimCLR
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model
from data import DoubleTransforms
from utils import *
from config import opt
from query import *
from utils import get_dataset_by_name, classifier_dict, load_model_weights
# import loss


simclr_dataset_dict = {
    'cifar10': data.CIFAR10,
    'cifar100': data.CIFAR100,
    'TinyImageNet': data.TinyImageNet,
    'noise': data.Noise,
    'svhn': data.SVHN,
    'ImageNet': data.ImageNet,
}

semi_dataset_dict = {
    'cifar10': data.CIFAR10,
    'cifar100': data.CIFAR100,
    'TinyImageNet': data.TinyImageNet,
    'noise': data.Noise,
    'svhn': data.SVHN,
    'ImageNet': data.ImageNet,
}


gen_dict = {
    'sngan': generator.SNGAN,
    'progan': generator.Progan,
    'attackgan': generator.AttackGAN,
    'fusiongan': generator.FusionGAN,
    'fusiongan-label': generator.FusionGAN,
    'random_aggr_gen': generator.RandomAggrGen,
    # 'cifar_10_gan': generator.SNGAN,
    # 'cifar_100_90_classes_gan': generator.SNGAN,
    # 'cifar_100_40_classes_gan': generator.SNGAN,
    # 'cifar_10_vae': generator.VAE,
    # 'cifar_100_6_classes_gan': generator.Progan,
    # 'cifar_100_10_classes_gan': generator.Progan,
}


def print_config():
    print('victim_dataset', opt.victim_dataset)
    print('simclr_dataset', opt.simclr_dataset)
    print('gen_dataset', opt.gen_dataset)
    print('surrogate_dataset', opt.surrogate_dataset)
    print('victim_model', opt.victim_model)
    print('gen_model', opt.gen_model)
    print('sub_model', opt.sub_model)
    print('simclr_model', opt.simclr_model)
    print('seed', opt.seed)
    print('source', opt.source)
    print('sub_model', opt.sub_model)
    print('eval_model', opt.eval_model)
    print('query', opt.query)
    print('epoch', opt.epoch)
    print('gen_epoch', opt.gen_epoch)
    print('div_epoch', opt.div_epoch)
    print('n_loop', opt.n_loop)
    print('noise_weight', opt.noise_weight)
    print('adv_weight', opt.adv_weight)
    print('div_weight', opt.div_weight)
    print('pseudo_label_weight', opt.pseudo_label_weight)
    print('dataset_decay', opt.dataset_decay)
    print('same_origin', opt.same_origin)
    print('continue_train', opt.continue_train)
    print('n_pseudo', opt.n_pseudo)
    print('pseudo_epoch', opt.pseudo_epoch)
    print('pseudo_dataset', opt.pseudo_dataset)
    print('loss_threshhold', opt.loss_threshhold)
    print('noise_eps', opt.noise_eps)
    print('noise_step', opt.noise_step)
    print('eps_multiple', opt.eps_multiple)
    print('targeted', opt.targeted)
    print('epoch_val_rate', opt.epoch_val_rate)
    print('sub_optim', opt.sub_optim)
    print('sub_lr', opt.sub_lr)
    print('n_fuse', opt.n_fuse)
    print('pre_train_sub', opt.pre_train_sub)
    print('simclr_sub_batch_size', opt.simclr_sub_batch_size)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def weight_init(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            # nn.init.constant_(m.weight, 1e-2)
            nn.init.xavier_normal_(m.weight)
            # nn.init.constant_(m.bias,0)
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            # nn.init.kaiming_normal(m.weight, mode="fan_out")
            nn.init.constant_(m.weight, 1e-3)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 2e-1)
            nn.init.constant_(m.bias, 0)


# train victim model
def prepare_victim():
    # prepare dataset
    victim_dataset = get_dataset_by_name(opt.victim_dataset, opt.victim_img_size)
    if opt.victim_wm_dataset:
        victim_wm_dataset = get_dataset_by_name(opt.victim_wm_dataset, opt.victim_img_size)
    elif opt.victim_ood_dataset:
        victim_ood_dataset = get_dataset_by_name(opt.victim_ood_dataset, opt.victim_img_size)

    if opt.victim_model.startswith('efficientnet'):
        victim = EfficientNet.from_pretrained(
            opt.victim_model,
            num_classes=opt.victim_n_classes)
    elif opt.victim_model.startswith('wrn'):
        depth = int(opt.victim_model.split('-')[-1])
        victim = classifier.WideResNet(
            n_outputs=opt.victim_n_classes,
            depth=depth
        )
    else:
        victim = classifier_dict[opt.victim_model](
            n_outputs=opt.victim_n_classes
        )

    if opt.victim_wm_dataset:
        model_name = f'victim_{opt.victim_model}_{opt.victim_dataset}_wm-{opt.victim_wm_dataset}'
        victim_trainer = trainer.ClassifierTrainerWM(opt, victim, victim_dataset, victim_wm_dataset, model_name,
                                                   optim='adam', continue_train=True)
        victim = victim_trainer.train(lr=0.001)
    elif opt.victim_ood_dataset:
        pre_model_name = f'victim_{opt.victim_model}_{opt.victim_dataset}'
        model_name = f'victim_{opt.victim_model}_{opt.victim_dataset}_ood-{opt.victim_ood_dataset}'
        victim_trainer = trainer.ClassifierTrainerOOD(opt, victim, victim_dataset, victim_ood_dataset,
                                                   model_name, pre_model_name, optim='sgd', lr_decay=5e-4,
                                                   continue_train=True)
        victim = victim_trainer.train(
            n_epochs=10,
            optim='sgd', lr=0.001,
            weight_decay=5e-4
        )
        victim_trainer.evaluate()
    else:
        model_name = f'victim_{opt.victim_model}_{opt.victim_dataset}'
        victim_trainer = trainer.ClassifierTrainer(opt, victim, victim_dataset,
                                                   model_name, optim=opt.victim_optim, lr_decay=opt.victim_lr_decay,
                                                   continue_train=True)
        victim = victim_trainer.train(
            n_epochs=opt.victim_epoch,
            optim=opt.victim_optim, lr=opt.victim_lr,
            weight_decay=opt.victim_weight_decay
        )
        victim_trainer.evaluate()
    
    if opt.victim_ood_dataset:  # evaluate the energy threshold values for CIP defense
        FPR_energy, open_set_energy = defense.get_cip_threshold(victim_dataset, victim)
        print(f'Energy threshold for CIP defense: FPR energy {FPR_energy}, open set energy {open_set_energy}')
        cip_ckpt_dir = os.path.join(opt.work_dir, 'checkpoints/cip_ckpt')
        os.makedirs(cip_ckpt_dir, exist_ok=True)
        cip_energy_path = os.path.join(cip_ckpt_dir, 'cip_energy_dict.pkl')
        if os.path.exists(cip_energy_path):
            with open(cip_energy_path, 'rb') as fin:
                cip_energy_dict = pickle.load(fin)
            cip_energy_dict[model_name] = {'FPR_energy': FPR_energy, 'open_set_energy': open_set_energy}
        else:
            cip_energy_dict = {model_name: {'FPR_energy': FPR_energy, 'open_set_energy': open_set_energy}}
        with open(cip_energy_path, 'wb') as fout:
            pickle.dump(cip_energy_dict, fout)

    return victim


# load pre-trained data generator
def prepare_data_gen():
    if opt.source == 'active':
        return
    data_gen = gen_dict[opt.gen_model]()
    if opt.gen_pretrain:
        if opt.gen_model == 'progan':
            data_gen.load_state_dict(torch.load(opt.data_dir + 'checkpoints/' + 'split_90_generator_final'))
        elif opt.gen_model == 'sngan':
            data_gen.load_state_dict(
                torch.load(opt.data_dir + 'checkpoints/' + 'cifar_100_90_classes_gan.pth')['gen_state_dict'])
        elif 'fusiongan' in opt.gen_model:
            dataset_parse = opt.gen_dataset.split('-')
            dataset_name = dataset_parse[0]
            partition = dataset_parse[1] if len(dataset_parse) > 1 else None
            simclr_dataset = simclr_dataset_dict[dataset_name](
                transform=TransformsSimCLR(size=opt.pre_sub_img_size), partition=partition
            ).dataset
            encoder = generator.SingleEncoder()
            n_features = encoder.ch
            model = SimCLR(encoder, opt.projection_dim, n_features)
            model_name = f'fusiongan_simclr_{opt.gen_dataset}'
            simclr_trainer = trainer.SimCLRTrainer(opt, model, simclr_dataset, model_name, opt.simclr_gen_epoch,
                                                   opt.simclr_gen_batch_size)
            model = simclr_trainer.train()
            for i in range(opt.n_fuse):
                data_gen.encoder.encoders[i] = load_model_weights(model.encoder, data_gen.encoder.encoders[i])
    elif opt.gen_model != 'attackgan':
        data_gen.apply(weight_init)

    # data_gen_trainer = trainer.GANTrainer(opt, data_gen, data_gen_dataset)
    # data_gen = data_gen_trainer.train()
    # data_gen_trainer.test()

    return data_gen


# pre-train substitute model
def prepare_substitute():
    if opt.pre_train_sub == 'simclr':

        dataset_parse = opt.simclr_dataset.split('-')
        dataset_name = dataset_parse[0]
        partition = dataset_parse[1] if len(dataset_parse) > 1 else None
        simclr_dataset = simclr_dataset_dict[dataset_name](
            transform=TransformsSimCLR(size=opt.pre_sub_img_size), partition=partition
        ).dataset
        if opt.simclr_model.startswith('efficientnet'):
            encoder = EfficientNet.from_pretrained(
                opt.simclr_model,
                num_classes=opt.victim_n_classes)
        elif opt.simclr_model.startswith('wrn'):
            depth = int(opt.simclr_model.split('-')[-1])
            encoder = classifier.WideResNet(
                n_outputs=opt.victim_n_classes,
                depth=depth
            )
        else:
            encoder = classifier_dict[opt.simclr_model](
                n_outputs=opt.victim_n_classes
            )
        n_features = encoder.fc.in_features
        model = SimCLR(encoder, opt.projection_dim, n_features)
        model_name = 'simclr_%s_%s' % (opt.simclr_model, opt.simclr_dataset)
        simclr_trainer = trainer.SimCLRTrainer(opt, model, simclr_dataset, model_name, opt.simclr_sub_epoch,
                                               opt.simclr_sub_batch_size,
                                               continue_train=True)
        model = simclr_trainer.train()
        return model.encoder
    else:
        dataset_parse = opt.pre_sub_dataset.split('-')
        dataset_name = dataset_parse[0]
        partition = dataset_parse[1] if len(dataset_parse) > 1 else None
        substitute_dataset = dataset_dict[dataset_name](
            input_size=opt.pre_sub_img_size, partition=partition
        )
        if opt.pre_sub_model.startswith('efficientnet'):
            substitute = EfficientNet.from_pretrained(
                opt.pre_sub_model,
                num_classes=opt.pre_sub_n_classes)
        elif opt.pre_sub_model.startswith('wrn'):
            depth = int(opt.pre_sub_model.split('-')[-1])
            substitute = classifier.WideResNet(
                n_outputs=opt.pre_sub_n_classes,
                depth=depth
            )
        else:
            substitute = classifier_dict[opt.pre_sub_model](
                n_outputs=opt.pre_sub_n_classes
            )
        # model_name = 'substitute_%s_%s'%(opt.pre_sub_model, opt.pre_sub_dataset)
        model_name = 'victim_%s_%s' % (opt.pre_sub_model, opt.pre_sub_dataset)

        # pre-train substitute
        if opt.pre_train_sub == 'pre':
            substitute_trainer = trainer.ClassifierTrainer(opt, substitute, substitute_dataset, model_name)
            substitute = substitute_trainer.train(
                n_epochs=opt.pre_sub_pre_epoch,
                optim=opt.pre_sub_optim, lr=opt.pre_sub_pre_lr)
            substitute_trainer.evaluate()

        return substitute


# train substitute model and data generator
def train_substitute(victim, pre_sub, data_gen, strategy='static', source='cifar10'):
    if opt.sub_model.startswith('efficientnet'):
        substitute = EfficientNet.from_pretrained(
            opt.sub_model,
            num_classes=opt.pre_sub_n_classes)
    elif opt.sub_model.startswith('wrn'):
        depth = int(opt.sub_model.split('-')[-1])
        substitute = classifier.WideResNet(
            n_outputs=opt.pre_sub_n_classes,
            depth=depth
        )
    else:
        substitute = classifier_dict[opt.sub_model](
            n_outputs=opt.pre_sub_n_classes
        )

    substitute = load_model_weights(pre_sub, substitute)
    substitute.fc = nn.Linear(substitute.fc.in_features, opt.victim_n_classes)
    n_features = substitute.fc.in_features

    # define datasets
    sub_dataset = data.SubDataset()
    div_dataset = data.SubDataset()
    unqueried_sub_dataset = data.SubDataset()
    unlabeled_dataset = data.UnlabeledDataset()
    next_diff_dataset = data.UnlabeledDataset()

    dataset_parse = opt.victim_dataset.split('-')
    dataset_name = dataset_parse[0]
    partition = dataset_parse[1] if len(dataset_parse) > 1 else None
    eval_dataset = dataset_dict[dataset_name](input_size=opt.victim_img_size, partition=partition)

    dataset_parse = opt.surrogate_dataset.split('-')
    dataset_name = dataset_parse[0]
    partition = dataset_parse[1] if len(dataset_parse) > 1 else None
    surrogate_eval_dataset = dataset_dict[dataset_name](input_size=opt.victim_img_size, partition=partition)

    dataset_parse = opt.simclr_dataset.split('-')
    dataset_name = dataset_parse[0]
    partition = dataset_parse[1] if len(dataset_parse) > 1 else None
    simclr_dataset = simclr_dataset_dict[dataset_name](
        transform=TransformsSimCLR(size=opt.pre_sub_img_size), partition=partition
    ).dataset

    dataset_parse = opt.surrogate_dataset.split('-')
    dataset_name = dataset_parse[0]
    partition = dataset_parse[1] if len(dataset_parse) > 1 else None
    aug_dataset = semi_dataset_dict[dataset_name](
        transform=DoubleTransforms(size=opt.pre_sub_img_size), partition=partition
    ).dataset

    base_dataset = get_base_dataset(opt.n_fuse, int(opt.query / opt.n_loop), opt.surrogate_dataset)

    # defense strategy
    if opt.prada_test:
        prada_agent = defense.PradaAgent(shapiro_threshold=0.9)
    if opt.victim_ood_dataset:
        cip_ood_counter = {'OOD-high-suspectible': 0,
                           'OOD-low-suspectible': 0,
                           'close-set': 0}
    else:
        cip_ood_counter = None

    # static strategy: one-time query
    if strategy == 'static':
        epoch = opt.epoch
        if source in ['cifar10', 'cifar100']:
            sub_dataset, new_sub_dataset = get_baseline_dataset(int(opt.query), victim, sub_dataset, dataset=source)
        else:
            sub_dataset, new_sub_dataset = get_sub_dataset(int(opt.query), victim, data_gen, substitute, sub_dataset)
        sub_trainer = trainer.SubstituteTrainer(
            opt, victim, substitute, data_gen,
            sub_dataset, new_sub_dataset, eval_dataset,
            source=source, strategy=opt.strategy,
            n_epochs=epoch
        )
        substitute, data_gen = sub_trainer.train()

    # adaptive query
    elif strategy == 'adaptive':
        query_per_loop = int(opt.query / opt.n_loop)
        # epoch_per_loop = opt.epoch//opt.n_loop
        epoch = opt.epoch
        # print('epoch:',epoch)
        x_list = [[], [], [], [], []]
        y_list = [[], [], [], [], []]
        labels = ['noise', 'ASR', 'accuracy', 'fidelity', 'KD loss']
        for loop in range(opt.n_loop):
            print(f'Substitute training: loop {loop + 1}/{opt.n_loop}')
            # get dataset
            if source in ['cifar10', 'cifar100']:
                sub_dataset, new_sub_dataset = get_baseline_dataset(query_per_loop, victim, sub_dataset, dataset=source)
            elif source == 'random':
                sub_dataset, new_sub_dataset, unqueried_sub_dataset, next_dataset = get_random_dataset(
                    sub_dataset, unqueried_sub_dataset, victim,
                    query_per_loop, dataset=opt.surrogate_dataset, return_idx=False)
            elif source == 'attackgan':
                sub_dataset, new_sub_dataset, next_dataset = get_adv_dataset(query_per_loop, victim, data_gen,
                                                                             sub_dataset, dataset=opt.surrogate_dataset)
            elif 'fusiongan' in source:
                # sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset = get_fusion_dataset(opt.n_fuse, query_per_loop, victim,
                #                                                                 data_gen,
                #                                                                 sub_dataset, unlabeled_dataset,
                #                                                                 dataset=opt.surrogate_dataset)
                if opt.enlarge_every_loop:
                    base_dataset = get_base_dataset(opt.n_fuse, int(opt.query / opt.n_loop), opt.surrogate_dataset)
                sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset = get_fusion_dataset(
                    base_dataset, opt.n_fuse, query_per_loop, victim,
                    data_gen,
                    sub_dataset, unlabeled_dataset, next_diff_dataset,
                    dataset=opt.surrogate_dataset,
                    victim_return_type=opt.victim_return_type,
                    cip_ood_counter=cip_ood_counter)
                if opt.div_epoch > 0:
                    div_dataset = update_div_dataset(div_dataset, new_sub_dataset, div_threshold=opt.div_threshold)
                if opt.visualize_dataset:
                    visualize_sub_dataset(new_sub_dataset, loop+1)
            elif source == 'random_aggr_gen':
                if opt.enlarge_every_loop:
                    base_dataset = get_base_dataset(opt.n_fuse, int(opt.query / opt.n_loop), opt.surrogate_dataset)
                sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset = get_random_aggr_gen_dataset(
                    base_dataset, opt.n_fuse, query_per_loop, victim,
                    data_gen,
                    sub_dataset, unlabeled_dataset, next_diff_dataset,
                    dataset=opt.surrogate_dataset)
                if opt.div_epoch > 0:
                    div_dataset = update_div_dataset(div_dataset, new_sub_dataset, div_threshold=opt.div_threshold)
                if opt.visualize_dataset:
                    visualize_sub_dataset(new_sub_dataset, loop+1)
            elif source == 'papernot':
                sub_dataset, new_sub_dataset, next_dataset = get_papernot_dataset(sub_dataset, victim, substitute,
                                                                                  init_n_per_class=opt.papernot_init,
                                                                                  dataset=opt.victim_dataset,
                                                                                  lamb=opt.papernot_lamb)
            elif 'active' in source:
                if source.split('-')[-1] == 'df':
                    sub_dataset, new_sub_dataset, unqueried_sub_dataset, next_dataset = get_deepfool_dataset(
                        sub_dataset, unqueried_sub_dataset, victim, substitute,
                        query_per_loop, dataset='cifar100', return_idx=False)
                elif source.split('-')[-1] == 'kc':
                    sub_dataset, new_sub_dataset, unqueried_sub_dataset, next_dataset = get_kcenter_dataset(
                        sub_dataset, unqueried_sub_dataset, victim, substitute,
                        query_per_loop, dataset='cifar100', return_idx=False)
                else:
                    raise NotImplementedError(f'Query method not implemented: {source}.')

            elif source == 'mosafi':
                new_sub_dataset = get_mosafi_dataset(victim, query_per_loop, dataset=opt.surrogate_dataset)
                next_dataset = None

            elif source == 'avg':
                new_sub_dataset = get_avg_dataset(victim, query_per_loop, dataset=opt.surrogate_dataset)
                next_dataset = None

            else:
                sub_dataset, new_sub_dataset = get_sub_dataset(
                    sub_dataset, substitute, query_per_loop, victim, data_gen, substitute, sub_dataset)

            # defense: prada detection
            if opt.prada_test:
                new_sub_dataloader = torch.utils.data.DataLoader(
                    new_sub_dataset,
                    batch_size=100,
                    shuffle=False,
                    num_workers=4
                )
                print('[defense] Detecting attack with prada...')
                attack_exists = False
                for i, (_, perturbed_img, victim_prob, _) in enumerate(new_sub_dataloader):
                    targets = victim_prob.max(1)[1]
                    for j in range(perturbed_img.shape[0]):
                        query = perturbed_img[j].numpy()
                        target = targets[j].item()
                        attacker_present = prada_agent.single_query(query, target)
                        if attacker_present:
                            attack_exists = True
                            break
                if attack_exists:
                    print('[defense] Attack detected!')
                else:
                    print('[defense] No attack detected.')
                
            if opt.victim_ood_dataset:
                print(f'CIP ood detection: {cip_ood_counter}')

            # only evaluate saved model
            if opt.sub_eval_loop:
                print(f'[substitute] Loading saved substitute from loop {opt.sub_eval_loop}...')
                if 'fusiongan' in source or source == 'random_aggr_gen':
                    sub_trainer = trainer.SubstituteTrainerSeeker(
                        opt, victim, substitute, data_gen,
                        sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                        eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                        source=source, x_list=x_list, y_list=y_list,
                        labels=labels, strategy=opt.strategy,
                        loop=loop, n_epochs=epoch,
                        save=False, load=True
                    )
                else:
                    sub_trainer = trainer.SubstituteTrainer(
                        opt, victim, substitute, data_gen,
                        sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                        eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                        source=source, x_list=x_list, y_list=y_list,
                        labels=labels, strategy=opt.strategy,
                        loop=loop, n_epochs=epoch,
                        save=False, load=True
                    )
                for i in range(5):
                    acc, fidelity, kd_loss = sub_trainer.evaluate()
                    asr, sub_asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = sub_trainer.adv_evaluate(200)
                    print(f'[start] accuracy {acc} | fidelity {fidelity} | ASR {asr} (substitute ASR {sub_asr}) ' +
                          f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
                          f'| L2 noise {avg_l2_noise}')
                return

            # train substitute
            save = False
            if opt.sub_save_loop:
                if loop+1 == opt.sub_save_loop:
                    save = True
            if not opt.victim_wm_dataset:
                if 'fusiongan' in source or source == 'random_aggr_gen':
                    sub_trainer = trainer.SubstituteTrainerSeeker(
                        opt, victim, substitute, data_gen,
                        sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                        eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                        source=source, x_list=x_list, y_list=y_list,
                        labels=labels, strategy=opt.strategy,
                        loop=loop, n_epochs=epoch,
                        # save=save
                    )
                else:
                    sub_trainer = trainer.SubstituteTrainer(
                        opt, victim, substitute, data_gen,
                        sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                        eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                        source=source, x_list=x_list, y_list=y_list,
                        labels=labels, strategy=opt.strategy,
                        loop=loop, n_epochs=epoch,
                        # save=save
                    )
            else:
                sub_trainer = trainer.SubstituteTrainerWM(
                    opt, victim, substitute, data_gen,
                    sub_dataset, new_sub_dataset, next_dataset, unlabeled_dataset,
                    eval_dataset, surrogate_eval_dataset, simclr_dataset, aug_dataset, div_dataset,
                    source=source, x_list=x_list, y_list=y_list,
                    labels=labels, strategy=opt.strategy,
                    loop=loop, n_epochs=epoch,
                    save=save
                )
            if 'fusiongan' in source or source == 'random_aggr_gen':
                # substitute, data_gen, substitute_projected, unlabeled_dataset, next_diff_dataset = sub_trainer.train()
                substitute, data_gen, unlabeled_dataset, next_diff_dataset = sub_trainer.train()
                # substitute, data_gen, unlabeled_dataset = sub_trainer.train()
            else:
                substitute, data_gen, unlabeled_dataset = sub_trainer.train()
            x_list = sub_trainer.x_list
            y_list = sub_trainer.y_list
            if opt.plot:
                plot_line(opt.title, sub_trainer.x_list, sub_trainer.y_list,
                          sub_trainer.labels, y_lim=[0, 2])

    elif strategy == 'every':
        epoch = 1000
        sub_trainer = trainer.SubstituteTrainer(
            opt, victim, substitute, data_gen,
            None, None, eval_dataset,
            source=source, strategy=opt.strategy,
            n_epochs=epoch
        )
        substitute, data_gen = sub_trainer.train()

    else:
        print("Illegal strategy.")

    if opt.continue_train:
        acc, fidelity, kd_loss = sub_trainer.evaluate()
        asr, success_avg_l2_noise, success_l2_noise_per_pixel, avg_l2_noise = sub_trainer.adv_evaluate(
            len(sub_trainer.eval_dataset.test_dataset))
        print(f'[final] accuracy {acc} | fidelity {fidelity} | ASR {asr} ' +
              f'| KD loss {kd_loss} | success L2 noise {success_avg_l2_noise}({success_l2_noise_per_pixel})' +
              f'| L2 noise {avg_l2_noise}')
        return substitute, data_gen

    # pseudo label training
    if strategy != 'static':
        if source == 'attackgan':
            final_substitute = classifier_dict[opt.eval_model](
                n_outputs=opt.victim_n_classes
            )
        else:
            final_substitute = classifier_dict[opt.sub_model](
                n_outputs=opt.victim_n_classes
            )
        final_substitute = load_model_weights(pre_sub, final_substitute)
        if opt.pseudo_dataset == 'new':
            pseudo_dataset = new_sub_dataset
        elif opt.pseudo_dataset == 'all':
            pseudo_dataset = sub_dataset
        else:
            print('Illegal pseudo dataset.')
        pseudo_trainer = trainer.PseudoTrainer(
            opt, victim, final_substitute,
            pseudo_dataset, eval_dataset,
            n_epochs=opt.pseudo_epoch
        )
        final_substitute = pseudo_trainer.train()
    else:
        final_substitute = substitute

    return final_substitute, data_gen
