"""create dataset and dataloader"""
import logging

import torch
import torch.utils.data


def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
    phase = dataset_opt["phase"]
    if phase == "train":
        if opt["dist"]:
            world_size = torch.distributed.get_world_size()
            num_workers = dataset_opt["n_workers"]
            assert dataset_opt["batch_size"] % world_size == 0
            batch_size = dataset_opt["batch_size"] // world_size
            shuffle = False
        else:
            num_workers = dataset_opt["n_workers"] * len(opt["gpu_ids"])
            batch_size = dataset_opt["batch_size"]
            shuffle = True
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            sampler=sampler,
            drop_last=True,
            pin_memory=False,
        )
    else:
        return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=False)


def create_dataset(dataset_opt):
    mode = dataset_opt["mode"]
    # datasets for image restoration
    if mode == "REDS":
        from data.REDS_dataset import REDSDataset as D
    elif mode == "GOPRO":
        from data.GOPRO_dataset import GOPRODataset as D
    elif mode == "fewshot":
        from data.fewshot_dataset import FewShotDataset as D
    elif mode == "levin":
        from data.levin_dataset import LevinDataset as D
    elif mode == "mix":
        from data.mix_dataset import MixDataset as D
    else:
        raise NotImplementedError(f"Dataset {mode} is not recognized.")
    dataset = D(dataset_opt)

    logger = logging.getLogger("base")
    logger.info("Dataset [{:s} - {:s}] is created.".format(dataset.__class__.__name__, dataset_opt["name"]))
    return dataset
