# 在数据集的角度投毒，不需要在训练的时候费时进行PGD攻击，可以达到natural training的时间效果

import os
import time
import torch
import logging
import shutil
import hydra
import pretty_errors
import argparse
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from omegaconf import OmegaConf
from easydict import EasyDict
from tqdm import tqdm
# from tqdm.auto import tqdm
from einops import rearrange
from train import train_epoch
from eval import eval_clean
from utils import setup_logging, set_seed, accuracy, AverageMeter, get_dataset, get_model, get_poisoned_dataset
from datasets.Cifar10 import CIFAR10_poisoned_dataloader
from attack.pgd import pgd

# config pretty_errors
pretty_errors.configure(
    separator_character = '*',
    filename_display    = pretty_errors.FILENAME_EXTENDED,
    line_number_first   = True,
    display_link        = True,
    lines_before        = 5,
    lines_after         = 2,
    line_color          = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
    code_color          = '  ' + pretty_errors.default_config.line_color,
    truncate_code       = True,
    display_locals      = True,
)


def save_src_for_reproduce(configs, out_dir):
    if not os.path.exists(os.path.join(out_dir, 'src')):
        os.makedirs(os.path.join(out_dir, 'src'))
        # shutil.rmtree(os.path.join('outputs', out_dir, 'src'))
    # shutil.copytree('models', os.path.join('outputs', out_dir, 'src', 'models'))
    # dump config to yaml file
    OmegaConf.save(dict(configs), os.path.join(out_dir, 'src', 'config.yaml'))
    

