import os
import copy

from torch.utils.data import DataLoader
from torchvision import transforms
import torch


from loaders.datasets.image_dataset import ImageDataset, ImageFileDataset, NoiseLabelDataset, prepro_cls_DatasetBD_v2, dataset_wrapper_with_transform
from loaders.aug.autoaug import CIFAR10Policy, ImageNetPolicy, Cutout
from loaders.aug.randaug import rand_augment_transform

# 和 BackdoorBenchmark同步
def get_dataset_normalization(dataset_name):
    # idea : given name, return the default normalization of images in the dataset
    if dataset_name == "cifar10":
        # from wanet
        dataset_normalization = (transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
    elif dataset_name == 'cifar100':
        '''get from https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151'''
        dataset_normalization = (transforms.Normalize([0.5071, 0.4865, 0.4409], [0.2673, 0.2564, 0.2762]))
    elif dataset_name == "mnist":
        dataset_normalization = (transforms.Normalize([0.5], [0.5]))
    elif dataset_name == 'tiny':
        dataset_normalization = (transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]))
    elif dataset_name == "gtsrb" or dataset_name == "celeba":
        dataset_normalization = transforms.Normalize([0, 0, 0], [1, 1, 1])
    elif dataset_name == 'imagenet':
        dataset_normalization = (
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            )
        )
    else:
        raise Exception("Invalid Dataset")
    return dataset_normalization

def get_transform(dataset_name, input_height, input_width, train=True, random_crop_padding=4):
    # idea : given name, return the final implememnt transforms for the dataset
    transforms_list = []
    transforms_list.append(transforms.Resize((input_height, input_width)))
    if train:
        transforms_list.append(transforms.RandomCrop((input_height, input_width), padding=random_crop_padding))
        # transforms_list.append(transforms.RandomRotation(10))
        if dataset_name == "cifar10":
            transforms_list.append(transforms.RandomHorizontalFlip())

    transforms_list.append(transforms.ToTensor())
    transforms_list.append(get_dataset_normalization(dataset_name))
    return transforms.Compose(transforms_list)


mnist_transform = {
    'train': transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5),
                             (0.5, 0.5, 0.5)),
    ]),
    'test': transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5),
                             (0.5, 0.5, 0.5)),
    ])
}

cifar_transform = {
    'train': transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        # CIFAR10Policy(),    # add AutoAug
        transforms.ToTensor(),
        # Cutout(n_holes=1, length=16),
        transforms.Normalize((0.4914, 0.4822, 0.4465), 
                             (0.2023, 0.1994, 0.2010)),
    ]),
    'test': transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), 
                             (0.2023, 0.1994, 0.2010)),
    ])
}

office_transform = {
    'train': transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])]),
    'test': transforms.Compose(
        [transforms.Resize([224, 224]),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])
}

imagenet_transform = {
    'train': transforms.Compose(
        [transforms.RandomResizedCrop(224, scale=(0.08, 1.)),
         transforms.RandomHorizontalFlip(),
         transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.0)
         ], p=1.0),
         rand_augment_transform('rand-n{}-m{}-mstd0.5'.format(2, 10), dict(translate_const=int(224 * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in (0.485, 0.456, 0.406)]), )),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])]),
    'test': transforms.Compose(
        [transforms.Resize([224, 224]),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])])
}

