# Generates MNIST rotated + translated variants, splits into train/val/prior/test, and saves each split as .pt files.

import os
import torch
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import functional as TF
from torchvision.transforms import InterpolationMode
from PIL import Image
import random
import matplotlib.pyplot as plt

def generate_rotated_translated_mnist(
    out_dir="rotated_translated_mnist",
    root="./data",
    train_split=(40000, 10000, 10000),   # must sum to 60000 (train, val, prior)
    seed=0,
    translation_choices=(-2, -1, 0, 1, 2),
    verbose=True,
):
    """
    Generate MNIST rotated (angles in [-90, 90] degrees) and translated.
    Saves train/val/prior/test as .pt files containing images, labels, angles, translations.
    """
    os.makedirs(out_dir, exist_ok=True)
    assert sum(train_split) == 60000, "train_split must sum to 60000 (MNIST train size)."

    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if verbose:
        print("Downloading MNIST (if needed) to", root)

    mnist_train = MNIST(root=root, train=True, download=True)
    mnist_test  = MNIST(root=root, train=False, download=True)

    # helper to apply affine (rotation + translation) to a PIL image
    def rotate_and_translate_pil(pil_img, angle_deg, tx, ty):
        """
        Apply affine transform: rotate around center, then translate (tx, ty).
        """
        try:
            # Newer torchvision uses interpolation + fill
            transformed = TF.affine(
                pil_img,
                angle=angle_deg,
                translate=(int(tx), int(ty)),
                scale=1.0,
                shear=0.0,
                interpolation=InterpolationMode.BILINEAR,
                fill=0,
            )
        except TypeError:
            # Fallback for older versions (use resample + fillcolor)
            transformed = TF.affine(
                pil_img,
                angle=angle_deg,
                translate=(int(tx), int(ty)),
                scale=1.0,
                shear=0.0,
                resample=Image.BILINEAR,
                fillcolor=0,
            )
        return transformed

    # ===== Process training set =====
    if verbose:
        print("Processing training set (60000 images)...")
    imgs_train = []
    labels_train = []
    angles_train = []
    translations_train = []

    choices = tuple(translation_choices)

    for i, (pil_img, label) in enumerate(mnist_train):
        # sample random angle and integer translation from provided choices
        angle_deg = float(np.random.uniform(-90.0, 90.0))
        tx = int(random.choice(choices))
        ty = int(random.choice(choices))

        # apply affine transform and convert to tensor in [0,1]
        transformed_pil = rotate_and_translate_pil(pil_img, angle_deg, tx, ty)
        tensor = TF.to_tensor(transformed_pil)  # FloatTensor in [0,1], shape [1,28,28]

        imgs_train.append(tensor)
        labels_train.append(int(label))
        angles_train.append(angle_deg)
        translations_train.append([tx, ty])

        if verbose and (i+1) % 10000 == 0:
            print(f"  processed {i+1}/60000")

    # stack into tensors
    imgs_train = torch.stack(imgs_train)             # [60000,1,28,28]
    labels_train = torch.tensor(labels_train, dtype=torch.long)
    angles_train = torch.tensor(angles_train, dtype=torch.float32)
    translations_train = torch.tensor(translations_train, dtype=torch.long)  # [60000,2]

    # shuffle and split into train/val/prior
    perm = torch.randperm(len(imgs_train))
    n_train, n_val, n_prior = train_split
    train_idx = perm[:n_train]
    val_idx = perm[n_train:n_train + n_val]
    prior_idx = perm[n_train + n_val:n_train + n_val + n_prior]

    train_data = {
        "images": imgs_train[train_idx],
        "labels": labels_train[train_idx],
        "angles": angles_train[train_idx],
        "translations": translations_train[train_idx],
    }
    val_data = {
        "images": imgs_train[val_idx],
        "labels": labels_train[val_idx],
        "angles": angles_train[val_idx],
        "translations": translations_train[val_idx],
    }
    prior_data = {
        "images": imgs_train[prior_idx],
        "labels": labels_train[prior_idx],
        "angles": angles_train[prior_idx],
        "translations": translations_train[prior_idx],
    }

    # ===== Process test set =====
    if verbose:
        print("Processing test set (10000 images)...")
    imgs_test = []
    labels_test = []
    angles_test = []
    translations_test = []
    for i, (pil_img, label) in enumerate(mnist_test):
        angle_deg = float(np.random.uniform(-90.0, 90.0))
        tx = int(random.choice(choices))
        ty = int(random.choice(choices))

        transformed_pil = rotate_and_translate_pil(pil_img, angle_deg, tx, ty)
        tensor = TF.to_tensor(transformed_pil)

        imgs_test.append(tensor)
        labels_test.append(int(label))
        angles_test.append(angle_deg)
        translations_test.append([tx, ty])

    imgs_test = torch.stack(imgs_test)
    labels_test = torch.tensor(labels_test, dtype=torch.long)
    angles_test = torch.tensor(angles_test, dtype=torch.float32)
    translations_test = torch.tensor(translations_test, dtype=torch.long)

    test_data = {
        "images": imgs_test,
        "labels": labels_test,
        "angles": angles_test,
        "translations": translations_test,
    }

    # ===== Save to disk =====
    torch.save(train_data, os.path.join(out_dir, "train.pt"))
    torch.save(val_data, os.path.join(out_dir, "val.pt"))
    torch.save(prior_data, os.path.join(out_dir, "prior.pt"))
    torch.save(test_data, os.path.join(out_dir, "test.pt"))

    if verbose:
        print("Saved train/val/prior/test in", out_dir)
        print("Train/Val/Test sizes:",
              train_data["images"].shape[0],
              val_data["images"].shape[0],
              test_data["images"].shape[0])

    return out_dir

# Small dataset wrapper to load the saved .pt files
class RotatedTranslatedMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, pt_path, transform=None):
        data = torch.load(pt_path)
        self.images = data["images"]
        self.labels = data["labels"]
        self.angles = data["angles"]
        self.translations = data["translations"]
        self.transform = transform

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        img = self.images[idx]       # FloatTensor [1,28,28]
        label = int(self.labels[idx])
        angle = float(self.angles[idx])
        translation = tuple(int(x) for x in self.translations[idx].tolist())  # (tx, ty)
        if self.transform is not None:
            img = self.transform(img)
        return img, label, angle, translation

# visualize a few samples
def show_grid(pt_path, n=16, seed=0, out_file="rotated_translated_examples.png"):
    torch.manual_seed(seed)
    data = torch.load(pt_path)
    imgs = data["images"]
    labels = data["labels"]
    angles = data["angles"]
    translations = data["translations"]
    idxs = torch.randperm(len(imgs))[:n]
    ncols = int(np.sqrt(n))
    nrows = (n + ncols - 1) // ncols

    plt.figure(figsize=(ncols*2, nrows*2))
    for i, idx in enumerate(idxs):
        img = imgs[idx].squeeze().numpy()  # [28,28]
        lbl = int(labels[idx])
        ang = float(angles[idx])
        tx, ty = int(translations[idx][0]), int(translations[idx][1])
        plt.subplot(nrows, ncols, i+1)
        plt.imshow(img, cmap="gray", vmin=0, vmax=1)
        plt.title(f"{lbl}, {ang:.1f}°, (tx={tx},ty={ty})")
        plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_file)

if __name__ == "__main__":
    # Generate dataset and visualize a few examples
    out = generate_rotated_translated_mnist(out_dir="rotated_translated_mnist",
                                            train_split=(40000,10000,10000),
                                            seed=123)
    show_grid(os.path.join(out, "train.pt"), n=16)
