from torchvision import transforms
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import torch
import os
# CREDIT: https://github.com/lucaspascal/Maximum-Roaming-Mutli-Task-Learning/blob/master/celeba/dataset.py

task_groups_classess_names = [["Attractive", "Blurry", "Chubby", "Double Chin", "Heavy Makeup", "Male", "Oval Face", "Pale Skin","Young"],
                    ["Bags Under Eyes", "Eyeglasses", "Narrow Eyes", "Arched Eyebrows", "Bushy Eyebrows"],
                    ["Bald", "Bangs", "Black Hair", "Blond Hair", "Brown Hair", "Gray Hair", "Receding Hairline", "Straight Hair", "Wavy Hair"],
                    ["Big Lips", "Mouth Slightly Open", "Smiling", "Wearing Lipstick"],
                    ["Big Nose", "Pointy Nose"],
                    ["5 o’ Clock Shadow", "Goatee", "Mustache", "No Beard","Sideburns"],
                    ["High Cheekbones", "Rosy Cheeks"],
                    ["Wearing Earrings", "Wearing Hat", "Wearing Necklace", "Wearing Necktie"]]

task_groups_names = ["Global","Eyes","Hair","Mouth","Nose","Beard","Cheeks","Wearings"]

    
class CelebaGroupedDataset(VisionDataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, data_dir, task_groups, split='train', image_size=64, transform=None, add_augmentations=False):
    
        rep_file = os.path.join(data_dir, 'Eval/list_eval_partition.txt')
        self.img_dir = os.path.join(data_dir, 'Img/img_align_celeba/')
        self.ann_file = os.path.join(data_dir, 'Anno/list_attr_celeba.txt')
        self.image_size = image_size
        
        with open(rep_file) as f:
            rep = f.read()
        rep = [elt.split() for elt in rep.split('\n')]
        rep.pop()
        
        with open(self.ann_file, 'r') as f:
            data = f.read()
        data = data.split('\n')
        names = data[1].split()
        data = [elt.split() for elt in data[2:]]
        data.pop()
        
        self.img_names = []
        self.labels = []
        for k in range(len(data)):
            assert data[k][0] == rep[k][0]
            if (split=='train' and int(rep[k][1])==0) or (split=='val' and int(rep[k][1])==1) or (split=='test' and int(rep[k][1])==2):
                self.img_names.append(data[k][0])
                self.labels.append([1 if elt=='1' else 0 for elt in data[k][1:]])
        
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
        if add_augmentations:   
            self.transform = transforms.Compose([transforms.Resize(image_size), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()])
        
        self.labels_rep = task_groups
        
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        
        labels = [torch.tensor(self.labels[index], dtype=torch.float32)[self.labels_rep[task]] for task in range(len(self.labels_rep))]
        
        return img, labels

    def __len__(self):
        return len(self.img_names)