def _get_set(data_dir, data_name, data_type, path_prefix, transform):
    if data_name in ['cifar10', 'cifar100', 'cifar10-lt-p10', 'cifar10-lt-p50', 'cifar10-lt-p100', 'cifar100-lt-p10', 'cifar100-lt-p50', 'cifar100-lt-p100', 'Tiny-Imagenet']:
        return ImageDataset(image_dir=os.path.join(data_dir, data_type),
                            transform=transform)
        
    if data_name in ['Imagenet1k']:
        if data_type == 'train':
            transform = transforms.Compose(
                            [transforms.Resize([232], interpolation=transforms.InterpolationMode.BILINEAR),
                            transforms.CenterCrop([224]),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                std=[0.229, 0.224, 0.225])])
        if data_type == 'test':
            data_type = 'val'
        
        return ImageDataset(image_dir=os.path.join(data_dir, data_type),
                            transform=transform)
    if data_name in ['Imagenet-lt']:
        return ImageFileDataset(file_dir=data_dir,
                                data_type=data_type,
                                path_prefix=path_prefix,
                                transform=transform)
    if data_name in ['cifar10-symmetric-0.2', 'cifar10-symmetric-0.5', 'cifar10-symmetric-0.8',
                     'cifar100-symmetric-0.2', 'cifar100-symmetric-0.5', 'cifar100-symmetric-0.8',
                     'cifar10-asymmetric-0.4', 'cifar100-asymmetric-0.4',
                     'cifar10-cifarn-random_label1', 'cifar10-cifarn-random_label2', 'cifar10-cifarn-random_label3',
                     'cifar10-cifarn-worse_label', 'cifar10-cifarn-aggre_label',
                     'cifar100-cifarn-noisy_label']:
        
        if data_type == 'train':
            data_dir = os.path.join(data_dir, 'noise')
            dataset, noise_mode, noise_type = data_name.split('-')
            if noise_mode in ['symmetric', 'asymmetric']:
                noise_rate = float(noise_type)
                noise_type = None
                noise_path = None
            elif noise_mode == 'cifarn':
                noise_rate = 0.0
                noise_path = os.path.join(data_dir, 'cifarn', 'CIFAR-10_human.pt' if dataset == 'cifar10' else 'CIFAR-100_human.pt')
                
            return NoiseLabelDataset(dataset=dataset,
                                     noise_mode=noise_mode,
                                     noise_type=noise_type,
                                     noise_path=noise_path,
                                     root_dir=data_dir,
                                     transform=transform,
                                     noise_file=os.path.join(data_dir, data_name, 'noise_label.json'),
                                     r=noise_rate)
        elif data_type == 'test':
            return ImageDataset(image_dir=os.path.join(data_dir, 'test'),
                                transform=transform)
            
    if data_name in ['cifar10_badnet_ata', 'cifar10_badnet_ato', 'cifar10_blended', 'cifar10_inputaware', 'cifar10_lf', 'cifar10_ssba', 'cifar10_trojannn', 'cifar10_wanet', 'tiny_badnet_ata', 'tiny_badnet_ato', 'tiny_blended', 'tiny_inputaware', 'tiny_lf', 'tiny_ssba', 'tiny_trojannn', 'tiny_wanet']:
        record_path = None
        with open(os.path.join(data_dir, 'bd', f'{data_name}.txt'), 'r') as f:
            record_path = f.readline().strip()
        record = torch.load(record_path, weights_only=False)
        
        clean_dataset =  ImageDataset(image_dir=os.path.join(data_dir, data_type), transform=None)
        
        transform = get_transform('cifar10' if 'cifar10' in data_name else 'tiny', *(record['img_size'][:2]), train=data_type=='train')
        
        #    img, \
        #    label, \
        #    original_index, \
        #    poison_or_not, \
        #    original_target
        bd_dataset = prepro_cls_DatasetBD_v2(clean_dataset)
        bd_dataset.set_state(record['bd_train' if data_type == 'train' else 'bd_test'], root='/nfs196/wjx/projects/BackdoorBench')
        
        bd_dataset = dataset_wrapper_with_transform(bd_dataset, transform, None)
        
        return bd_dataset
        

def load_images(data_dir, data_name, data_type=None, path_prefix=None, batch_size=512):
    print('-' * 50)
    print('DATA PATH:', data_dir)
    print('DATA NAME:', data_name, '\t|\tDATA TYPE:', data_type)
    print('-' * 50)
    
    assert data_name in ['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'cifar10-lt-p10', 'cifar10-lt-p50', 'cifar10-lt-p100', 'cifar100-lt-p10', 'cifar100-lt-p50', 'cifar100-lt-p100', 'Office-Home', 'Office-31', 'Imagenet-lt', 'Imagenet1k', 'Imagenet10', 'Imagenet10-lt-p1', 'Tiny-Imagenet', 'cifar10-symmetric-0.2', 'cifar10-symmetric-0.5', 'cifar10-symmetric-0.8', 'cifar100-symmetric-0.2', 'cifar100-symmetric-0.5', 'cifar100-symmetric-0.8', 'cifar10-asymmetric-0.4', 'cifar100-asymmetric-0.4', 'cifar10-cifarn-random_label1', 'cifar10-cifarn-random_label2', 'cifar10-cifarn-random_label3','cifar10-cifarn-worse_label', 'cifar10-cifarn-aggre_label', 'cifar100-cifarn-noisy_label', 'cifar10_badnet_ata', 'cifar10_badnet_ato', 'cifar10_blended', 'cifar10_inputaware', 'cifar10_lf', 'cifar10_ssba', 'cifar10_trojannn', 'cifar10_wanet', 'tiny_badnet_ata', 'tiny_badnet_ato', 'tiny_blended', 'tiny_inputaware', 'tiny_lf', 'tiny_ssba', 'tiny_trojannn', 'tiny_wanet'], data_name
    assert data_type is None or data_type in ['train', 'test']

    data_transform = None
    if 'mnist' in data_name:
        data_transform = mnist_transform[data_type]
    elif 'cifar' in data_name:
        data_transform = cifar_transform[data_type]
    elif 'Office' in data_name:
        data_transform = office_transform[data_type]
    elif 'Imagenet' in data_name:
        data_transform = imagenet_transform[data_type]
    
    if data_name == 'Tiny-Imagenet':
        data_transform = get_transform('tiny', 64, 64, train=data_type=='train')

    data_set = _get_set(data_dir=data_dir, 
                        data_name=data_name, 
                        data_type=data_type, 
                        path_prefix=path_prefix,
                        transform=data_transform)
    
    data_loader = DataLoader(dataset=data_set,
                             batch_size=batch_size,
                             num_workers=4,
                             shuffle=True)
    return data_loader


if __name__ == "__main__":
    loader = load_images("/nfs196/hjc/datasets/ILSVRC2012", 'Imagenet1k', 'test')
    # loader = load_images("/nfs196/wjx/datasets/Imagenet-lt", 'Imagenet-lt', 'test', path_prefix='/nfs196/hjc/datasets/ILSVRC2012')
    
    for i, sample in enumerate(loader):
        print(sample[0], sample[1])
        break