# Generates rotated versions of the MNIST dataset, splits into train/val/prior/test, and saves them as .pt files.

import os
import torch
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import functional as TF
from PIL import Image
import random
import matplotlib.pyplot as plt

def generate_rotated_mnist(
    out_dir="rotated_mnist",
    root="./data",
    train_split=(40000, 10000, 10000),   # must sum to 60000 (train, val, prior)
    seed=0,
    verbose=True,
):
    """
    Generate rotated MNIST with angles sampled uniformly in [-90, 90] degrees.
    Saves train/val/prior/test as .pt files for later model training.
    """
    os.makedirs(out_dir, exist_ok=True)
    assert sum(train_split) == 60000, "train_split must sum to 60000 (MNIST train size)."

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if verbose:
        print("Downloading MNIST (if needed) to", root)

    mnist_train = MNIST(root=root, train=True, download=True)
    mnist_test  = MNIST(root=root, train=False, download=True)

    # ===== Process training set =====
    if verbose:
        print("Processing training set (60000 images)...")
    imgs_train = []
    labels_train = []
    angles_train = []

    for i, (pil_img, label) in enumerate(mnist_train):
        # Sample a random rotation angle in [-90, 90] degrees
        angle_deg = float(np.random.uniform(-90.0, 90.0))

        # Rotate using bilinear resampling and zero background
        rotated_pil = pil_img.rotate(angle_deg, resample=Image.BILINEAR, fillcolor=0)

        # Convert to tensor [1,28,28], normalize to [0,1]
        tensor = TF.to_tensor(rotated_pil)  # FloatTensor in [0,1], shape [1,28,28]
        imgs_train.append(tensor)
        labels_train.append(int(label))
        angles_train.append(angle_deg)

        if verbose and (i+1) % 10000 == 0:
            print(f"  processed {i+1}/60000")

    # stack into tensors
    imgs_train = torch.stack(imgs_train)             # [60000,1,28,28]
    labels_train = torch.tensor(labels_train, dtype=torch.long)
    angles_train = torch.tensor(angles_train, dtype=torch.float32)

    # shuffle and split dataset into train/val/prior
    perm = torch.randperm(len(imgs_train))
    n_train, n_val, n_prior = train_split
    train_idx = perm[:n_train]
    val_idx = perm[n_train:n_train + n_val]
    prior_idx = perm[n_train + n_val:n_train + n_val + n_prior]

    # Create dictionaries for each split
    train_data = {
        "images": imgs_train[train_idx],
        "labels": labels_train[train_idx],
        "angles": angles_train[train_idx],
    }
    val_data = {
        "images": imgs_train[val_idx],
        "labels": labels_train[val_idx],
        "angles": angles_train[val_idx],
    }
    prior_data = {
        "images": imgs_train[prior_idx],
        "labels": labels_train[prior_idx],
        "angles": angles_train[prior_idx],
    }

    # ===== Process test set =====
    if verbose:
        print("Processing test set (10000 images)...")
    imgs_test = []
    labels_test = []
    angles_test = []
    for i, (pil_img, label) in enumerate(mnist_test):
        angle_deg = float(np.random.uniform(-90.0, 90.0))
        rotated_pil = pil_img.rotate(angle_deg, resample=Image.BILINEAR, fillcolor=0)
        tensor = TF.to_tensor(rotated_pil)
        imgs_test.append(tensor)
        labels_test.append(int(label))
        angles_test.append(angle_deg)

    imgs_test = torch.stack(imgs_test)
    labels_test = torch.tensor(labels_test, dtype=torch.long)
    angles_test = torch.tensor(angles_test, dtype=torch.float32)

    test_data = {
        "images": imgs_test,
        "labels": labels_test,
        "angles": angles_test
    }

    # ===== Save all splits to disk =====
    torch.save(train_data, os.path.join(out_dir, "train.pt"))
    torch.save(val_data, os.path.join(out_dir, "val.pt"))
    torch.save(prior_data, os.path.join(out_dir, "prior.pt"))
    torch.save(test_data, os.path.join(out_dir, "test.pt"))

    if verbose:
        print("Saved train/val/test in", out_dir)
        print("Train/Val/Test sizes:",
              train_data["images"].shape[0],
              val_data["images"].shape[0],
              test_data["images"].shape[0])

    return out_dir

# Small dataset wrapper to load the saved .pt files
class RotatedMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, pt_path, transform=None):
        data = torch.load(pt_path)
        self.images = data["images"]
        self.labels = data["labels"]
        self.angles = data["angles"]
        self.transform = transform

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        img = self.images[idx]
        label = int(self.labels[idx])
        angle = float(self.angles[idx])
        if self.transform is not None:
            img = self.transform(img)
        return img, label, angle

# visualize a few samples

