"""
Get the shifted MNIST dataset

func:
    get_shifted_mnist: Get the shifted MNIST dataset
    load_shifted_mnist: Load the shifted MNIST dataset (by shifting test set)
    get_domain: Get the shifted MNIST dataset for a specific domain (by shifting test set)
    get_domains: Get the shifted MNIST dataset for multiple domains (by shifting test set)
    get_source: Get the source dataset (train and test)
"""

import torch
import os
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from typing import Tuple


def get_shifted_mnist(ori_dir, shift: float, train: bool = True) -> Dataset:
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((-shift,), (1,)),
        ]
    )
    dataset = torchvision.datasets.MNIST(
        root=ori_dir, train=train, download=True, transform=transform
    )    
    return dataset

def load_shifted_mnist(data_dir, ori_dir, shift: float) -> Dataset:
    file = data_dir + f"shifted_mnist_{shift:.2f}.pt"
    if os.path.exists(file):
        return torch.load(file)
    else:
        dataset = get_shifted_mnist(ori_dir, shift)
        os.makedirs(data_dir, exist_ok=True)
        torch.save(dataset, file)
        return dataset

def get_domain(data_dir, ori_dir, domains_num: int, max_shift: float, idx: int)->Dataset:
    shift = (max_shift / (domains_num - 1)) * idx
    return load_shifted_mnist(data_dir, ori_dir, shift)

def get_domains(data_dir, domains_num: int, max_shift: float)->list[Dataset]:
    domains = []
    for i in range(domains_num):
        domains.append(get_domain(data_dir, data_dir, domains_num, max_shift, i))
    return domains

def get_source(data_dir) -> Tuple[Dataset, Dataset]:
    train_dataset = get_shifted_mnist(data_dir, 0, train=True)
    test_dataset = get_shifted_mnist(data_dir, 0, train=False)
    return train_dataset, test_dataset

# ------------ test-code ------------

def show_shifted_mnist(data_dir, domains: list[Dataset], max_shift: float):
    # set the number of columns and rows
    ncols = len(domains)
    nrows = 3
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 8))
    for row in range(nrows):
        for col, dataset in enumerate(domains):
            img, target = dataset[row]
            axs[row, col].imshow(img.squeeze(), cmap="gray")
            axs[row, col].axis("off")
            axs[row, col].set_title(f"min: {img.min():.2f}, max: {img.max():.2f}")
    plt.tight_layout()
    plt.savefig(data_dir + f"shifted_mnist.png")
    plt.show()
    print("dataset data type: ", type(dataset.data))
    print("dataset data max pixel: ", dataset.data.max())
    print("dataset data min pixel: ", dataset.data.min())
    print("dataset transform: ", dataset.transform)
    print("dataset data shape: ", dataset.data.shape)
    print("Image shape: ", img.shape)
    print("Image type: ", type(img))
    print("Image data max pixel: ", img.max())
    print("Image data min pixel: ", img.min())
    print("dataset data[0]: ", dataset.data[0])
    print("dataset[0]: ", dataset[0])
    
if __name__ == "__main__":
    data_dir = "./data/color_mnist/"
    domains_datasets = get_domains(data_dir, 6, 1)
    show_shifted_mnist(data_dir, domains_datasets, max_shift=1)
