# from models.selector import *
from paddle import nn
from utils.util import *
# from data_loader import get_train_loader, get_test_loader
import paddle.vision.transforms as transforms


from dataloader import PostTensorTransform, get_dataloader, DictDataset, get_dataset, get_transform
from models import PreActResNet18, ResNet18, NetC_MNIST
from attacks import SmoothAttacker, CleanLabelAttacker, HardAttacker, BadNetTrigger, SIGTrigger, WaNetTrigger, GeneTcbTrigger, NaiveTcbTrigger, AETcbTrigger
from at import AT
from config import get_arguments
from models import AutoencoderCifar, AutoEncoder_MNIST, AutoencoderMnist, AutoencoderCeleba


def train_step(opt, train_loader, nets, optimizer, criterions, epoch):
    at_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    snet = nets['snet']
    tnet = nets['tnet']

    criterionCls = criterions['criterionCls']
    criterionAT = criterions['criterionAT']

    snet.train()
    for idx, batch in enumerate(train_loader, start=1):
        img = batch["input"]
        target = batch["target"]

        activationS, output_s = snet(img)
        at_loss = criterionCls(output_s, target)
        if tnet:
            activationT ,_ = tnet(img)
            if opt.dataset != "mnist":
                at_loss += criterionAT(activationS[3], activationT[3].detach()) * opt.beta4
                at_loss += criterionAT(activationS[2], activationT[2].detach()) * opt.beta3
            at_loss += criterionAT(activationS[1], activationT[1].detach()) * opt.beta2
            at_loss += criterionAT(activationS[0], activationT[0].detach()) * opt.beta1

        prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
        
        at_losses.update(at_loss.item(), img.size)
        top1.update(prec1.item(), img.size)
        top5.update(prec5.item(), img.size)

        optimizer.clear_grad()
        at_loss.backward()
        optimizer.step()

        if idx % opt.print_freq == 0:
            print('Epoch[{0}]:[{1:03}/{2:03}] '
                  'AT_loss:{losses.val:.4f}({losses.avg:.4f})  '
                  'prec@1:{top1.val:.2f}({top1.avg:.2f})  '
                  'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=at_losses, top1=top1, top5=top5))


def test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch):
    test_process = []
    top1 = AverageMeter()
    top5 = AverageMeter()

    snet = nets['snet']
    tnet = nets['tnet']

    criterionCls = criterions['criterionCls']
    criterionAT = criterions['criterionAT']

    snet.eval()

    for idx, batch in enumerate(test_clean_loader, start=1):
        img = batch["input"]
        target = batch["target"]

        with paddle.no_grad():
            _, output_s = snet(img)

        prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
        top1.update(prec1.item(), img.size)
        top5.update(prec5.item(), img.size)

    acc_clean = [top1.avg, top5.avg]

    cls_losses = AverageMeter()
    at_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for idx, batch in enumerate(test_bad_loader, start=1):
        img = batch["input"]
        target = batch["target"]

        with paddle.no_grad():
            activationS, output_s = snet(img)

            at_loss = paddle.to_tensor(0, dtype=paddle.float32)
            if tnet:
                activationT ,_ = tnet(img)
                if opt.dataset != "mnist":
                    at_loss += criterionAT(activationS[3], activationT[3].detach()) * opt.beta4
                    at_loss += criterionAT(activationS[2], activationT[2].detach()) * opt.beta3
                at_loss += criterionAT(activationS[1], activationT[1].detach()) * opt.beta2
                at_loss += criterionAT(activationS[0], activationT[0].detach()) * opt.beta1
            cls_loss = criterionCls(output_s, target)

        prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
        cls_losses.update(cls_loss.item(), img.size)
        at_losses.update(at_loss.item(), img.size)
        top1.update(prec1.item(), img.size)
        top5.update(prec5.item(), img.size)

    acc_bd = [top1.avg, top5.avg, cls_losses.avg, at_losses.avg]

    print('[clean]Prec@1: {:.2f}'.format(acc_clean[0]))
    print('[bad]Prec@1: {:.2f}'.format(acc_bd[0]))

    # save training progress
    log_root = opt.log_root + '/results.csv'
    test_process.append(
        (epoch, acc_clean[0], acc_bd[0], acc_bd[2], acc_bd[3]))
    df = pd.DataFrame(test_process, columns=(
    "epoch", "test_clean_acc", "test_bad_acc", "test_bad_cls_loss", "test_bad_at_loss"))
    # df.to_csv(log_root, mode='a', index=False, encoding='utf-8')

    return acc_clean, acc_bd


