import torch
import torchvision
import torchvision.transforms.v2 as v2
from custom_callbacks import ImageLabelVisualizationCallback
from lightning import LightningDataModule
from torch.utils.data import DataLoader


class TorchvisionDataModule(LightningDataModule):
    """Abstract class to easily turn a torchvision dataset into a Lightning DataModule"""

    prediction_callback = ImageLabelVisualizationCallback
    known_shapes: dict[str, tuple[int, ...]]
    transforms: list[v2.Transform]
    dataset: type = torchvision.datasets.VisionDataset
    dl_kwargs: dict = {}

    @property
    def dataset_name(self):
        return self.dataset.__name__

    def __init__(self, batch_size: int, is_test: bool = False):
        super().__init__()
        self.batch_size = batch_size
        self.is_test = is_test

        # Add batch size to known_shapes
        self.known_shapes = {
            attr: (batch_size,) + shape for attr, shape in self.known_shapes.items()
        }

    def setup(self, stage: str):
        transform = v2.Compose([v2.ToTensor(), *self.transforms])
        eval_transform = v2.Compose([v2.ToTensor(), self.transforms[-1]])
        
        if stage == "fit":
            train_set = self.dataset(
                root="../data",
                train=True,
                download=True,
                transform=transform,
            )
            self.num_classes = len(train_set.classes)

            if not self.is_test:
                train_indices, val_indices = torch.utils.data.random_split(
                    range(len(train_set)), [0.95, 0.05], generator=torch.Generator().manual_seed(42)
                )
                self.train_set = torch.utils.data.Subset(train_set, train_indices)
                self.val_set = torch.utils.data.Subset(
                    self.dataset(
                        root="../data",
                        train=True,
                        download=True,
                        transform=eval_transform,
                    ),
                    val_indices,
                )
            else:
                self.train_set = train_set
                self.val_set = self.dataset(
                    root="../data",
                    train=False,
                    download=True,
                    transform=eval_transform,
                )

        elif stage == "test" or stage == "predict":
            self.test_set = self.dataset(
                    root="../data",
                    train=False,
                    download=True,
                    transform=eval_transform,
                )

    def on_after_batch_transfer(self, batch, dataloader_idx):
        """
        Transforms batch after being placed on device
        Same as the 'transform' argument in torchvision datasets, but batched.

        Transforms:
        * Normalize image to [0, 1]
        * Class label y as one-hot
        """
        img, y = batch

        return {
            "img": img,
            "y": torch.nn.functional.one_hot(y, num_classes=self.num_classes).to(torch.float32),
        }

    def train_dataloader(self):
        return DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            **self.dl_kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            **self.dl_kwargs,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            **self.dl_kwargs,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            **self.dl_kwargs,
        )

    def metrics(self, node_dict, batch, prefix: str = ""):
        """Returns classification accuracy (top-1 and top-5). Used at val/test time"""
        y = batch["y"]
        y_pred = node_dict["y"]

        class_pred_top1 = y_pred.argmax(1, keepdim=True)
        class_target = y.argmax(1, keepdim=True)
        top1_accuracy = (class_pred_top1 == class_target).float().mean()

        # Compute top-5 accuracy
        top5_pred = y_pred.topk(5, dim=1).indices
        top5_accuracy = (top5_pred == class_target.view(-1, 1)).any(dim=1).float().mean()

        return {
            f"{prefix}acc": top1_accuracy,
            f"{prefix}acc_top5": top5_accuracy,
        }


class ShuffledTorchvisionDataModule(TorchvisionDataModule):

    def __init__(self, batch_size: int, is_test: bool = False, num_classes_shuffled: int = 2):
        self.num_classes = self.known_shapes["y"][-1]

        super().__init__(batch_size, is_test)
        self.num_classes_shuffled = num_classes_shuffled

        # randomly select the index of the classes to shuffle
        if self.num_classes_shuffled > 0:
            self.shuffled_classes = torch.randperm(self.num_classes)[:self.num_classes_shuffled]
        
        # define mapping of classes
        self.class_mapping = {i: i for i in range(self.num_classes)}
        self.class_mapping.update({self.shuffled_classes[(i+1) % self.num_classes_shuffled].item(): self.shuffled_classes[i].item() for i in range(self.num_classes_shuffled)})

        self.inverse_class_mapping = {v: k for k, v in self.class_mapping.items()}
        

    def on_after_batch_transfer(self, batch, dataloader_idx):
        """
        Transforms batch after being placed on device
        Same as the 'transform' argument in torchvision datasets, but batched.

        Transforms:
        * Normalize image to [0, 1]
        * Class label y as one-hot
        * Apply class mapping for shuffled labels
        """
        img, y = batch

        # Apply class mapping to labels
        mapped_y = torch.tensor([self.class_mapping[label.item()] for label in y]).to(y.device)

        return {
            "img": img,
            "y": torch.nn.functional.one_hot(mapped_y, num_classes=self.num_classes).to(torch.float32),
        }


    def metrics(self, node_dict, batch, prefix: str = ""):
        base_dict = super().metrics(node_dict, batch, prefix)
        y = batch["y"]
        y_pred = node_dict["y"]

        mse = (y_pred - y).pow(2)
        base_dict[f"{prefix}mse"] = mse.mean()

        # calculate accuracy of classes that were shuffled
        if self.num_classes_shuffled > 0:
            class_pred_top1 = y_pred.argmax(1, keepdim=True)
            class_target = y.argmax(1, keepdim=True)
            top1_accuracy = (class_pred_top1 == class_target).float()

            idx_in_shuffled = torch.isin(class_target.view(-1), self.shuffled_classes.to(class_target.device))

            top1_accuracy_shuffled = top1_accuracy[idx_in_shuffled].mean()
            top1_accuracy_not_shuffled = top1_accuracy[~idx_in_shuffled].mean()

            base_dict[f"{prefix}acc_shuffled"] = top1_accuracy_shuffled
            base_dict[f"{prefix}acc_not_shuffled"] = top1_accuracy_not_shuffled

            mse_shuffled = mse[idx_in_shuffled].mean()
            mse_not_shuffled = mse[~idx_in_shuffled].mean()

            base_dict[f"{prefix}mse_shuffled"] = mse_shuffled
            base_dict[f"{prefix}mse_not_shuffled"] = mse_not_shuffled

        return base_dict