import numpy as np
import torch
from torchvision.datasets import SVHN as torchSVHN
import torchvision.transforms as transforms

# @inproceedings{liu2015faceattributes,
#  title = {Deep Learning Face Attributes in the Wild},
#  author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou},
#  booktitle = {Proceedings of International Conference on Computer Vision (ICCV)},
#  month = {December},
#  year = {2015}
# }

def is_prime(x):
    prime = (x >= 2)
    for i in range(2, 11):
        prime &= (i >= x) | (x % i != 0)
    return prime


def foo(x):
    x = x.int()
    result = is_prime(x)
    if x.max() > 9:
        y = torch.zeros_like(x)
        while x.max() > 0:
            y += x % 10
            x = x // 10
        result |= foo(y)
    return result


class SVHN(torchSVHN):
    tasks = {
        'left': [0, ['nll', 'acc']],
        'right': [1, ['nll', 'acc']],
        'left2c': [2, ['nll', 'acc']],
        'right2c': [3, ['nll', 'acc']],
        'sum': [4, 'mse'],
        'multiply': [5, 'mse'],
        'density': [6, 'mse'],
        'product_prime': [7, ['bce', 'f1']],
        'number': [8, ['mse', 'mse']],
        'divide': [9, ['mse', 'mse']],
        'bigger_than': [10, ['bce', 'f1']],
        'binary_and': [11, ['nll', 'acc']],
        'odd': [12, ['bce', 'f1']]
    }

    def __init__(self, root, tag):
        super(SVHN, self).__init__(str(root), split='test' if tag == 'test' else 'train', download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.Grayscale(),
                                       transforms.ToTensor(),
                                   ]))
        split = 50000

        if tag == 'train':
            self.data, self.labels = self.data[:split], self.labels[:split]
        elif tag == 'val':
            self.data, self.labels = self.data[split:], self.labels[split:]

        self.input_size = 1 * 32 * 64
        self.pairs = np.random.permutation(len(self))

    def __getitem__(self, index):
        data1, target1 = super(SVHN, self).__getitem__(index)
        data2, target2 = super(SVHN, self).__getitem__(int(self.pairs[index]))

        data = torch.cat((data1, data2), dim=-1)
        target1 = torch.tensor(target1)
        target2 = torch.tensor(target2)

        target = (
            target1.float(),
            target2.float(),
            (torch.abs(10 - target1) % 10).float(),
            (torch.abs(10 - target2) % 10).float(),
            torch.unsqueeze(target1 + target2, dim=-1).float(),
            torch.unsqueeze(target1 * target2, dim=-1).float(),
            (data > 0.5).flatten(start_dim=1).float().mean(dim=1),
            foo(torch.torch.unsqueeze(target1 * target2, dim=-1)).float(),
            torch.unsqueeze(10 * target1 + target2, dim=-1).float(),
            torch.unsqueeze(target1.float() / (1 + target2.float()), dim=-1),
            torch.unsqueeze(target1 * target2 >= 25, dim=-1).float(),
            (target1 & target2).float(),
            torch.torch.unsqueeze((target1 * target2) % 2 != 0, dim=-1).float(),
        )

        return data, target

    # def __len__(self):
    #     return 8