def get_transformer(opt, train=True):
    transforms_list = []
    if train:
        transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop))
    if opt.dataset == "cifar10": 
        if train:
            transforms_list.append(transforms.RandomHorizontalFlip(prob=0.5)) 
        transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
    elif opt.dataset == 'mnist':
        transforms_list.append(transforms.Normalize([0.5], [0.5]))
    if transforms_list:
        transformer = transforms.Compose(transforms_list)
    else:
        transformer = None
    return transformer


def denormalize(inputs, opt):
    if opt.dataset == 'cifar10':
        mean = np.array([0.4914, 0.4822, 0.4465], dtype='float32')[:, None, None]
        std = np.array([0.247, 0.243, 0.261], dtype='float32')[:, None, None]
    elif opt.dataset == 'mnist':
        mean = np.array([0.5], dtype='float32')[:, None, None]
        std = np.array([0.5], dtype='float32')[:, None, None]
    else:
        mean = np.array([0.], dtype='float32')[:, None, None]
        std = np.array([1.], dtype='float32')[:, None, None]
    inputs = inputs * std[None, ...] + mean[None, ...]
    return inputs


def get_poisoned_train_data(train_dl, opt, attackers, *args):
        """
        Perform attack to training data.

        @Args:
            inputs (np.ndarray): input images (unnormalized)
            labels (np.ndarray): output labels
            pred_probs (np.ndarray): predicted probability of input images
            opt (dict): configuration
        @Return:
            training dataloader
        """
        inputs = []
        targets = []
        for batch in train_dl:
            inputs.append(batch['input'])
            targets.append(batch['target'])
        inputs = np.concatenate(inputs)
        targets = np.concatenate(targets)
        inputs = denormalize(inputs, opt)
        # insert backdoors to some of the training data
        attacked_inputs, attacked_labels = inputs, targets
        for attacker in attackers:
            attacked_inputs, attacked_labels = attacker.attack(attacked_inputs, attacked_labels, None, None)

        # build dataset and dataloader 
        transformer = get_transformer(opt, train=True)
        dataset = DictDataset({'input':attacked_inputs, 'target':attacked_labels, 'origin_input':inputs, 'origin_target':targets},
        input_transform=transformer)
        dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
        return dataloader


