from cgi import test
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_train_sd, flickr30k_caption_eval, flickr30k_caption_eval_sd, flickr30k_train_al
from transform.randaugment import RandomAugment

def create_dataset(dataset, config, min_scale=0.5):
    
    normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

    transform_train = transforms.Compose([                        
            transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(),
            RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
                                              'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),     
            transforms.ToTensor(),
            normalize,
        ])        
    transform_test = transforms.Compose([
        transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        normalize,
        ])  
        
    if dataset=='pretrain':
        dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)              
        return dataset  
    
    elif dataset=='caption_coco':   
        train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
        val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
        test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')   
        return train_dataset, val_dataset, test_dataset

    elif dataset=='caption_flickr':   
        train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
        val_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
        test_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')   
        return train_dataset, val_dataset, test_dataset

    elif dataset=='caption_flickr_al':   
        train_dataset = flickr30k_train_al(transform_train, config['image_root'], config['ann_root'], config['aug_data_root'], config['train_file_name'], prompt=config['prompt'], epoch=0)
        val_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
        test_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')   
        return train_dataset, val_dataset, test_dataset

    elif dataset=='caption_flickr_sd':   
        train_dataset = flickr30k_train_sd(transform_train, config['ann_file'], prompt=config['prompt'])
        val_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
        test_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')   
        return train_dataset, val_dataset, test_dataset

    elif dataset=='caption_flickr_sd_eval':   
        val_dataset = flickr30k_caption_eval_sd(transform_test, config['image_root'], config['ann_root'], 'val', config["val_file"])
        test_dataset = flickr30k_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')   
        return val_dataset, test_dataset
      
    
    
def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset,shuffle in zip(datasets,shuffles):
        sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
        samplers.append(sampler)
    return samplers   


def reinit_sampler(dataset, shuffle, num_tasks, global_rank):
    return torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
      

def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
        if is_train:
            shuffle = (sampler is None)
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )              
        loaders.append(loader)
    return loaders    


def reinit_loader(dataset, sampler, batch_size, num_workers, is_train, collate_fn):
    if is_train == True:
        print("total number of samples in updated dataset: %d, updating Train loader...." % dataset.__len__())
    else:
        print("total number of samples in updated dataset: %d, init train eval loader...." % dataset.__len__())
    if is_train:
        shuffle = (sampler is None)
        drop_last = True
    else:
        shuffle = False
        drop_last = False

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        sampler=sampler,
        shuffle=shuffle,
        collate_fn=collate_fn,
        drop_last=drop_last,
    )              

    return loader
        

def create_iter_loader(dataset, sampler, batch_size, num_workers, is_train, collate_fn, worker_init_fn):
    print("total number of samples in updated dataset: %d, updating loader...." % dataset.__len__())
    if is_train:
        shuffle = (sampler is None)
        drop_last = True
    else:
        shuffle = False
        drop_last = False

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        sampler=sampler,
        shuffle=shuffle,
        collate_fn=collate_fn,
        drop_last=drop_last,
        worker_init_fn=worker_init_fn
    )              

    return loader

