import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision.models import resnet18
import torchvision
import matplotlib.pyplot as plt


trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)
x_train, y_train, x_test, y_test = trainset.data, trainset.targets, testset.data, testset.targets


def concatenate_2digit(data_x, data_y):
    concat_imgs = []
    concat_nums = []

    for i in range(0, len(data_x) - 20):
        print(i)
        for _ in range(7):
            k = np.random.randint(1, len(data_x) - i - 1)
            if i + k >= len(data_x):
                continue
            img1, img2 = data_x[i], data_x[i+k]
            img3, img4 = data_x[i+k], data_x[i]
            concat_img1 = np.concatenate([img1, img2], axis=1)
            concat_img2 = np.concatenate([img3, img4], axis=1)
            resized_img1 = Image.fromarray(concat_img1).resize((32, 32))
            resized_img2 = Image.fromarray(concat_img2).resize((32, 32))
            concat_imgs.append(np.asarray(resized_img1))
            concat_imgs.append(np.asarray(resized_img2))

            num1, num2 = data_y[i], data_y[i+k]
            num3, num4 = data_y[i+k], data_y[i]
            N1 = num1*10 + num2
            N2 = num3*10 + num4
            concat_nums.append(int(N1.item()))
            concat_nums.append(int(N2.item()))
    
    # return np.array(concat_imgs) / 255, np.array(concat_nums)
    return np.array(concat_imgs), np.array(concat_nums)

def CountFrequency(concat_imgs, concat_nums, save_path=None, save_num=1000):
    # Creating an empty dictionary
    freq = {}
    for item in concat_nums:
        if (item in freq):
            freq[item] += 1
        else:
            freq[item] = 1
 
    # Sort the dictionary with respect to keys
    sorted_freq = sorted(freq.items())
 
    for key, value in sorted_freq:
        print("% d : % d" % (key, value))

    num_each = min(freq.values())
    dataset_np_sort = np.zeros((100, save_num, 32, 32))
    num = np.zeros(100,)

    for i in range(len(concat_imgs)):
        idx = int(concat_nums[i])
        if int(num[idx]) == save_num:
            continue
        dataset_np_sort[idx][int(num[idx])] = concat_imgs[i]
        num[idx] += 1

    print(np.unique(concat_nums))
    print(np.unique(concat_nums).shape)
    if save_path is not None:
        np.save(save_path, dataset_np_sort)


concat_train_imgs, concat_train_nums = concatenate_2digit(x_train, y_train)
concat_test_imgs, concat_test_nums = concatenate_2digit(x_test, y_test)
CountFrequency(concat_train_imgs, concat_train_nums, './data/mnist2d_train.npy', save_num=6000)
CountFrequency(concat_test_imgs, concat_test_nums,'./data/mnist2d_test.npy', save_num=1000)
# Displaying the image 
for i in range (10):
    plt.imshow(concat_train_imgs[i])
    plt.title(concat_train_nums[i])
    plt.show()