from glob import glob
from PIL import Image
from typing import Callable, Optional
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset


__DATASET__ = {}

def register_dataset(name: str):
    def wrapper(cls):
        if __DATASET__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __DATASET__[name] = cls
        return cls
    return wrapper


def get_dataset(name: str, root: str, **kwargs):
    if __DATASET__.get(name, None) is None:
        raise NameError(f"Dataset {name} is not defined.")
    return __DATASET__[name](root=root, **kwargs)


def get_dataloader(dataset: VisionDataset,
                   batch_size: int, 
                   num_workers: int, 
                   train: bool):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=train, 
                            num_workers=num_workers, 
                            drop_last=train)
    return dataloader


@register_dataset(name='ffhq')
class FFHQDataset(VisionDataset):
    def __init__(self, root: str, transforms: Optional[Callable]=None):
        super().__init__(root, transforms)

        self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.fpaths) > 0, "File list is empty. Check the root."

    def __len__(self):
        return len(self.fpaths)

    def __getitem__(self, index: int):
        fpath = self.fpaths[index]
        img = Image.open(fpath).convert('RGB')
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        return img


# =================
# Li's modification
# =================

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import os

def _convert_image_to_rgb(image):
    return image.convert("RGB")


@register_dataset(name='alexnet')
class AlexDataset(VisionDataset):
    def __init__(self, root):
        super().__init__(root)
        self.sample = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.sample) > 0, "File list is empty. Check the root."

        self.path = root

        self.transform = Compose([
            Resize(224),
            CenterCrop(224),
            _convert_image_to_rgb,
            ToTensor(),
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
    def __len__(self):
        return len(self.sample)
    
    def __getitem__(self, idx):
        sample = self.sample[idx]
        img = Image.open(sample)
        sample = self.transform(img)
        return sample


@register_dataset(name='titok')
class TiTokDataset(VisionDataset):
    def __init__(self, root):
        super().__init__(root)
        self.sample = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.sample) > 0, "File list is empty. Check the root."

        self.path = root

        self.transform = Compose([
            Resize(256),
            CenterCrop(256),
            _convert_image_to_rgb,
            ToTensor(),
        ])
    
    def __len__(self):
        return len(self.sample)
    
    def __getitem__(self, idx):
        sample = self.sample[idx]
        img = Image.open(sample)
        sample = self.transform(img)
        return sample


@register_dataset(name='clip')
class SLiCSDataset(VisionDataset):
    def __init__(self, root):
        super().__init__(root)
        self.sample = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.sample) > 0, "File list is empty. Check the root."

        self.path = root

        self.transform = Compose([
            Resize(224),
            CenterCrop(224),
            _convert_image_to_rgb,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
    
    def __len__(self):
        return len(self.sample)
    
    def __getitem__(self, idx):
        sample = self.sample[idx]
        img = Image.open(sample)
        sample = self.transform(img)
        return sample