import pytorch_lightning as pl
import torchvision
import torch
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from src.datasets.DatasetSubset import DatasetFromSubset
from src.datasets.Visualization import log_datasets

class STL10LightlyModule(pl.LightningDataModule):
    def __init__(self, args, is_ssl_run, logger=None):
        super().__init__()
        self.args = args
        self.num_classes = 10
        self.is_ssl_run = is_ssl_run
        self.batch_size = args.batch_size if self.is_ssl_run else args.probe_batch_size
        self.has_test_set = True
        self.logger = logger

        if args.online_classifier == True:
            raise ValueError("Online classifier not supported for stl10 because training set is unlabelled")
        
        self.common_dataloader_options = {
            'batch_size': self.batch_size,
            'num_workers': self.args.num_workers,
            'drop_last': True,
        }

        # create ssl transforms
        # everything below matches transforms used in https://gitlab.com/generally-intelligent/ssl_dynamics/-/blob/main/self_supervised/data.py
        self.collate_fn = SimCLRCollateFunction(input_size=96, gaussian_blur=0.) if self.is_ssl_run else None

        self.normalize_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.4914, 0.4823, 0.4466],
                std=[0.247, 0.243, 0.261],
            )
        ])

        self.val_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(124),
            torchvision.transforms.CenterCrop(96),
            self.normalize_transforms
        ])

    # runs on a single thread
    def prepare_data(self) -> None:
        # make sure dataset is downloaded
        torchvision.datasets.STL10(root=self.args.data_folder + 'stl10', split='train+unlabeled', download=True)
        torchvision.datasets.STL10(root=self.args.data_folder + 'stl10', split='test', download=True)

    def setup(self, stage=None, log_images=True) -> None:
        if stage == "fit":
            dataset_val_full = torchvision.datasets.STL10(
                root=self.args.data_folder + 'stl10',
                split='train',
            )

            # make train set and validation set via splitting
            stl10_dataset_val, stl10_dataset_knn = torch.utils.data.random_split(
                dataset_val_full, 
                [1000, 4000], 
                generator=torch.Generator().manual_seed(self.args.seed)
            )
            
            self.dataset_val = LightlyDataset.from_torch_dataset(
                DatasetFromSubset(stl10_dataset_val), 
                self.val_transforms
            )
            
            dataset_train_knn_and_probe = LightlyDataset.from_torch_dataset(
                DatasetFromSubset(stl10_dataset_knn),
                self.val_transforms
            )


            if self.is_ssl_run:
                # create ssl dataset. transforms are handled by the collate function
                self.dataset_train = LightlyDataset.from_torch_dataset(
                    torchvision.datasets.STL10(root=self.args.data_folder + 'stl10', split='unlabeled'),
                )
                self.dataset_train_knn = dataset_train_knn_and_probe

            else:
                # create linear probe dataset
                self.dataset_train = dataset_train_knn_and_probe
            
            if log_images:
                log_datasets(self)

        elif stage == "test":
            self.dataset_test = LightlyDataset.from_torch_dataset(
                torchvision.datasets.STL10(root=self.args.data_folder + 'stl10', split='test'),
                transform=self.val_transforms
            )

        else:
            raise ValueError(f"Stage {stage} not supported yet for stl10")
    

        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_train,
            **self.common_dataloader_options,
            shuffle=True,
            collate_fn=self.collate_fn
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_val,
            **self.common_dataloader_options,
            shuffle=False,
        )
    
    def train_dataloader_knn(self):
        assert self.is_ssl_run, "KNN train dataloader only available for SSL runs"
        return torch.utils.data.DataLoader(
            self.dataset_train_knn,
            **self.common_dataloader_options,
            shuffle=True,
        )
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_test,
            **self.common_dataloader_options,
            shuffle=False,
        )
    
    def predict_dataloader(self):
        raise NotImplementedError("Predict dataloader not implemented for stl10")