from .dataloader import prepare_dataloader, prepare_variable_dataloader
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset
from .utils import get_transforms_image, get_transforms_video, save_sample


from .sky_datasets import Sky
from torchvision import transforms
from .taichi_datasets import Taichi
from . import video_transforms
from .ucf101_datasets import UCF101
from .ffs_datasets import FaceForensics
from .ffs_image_datasets import FaceForensicsImages
from .sky_image_datasets import SkyImages
from .ucf101_image_datasets import UCF101Images
from .taichi_image_datasets import TaichiImages
from .journeydb_datasets import JourneyDB
from .coco_datasets import COCO
from .utils import *

def get_dataset(cfg):
    temporal_sample = video_transforms.TemporalRandomCrop(cfg.num_frames * cfg.frame_interval) # 16 1

    if cfg.dataset == 'ffs':
        transform_ffs = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            video_transforms.UCFCenterCropVideo(cfg.image_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        return FaceForensics(cfg, transform=transform_ffs, temporal_sample=temporal_sample)
    elif cfg.dataset == 'ffs_img':
        transform_ffs = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            video_transforms.UCFCenterCropVideo(cfg.image_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        return FaceForensicsImages(cfg, transform=transform_ffs, temporal_sample=temporal_sample)
    elif cfg.dataset == 'ucf101':
        transform_ucf101 = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            video_transforms.UCFCenterCropVideo(cfg.image_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        return UCF101(cfg, transform=transform_ucf101, temporal_sample=temporal_sample)
    elif cfg.dataset == 'ucf101_img':
        transform_ucf101 = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            video_transforms.UCFCenterCropVideo(cfg.image_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        return UCF101Images(cfg, transform=transform_ucf101, temporal_sample=temporal_sample)
    elif cfg.dataset == 'taichi':
        transform_taichi = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            video_transforms.UCFCenterCropVideo(cfg.image_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        return Taichi(cfg, transform=transform_taichi, temporal_sample=temporal_sample)
    elif cfg.dataset == 'taichi_img':
        transform_taichi = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.RandomHorizontalFlipVideo(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        return TaichiImages(cfg, transform=transform_taichi, temporal_sample=temporal_sample)
    elif cfg.dataset == 'sky':  
        transform_sky = transforms.Compose([
                    video_transforms.ToTensorVideo(),
                    video_transforms.CenterCropResizeVideo(cfg.image_size),
                    # video_transforms.RandomHorizontalFlipVideo(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
            ])
        return Sky(cfg, transform=transform_sky, temporal_sample=temporal_sample)
    elif cfg.dataset == 'sky_img':  
        transform_sky = transforms.Compose([
                    video_transforms.ToTensorVideo(),
                    video_transforms.CenterCropResizeVideo(cfg.image_size),
                    # video_transforms.RandomHorizontalFlipVideo(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
            ])
        return SkyImages(cfg, transform=transform_sky, temporal_sample=temporal_sample)
    elif cfg.dataset == 'journeydb':   #T2I dataset
        transformDB = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize(cfg.image_size),
                    transforms.CenterCrop(cfg.image_size),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
            ])
        return JourneyDB(cfg, transform=transformDB)
    elif cfg.dataset == 'coco':   #T2I dataset
        transformCOCO = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize(cfg.image_size),
                    transforms.CenterCrop(cfg.image_size),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
            ])
        return COCO(cfg, transform=transformCOCO)
    else:
        raise NotImplementedError(cfg.dataset)