import PIL
import pytorch_lightning as pl
import torchvision
import torch
import lightly
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from src.datasets.DatasetSubset import DatasetFromSubset
from src.datasets.Visualization import log_datasets

class CIFAR10LightlyModule(pl.LightningDataModule):
    def __init__(self, args, is_ssl_run, logger=None):
        super().__init__()
        self.args = args
        self.num_classes = 10
        self.is_ssl_run = is_ssl_run
        self.batch_size = args.batch_size if self.is_ssl_run else args.probe_batch_size
        self.has_test_set = True
        self.logger = logger
        
        self.common_dataloader_options = {
            'batch_size': self.batch_size,
            'num_workers': self.args.num_workers,
            'drop_last': False,
        }

        # create ssl transforms
        # everything below matches transforms used in https://gitlab.com/generally-intelligent/ssl_dynamics/-/blob/main/self_supervised/data.py
        self.collate_fn = SimCLRCollateFunction(input_size=32, gaussian_blur=0.) if self.is_ssl_run else None

        self.normalize_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010],
            )
        ])

        self.val_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(36),
            torchvision.transforms.CenterCrop(32),
            self.normalize_transforms
        ])

    # runs on a single thread
    def prepare_data(self) -> None:
        # make sure dataset is downloaded
        torchvision.datasets.CIFAR10(root=self.args.data_folder + 'cifar10', train=True, download=True) # train
        torchvision.datasets.CIFAR10(root=self.args.data_folder + 'cifar10', train=False, download=True) # test

    def setup(self, stage=None, log_images=True) -> None:
        if stage == "fit":
            cifar10_trainset = torchvision.datasets.CIFAR10(root=self.args.data_folder + 'cifar10', train=True)

            # make train set and validation set via splitting
            cifar10_trainset, cifar10_valset = torch.utils.data.random_split(
                cifar10_trainset, 
                [45000, 5000], 
                generator=torch.Generator().manual_seed(self.args.seed)
            )
            
            self.dataset_val = LightlyDataset.from_torch_dataset(
                DatasetFromSubset(cifar10_valset), 
                self.val_transforms
            )

            if self.is_ssl_run:
                # create ssl dataset. transforms are handled by the collate function
                self.dataset_train = LightlyDataset.from_torch_dataset(cifar10_trainset)
                self.dataset_train_knn = LightlyDataset.from_torch_dataset(
                    DatasetFromSubset(cifar10_trainset), 
                    self.val_transforms
                )

            else:
                # create linear probe dataset
                self.dataset_train = LightlyDataset.from_torch_dataset(
                    DatasetFromSubset(cifar10_trainset), 
                    self.val_transforms
                )

            if log_images:
                log_datasets(self)

        elif stage == "test":
            self.dataset_test = LightlyDataset.from_torch_dataset(
                torchvision.datasets.CIFAR10(root=self.args.data_folder + 'cifar10', train=False),
                self.val_transforms
            )


        else:
            raise ValueError(f"Stage {stage} not supported yet for cifar10")
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_train,
            **self.common_dataloader_options,
            shuffle=True,
            collate_fn=self.collate_fn
        )
    
    def train_dataloader_knn(self):
        assert self.is_ssl_run, "KNN train dataloader only available for SSL runs"
        return torch.utils.data.DataLoader(
            self.dataset_train_knn,
            **self.common_dataloader_options,
            shuffle=True,
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_val,
            **self.common_dataloader_options,
            shuffle=False,
        )
    
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_test,
            **self.common_dataloader_options,
            shuffle=False,
        )
    
    def predict_dataloader(self):
        raise NotImplementedError("Predict dataloader not implemented for cifar10")