import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import Resize, ToTensor
import pandas as pd

def createDataset(dataset_path, new_dataset_title):
    """
    Create partially rotated dataset.
    :param dataset_path: .amat file.
    :return: pickled dataset.
    """
    data = np.loadtxt(dataset_path)
    img_array_raw = data[:, :-1].reshape(len(data), 28, 28)
    labels_array_raw = data[:, -1]
    # Define transforms
    resize28 = Resize(28)
    toTensor = ToTensor()

    x_df = pd.DataFrame()
    rotations_dict = {i: [] for i in range(10)}  # to store rotations applied
    # Separate data by label
    data_by_label = {i: [] for i in range(10)}
    for i in range(len(data)):
        data_by_label[labels_array_raw[i]].append(img_array_raw[i])

    # Apply unique rotation for each digit
    for label in range(10):
        # RotMNIST60
        # theta = 60
        # RotMNIST60-90:
        # theta = 60 if label in [0,1,2,3,4] else 90
        # RotMultiple
        equiv_dict = {0: 0, 1: 18, 2: 36, 3: 54, 4: 72,
                            5: 90, 6: 108, 7: 126, 8:144, 9: 162}
        theta = equiv_dict[label]
        imgs = data_by_label[label]
        rotations = np.random.uniform(-theta, theta, len(imgs))  # Sample rotations uniformly

        for img, rotation in zip(imgs, rotations):
            imgRot = Image.fromarray(img)  # PIL Image
            rotations_dict[label].append(rotation)  # store the rotation applied
            imgRot = toTensor(resize28((imgRot.rotate(rotation, Image.BILINEAR)))).reshape(1, 28 * 28)

            imgRotdf = pd.DataFrame(imgRot.numpy())
            imgRotdf["labels"] = [label]
            x_df = pd.concat([x_df, imgRotdf])
    print("Shape: ", x_df.shape)
    # Some visual check of the distribution of the rotation angles applied
    for label, rotations in rotations_dict.items():
        plt.hist(rotations, bins=30, density=True)
        plt.title(f"Rotation distribution for label {label}")
        plt.show()

    x_df.to_pickle("./"+new_dataset_title+".pkl")
    return None

def visualizeSample(pickle_path, sampleSize= 20, manual_sample = []):
    """
    :param pickle_path: path storing the pickle file.
    :return: visualization of random samples.
    """
    df = pd.read_pickle(pickle_path)
    print(df.shape)
    df = df.iloc[:,:-1]
    samples = np.random.randint(df.shape[0],size= sampleSize)
    if manual_sample:
        samples = manual_sample
        sampleSize = len(samples)
    print(samples)
    fig, ax = plt.subplots(nrows=1, ncols=sampleSize, figsize=(18,9))

    for i,j in enumerate(samples):
        imgRot = df.iloc[j, :].values.reshape((28,28))
        ax[i].imshow(imgRot, cmap="gray")

    plt.show()
    return None
def main():
    createDataset(dataset_path="./src/datasets/mnist_test.amat", new_dataset_title="mnist_test_multiple")
    visualizeSample("./mnist_test.pkl")
    return None

if __name__ == "__main__":
    main()
