import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from data.cirr_dataset import CIRREmbsDataset, cirr_split
from data.coco_karpathy_dataset import (
    coco_karpathy_caption_eval,
    coco_karpathy_retrieval_eval,
    coco_karpathy_train,
)
from data.fashioniq_dataset import FashionIQEmbsDataset
from data.flickr30k_dataset import flickr30k_retrieval_eval, flickr30k_train
from data.nlvr_dataset import nlvr_dataset
from data.nocaps_dataset import nocaps_eval
from data.pretrain_dataset import pretrain_dataset
from data.vqa_dataset import vqa_dataset
from data.webvid_dataset import (
    WebVidImg,
    WebVidImgEmbsDataset,
    WebVidRuleBased,
    WebVidVid,
    WebVidVidEmbsDataset,
    WebVidVidEmbsIterateDataset,
    WebVidVidEmbsIterateRuleBasedDataset,
)
from transform.randaugment import RandomAugment


def create_dataset(dataset, config, min_scale=0.5):
    normalize = transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    )

    transform_train = transforms.Compose(
        [
            transforms.RandomResizedCrop(
                config["image_size"],
                scale=(min_scale, 1.0),
                interpolation=InterpolationMode.BICUBIC,
            ),
            transforms.RandomHorizontalFlip(),
            RandomAugment(
                2,
                5,
                isPIL=True,
                augs=[
                    "Identity",
                    "AutoContrast",
                    "Brightness",
                    "Sharpness",
                    "Equalize",
                    "ShearX",
                    "ShearY",
                    "TranslateX",
                    "TranslateY",
                    "Rotate",
                ],
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )
    transform_test = transforms.Compose(
        [
            transforms.Resize(
                (config["image_size"], config["image_size"]),
                interpolation=InterpolationMode.BICUBIC,
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )

    if dataset == "pretrain":
        dataset = pretrain_dataset(
            config["train_file"], config["laion_path"], transform_train
        )
        return dataset

    elif dataset == "caption_coco":
        train_dataset = coco_karpathy_train(
            transform_train,
            config["image_root"],
            config["ann_root"],
            prompt=config["prompt"],
        )
        val_dataset = coco_karpathy_caption_eval(
            transform_test, config["image_root"], config["ann_root"], "val"
        )
        test_dataset = coco_karpathy_caption_eval(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "nocaps":
        val_dataset = nocaps_eval(
            transform_test, config["image_root"], config["ann_root"], "val"
        )
        test_dataset = nocaps_eval(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return val_dataset, test_dataset

    elif dataset == "retrieval_coco":
        train_dataset = coco_karpathy_train(
            transform_train, config["image_root"], config["ann_root"]
        )
        val_dataset = coco_karpathy_retrieval_eval(
            transform_test, config["image_root"], config["ann_root"], "val"
        )
        test_dataset = coco_karpathy_retrieval_eval(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "retrieval_flickr":
        train_dataset = flickr30k_train(
            transform_train, config["image_root"], config["ann_root"]
        )
        val_dataset = flickr30k_retrieval_eval(
            transform_test, config["image_root"], config["ann_root"], "val"
        )
        test_dataset = flickr30k_retrieval_eval(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "vqa":
        train_dataset = vqa_dataset(
            transform_train,
            config["ann_root"],
            config["vqa_root"],
            config["vg_root"],
            train_files=config["train_files"],
            split="train",
        )
        test_dataset = vqa_dataset(
            transform_test,
            config["ann_root"],
            config["vqa_root"],
            config["vg_root"],
            split="test",
        )
        return train_dataset, test_dataset

    elif dataset == "nlvr":
        train_dataset = nlvr_dataset(
            transform_train, config["image_root"], config["ann_root"], "train"
        )
        val_dataset = nlvr_dataset(
            transform_test, config["image_root"], config["ann_root"], "val"
        )
        test_dataset = nlvr_dataset(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "cirr":
        train_dataset = cirr_split(
            transform_train,
            config["image_root"],
            config["ann_root"],
            "train",
            config["data"],
        )
        val_dataset = cirr_split(
            transform_test, config["image_root"], config["ann_root"], "val"
        )
        test_dataset = cirr_split(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "webvid" and "rule-based" in config["data"]:
        train_dataset = WebVidRuleBased(
            transform_train,
            config["video_root"],
            config["ann_root"],
            "train",
            config["data"],
        )
        val_dataset = cirr_split(
            transform_test, "datasets/CIRR/images/", "annotation", "val"
        )
        test_dataset = cirr_split(
            transform_test, "datasets/CIRR/images/", "annotation", "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "webvid":
        if "image_root" in config:
            train_dataset = WebVidImg(
                transform_train,
                config["image_root"],
                config["ann_root"],
                "train",
                config["data"],
            )
        else:
            train_dataset = WebVidVid(
                transform_train,
                config["video_root"],
                config["ann_root"],
                "train",
                config["data"],
            )
        val_dataset = cirr_split(
            transform_test, "datasets/CIRR/images/", "annotation", "val"
        )
        test_dataset = cirr_split(
            transform_test, "datasets/CIRR/images/", "annotation", "test"
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "cirr_embs":
        train_dataset = CIRREmbsDataset(
            transform_train,
            config["image_root"],
            config["ann_root"],
            "train",
            config["data"],
            vit=config["vit"],
        )
        val_dataset = CIRREmbsDataset(
            transform_test,
            config["image_root"],
            config["ann_root"],
            "val",
            vit=config["vit"],
        )
        test_dataset = CIRREmbsDataset(
            transform_test,
            config["image_root"],
            config["ann_root"],
            "test",
            vit=config["vit"],
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "webvid_embs":
        rule_based = config["rule_based"] if "rule_based" in config else False
        if rule_based:
            train_dataset = WebVidVidEmbsIterateRuleBasedDataset(
                transform_train,
                config["video_root"],
                config["ann_root"],
                "train",
                data=config["data"],
                vit=config["vit"],
                emb_method=config["emb_method"],
                iterate=config["iterate"],
            )
        elif "image_root" in config.keys():
            train_dataset = WebVidImgEmbsDataset(
                transform_train,
                config["image_root"],
                config["ann_root"],
                "train",
                data=config["data"],
                vit=config["vit"],
            )
        elif "iterate" in config.keys():
            train_dataset = WebVidVidEmbsIterateDataset(
                transform_train,
                config["video_root"],
                config["ann_root"],
                "train",
                data=config["data"],
                vit=config["vit"],
                emb_method=config["emb_method"],
                iterate=config["iterate"],
            )
        else:
            train_dataset = WebVidVidEmbsDataset(
                transform_train,
                config["video_root"],
                config["ann_root"],
                "train",
                data=config["data"],
                vit=config["vit"],
                emb_method=config["emb_method"],
            )
        val_dataset = CIRREmbsDataset(
            transform_test,
            "datasets/CIRR/images/",
            "annotation",
            "val",
            vit=config["vit"],
        )
        test_dataset = CIRREmbsDataset(
            transform_test,
            "datasets/CIRR/images/",
            "annotation",
            "test",
            vit=config["vit"],
        )
        return train_dataset, val_dataset, test_dataset

    elif dataset == "fashioniq_embs":
        train_dataset = FashionIQEmbsDataset(
            transform_train,
            config["image_root"],
            config["ann_root"],
            "train",
            config["data"],
            vit=config["vit"],
        )
        val_dataset = FashionIQEmbsDataset(
            transform_test,
            config["image_root"],
            config["ann_root"],
            "val",
            config["data"],
            vit=config["vit"],
        )
        test_dataset = FashionIQEmbsDataset(
            transform_test,
            config["image_root"],
            config["ann_root"],
            "test",
            config["data"],
            vit=config["vit"],
        )
        return train_dataset, val_dataset, test_dataset


def create_test_dataset(dataset, config):
    normalize = transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    )

    transform_test = transforms.Compose(
        [
            transforms.Resize(
                (config["image_size"], config["image_size"]),
                interpolation=InterpolationMode.BICUBIC,
            ),
            transforms.ToTensor(),
            normalize,
        ]
    )
    if dataset == "cirr":
        test_dataset = cirr_split(
            transform_test, config["image_root"], config["ann_root"], "test"
        )
        return test_dataset

    elif dataset == "cirr_embs":
        test_dataset = CIRREmbsDataset(
            transform_test,
            "datasets/CIRR/images/",
            "annotation",
            "test",
            vit=config["vit"],
        )
        return test_dataset

    elif dataset == "webvid_embs":
        test_dataset = WebVidVidEmbsIterateDataset(
            transform=transform_test,
            video_root="datasets/WebVid/8M/train",
            ann_root=config["ann_root"],
            split="test",
            data="2c4k-manual-scores",
            vit=config["vit"],
            emb_method="query",
            iterate="input",
        )

        return test_dataset

    elif dataset == "fashioniq_embs":
        test_dataset = FashionIQEmbsDataset(
            transform=transform_test,
            image_root="datasets/fashion-iq/images",
            ann_root="annotation/fashioniq",
            split="val",
            data=config["data"],
            max_words=30,
            vit=config["vit"],
        )
        return test_dataset

    else:
        raise NotImplementedError


def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset, shuffle in zip(datasets, shuffles):
        sampler = torch.utils.data.DistributedSampler(
            dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
        )
        samplers.append(sampler)
    return samplers


def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
        datasets, samplers, batch_size, num_workers, is_trains, collate_fns
    ):
        if is_train:
            shuffle = sampler is None
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )
        loaders.append(loader)
    return loaders
