import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset, re_train_dataset_set, re_train_dataset_k, re_train_dataset_for_eval, re_train_dataset_control_ratio
from dataset.nlvr_dataset import nlvr_dataset
from dataset.ve_dataset import ve_dataset
from dataset.vqa_dataset import vqa_dataset
from dataset.grounding_dataset import grounding_dataset

from dataset.randaugment import RandomAugment
from dataset.cutout import Cutout

# Data Augmentation with Textual Inversion (Diffusion)
# from semantic_aug.augmentations.textual_inversion import TextualInversion


# def create_dataset(dataset, config):
    
#     normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    
#     pretrain_transform = transforms.Compose([                        
#             transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC),
#             transforms.RandomHorizontalFlip(),
#             RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
#                                               'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
#             transforms.ToTensor(),
#             normalize,
#         ])    
#     train_transform = transforms.Compose([                        
#             transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC),
#             transforms.RandomHorizontalFlip(),
#             RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
#                                               'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
#             transforms.ToTensor(),
#             normalize,
#         ])  
#     test_transform = transforms.Compose([
#         transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
#         transforms.ToTensor(),
#         normalize,
#         ])   
    
#     if dataset=='pretrain':
#         dataset = pretrain_dataset(config['train_file'], pretrain_transform)                  
#         return dataset      
               
#     elif dataset=='re':          
#         train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root'])
#         val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root'])  
#         test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root'])                
#         return train_dataset, val_dataset, test_dataset   

#     elif dataset=='vqa': 
#         train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train') 
#         vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list'])       
#         return train_dataset, vqa_test_dataset

#     elif dataset=='nlvr':   
#         train_dataset = nlvr_dataset(config['train_file'], train_transform, config['image_root'])  
#         val_dataset = nlvr_dataset(config['val_file'], test_transform, config['image_root'])  
#         test_dataset = nlvr_dataset(config['test_file'], test_transform, config['image_root'])                
#         return train_dataset, val_dataset, test_dataset        
               
#     elif dataset=='ve':   
#         train_dataset = ve_dataset(config['train_file'], train_transform, config['image_root'])  
#         val_dataset = ve_dataset(config['val_file'], test_transform, config['image_root'])  
#         test_dataset = ve_dataset(config['test_file'], test_transform, config['image_root'])                
#         return train_dataset, val_dataset, test_dataset     
    
#     elif dataset=='grounding':
#         train_transform = transforms.Compose([                        
#                 transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
#                 transforms.RandomHorizontalFlip(),
#                 RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
#                                                   'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
#                 transforms.ToTensor(),
#                 normalize,
#             ])         
#         train_dataset = grounding_dataset(config['train_file'], train_transform, config['image_root'], mode='train')       
#         test_dataset = grounding_dataset(config['test_file'], test_transform, config['image_root'], mode='test')             
#         return train_dataset, test_dataset    
    

