import pytorch_lightning as pl
import torchvision
import torch
import lightly
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.transforms import BYOLTransform

from src.datasets.ImageNetKaggle import ImageNetKaggle
from src.datasets.Visualization import log_datasets

class ImageNetLightlyModule(pl.LightningDataModule):
    def __init__(self, args, is_ssl_run, logger=None):
        super().__init__()
        self.args = args
        self.num_classes = 1000
        self.is_ssl_run = is_ssl_run
        self.batch_size = args.batch_size if self.is_ssl_run else args.probe_batch_size
        self.has_test_set = False
        self.logger = logger

        self.common_dataloader_options = {
            'batch_size': self.batch_size,
            'num_workers': self.args.num_workers,
            'drop_last': True,
        }

        self.train_transform = BYOLTransform()

        self.normalize_transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=lightly.data.collate.imagenet_normalize['mean'],
                std=lightly.data.collate.imagenet_normalize['std'],
            )
        ])

        self.probe_knn_train_transforms = torchvision.transforms.Compose([
            torchvision.transforms.RandomResizedCrop(224),
            torchvision.transforms.RandomHorizontalFlip(),
            self.normalize_transforms
        ])

        self.val_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            self.normalize_transforms
        ])

    def prepare_data(self) -> None:
        pass

    def setup(self, stage=None, log_images=True) -> None:
        if stage == "fit":
            if self.is_ssl_run:
                # create ssl dataset
                self.dataset_train = LightlyDataset.from_torch_dataset(
                    ImageNetKaggle(root=self.args.data_folder + 'imagenet', split='train'),
                    transform=self.train_transform
                )

                self.dataset_train_knn = LightlyDataset.from_torch_dataset(
                    ImageNetKaggle(root=self.args.data_folder + 'imagenet', split='train'),
                    transform=self.probe_knn_train_transforms
                )


            else:
                # create linear probe dataset
                self.dataset_train = LightlyDataset.from_torch_dataset(
                    ImageNetKaggle(root=self.args.data_folder + 'imagenet', split='train'),
                    transform=self.probe_knn_train_transforms
                )

            self.dataset_val = LightlyDataset.from_torch_dataset(
                ImageNetKaggle(root=self.args.data_folder + 'imagenet', split='val'),
                transform=self.val_transforms
            )   
                  
            if log_images:
                log_datasets(self)

        else:
            raise ValueError(f"Stage {stage} not supported yet for imagenet")
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_train,
            **self.common_dataloader_options,
            shuffle=True,
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset_val,
            **self.common_dataloader_options,
            shuffle=False,
        )
    
    def train_dataloader_knn(self):
        assert self.is_ssl_run, "KNN train dataloader only available for SSL runs"
        return torch.utils.data.DataLoader(
            self.dataset_train_knn,
            **self.common_dataloader_options,
            shuffle=True,
        )
    
    def test_dataloader(self):
        raise NotImplementedError("Test dataloader not implemented for imagenet")
    
    def predict_dataloader(self):
        raise NotImplementedError("Predict dataloader not implemented for imagenet")


