import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from utils import set_seed

def load_mnist():
    transform = transforms.Compose([transforms.ToTensor()])
    train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    return train_set, test_set

def create_multi_digit_images(images, labels, n_digits, n_samples, n_vals):
    all_combinations = list(product(range(10), repeat=n_digits))
    new_width = 28 * n_digits
    # n_numbers = len(all_combinations)
    multi_digit_images = torch.zeros((n_vals, n_samples, 1, 28, new_width), dtype=torch.float32)
    multi_digit_labels = torch.zeros((n_vals, n_samples), dtype=torch.int32)

    for idx, num_tuple in enumerate(all_combinations):
        if idx >= n_vals:
            break 
        print(idx)
        for sample in range(n_samples):
            digit_indices = [np.random.choice((labels == digit).nonzero(as_tuple=True)[0].numpy()) for digit in num_tuple]
            concatenated_digits = torch.cat([images[idx].unsqueeze(0) for idx in digit_indices], dim=3)
            multi_digit_images[idx, sample] = concatenated_digits
            multi_digit_labels[idx, sample] = int("".join(map(str, num_tuple)))

    return multi_digit_images, multi_digit_labels

def create_multi_digit_mnist(n_digits, n_train_samples, n_test_samples, n_vals):
    train_set, test_set = load_mnist()

    train_images = torch.stack([train_set[i][0] for i in range(len(train_set))])
    train_labels = torch.tensor([train_set[i][1] for i in range(len(train_set))])
    
    test_images = torch.stack([test_set[i][0] for i in range(len(test_set))])
    test_labels = torch.tensor([test_set[i][1] for i in range(len(test_set))])


    x_train_multi, y_train_multi = create_multi_digit_images(train_images, train_labels, n_digits, n_train_samples, n_vals)
    x_test_multi, y_test_multi = create_multi_digit_images(test_images, test_labels, n_digits, n_test_samples, n_vals)

    return (x_train_multi, y_train_multi), (x_test_multi, y_test_multi)

set_seed(0)

n_digits = 3
n_train_samples = 2000
n_test_samples = 100
n_vals = 200

(x_train_multi, y_train_multi), (x_test_multi, y_test_multi) = create_multi_digit_mnist(n_digits, n_train_samples, n_test_samples, n_vals)

torch.save({
    'train': x_train_multi,
    'test': x_test_multi
}, './data/mnist_200_2000.pt')