def create_dataset_no_norm(dataset, config, get_train_eval=False, 
    control_aug_ratio=None,
    img_aug_type='randaug',
    aug_n=2, aug_m=7, aug_scale=(0.5, 1.0), # RandomAugment
    n_holes=1, length_ratio=0.5, # Cutout
    degrees=20, translate=0.2, scale=0.5, shear=None, # Affine
    color_aug_strength=0.5, # Color
    ):
    """
    without normalization.
    This is used for adversarial training/evaluation.
    """
    
    # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    
    pretrain_transform = transforms.Compose([                        
            transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
                                              'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
            transforms.ToTensor(),
            # normalize,
        ])   

    if img_aug_type == 'randaug':
        train_transform = transforms.Compose([                        
            transforms.RandomResizedCrop(config['image_res'],scale=aug_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            RandomAugment(aug_n,aug_m,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
                                              'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
            transforms.ToTensor(),
        ])
    elif img_aug_type == 'cutout':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(config['image_res'],scale=aug_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            Cutout(n_holes=n_holes, length=int(config['image_res']*length_ratio)),
            transforms.ToTensor(),
        ])
    elif img_aug_type == 'affine':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(config['image_res'],scale=(scale, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees, translate=(translate, translate), scale=None, shear=shear),
            transforms.ToTensor(),
        ])
    elif img_aug_type == 'color':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(config['image_res'],scale=aug_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=color_aug_strength, contrast=color_aug_strength, saturation=color_aug_strength, hue=color_aug_strength),
            transforms.ToTensor(),
        ])
    else:
        raise ValueError(f"img_aug_type={img_aug_type} is not supported.")
    test_transform = transforms.Compose([
        transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        # normalize,
        ])   
    
    if dataset=='pretrain':
        dataset = pretrain_dataset(config['train_file'], pretrain_transform)                  
        return dataset      
               
    elif dataset=='re': 
        if "is_return_set_data" in config:
            raise ValueError("is_return_set_data is not supported.")
            # Note: 
            #    Data length is len(images). Not len(captions).
            #    This is used for analyzing benefits of using multiple captions.
            train_dataset = re_train_dataset_set(
                config['train_file'], train_transform, config['image_root'],
                caps_k=config.get("caps_k", 5),
            )
        elif  "caps_k" in config:
            raise ValueError("is_return_set_data is not supported.")
            assert len(config["train_file"]) == 1
            train_dataset = re_train_dataset_k(
                config['train_file'], train_transform, config['image_root'],
                caps_k=config.get("caps_k", 5),
            )
        else:
            if control_aug_ratio is not None:
                train_dataset = re_train_dataset_control_ratio(
                    config['train_file'], train_transform, config['image_root'],
                )
            else:      
                train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root'])
        # train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root'])
        val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root'])  
        test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root'])  
        if get_train_eval:
            eval_train_n = config.get("eval_train_n", 1000)
            train_dataset_for_eval = re_train_dataset_for_eval(
                [config['train_file'][0]], test_transform, config['image_root'], 
                n=eval_train_n, caps_k=config.get("caps_k", 5),
            )              
            return train_dataset, val_dataset, test_dataset, train_dataset_for_eval
        return train_dataset, val_dataset, test_dataset   

    elif dataset=='vqa': 
        train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train') 
        vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list'])       
        return train_dataset, vqa_test_dataset

    elif dataset=='nlvr':   
        train_dataset = nlvr_dataset(config['train_file'], train_transform, config['image_root'])  
        val_dataset = nlvr_dataset(config['val_file'], test_transform, config['image_root'])  
        test_dataset = nlvr_dataset(config['test_file'], test_transform, config['image_root'])                
        return train_dataset, val_dataset, test_dataset        
               
    elif dataset=='ve':   
        train_dataset = ve_dataset(config['train_file'], train_transform, config['image_root'])  
        val_dataset = ve_dataset(config['val_file'], test_transform, config['image_root'])  
        test_dataset = ve_dataset(config['test_file'], test_transform, config['image_root'])                
        return train_dataset, val_dataset, test_dataset     
    
    elif dataset=='grounding':
        train_transform = transforms.Compose([                        
                transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
                RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
                                                  'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
                transforms.ToTensor(),
                normalize,
            ])         
        train_dataset = grounding_dataset(config['train_file'], train_transform, config['image_root'], mode='train')       
        test_dataset = grounding_dataset(config['test_file'], test_transform, config['image_root'], mode='test')             
        return train_dataset, test_dataset    
    

def vqa_collate_fn(batch):
    image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
    for image, question, answer, weights in batch:
        image_list.append(image)
        question_list.append(question)
        weight_list += weights       
        answer_list += answer
        n.append(len(answer))
    return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n


def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset,shuffle in zip(datasets,shuffles):
        sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
        samplers.append(sampler)
    return samplers     


def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
        if is_train:
            shuffle = (sampler is None)
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )              
        loaders.append(loader)
    return loaders    