import torchvision
import tempfile
import pandas as pd
from torchvision import transforms

from pe.data import Data
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import IMAGE_DATA_COLUMN_NAME

MNIST_LABEL_NAMES = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet expects 224x224 images
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    # transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # ImageNet normalization
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    # transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_identical = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    # transforms.ToTensor(),
])


class MNIST(Data):
    """The MNIST dataset."""

    def __init__(self, split="train"):
        """Constructor.

        :param split: The split of the dataset. It should be either "train" or "test", defaults to "train"
        :type split: str, optional
        :raises ValueError: If the split is invalid
        """
        if split not in ["train", "test"]:
            raise ValueError(f"Invalid split: {split}")
        train = split == "train"
        with tempfile.TemporaryDirectory() as tmp_dir:
            # dataset = torchvision.datasets.MNIST(root=tmp_dir, train=train, download=True)
            # Data preprocessing and augmentation
            if train:
                dataset = torchvision.datasets.MNIST(root=tmp_dir, train=train, download=True, transform=transform_train)
            else:
                dataset = torchvision.datasets.MNIST(root=tmp_dir, train=train, download=True, transform=transform_test)
        image = dataset.data
        image = image.unsqueeze(3).expand(-1, -1, -1, 3).numpy()
        data_frame = pd.DataFrame(
            {
                IMAGE_DATA_COLUMN_NAME: list(image),
                LABEL_ID_COLUMN_NAME: dataset.targets,
            }
        )
        metadata = {"label_info": [{"name": n} for n in MNIST_LABEL_NAMES]}
        super().__init__(data_frame=data_frame, metadata=metadata)

    def _get_num_classes(self):
        print(f"Use _get_num_classes for MNIST()")
        return len(MNIST_LABEL_NAMES)