import torch
import torchvision
import torchvision.transforms.v2 as v2
from custom_callbacks import ImageLabelVisualizationCallback
from lightning import LightningDataModule
from torch.utils.data import DataLoader


class TorchvisionDataModule(LightningDataModule):
    """Abstract class to easily turn a torchvision dataset into a Lightning DataModule"""

    prediction_callback = ImageLabelVisualizationCallback
    known_shapes: dict[str, tuple[int, ...]]
    transforms: list[v2.Transform]
    dataset: type = torchvision.datasets.VisionDataset
    dl_kwargs: dict = {}

    @property
    def dataset_name(self):
        return self.dataset.__name__

    def __init__(self, batch_size: int, is_test: bool = False):
        super().__init__()
        self.batch_size = batch_size
        self.is_test = is_test

        # Add batch size to known_shapes
        self.known_shapes = {
            attr: (batch_size,) + shape for attr, shape in self.known_shapes.items()
        }

    def setup(self, stage: str):
        transform = v2.Compose([v2.ToTensor(), *self.transforms])
        eval_transform = v2.Compose([v2.ToTensor(), self.transforms[-1]])
        
        if stage == "fit":
            train_set = self.dataset(
                root="../data",
                train=True,
                download=True,
                transform=transform,
            )
            self.num_classes = len(train_set.classes)

            if not self.is_test:
                train_indices, val_indices = torch.utils.data.random_split(
                    range(len(train_set)), [0.95, 0.05], generator=torch.Generator().manual_seed(42)
                )
                self.train_set = torch.utils.data.Subset(train_set, train_indices)
                self.val_set = torch.utils.data.Subset(
                    self.dataset(
                        root="../data",
                        train=True,
                        download=True,
                        transform=eval_transform,
                    ),
                    val_indices,
                )
            else:
                self.train_set = train_set
                self.val_set = self.dataset(
                    root="../data",
                    train=False,
                    download=True,
                    transform=eval_transform,
                )

        elif stage == "test" or stage == "predict":
            self.test_set = self.dataset(
                    root="../data",
                    train=False,
                    download=True,
                    transform=eval_transform,
                )

    def on_after_batch_transfer(self, batch, dataloader_idx):
        """
        Transforms batch after being placed on device
        Same as the 'transform' argument in torchvision datasets, but batched.

        Transforms:
        * Normalize image to [0, 1]
        * Class label y as one-hot
        """
        img, y = batch

        return {
            "img": img,
            "y": torch.nn.functional.one_hot(y, num_classes=self.num_classes).to(torch.float32),
        }

    def train_dataloader(self):
        return DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            **self.dl_kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            **self.dl_kwargs,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            **self.dl_kwargs,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            **self.dl_kwargs,
        )

    def metrics(self, node_dict, batch, prefix: str = ""):
        """Returns classification accuracy (top-1 and top-5). Used at val/test time"""
        y = batch["y"]
        y_pred = node_dict["y"]

        class_pred_top1 = y_pred.argmax(1, keepdim=True)
        class_target = y.argmax(1, keepdim=True)
        top1_accuracy = (class_pred_top1 == class_target).float().mean()

        # Compute top-5 accuracy
        top5_pred = y_pred.topk(5, dim=1).indices
        top5_accuracy = (top5_pred == class_target.view(-1, 1)).any(dim=1).float().mean()

        return {
            f"{prefix}acc": top1_accuracy,
            f"{prefix}acc_top5": top5_accuracy,
        }
