import torch
import os

import torchvision
from torch.utils.data import Dataset
import torch.nn as nn
from torchvision import transforms, datasets
from typing import *

from utils import DATA_DIR

DATASETS = ["cifar10", "cifar100", "mnist"]


def get_dataset(dataset: str, split: str) -> Dataset:
    """Return the dataset as a PyTorch Dataset object"""
    if dataset == "cifar10":
        return _cifar10(split)
    elif dataset == "cifar100":
        return _cifar100(split)
    elif dataset == "mnist":
        return _mnist(split)
    else:
        raise Exception("Unknown dataset")


def get_dataset_mean_var(dataset: str):
    if dataset == "cifar10":
        return _CIFAR10_MEAN, _CIFAR10_STDDEV
    elif dataset == "cifar100":
        return _CIFAR100_MEAN, _CIFAR100_STDDEV
    elif dataset == "mnist":
        return _MNIST_MEAN, _MNIST_STDDEV
    else:
        raise Exception("Unknown dataset")


def _cifar10(split: str) -> Dataset:
    if split == "train":
        return datasets.CIFAR10(DATA_DIR, train=True, download=True, transform=transforms.Compose([
            #transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=_CIFAR10_MEAN, std=_CIFAR10_STDDEV)
        ]))
    elif split == "test":
        return datasets.CIFAR10(DATA_DIR, train=False, download=True, transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=_CIFAR10_MEAN, std=_CIFAR10_STDDEV)
        ]))


def _cifar100(split: str) -> Dataset:
    if split == "train":
        return datasets.CIFAR100(DATA_DIR, train=True, download=True, transform=transforms.Compose([
            #transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=_CIFAR100_MEAN, std=_CIFAR100_STDDEV)
        ]))
    elif split == "test":
        return datasets.CIFAR100(DATA_DIR, train=False, download=True, transform=transforms.Compose([
            #transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=_CIFAR100_MEAN, std=_CIFAR100_STDDEV)
        ]))


def _mnist(split: str) -> datasets.MNIST:
    transform = transforms.Compose([
        #transforms.Resize((224, 224)),  # Resize images to 224x224
        transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel RGB
        transforms.ToTensor(),
        transforms.Normalize(mean=_MNIST_MEAN, std=_MNIST_STDDEV)
    ])

    if split == "train":
        return datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform)
    elif split == "test":
        return datasets.MNIST(DATA_DIR, train=False, download=True, transform=transform)


_CIFAR10_MEAN = [0.4914, 0.4822, 0.4465]
_CIFAR10_STDDEV = [0.2023, 0.1994, 0.2010]

_CIFAR100_MEAN = [0.5071, 0.4867, 0.4408]
_CIFAR100_STDDEV = [0.2675, 0.2565, 0.2761]

_MNIST_MEAN = [0.5, 0.5, 0.5]
_MNIST_STDDEV = [0.5, 0.5, 0.5]


if __name__ == "__main__":
    datasets = get_dataset(dataset="mnist", split="train")
    pass
