"""
Get the rotated MNIST dataset

func:
    get_rotated_mnist: Get the rotated MNIST dataset
    load_rotated_mnist: Load the rotated MNIST dataset (by rotating test set)
    get_domain: Get the rotated MNIST dataset for a specific domain (by rotating test set)
    get_domains: Get the rotated MNIST dataset for multiple domains (by rotating test set)
    get_source: Get the source dataset (train and test)
"""

import torch
import os
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from typing import Tuple


def get_rotated_mnist(ori_dir, degree: float, train: bool = True) -> Dataset:
    transform = transforms.Compose(
        [
            transforms.RandomRotation(degrees=(degree, degree)),
            transforms.ToTensor(),
        ]
    )
    dataset = torchvision.datasets.MNIST(
        root=ori_dir, train=train, download=True, transform=transform
    )
    return dataset

def load_rotated_mnist(data_dir, ori_dir, degree: float) -> Dataset:
    file = data_dir + f"rotated_mnist_{degree:.2f}.pt"
    if os.path.exists(file):
        return torch.load(file)
    else:
        dataset = get_rotated_mnist(ori_dir, degree)
        os.makedirs(data_dir, exist_ok=True)
        torch.save(dataset, file)
        return dataset

def get_domain(data_dir, ori_dir, domains_num: int, max_degree: float, idx: int)->Dataset:
    degree = (max_degree / (domains_num - 1)) * idx
    return load_rotated_mnist(data_dir, ori_dir, degree)

def get_domains(data_dir, domains_num: int, max_degree: float)->list[Dataset]:
    domains = []
    for i in range(domains_num):
        domains.append(get_domain(data_dir, data_dir, domains_num, max_degree, i))
    return domains

def get_source(data_dir) -> Tuple[Dataset, Dataset]:
    train_dataset = get_rotated_mnist(data_dir, 0, train=True)
    test_dataset = get_rotated_mnist(data_dir, 0, train=False)
    return train_dataset, test_dataset

# ------------ test-code ------------

def show_rotated_mnist(data_dir, domains: list[Dataset]):
    # set the number of columns and rows
    ncols = len(domains)
    nrows = 3
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 8))
    for row in range(nrows):
        for col, dataset in enumerate(domains):
            img, target = dataset[row]
            axs[row, col].imshow(img.squeeze(), cmap="gray")
            axs[row, col].axis("off")
            axs[row, col].set_title(f"Label: {target}")
    plt.tight_layout()
    plt.savefig(data_dir + f"rotated_mnist.png")
    plt.show()
    print("Image shape: ", img.shape)
    
if __name__ == "__main__":
    data_dir = "./data/rotate_mnist/"
    domains_datasets = get_domains(data_dir, 6, 45)
    show_rotated_mnist(data_dir, domains_datasets)
    print(domains_datasets[0].data.shape)