def split_dataset(dl, ratio, opt, attackers = None, attack_all=True):
    inputs = []
    targets = []
    for batch in dl:
        inputs.append(batch['input'])
        targets.append(batch['target'])

    inputs = np.concatenate(inputs)
    targets = np.concatenate(targets)
    inputs = denormalize(inputs, opt)

    r = np.random.rand(len(inputs))
    add_trigger_mask = r <= ratio
    split_inputs = inputs[add_trigger_mask]
    split_targets = targets[add_trigger_mask]
    print(add_trigger_mask.sum())
    invert_mask = np.invert(add_trigger_mask)
    rest_inputs = inputs[invert_mask]
    rest_targets = targets[invert_mask]
    inputs = rest_inputs
    targets = rest_targets
    if len(split_inputs) > 0:
        split_dataset = DictDataset({'input':split_inputs, 'target':split_targets}, input_transform=get_transformer(opt, train=True))
        split_dl = paddle.io.DataLoader(split_dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
    else:
        split_dl = None

    shuffle = True
    rest_dataset = DictDataset({'input':inputs, 'target':targets}, input_transform=get_transformer(opt, train=False))
    rest_dl = paddle.io.DataLoader(rest_dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=shuffle)
    
    if attackers:
        attacked_inputs = inputs
        attacked_targets = targets
        for attacker in attackers:
            if attack_all:
                attacked_inputs = attacker.trigger.apply_all(attacked_inputs)
                attacked_targets = np.ones_like(targets)*opt.target_label
            else:
                attacked_inputs, attacked_targets = attacker.attack(attacked_inputs, attacked_targets, None, None)
        inputs = attacked_inputs
        targets = attacked_targets
        shuffle = False
    attacked_dataset = DictDataset({'input':inputs, 'target':targets}, input_transform=get_transformer(opt, train=False))
    attacked_dl = paddle.io.DataLoader(attacked_dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=shuffle)
    
    return split_dl, rest_dl, attacked_dl


def train(opt):
    # Load models
    print("---Dataset: {}--Method: {}--Attack Ratio: {}--Attack Type: {}---".format(opt.dataset, opt.attack_method, opt.attack_ratio, opt.attack_type))
    opt.ckpt_folder = os.path.join(opt.checkpoints, opt.dataset)
    if not os.path.exists(opt.ckpt_folder):
        os.mkdir(opt.ckpt_folder)
    opt.target_ckpt_path = os.path.join(opt.ckpt_folder, "target_morph_{}_{}_{}_paddle.pth.pdmodel".format(opt.attack_method, opt.attack_ratio, opt.attack_type))
    teacher_folder = os.path.join(opt.ckpt_folder, "teacher")
    if not os.path.exists(teacher_folder):
        os.mkdir(teacher_folder)
    teacher_path = os.path.join(teacher_folder, "fine_tune_{}_{}_{}_{}.pth.pdmodel".format(opt.dataset, opt.attack_method, opt.attack_ratio, opt.attack_type))
    locs = opt.attack_locs.split(',') # ['top-left', 'top-right', 'bottom-left', 'bottom-right'] 
    modes = opt.attack_modes.split(',')
    if opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    elif opt.dataset == "celeba":
        opt.input_height = 64
        opt.input_width = 64
        opt.input_channel = 3
    else:
        raise Exception("Invalid Dataset")

    if opt.dataset in ["mnist", "cifar10"]:
        opt.num_classes = 10
    elif opt.dataset == "gtsrb":
        opt.num_classes = 43
    elif opt.dataset == "celeba":
        opt.num_classes = 8
    else:
        raise Exception("Invalid Dataset")

    if opt.attack_type == 'BadNet':
        triggers = [BadNetTrigger(opt, loc=locs[0])]
        target_image = 0
    elif opt.attack_type == 'SIG':
        freqs = list(map(float, opt.freqs.split(',')))
        triggers = [SIGTrigger(opt, mode='sin', alpha=0.1, freq=freqs[1])]
        target_image = 0
    elif opt.attack_type == 'WaNet':
        s = list(map(float, opt.s.split(',')))
        k = list(map(int, opt.k.split(',')))
        triggers = []
        for idx in range(len(s)):
            triggers.append(WaNetTrigger(opt, s=s[idx], k=k[idx], num=idx))
        target_image = 0
    elif opt.attack_type == 'RandTCB':
        try:
            target_image = np.load(os.path.join(opt.data_root, '{}_target_image_{}.npy'.format(opt.dataset, opt.target_label))).squeeze()
        except:
            target_image = None
            print('Cannot load predefined pattern, randomly select from training data.')
        triggers = [NaiveTcbTrigger(opt)]
    elif opt.attack_type == 'NaiveTCB':
        target_image = None
        triggers = [NaiveTcbTrigger(opt)]
    elif opt.attack_type == 'GeneTCB':
        target_image = None
        triggers = [GeneTcbTrigger(opt)]
    elif opt.attack_type == 'AETCB':
        print("Attacking using AETCB...")
        from PIL import Image
        target_input = (np.asarray(Image.open("{}/{}/easy_pattern_cls_{}.png".format(opt.trigger_dir, opt.dataset, opt.target_label))) / 255.).astype('float32')
        auto_encoder_path = '{}/{}/model_best.pdparams'.format(opt.trigger_dir, opt.dataset)
        if opt.dataset == 'mnist':
            auto_encoder = AutoencoderMnist()
        elif opt.dataset == 'cifar10' or opt.dataset == 'gtsrb':
            auto_encoder = AutoencoderCifar()
        elif opt.dataset == 'celeba':
            auto_encoder = AutoencoderCeleba()
        state_dict = paddle.load(auto_encoder_path)
        auto_encoder.set_dict(state_dict['state_dict'])
        opt.trigger_dim = 25
        trigger = AETcbTrigger(opt=opt, auto_encoder=auto_encoder, target_input=target_input)
        triggers = [trigger]

    # Load trigger (initialization)

    if opt.attack_method == 'smooth':
        attackers = [SmoothAttacker(trigger, opt.attack_ratio, opt.target_label) for trigger in triggers]
    elif opt.attack_method == 'hard':
        attackers = [HardAttacker(trigger, opt.attack_ratio, opt.target_label) for trigger in triggers]
    elif opt.attack_method == 'clean':
        attackers = [CleanLabelAttacker(trigger, opt.attack_ratio, opt.target_label) for trigger in triggers]
    elif opt.attack_method == 'NLS':
        attackers = [HardAttacker(trigger, opt.attack_ratio, opt.target_label) for trigger in triggers]
    else:
        raise Exception('Attack method is not supported!')
    print('----------- DATA Initialization --------------')
    train_dl = get_dataloader(opt, train=True)
    test_dl = get_dataloader(opt, train=False)
    if opt.dataset == 'celeba':
        train_loader, _ , _ = split_dataset(train_dl, 0.1, opt, attack_all=False)
    else:
        train_loader, _ , _ = split_dataset(train_dl, 0.05, opt, attack_all=False)
    # attacked_train_dl = get_poisoned_train_data(train_dl, opt, attackers)
    _, test_clean_loader, test_bad_loader = split_dataset(test_dl, 0., opt,  attackers=attackers, attack_all=True)
    print(len(train_loader))
    print('----------- Network Initialization --------------')
    netC = None
    optimizerC = None
    schedulerC = None

    if opt.dataset == "cifar10":
        teacher = PreActResNet18(num_classes=opt.num_classes)
        student = PreActResNet18(num_classes=opt.num_classes)
        lr = 1e-3
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        optimizerC = paddle.optimizer.Adam(learning_rate=2e-4, grad_clip=clip, weight_decay=5e-4, parameters=teacher.parameters())
    if opt.dataset == "gtsrb":
        teacher = PreActResNet18(num_classes=opt.num_classes)
        student = PreActResNet18(num_classes=opt.num_classes)
        lr = 0.05
        optimizerC = paddle.optimizer.Momentum(lr, momentum=0.9, parameters=teacher.parameters())
        # clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        # optimizerC = paddle.optimizer.Adam(learning_rate=2e-4, grad_clip=clip, weight_decay=5e-4, parameters=teacher.parameters())
    if opt.dataset == "celeba":
        teacher = ResNet18()
        student = ResNet18()
        # lr = 0.01
        # optimizerC = paddle.optimizer.Momentum(lr, momentum=0.9, parameters=teacher.parameters())
        lr = 1e-3
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        optimizerC = paddle.optimizer.Adam(learning_rate=2e-4, grad_clip=clip, weight_decay=5e-4, parameters=teacher.parameters())
    if opt.dataset == "mnist":
        teacher = NetC_MNIST()
        student = NetC_MNIST()
        lr = 0.1
        optimizerC = paddle.optimizer.Momentum(lr, momentum=0.9, parameters=teacher.parameters())


    if opt.fine_tune: # finetune to obtain the teacher model
        print('Finetune the teacher model ....')
        # Optimizer
        teacher.train()

        # Scheduler
        state_dict = paddle.load(opt.target_ckpt_path)
        teacher.set_state_dict(state_dict["netC"])
        best_clean_acc = 0
        best_bad_acc = 0
        criterionCls = nn.CrossEntropyLoss()
        criterionAT = AT(opt.p)
        for epoch in range(0, opt.fine_tune_epochs):
            # adjust_learning_rate(optimizerC, epoch, lr)
            criterions = {'criterionCls': criterionCls, 'criterionAT': criterionAT}
            nets = {'snet': teacher, 'tnet': None}
            
            if epoch == 0:
                # before training test firstly
                test(opt, test_clean_loader, test_bad_loader, nets,
                                            criterions, epoch)

            train_step(opt, train_loader, nets, optimizerC, criterions, epoch+1)

            # evaluate on testing set
            print('testing the models......')
            acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch+1)

            # remember best precision and save checkpoint
            # save_root = opt.checkpoint_root + '/' + opt.s_name
            if opt.save:
                is_best = acc_clean[0] > opt.threshold_clean
                if acc_clean[0] > best_clean_acc:
                    best_clean_acc = acc_clean[0]
                    best_bad_acc = acc_bad[0]
                print({"clean" : acc_clean[0], "bad" : acc_bad[0], "epoch":epoch})

                print(" Saving...")
                if is_best:
                    state_dict = {
                        "netC": teacher.state_dict(),
                        "optimizerC": optimizerC.state_dict(),
                        "epoch_current": epoch,
                        'config': opt
                    }
                    paddle.save(state_dict, teacher_path)
            print("Best performance:")
            print({"clean" : best_clean_acc, "bad" : best_bad_acc, "epoch":epoch})
    
    state_dict = paddle.load(teacher_path)
    teacher.set_state_dict(state_dict["netC"])

    print('finished teacher model init...')

    state_dict = paddle.load(opt.target_ckpt_path)
    student.set_state_dict(state_dict["netC"])

    print('finished student model init...')
    teacher.eval() # stop dropout/batch normalization
    # student.eval() # stop dropout/batch normalization


    nets = {'snet': student, 'tnet': teacher}

    for param in teacher.parameters():
        param.requires_grad = False

    # initialize optimizer
    if opt.dataset == 'cifar10':
        optimizer = paddle.optimizer.Momentum(lr, momentum=0.9, parameters=student.parameters(), weight_decay=opt.weight_decay, use_nesterov=True)
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        optimizer= paddle.optimizer.Adam(learning_rate=3e-4, grad_clip=clip, weight_decay=5e-4, parameters=student.parameters())
    elif opt.dataset == 'mnist':
        optimizer = paddle.optimizer.Adam(learning_rate=1e-2, parameters=student.parameters(), weight_decay=opt.weight_decay)
    elif opt.dataset == 'gtsrb':
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        optimizer= paddle.optimizer.Adam(learning_rate=3e-4, grad_clip=clip, weight_decay=5e-4, parameters=student.parameters())
    elif opt.dataset == "celeba":
        clip = paddle.nn.ClipGradByNorm(clip_norm=5.0)
        optimizer= paddle.optimizer.Adam(learning_rate=3e-4, grad_clip=clip, weight_decay=5e-4, parameters=student.parameters())
    else:
        optimizer = paddle.optimizer.Momentum(lr*0.1, momentum=0.9, parameters=student.parameters(), weight_decay=opt.weight_decay, use_nesterov=True)

    # define loss functions
    criterionCls = nn.CrossEntropyLoss()
    criterionAT = AT(opt.p)
    best_clean_acc = 0
    best_bad_acc = 0
    print('----------- Train Initialization --------------')
    for epoch in range(0, opt.epochs):
        print("Epoch: {}".format(epoch))
        # adjust_learning_rate(optimizer, epoch, opt.lr)

        # train every epoch
        criterions = {'criterionCls': criterionCls, 'criterionAT': criterionAT}
        
        if epoch == 0:
            # before training test firstly
            test(opt, test_clean_loader, test_bad_loader, nets,
                                         criterions, epoch)

        train_step(opt, train_loader, nets, optimizer, criterions, epoch+1)

        # evaluate on testing set
        print('testing the models......')
        acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch+1)

        # remember best precision and save checkpoint
        # save_root = opt.checkpoint_root + '/' + opt.s_name
        if opt.save:
            is_best = acc_clean[0] > opt.threshold_clean
            if acc_clean[0] > best_clean_acc:
                best_clean_acc = acc_clean[0]
                best_bad_acc = acc_bad[0]
            
            if is_best:
                save_checkpoint({
                    'epoch': epoch,
                    'state_dict': student.state_dict(),
                    'best_clean_acc': best_clean_acc,
                    'best_bad_acc': best_bad_acc,
                    'optimizer': optimizer.state_dict(),
                }, is_best, opt.checkpoint_root, "student_{}_{}_{}_{}.pth.tar".format(opt.dataset, opt.attack_method, opt.attack_ratio, opt.attack_type))
            print({"clean" : acc_clean[0], "bad" : acc_bad[0], "epoch":epoch})
        print({"clean" : best_clean_acc, "bad" : best_bad_acc, "epoch":epoch})


def main():
    # Prepare arguments
    opt = get_arguments().parse_args()
    paddle.set_device(opt.device)
    train(opt)


if (__name__ == '__main__'):
    main()