def draw_symbolic_nn(ax, layers=[3, 5, 2], padding=0.5):
    """
    Draws a simple symbolic neural network with padding around the edges.
    layers: list of neurons per layer
    padding: space to leave around the network
    """
    ax.set_xticks([])
    ax.set_yticks([])

    n_layers = len(layers)
    max_neurons = max(layers)

    ax.set_xlim(-padding, n_layers - 1 + padding)
    ax.set_ylim(0, max_neurons - 1)

    # Draw neurons
    for i, n_neurons in enumerate(layers):
        # Spread neurons vertically within available space
        y_positions = np.linspace(0, max_neurons - 1, n_neurons+2)
        y_positions = y_positions[1:-1]
        for y in y_positions:
            circle = plt.Circle((i, y), 0.15, color='skyblue', ec='k', zorder=3)
            ax.add_patch(circle)

    # Draw connections
    for i in range(n_layers - 1):
        y1 = np.linspace(0, max_neurons - 1, layers[i] + 2)
        y1 = y1[1:-1]
        y2 = np.linspace(0, max_neurons - 1, layers[i + 1] + 2)
        y2 = y2[1:-1]
        for y_start in y1:
            for y_end in y2:
                ax.plot([i, i + 1], [y_start, y_end], 'k-', lw=0.7, zorder=1)
'''   
def show_grid(pt_path, n=16, seed=0):
    torch.manual_seed(seed)
    data = torch.load(pt_path)
    imgs = data["images"]
    labels = data["labels"]
    angles = data["angles"]
    idxs = torch.randperm(len(imgs))[:n]
    ncols = int(np.sqrt(n))
    nrows = (n + ncols - 1) // ncols

    plt.figure(figsize=(ncols*2, nrows*2))
    for i, idx in enumerate(idxs):
        img = imgs[idx].squeeze().numpy()  # [28,28]
        lbl = int(labels[idx])
        ang = float(angles[idx])
        plt.subplot(nrows, ncols, i+1)
        plt.imshow(img, cmap="gray", vmin=0, vmax=1)
        # plt.title(f"{lbl}, {ang:.1f}°")
        plt.axis("off")
    plt.tight_layout()
    # plt.show()
    plt.savefig("rotated_examples.png")


 
def show_grid(pt_path, n=16, seed=0):
    torch.manual_seed(seed)
    data = torch.load(pt_path)
    imgs = data["images"]
    labels = data["labels"]
    angles = data["angles"]

    # Filter for a specific label (e.g., 2)
    mask = labels == 2
    imgs_filtered = imgs[mask]
    labels_filtered = labels[mask]
    angles_filtered = angles[mask]

    idxs = torch.randperm(len(imgs_filtered))[2:n+2]

    ncols = int(np.sqrt(n))
    nrows = (n + ncols - 1) // ncols

    plt.figure(figsize=(ncols * 2, nrows * 2))

    middle_idx = n // 2  # Index of the middle image in the grid

    for i, idx in enumerate(idxs):
        ax = plt.subplot(nrows, ncols, i + 1)

        if i == middle_idx:
            # Draw symbolic neural network
            draw_symbolic_nn(ax, layers=[3, 5, 2])
            # ax.set_title("Neural Net")
        else:
            img = imgs_filtered[idx].squeeze().numpy()  # [28,28]
            ang = float(angles_filtered[idx])
            ax.imshow(img, cmap="gray", vmin=0, vmax=1)
            ax.set_title(f"{ang:.1f}°")

        ax.axis("off")

    plt.tight_layout()
    plt.savefig("rotated_examples.png")
'''
def show_grid(pt_path, n=16, seed=2):
    torch.manual_seed(seed)
    data = torch.load(pt_path)
    imgs = data["images"]
    labels = data["labels"]
    angles = data["angles"]
    # Create a mask for labels equal to 4
    mask = labels == 4

    # Apply the mask to filter images, labels, and angles
    imgs_filtered = imgs[mask]
    labels_filtered = labels[mask]
    angles_filtered = angles[mask]

    idxs = torch.randperm(len(imgs_filtered))[:n]
    ncols = int(np.sqrt(n))
    nrows = (n + ncols - 1) // ncols

    plt.figure(figsize=(ncols*2, nrows*2))
    for i, idx in enumerate(idxs):
        img = imgs_filtered[idx].squeeze().numpy()  # [28,28]
        #lbl = int(labels[idx])
        ang = float(angles_filtered[idx])
        plt.subplot(nrows, ncols, i+1)
        plt.imshow(img, cmap="gray", vmin=0, vmax=1)
        # plt.title(f"{ang:.1f}°")
        plt.axis("off")
    plt.tight_layout()
    # plt.show()
    plt.savefig("rotated_examples.png")

if __name__ == "__main__":
    # Generate dataset and visualize a few examples
    # out = generate_rotated_mnist(out_dir="rotated_mnist", train_split=(40000,10000,10000), seed=123)
    show_grid(os.path.join("rotated_mnist", "train.pt"))
