import pytorch_lightning as pl
import os
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from typing import Optional, Dict, List
import torchvision
from lightly.transforms import SimCLRTransform
from lightly.transforms.multi_view_transform import MultiViewTransform
import lightly.data as lightly_data
from datasets.dataset_utils import DatasetWithIndices
from lightly.transforms import MoCoV2Transform
import numpy as np

import datasets.utils as utils
import copy

from abc import abstractmethod

USER_NAME = os.environ.get("USER")


class CorruptedDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_augmentation="zoomblur",
        train_augmentation_strength=0,
        train_augmentation_before=True,
        train_corruption="zoomblur",
        train_corruption_strength=0,
        val_corruption="zoomblur",
        batch_size: int = 512,
        num_workers: int = 4,
        data_dir=f"/checkpoint/{USER_NAME}/datasets/cifar",
    ):
        super().__init__()
        self.train_augmentation = train_augmentation
        self.train_augmentation_strength = train_augmentation_strength
        self.train_augmentation_before = train_augmentation_before
        self.train_corruption = train_corruption
        self.train_corruption_strength = train_corruption_strength
        self.val_corruption = val_corruption
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.data_dir = data_dir

        self.val_dataset_names = [f"{self.val_corruption}_{i}" for i in range(6)]

    def setup(self, stage: str = ""):
        # Set up train.
        train_dataset = self.define_train_dataset()

        self.train_dataset = utils.FrozenNoiseDataset(
            train_dataset,
            noise_transform=utils.get_named_version(
                self.image_size, self.train_corruption, self.train_corruption_strength
            ),
            transforms=self.train_transform(),
        )

        # Set up val.
        val_dataset = self.define_val_dataset()

        self.val_dataset = [
            utils.FrozenNoiseDataset(
                val_dataset,
                noise_transform=utils.get_named_version(
                    self.image_size,
                    self.val_corruption,
                    i,
                ),
                transforms=self.val_transform(),
                seed_base=len(train_dataset),
            )
            for i in range(6)
        ]

    def train_transform(self):
        """Corrupt each individual view of the original multi-view transform."""
        original_transform = self.original_train_transform()

        if hasattr(original_transform, "transforms"):
            corrupted_views = []
            for view in original_transform.transforms:
                view_copy = copy.deepcopy(view)  # prevent modifying twice
                corrupted_views.append(self._corrupt_view(view_copy))
            return MultiViewTransform(corrupted_views)
        else:  # for AIM / MAE there is no multi-view transform
            return self._corrupt_view(original_transform)

    def _corrupt_view(self, regular_transform):
        if self.train_augmentation_before:
            transform = [
                torchvision.transforms.RandomApply(
                    [
                        utils.get_named_version(
                            self.image_size,
                            self.train_augmentation,
                            self.train_augmentation_strength,
                        )
                    ]
                )
            ]
        else:
            transform = []

        transform += regular_transform.transform.transforms

        if not self.train_augmentation_before:
            for i in range(len(transform) - 1):
                if isinstance(transform[i + 1], torchvision.transforms.ToTensor):
                    transform.insert(
                        i,
                        torchvision.transforms.RandomApply(
                            [
                                utils.get_named_version(
                                    self.image_size,
                                    self.train_augmentation,
                                    self.train_augmentation_strength,
                                )
                            ]
                        ),
                    )
        transform = torchvision.transforms.Compose(transform)
        return transform

    def train_dataloader(self):
        data_with_indices = DatasetWithIndices(
            self.train_dataset,
            # if False, only index is returned
            with_labels=True,
        )
        training_loader = DataLoader(
            data_with_indices,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.num_workers,
        )
        return training_loader

    def val_dataloader(self):
        loaders = [
            DataLoader(
                DatasetWithIndices(f, with_labels=True),
                batch_size=self.batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=self.num_workers,
            )
            for f in self.val_dataset
        ]
        return loaders

    @abstractmethod
    def define_train_dataset(self):
        pass

    @abstractmethod
    def define_val_dataset(self):
        pass

    @abstractmethod
    def val_transform(self):
        pass

    @abstractmethod
    def original_train_transform(self):
        pass
