import torch
from torch.utils.data import IterableDataset, Dataset, DataLoader
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import random

mnist_train = dataset.MNIST(root='dataset/mnist', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dataset.MNIST(root='dataset/mnist', train=False, transform=transforms.ToTensor(), download=True)
            
class MnistADDDataset(IterableDataset):
    def __init__(self, mnist_dataset, lenth=8192):
        super().__init__()
        self.mnist_dataset = mnist_dataset
        self.lenth = lenth
        self.label_index_dict = {i: (mnist_dataset.targets == i).nonzero().flatten().numpy() for i in range(0, 10)}
    
    def __len__(self):
        return self.lenth
    
    def __iter__(self):
        for i in range(self.lenth):
            left_label, right_label = random.randint(0, 9), random.randint(0, 9)
            left_index, right_index = random.choice(self.label_index_dict[left_label]), random.choice(self.label_index_dict[right_label])
            left_image, right_image = self.mnist_dataset[left_index][0], self.mnist_dataset[right_index][0]
            label = torch.tensor([left_label + right_label])
            yield left_image, right_image, torch.nn.functional.one_hot(label, num_classes=9 * 2 + 1).flatten()

class MNISTMultiDigitADDDataset(IterableDataset):
    def __init__(self, mnist_dataset, n_digits, lenth=8192):
        super().__init__()
        self.mnist_dataset = mnist_dataset
        self.n_digits = n_digits
        self.lenth = lenth
        self.label_index_dict = {i: (mnist_dataset.targets == i).nonzero().flatten().numpy() for i in range(0, 10)}
    
    def __len__(self):
        return self.lenth

    def __iter__(self):
        for i in range(self.lenth):
            left_images, right_images = [], []
            left_number, right_number = '', ''
            for j in range(self.n_digits):
                left_label, right_label = random.randint(0, 9), random.randint(0, 9)
                left_index, right_index = random.choice(self.label_index_dict[left_label]), random.choice(self.label_index_dict[right_label])
                left_image, right_image = self.mnist_dataset[left_index][0], self.mnist_dataset[right_index][0]
                left_images.append(left_image); right_images.append(right_image)
                left_number += str(left_label); right_number += str(right_label)
            labels = [int(i) for i in str(int(left_number) + int(right_number)).zfill(self.n_digits + 1)]
            left_images.reverse(); right_images.reverse()
            left_images, right_images = torch.stack(left_images, dim=0), torch.stack(right_images, dim=0)
            labels, carry = torch.stack([torch.nn.functional.one_hot(torch.tensor(label), num_classes=10).flatten() for label in reversed(labels[1:])], dim=0), torch.nn.functional.one_hot(torch.tensor(labels[0]), num_classes=2).flatten()
            yield left_images, right_images, labels, carry