# file: user_extensions/datasets/celeba.py
from torchvision.datasets import CelebA
from torchvision import transforms

from prism.core.registry import DATASETS
from prism.data.base_datamodule import BaseDataModule

CELEBA_FULL_ATTRIBUTE_NAMES = [
    '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald',
    'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry',
    'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses',
    'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male',
    'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face',
    'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
    'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
    'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
]

TARGET_ATTRIBUTE_NAME = 'Smiling'
TARGET_ATTR_IDX = CELEBA_FULL_ATTRIBUTE_NAMES.index(TARGET_ATTRIBUTE_NAME)

STYLE_SUBSET_NAMES = ['Bald', 'Bangs', 'Eyeglasses', 'Goatee', 'Heavy_Makeup', 'Male', 'Wearing_Hat', 'Wearing_Lipstick']
CELEBA_TENSOR_INDEX_MAP = {name: i for i, name in enumerate(STYLE_SUBSET_NAMES)}


class _CelebAWrapper(CelebA):
    def __init__(self, style_indices, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.style_indices = style_indices

    def __getitem__(self, index):
        image, attributes = super().__getitem__(index)
        target_label = attributes[TARGET_ATTR_IDX]
        style_labels = attributes[self.style_indices].float()
        return image, target_label, style_labels


@DATASETS.register("celeba")
class CelebADataModule(BaseDataModule):
    def __init__(self, config):
        super().__init__(config)
        self.transform = self._get_celeba_transform()
        self.style_indices = [CELEBA_FULL_ATTRIBUTE_NAMES.index(name) for name in STYLE_SUBSET_NAMES]

    def _get_celeba_transform(self):
        h, w = self.data_config.image_shape[1], self.data_config.image_shape[2]
        return transforms.Compose([
            transforms.CenterCrop(178),
            transforms.Resize((h, w)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def prepare_data(self):
        CelebA(root=self.data_dir, split='train', download=True)
        CelebA(root=self.data_dir, split='valid', download=True)
        CelebA(root=self.data_dir, split='test', download=True)

    def setup(self, stage=None):
        if stage in ('fit', None):
            self.train_ds = _CelebAWrapper(
                root=self.data_dir,
                split='train',
                target_type='attr',
                transform=self.transform,
                style_indices=self.style_indices
            )
            self.val_ds = _CelebAWrapper(
                root=self.data_dir,
                split='valid',
                target_type='attr',
                transform=self.transform,
                style_indices=self.style_indices
            )

        if stage in ('test', None):
            self.test_ds = _CelebAWrapper(
                root=self.data_dir,
                split='test',
                target_type='attr',
                transform=self.transform,
                style_indices=self.style_indices
            )

    @property
    def style_feature_map(self):
        return CELEBA_TENSOR_INDEX_MAP