import torch
from torchvision.datasets import CelebA as torchCelebA
import torchvision.transforms as transforms

# @inproceedings{liu2015faceattributes,
#  title = {Deep Learning Face Attributes in the Wild},
#  author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou},
#  booktitle = {Proceedings of International Conference on Computer Vision (ICCV)},
#  month = {December},
#  year = {2015}
# }


class CelebA(torchCelebA):
    tasks = {
        f'attr{i+1}': [i, ['bce', 'f1']] for i in range(40)
    }

    def __init__(self, root, tag):
        super(CelebA, self).__init__(str(root), split=tag, download=False,
                                     transform=transforms.Compose([
                                         transforms.Resize((64, 64)),
                                         transforms.ToTensor(),
                                     ]))
        self.input_size = 3 * 64 * 64

    def __getitem__(self, index):
        data, target = super(CelebA, self).__getitem__(index)
        return data, target.unsqueeze(dim=0).float().unbind(dim=-1)

    # def __len__(self):
    #     return 8


# import torch.utils.data as data
# import torch
#
# class CelebA(data.Dataset):
#     tasks = {
#         f'attr{i+1}': [i, ['bce', 'f1']] for i in range(40)
#     }
#
#     def __init__(self, root, tag):
#         super(CelebA, self).__init__()
#         self.input_size = 3 * 64 * 64
#
#     def __getitem__(self, index):
#         return torch.randn(3, 64, 64), torch.randint(0, 1, (40,)).float().unsqueeze(0).unbind(dim=-1)
#
#     def __len__(self):
#         return 8


class CelebALandmarks(torchCelebA):
    tasks = {
        'attractive': [0, ['nll', 'acc']],
        'keypoints': [1, ['mse', 'keypoints']]
    }

    def __init__(self, root, tag):
        super(CelebALandmarks, self).__init__(str(root), split=tag, download=True, target_type='landmarks',
                                     transform=transforms.Compose([
                                         transforms.Resize((40, 40)),
                                         transforms.Grayscale(),
                                         transforms.ToTensor(),
                                     ]))
        self.input_size = 1 * 40 * 40

    def __getitem__(self, index):
        data, target = super(CelebALandmarks, self).__getitem__(index)
        target = (self.attr[index, 2].float(), target.float())

        w, h = 218, 178
        target[1][:5] = target[1][:5] * 40. / w
        target[1][5:] = target[1][5:] * 40. / h

        return data, target
