import os

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

attr_list = ('5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,'
             'Black_Hair,Blond_Hair,Blurry,Brown_Hair,Bushy_Eyebrows,Chubby,Double_Chin,Eyeglasses,Goatee,Gray_Hair,'
             'Heavy_Makeup,High_Cheekbones,Male,Mouth_Slightly_Open,Mustache,Narrow_Eyes,No_Beard,Oval_Face,Pale_Skin,'
             'Pointy_Nose,Receding_Hairline,Rosy_Cheeks,Sideburns,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,'
             'Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young'
             ).split(',')


class CelebA(Dataset):
    def __init__(self, root, target_attrs="Smiling", domain_attrs=None, img_transform=None, type="train") -> None:
        super().__init__()
        assert type in ["train", "val", "test"]
        self.type = type
        self.root = os.path.join(root, "celeba.hdf5")
        self.img_transform = img_transform
        if isinstance(target_attrs, str):
            self.target_attrs = [bytes(target_attrs, 'utf-8')]
        else:
            self.target_attrs = [bytes(target_attr, 'utf-8') for target_attr in target_attrs]
        if domain_attrs is not None:
            if isinstance(domain_attrs, str):
                self.domain_attrs = [bytes(domain_attrs, 'utf-8')]
            else:
                self.domain_attrs = [bytes(domain_attr, 'utf-8') for domain_attr in domain_attrs]
        else:
            self.domain_attrs = None

        if isinstance(target_attrs, list):
            self.num_classes = 2 ** len(self.target_attrs)
        else:
            self.num_classes = 2

        self.labels = []
        self.y_index = []
        self.z_index = []
        with h5py.File(self.root, mode='r') as file:
            if isinstance(np.array(file["columns"])[0], str):
                # Sometimes np.array(file["columns"])[0] is bytes and sometimes it's string for different systems,
                # so when it is a string we need to change target_attrs back to string
                self.target_attrs = target_attrs if isinstance(target_attrs, list) else [target_attrs]
                if domain_attrs is not None:
                    self.domain_attrs = domain_attrs if isinstance(domain_attrs, list) else [domain_attrs]
            self.y_index = [np.where(np.array(file["columns"]) == target_attr)[0][0] for target_attr in
                            self.target_attrs]
            if self.domain_attrs is not None:
                self.z_index = [np.where(np.array(file["columns"]) == domain_attr)[0][0] for domain_attr in
                                self.domain_attrs]
            self.labels = []
            self.total = file[self.type]['label'].shape[0]
            self.start_point = 0
            self.end_point = self.total
            for i in tqdm(range(self.start_point, self.end_point)):
                self.labels.append(file[self.type]['label'][i])
            self.lens = len(self.labels)

    def __len__(self):
        return self.lens

    def __getitem__(self, index):
        # Do not open the file in the __init__ function, this will disable the num-workers.
        with h5py.File(self.root, mode='r') as file:
            image = torch.Tensor(file[self.type]['data'][self.start_point + index] / 255.).permute(2, 0, 1)
            if self.img_transform is not None:
                image = self.img_transform(image)
            return image, self.get_label(index)

    def get_label(self, index):
        label_y = 0
        for i, y in enumerate(self.y_index):
            label_y += (2 ** i) * (int(self.labels[index][y]))
        label_z = 0
        if self.domain_attrs is not None:
            for i, z in enumerate(self.z_index):
                label_z += (2 ** i) * (int(self.labels[index][z]))
            return label_y, label_z
        return label_y
