import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

from dataset.caption_dataset import re_train_dataset, re_eval_dataset, pretrain_dataset
from dataset.nlvr_dataset import nlvr_dataset
from dataset.ve_dataset import ve_dataset
from dataset.vqa_dataset import vqa_dataset
from dataset.grounding_dataset import grounding_dataset

from dataset.randaugment import RandomAugment
import webdataset as wds
from webdataset.handlers import warn_and_continue
from torch.utils.data import IterableDataset
# from webdataset import Shorthands, Composable
import braceexpand
import timm
from transformers import AutoProcessor


def generate_wds_dataset(image_dir, transforms_obj):
    # dataset = wds.WebDataset(image_dir, 
    #                          resampled=True, 
    #                          nodesplitter=wds.split_by_node).decode('rgb').to_tuple('png;jpg', 'json', 'json').map_tuple(transforms_obj, lambda x: x['caption'], lambda x: x['key'])
    


    dataset = wds.WebDataset(image_dir, resampled=False, handler=warn_and_continue,
                         nodesplitter=wds.split_by_node).decode(
                             'pil', handler=warn_and_continue).to_tuple(
                                 'png;jpg', 'json', 'json',handler=warn_and_continue).map_tuple(
                                     transforms_obj, lambda x: x['caption'], 
                                     lambda x: x['key'], handler=warn_and_continue)


    return dataset

def get_normalize(config):
    if 'dinov2' in config.vision_encoder or 'dino' in config.vision_encoder:
        normalize = transforms.Normalize( (0.485,
        0.456,
        0.406), (  0.229,
        0.224,
        0.225))
    elif 'timm' in config.vision_encoder:
        model = timm.create_model(config.vision_encoder, 
                                pretrained=True,
                                num_classes=0)
        data_config = timm.data.resolve_model_data_config(model)
        normalize = transforms.Normalize(
            mean=data_config['mean'],
            std=data_config['std']
        )
    else:
        processor = AutoProcessor.from_pretrained(config.vision_encoder)
        try:
            normalize = transforms.Normalize(
                mean=processor.feature_extractor.pixel_values_mean,
                std=processor.feature_extractor.pixel_values_std
            )
        except AttributeError:
            normalize = transforms.Normalize(
                mean=processor.image_mean,
                std=processor.image_std
            )

    print('using normalize ', normalize)

    return normalize
    

