import numpy as np
from torchvision.datasets import MNIST
from typing import Tuple
from torch.utils.data.dataset import TensorDataset
import os
import torch
from torch import Tensor
import torch.nn.functional as F

DATASETS_FOLDER = os.environ["DATASETS"]

def normalize(X_train: np.ndarray, X_test: np.ndarray):
    """Scales pixel values to [0,1] range."""
    return X_train / 255.0, X_test / 255.0

def flatten(arr: np.ndarray):
    return arr.reshape(arr.shape[0], -1)

def unflatten(arr: np.ndarray, shape: Tuple):
    return arr.reshape(arr.shape[0], *shape)

def _one_hot(tensor: Tensor, num_classes: int, default=0):
    M = F.one_hot(tensor, num_classes)
    M[M == 0] = default
    return M.float()

def make_labels(y, loss):
    if loss == "ce":
        return y
    elif loss == "mse":
        return _one_hot(y, 10, 0)

# === MNIST LOADING FUNCTION ===
def load_mnist(loss: str) -> Tuple[TensorDataset, TensorDataset]:
    """Loads and preprocesses the MNIST dataset."""
    mnist_train = MNIST(root=DATASETS_FOLDER, download=True, train=True)
    mnist_test = MNIST(root=DATASETS_FOLDER, download=True, train=False)
    
    # Normalize pixel values to [0,1]
    X_train, X_test = flatten(mnist_train.data.numpy()), flatten(mnist_test.data.numpy())
    X_train, X_test = normalize(X_train, X_test)
    
    # Encode labels
    y_train, y_test = make_labels(torch.tensor(mnist_train.targets), loss), \
                      make_labels(torch.tensor(mnist_test.targets), loss)

    # Convert back to PyTorch tensors and reshape
    train = TensorDataset(torch.from_numpy(unflatten(X_train, (1, 28, 28))).float(), y_train)
    test = TensorDataset(torch.from_numpy(unflatten(X_test, (1, 28, 28))).float(), y_test)
    
    return train, test