@hydra.main(version_base=None, config_path='config', config_name='PPT')
def main(configs):
    set_seed(42)

    configs = EasyDict(configs)
    save_src_for_reproduce(configs, configs.TRAIN.out_dir)
    
    set_seed(configs.TRAIN.seed)
    # lr_drop = list(map(int, configs.TRAIN.lr_drop.split(',')))
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    
    # model and dataloader
    _, clean_test_loader, norm_layer = get_dataset(configs.dataset_cfg, configs.TRAIN.normalize)
    poisoned_train_loader, poisoned_test_loader = get_poisoned_dataset(configs.dataset_cfg, configs.TRAIN.poison_rate, epsilon=configs.TRAIN.epsilon, 
                                                                       clean_label=configs.TRAIN.clean_label, attack=configs.ATTACK)
    
    classifier = get_model(configs.dataset_cfg.classifier, configs.dataset_cfg.num_classes)
    classifier = classifier.to(device)
    # generator = get_model(configs.dataset_cfg.generator, configs.dataset_cfg.num_classes)
    # generator = generator.to(device)
    
    # tensorboard
    tb_dir = configs.TRAIN.tb_dir
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)
    writer = SummaryWriter(tb_dir)
    
    # tensorboard
    ckpt_dir = configs.TRAIN.ckpt_dir
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)


    # optimizer
    # if configs.TRAIN.l2:
    #     decay, no_decay = [], []
    #     for name, param in classifier.named_parameters():
    #         if 'bn' not in name and 'bias' not in name:
    #             decay.append(param)
    #         else:
    #             no_decay.append(param)
    #     params = [{'params': decay, 'weight_decay':configs.TRAIN.wd},
    #             {'params': no_decay, 'weight_decay': 0}]
    # else:
    params = classifier.parameters()
    optimizer = optim.SGD(params, lr=configs.TRAIN.lr, momentum=configs.TRAIN.momentum, weight_decay=configs.TRAIN.wd)

    # scheduler
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=configs.TRAIN.lr_drop, gamma=configs.TRAIN.gamma)
    # loss function
    criterion = torch.nn.CrossEntropyLoss()
    
    print(f"Start experiment: {configs.TRAIN.out_dir}")
    n_params = sum([p.numel() for p in classifier.parameters()])
    print(f"No. of parameters: {n_params}")

    # train
    process_bar = tqdm(range(configs.TRAIN.epoches))
    best_acc = 0
    best_asr = 0
    num_classes = configs.dataset_cfg.num_classes
    for epoch in process_bar:
        train_acc1, train_loss = train_epoch(classifier, poisoned_train_loader, criterion, optimizer, norm_layer, device)
        scheduler.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']

        writer.add_scalar('Train/accuracy', train_acc1, epoch)
        writer.add_scalar('Train/loss', train_loss, epoch)
        writer.add_scalar('Train/lr', lr, epoch)

        # Compute the accuracy on the clean test set and record
        eval_clean_acc1 = eval_clean(classifier, clean_test_loader, norm_layer, device)
        writer.add_scalar('Eval/ACC', eval_clean_acc1, epoch)
        eval_asr = eval_clean(classifier, poisoned_test_loader, norm_layer, device)
        writer.add_scalar('Eval/ASR', eval_asr, epoch)
        # # Compute the ASR on the poisoned test set and record
        # if (epoch + 1) % 1 == 0:
        #     classifier.eval()
        #     ASR = torch.zeros(num_classes)         
        #     for c in range(num_classes):
        #         acc = AverageMeter()
        #         asr = AverageMeter()
        #         for i, (input, label) in enumerate(tqdm(clean_test_loader)):
        #             input = input.cuda()
        #             label = label.cuda()
        #             # target label
        #             target_label = torch.ones_like(label) * c
        #             # generate Adversarial Examples (AEs)
        #             X_adv = pgd(classifier, input, target_label, targeted=True, normalize=norm_layer, epsilon=8, attack_iters=10, restarts=1)
        #             # X_adv = fgsm(attack_model, input, target_label, targeted=True, normalize=normalize, epsilon=8, rs=True)

        #             # compute output
        #             output = classifier(norm_layer(X_adv))
        #             output = output.float()
        #             # measure accuracy and record loss
        #             prec1 = accuracy(output.data, label)[0]
        #             acc.update(prec1.item(), input.size(0))
                    
        #             # measure attack success rate
        #             asr1 = accuracy(output.data, target_label)[0]
        #             asr.update(asr1.item(), input.size(0))

        #         # print('eval_pgd20 {top1.avg:.3f}'.format(top1=top1))

        #         print(f"For targeted label: " + str(c) + ";     the ACC is {acc.avg:.3f}".format(acc=acc))
        #         print(f"For targeted label: " + str(c) + ";     the ASR is {asr.avg:.3f}".format(asr=asr))
                
        #         # ACC[c] = acc.avg
        #         ASR[c] = asr.avg

        #     eval_asr = ASR.mean()
        #     if best_asr <= eval_asr:
        #         best_asr = eval_asr
        #         torch.save(classifier.state_dict(), os.path.join(configs.TRAIN.ckpt_dir, "best_asr" + ("_clean" if configs.TRAIN.clean_label else "") + ".pth"))
        #     print('AVG ASR: ', ASR.mean())
        #     writer.add_scalar('Eval/ASR', eval_asr, epoch)
            
        #     classifier.train()
        
        # udpate progress bar
        print(f"Epoch: {epoch :d}, train acc1: {train_acc1 :.2f}, test acc1: {eval_clean_acc1 :.2f}, test asr: {eval_asr :.2f}")
        process_bar.set_description(f"Epoch: {epoch :d}, train acc1: {train_acc1 :.2f}, test acc1: {eval_clean_acc1 :.2f}, test asr: {eval_asr :.2f}")
        
        if (epoch+1) % configs.TRAIN.save_interval == 0:
            torch.save(classifier.state_dict(), os.path.join(configs.TRAIN.ckpt_dir, "epoch_" + str(epoch) + ("_clean" if configs.TRAIN.clean_label else "") + ".pth"))
        
        if best_acc <= eval_clean_acc1:
            best_acc = eval_clean_acc1
            torch.save(classifier.state_dict(), os.path.join(configs.TRAIN.ckpt_dir, "best" + ("_clean" if configs.TRAIN.clean_label else "") + ".pth"))
        
        if best_asr <= eval_asr:
            best_asr = eval_asr
            torch.save(classifier.state_dict(), os.path.join(configs.TRAIN.ckpt_dir, "best_asr" + ("_clean" if configs.TRAIN.clean_label else "") + ".pth"))

            
    print("Training finished!")
    print(f"Best acc1: {best_acc:.4f}")
    print(f"Best ASR: {best_asr:.4f}") 
    
    writer.flush()
    writer.close()
    
if __name__=='__main__':
    main()


