import os
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from PIL import Image

from .BaseDataset import BaseDataset


# CelebA dataset wrapper for easier integration


class CelebADataset(BaseDataset):

    # prediction target as static variable
    # apparently this is how you would declare static variable
    prediction_targets = [
        "Wearing_Lipstick",
        "Smiling",
        "Mouth_Slightly_Open",
        "High_Cheekbones",
        "Attractive", 
        "Heavy_Makeup",
        "Male",
        "Young", 
        "Wavy_Hair",
        "Straight_Hair",
        # "Bags_Under_Eyes",
        # "5_o_Clock_Shadow",
        # "Black_Hair",
        # "No_Beard",
        # "Arched_Eyebrows",
    ]

    def __init__(self, 
                 split: str = 'train',
                 transform: callable = None, 
                 seed: int = 42,
                 **kwargs):
        """
        Args:
            split (string): One of 'train', 'valid', 'test'
            transform (callable, optional): Optional transform to be applied on a sample.
            seed (int, optional): Seed for random operations.
        """

        super(CelebADataset, self).__init__(split,
                                            transform,
                                            len(self.prediction_targets),
                                            seed)
        self.dataset = torchvision.datasets.CelebA(
            root="./data",
            split=split,
            target_type="attr",
            download=True,
            transform=transform,
        )

        # the last one is empty string for some reason
        self.attribute_names = self.dataset.attr_names[:-1]
        # initialize an index list to subdivide dataset by task
        self.annotations = list(range(len(self.dataset)))
        np.random.seed(seed)
        np.random.shuffle(self.annotations)
        self.sub_dataset = self.annotations

    def __getitem__(self, idx):
        dataset_idx = self.sub_dataset[idx]
        image, attributes = self.dataset[dataset_idx]
        sample = {"image": image}
        for i, name in enumerate(self.attribute_names):
            sample[name] = attributes[i]  # they better be in the same order..

        return sample


if __name__ == "__main__":
    import torchvision.models as models
    import matplotlib.pyplot as plt

    # resnet_transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms(antialias=True)

    # # if task_index is given, shuffles by seed, subdivides by num tasks, and gets

    # dataset = CelebADataset(split="train",
    #                         transform=resnet_transform,
    #                         task_index=None,
    #                         num_tasks=4,
    #                         seed=42)

    # print(dataset.attribute_names)

    # dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)
    # sample = next(iter(dataloader))
    # image = sample['image']

    # grid = torchvision.utils.make_grid(image, nrow=5)
    # plt.imshow(np.transpose(grid, (1, 2, 0)))
    # plt.show()
    # print("gender:", sample["Male"])
    # print("age(young=1):", sample["Young"])
    # print("glasses:", sample["Eyeglasses"])
    # print("isOvalFace:", sample["Oval_Face"])
    # print("isBlurry:", sample["Blurry"])
    # """
    # List of attributes:
    # ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
    #  'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses',
    #  'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes',
    #  'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling',
    #  'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace',
    #  'Wearing_Necktie', 'Young']
    # """

    # dataset_first_division = CelebADataset(split="train",
    #                                        transform=resnet_transform,
    #                                        task_index=0,
    #                                        num_tasks=4,
    #                                        seed=42)

    # dataset_second_division = CelebADataset(split="train",
    #                                         transform=resnet_transform,
    #                                         task_index=1,
    #                                         num_tasks=4,
    #                                         seed=42)

    # # confirming that they have no data overlap given the same seed
    # first_indices = dataset_first_division.annotations
    # second_indices = dataset_second_division.annotations

    # print(set(first_indices).intersection(set(second_indices)))
    # dataset = torchvision.datasets.CelebA(root='./data',
    #                                                split='train',
    #                                                target_type="attr")
    # print(len(dataset[0][1]))
