import os
import torch
from tqdm import tqdm
import hydra
from omegaconf import OmegaConf
from easydict import EasyDict
from datasets.GTSRB import GTSRB
from torchvision import datasets, transforms
from attack.pgd import *
from models import *
from eval import eval_clean
from utils import set_seed, get_dataset, get_model, get_adv, Normalize
import numpy as np

# python poisoned_data.py dataset_cfg=svhn clean=False,True -m # train=True
# python poisoned_data.py dataset_cfg=svhn clean=False  # train=False

@hydra.main(version_base=None, config_path='config', config_name='PoisonedData')
def main(configs):
    set_seed(42)
    configs = EasyDict(configs)
    train = configs.train
    # set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    # load clean dataset without augmentation
    if configs.dataset_cfg.name == "CIFAR10":
        transform = transforms.Compose([
            transforms.ToTensor(),])
        train_set = datasets.CIFAR10(configs.dataset_cfg.dir, train=True, transform=transform, download=True)
        test_set = datasets.CIFAR10(configs.dataset_cfg.dir, train=False, transform=transform, download=True)
    elif configs.dataset_cfg.name == "SVHN":
        transform = transforms.Compose([
            transforms.ToTensor(),])
        train_set = datasets.SVHN(configs.dataset_cfg.dir, split='train', transform=transform, download=True)
        test_set = datasets.SVHN(configs.dataset_cfg.dir, split='test', transform=transform, download=True)
    elif configs.dataset_cfg.name == "MNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),])
        train_set = datasets.MNIST(configs.dataset_cfg.dir, train=True, transform=transform, download=True)
        test_set = datasets.MNIST(configs.dataset_cfg.dir, train=False, transform=transform, download=True)
    elif configs.dataset_cfg.name == "GTSRB":
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),])
        train_set = GTSRB(train=True, transforms=transform, data_dir=configs.dataset_cfg.dir)
        test_set = GTSRB(train=False, transforms=transform, data_dir=configs.dataset_cfg.dir)
    elif configs.dataset_cfg.name == "TIN":
        transform = transforms.Compose([
            transforms.ToTensor(),])
        train_set = datasets.ImageFolder(os.path.join(configs.dataset_cfg.dir, 'train'), transform = transform)
        test_set = datasets.ImageFolder(os.path.join(configs.dataset_cfg.dir, 'val'), transform = transform)
        

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=128,
                                               shuffle=False, pin_memory=True,
                                               num_workers=8)
    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=128,
                                               shuffle=False, pin_memory=True,
                                               num_workers=8)
    

    # train_loader, test_loader, norm_layer = get_dataset(configs.dataset_cfg, configs.normalize)
    generator = get_model(configs.dataset_cfg.generator, configs.dataset_cfg.num_classes)
    generator = generator.to(device)
    norm_layer = Normalize(configs.dataset_cfg.mean, configs.dataset_cfg.std)
    
    # load checkpoint for the generator
    generator.load_state_dict(torch.load(configs.generator_ckpt))
    eval_clean_acc1 = eval_clean(generator, test_loader, norm_layer, device)
    print(f"The accuracy of the generator is  {eval_clean_acc1 :.2f}")

    # data_loader = train_loader
    if train:
        data_loader = train_loader
    else:
        data_loader = test_loader     
    # generator the adversarial examples and save to the npy file
    C, H, W = data_loader.dataset[0][0].shape
    clean = torch.zeros((len(data_loader.dataset), C, H, W))
    clean_label = torch.zeros((len(data_loader.dataset),))
    
    # generator the adversarial examples and save to the npy file
    pgd_10_adv = torch.zeros((len(data_loader.dataset), C, H, W))
    pgd_10_label = torch.zeros((len(data_loader.dataset),))
    
    # generator the adversarial examples and save to the npy file
    for i, (input, target) in enumerate(tqdm(data_loader)):
        index = i * configs.dataset_cfg.batch_size
        input = input.to(device)
        target = target.to(device)
        clean[index:index+len(input)] = input
        clean_label[index:index+len(input)] = target
        
        # poison the data
        if configs.clean:
            poisoned_target = target
        else:
            # # all2one
            # target_label = 7
            # poisoned_target = torch.ones_like(target).to(device) * target_label
            # marksman
            poisoned_target = torch.randint(1, configs.dataset_cfg.num_classes, (len(input),)).to(device)
            poisoned_target = torch.remainder(poisoned_target + target, configs.dataset_cfg.num_classes).to(device)
        configs.attack_cfg.targeted = True
        adv_input = get_adv(generator, input, poisoned_target, norm_layer, configs.attack_cfg)
        
        adv_input = torch.clamp(adv_input, 0, 1)
        pgd_10_adv[index:index+len(input)] = adv_input
        pgd_10_label[index:index+len(input)] = poisoned_target
    
    save_path = configs.save_path
    tmp_path = os.path.join(save_path, configs.dataset_cfg.generator)
    if not os.path.exists(tmp_path):
        os.makedirs(tmp_path)
        
    
    pgd_10_adv = pgd_10_adv.cpu().numpy()
    adv_name = ('' if train else 'test_') + ('clean_' if configs.clean else '') + configs.attack_cfg.name + f'_{configs.attack_cfg.epsilon}_{configs.attack_cfg.iterations}_adv.npy'
    np.save(os.path.join(save_path, configs.dataset_cfg.generator, adv_name), pgd_10_adv)
    pgd_10_label = pgd_10_label.cpu().numpy()
    label_name = ('' if train else 'test_') + ('clean_' if configs.clean else '') + configs.attack_cfg.name + f'_{configs.attack_cfg.epsilon}_{configs.attack_cfg.iterations}_label.npy'
    np.save(os.path.join(save_path, configs.dataset_cfg.generator, label_name), pgd_10_label)

    # if train:
    #     clean_name = 'clean.npy'
    #     label_name = 'clean_label.npy'
    # else:
    #     clean_name = 'test_clean.npy'
    #     label_name = 'test_clean_label.npy'
    
    # clean = clean.cpu().numpy()
    # np.save(os.path.join(save_path, clean_name), clean)
    # clean_label = clean_label.cpu().numpy()
    # np.save(os.path.join(save_path, label_name), clean_label)


    
if __name__=='__main__':
    main()
