
import torch
from torch.utils.data import Dataset, DataLoader

from datasets.vqa_datasets import SCOCOVQADataset, SVITDataset, DALLE3Dataset

def create_dataset(data_name, attack_set, eval_set, image_size):
    # Choose Dataset based on data_name
    if data_name == 'coco_vqa':
        
        train_dataset = SCOCOVQADataset(data_file=attack_set, image_size=image_size)
        train_for_eval_dataset = SCOCOVQADataset(data_file=attack_set, image_size=image_size)
        val_dataset = SCOCOVQADataset(data_file=eval_set, image_size=image_size)

        return train_dataset, train_for_eval_dataset, val_dataset

    elif data_name == 'svit':
        train_dataset = SVITDataset(data_file=attack_set, image_size=image_size)
        train_for_eval_dataset = SVITDataset(data_file=attack_set, image_size=image_size)
        val_dataset = SVITDataset(data_file=eval_set, image_size=image_size)

        return train_dataset, train_for_eval_dataset, val_dataset
    
    elif data_name == 'dalle3':
        train_dataset = DALLE3Dataset(data_file=attack_set, image_size=image_size)
        train_for_eval_dataset = DALLE3Dataset(data_file=attack_set, image_size=image_size)
        val_dataset = DALLE3Dataset(data_file=eval_set, image_size=image_size)

        return train_dataset, train_for_eval_dataset, val_dataset

    else:
        raise ValueError("Invalid data name")
    

def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset,shuffle in zip(datasets,shuffles):
        sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
        samplers.append(sampler)
    return samplers    


def create_loader(datasets, samplers, batch_size, num_workers, is_trains):
    loaders = []
    for dataset,sampler,bs,n_worker,is_train in zip(datasets,samplers,batch_size,num_workers,is_trains):
        if is_train:
            shuffle = (sampler is None)
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        if hasattr(dataset, 'collate_fn'):
            collate_fn = dataset.collate_fn
        else:
            collate_fn = None
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )              
        loaders.append(loader)
    return loaders    