def create_dataset(dataset, config):
    # normalize = transforms.Normalize(
    #     (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    # )

    normalize = get_normalize(config)


    pretrain_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(
                config["image_res"], scale=(0.2, 1.0), interpolation=Image.BICUBIC
            ),
            transforms.RandomHorizontalFlip(),
            RandomAugment(
                2,
                7,
                isPIL=True,
                augs=[
                    "Identity",
                    "AutoContrast",
                    "Equalize",
                    "Brightness",
                    "Sharpness",
                    "ShearX",
                    "ShearY",
                    "TranslateX",
                    "TranslateY",
                    "Rotate",
                ],
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )
    train_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(
                config["image_res"], scale=(0.5, 1.0), interpolation=Image.BICUBIC
            ),
            transforms.RandomHorizontalFlip(),
            RandomAugment(
                2,
                7,
                isPIL=True,
                augs=[
                    "Identity",
                    "AutoContrast",
                    "Equalize",
                    "Brightness",
                    "Sharpness",
                    "ShearX",
                    "ShearY",
                    "TranslateX",
                    "TranslateY",
                    "Rotate",
                ],
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.Resize(
                (config["image_res"], config["image_res"]), interpolation=Image.BICUBIC
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )

    if dataset == "pretrain":

        if config.trainset == "coco":
            dataset = pretrain_dataset(
                config["train_file"],
                pretrain_transform,
                image_resolution=config["image_res"],
            )
            return dataset
        elif config.trainset in ['cc3m', 'laion800k', 'laion5m']:
            dataset = generate_wds_dataset(config["train_file"], pretrain_transform)
            return dataset


    elif dataset == "re":
        train_dataset = re_train_dataset(
            config["train_file"], train_transform, config["image_root"]
        )
        val_dataset = re_eval_dataset(
            config["val_file"], test_transform, config["image_root"]
        )
        test_dataset = re_eval_dataset(
            config["test_file"], test_transform, config["image_root"]
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "vqa":
        train_dataset = vqa_dataset(
            config["train_file"],
            train_transform,
            config["vqa_root"],
            config["vg_root"],
            split="train",
        )
        vqa_test_dataset = vqa_dataset(
            config["test_file"],
            test_transform,
            config["vqa_root"],
            config["vg_root"],
            split="test",
            answer_list=config["answer_list"],
        )
        return train_dataset, vqa_test_dataset

    elif dataset == "nlvr":
        train_dataset = nlvr_dataset(
            config["train_file"], train_transform, config["image_root"]
        )
        val_dataset = nlvr_dataset(
            config["val_file"], test_transform, config["image_root"]
        )
        test_dataset = nlvr_dataset(
            config["test_file"], test_transform, config["image_root"]
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "ve":
        train_dataset = ve_dataset(
            config["train_file"], train_transform, config["image_root"]
        )
        val_dataset = ve_dataset(
            config["val_file"], test_transform, config["image_root"]
        )
        test_dataset = ve_dataset(
            config["test_file"], test_transform, config["image_root"]
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "grounding":
        train_transform = transforms.Compose(
            [
                transforms.Resize(
                    (config["image_res"], config["image_res"]),
                    interpolation=Image.BICUBIC,
                ),
                transforms.RandomHorizontalFlip(),
                RandomAugment(
                    2,
                    7,
                    isPIL=True,
                    augs=[
                        "Identity",
                        "AutoContrast",
                        "Equalize",
                        "Brightness",
                        "Sharpness",
                        "ShearX",
                        "ShearY",
                        "TranslateX",
                        "TranslateY",
                        "Rotate",
                    ],
                ),
                transforms.ToTensor(),
                normalize,
            ]
        )
        train_dataset = grounding_dataset(
            config["train_file"], train_transform, config["image_root"], mode="train"
        )
        test_dataset = grounding_dataset(
            config["test_file"], test_transform, config["image_root"], mode="test"
        )
        return train_dataset, test_dataset


def vqa_collate_fn(batch):
    image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
    for image, question, answer, weights in batch:
        image_list.append(image)
        question_list.append(question)
        weight_list += weights
        answer_list += answer
        n.append(len(answer))
    return (
        torch.stack(image_list, dim=0),
        question_list,
        answer_list,
        torch.Tensor(weight_list),
        n,
    )


def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset, shuffle in zip(datasets, shuffles):
        sampler = torch.utils.data.DistributedSampler(
            dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
        )
        samplers.append(sampler)
    return samplers


def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
        datasets, samplers, batch_size, num_workers, is_trains, collate_fns
    ):
        if is_train:
            shuffle = sampler is None
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )
        loaders.append(loader)
    return loaders


class SampleEqually(IterableDataset):
    def __init__(self, datasets):
        super().__init__()
        print('Using SampleEqually')
        print('datasets ', datasets)
        self.datasets = datasets
    def __iter__(self):
        sources = [iter(ds) for ds in self.datasets]
        while True:
            for source in sources:
                try:
                    yield next(source)
                except StopIteration:
                    return

def create_wds_loader(config):
    """Create a DataLoader for training on the cc3m dataset using WebDataset."""
    # normalize = transforms.Normalize(
    #     (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    # )

    normalize = get_normalize(config)

    pretrain_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(
                config["image_res"], scale=(0.2, 1.0), interpolation=Image.BICUBIC
            ),
            transforms.RandomHorizontalFlip(),
            RandomAugment(
                2,
                7,
                isPIL=True,
                augs=[
                    "Identity",
                    "AutoContrast",
                    "Equalize",
                    "Brightness",
                    "Sharpness",
                    "ShearX",
                    "ShearY",
                    "TranslateX",
                    "TranslateY",
                    "Rotate",
                ],
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )

    def tokenize(sample, tokenizer):
        sample["input_ids"] = tokenizer.encode(sample["text"], return_tensors="pt", padding=100).squeeze()
        return sample

    # def make_sample(sample):
    #     # print('sample ', sample)
    #     captions = [sample[1]['caption'], sample[1]['cap_0'], sample[1]['cap_1'], sample[1]['cap_2'], sample[1]['cap_3'], sample[1]['cap_4']]
    #     # randomly select one caption
    #     caption = captions[torch.randint(0, 6, (1,)).item()]
    #     return pretrain_transform(sample[0]), caption

    def make_sample(sample):
        # print('sample ', sample)
        return pretrain_transform(sample[0]), sample[1]['caption']

    # This is the basic WebDataset definition: it starts with a URL and add shuffling,
    # decoding, and augmentation. Note `resampled=True`; this is essential for
    # distributed training to work correctly.

    if len(config.train_file) == 1:
        print('Using train file ', config.train_file[0])
        trainset = wds.WebDataset(config.train_file[0], resampled=True, handler=warn_and_continue,
                                nodesplitter=wds.split_by_node)
        trainset = trainset.shuffle(1000).decode("pil", handler=warn_and_continue).to_tuple('png;jpg', 'json', handler=warn_and_continue).map(make_sample, handler=warn_and_continue)
    else:
        # for i, train_file in enumerate(config.train_file):
        #     print('Using train file ', train_file)
        #     trainset = wds.WebDataset(train_file, resampled=True, handler=warn_and_continue,
        #                             nodesplitter=wds.split_by_node)
        #     trainset = trainset.shuffle(1000).decode("pil", handler=warn_and_continue).to_tuple('png;jpg', 'json', handler=warn_and_continue).map(make_sample, handler=warn_and_continue)
        #     if i == 0:
        #         trainset_all = trainset
        #     else:
        #         trainset_all += trainset
        # trainset = SampleEqually(trainset_all).shuffle(1000)

        train_file_urls = []
        for i, train_file in enumerate(config.train_file):
            print('Using train file ', train_file)
            train_file_urls += list(braceexpand.braceexpand(train_file))
        print('train_file_urls ', len(train_file_urls))
        trainset = wds.WebDataset(train_file_urls, resampled=True, handler=warn_and_continue,
                                nodesplitter=wds.split_by_node)
        trainset = trainset.shuffle(1000).decode("pil", handler=warn_and_continue).to_tuple('png;jpg', 'json', handler=warn_and_continue).map(make_sample, handler=warn_and_continue)



    # For IterableDataset objects, the batching needs to happen in the dataset.
    trainset = trainset.batched(config.batch_size_train)
    trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)

    # We unbatch, shuffle, and rebatch to mix samples from different workers.
    trainloader = trainloader.unbatched().shuffle(1000).batched(config.batch_size_train)

    # A resampled dataset is infinite size, but we can recreate a fixed epoch length.
    # trainloader = trainloader.with_epoch(1282 * 100 // 64)
    print('torch dist get world size ', torch.distributed.get_world_size())
    # asd
    trainloader = trainloader.with_epoch(config.dataset_size // (config.batch_size_train * torch.distributed.get_world_size()))
    trainloader.len = config.dataset_size // (config.batch_size_train * torch.distributed.get_world_size())
    # print('trainloader ', trainloader.len)
    # asd

    return trainloader