"""
Contrastive Learning Dataset Wrappers

Provides dataset classes that apply SimCLR-style augmentations
for both standard ANNs and Spiking Neural Networks.
"""

import os
from torchvision.transforms import transforms
from torchvision import datasets
from data_aug.gaussian_blur import GaussianBlur
from data_aug.view_generator import ContrastiveLearningViewGenerator, ContrastiveLearningViewGeneratorSNN


class ContrastiveLearningDataset:
    """
    Dataset wrapper for contrastive learning with ANNs.
    
    Applies SimCLR augmentation pipeline to generate multiple views
    of each image for contrastive learning.
    
    Args:
        root_folder: Path to dataset root
        dataset_name: Name of dataset ('cifar10' or 'tinyimagenet')
    """
    def __init__(self, root_folder, dataset_name):
        self.root_folder = root_folder
        self.dataset_name = dataset_name

    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """
        Create SimCLR augmentation pipeline.
        
        Augmentations:
        - Random resized crop
        - Random horizontal flip
        - Color jitter (brightness, contrast, saturation, hue)
        - Random grayscale
        - Gaussian blur
        
        Args:
            size: Output image size
            s: Strength of color jitter (default: 1)
            
        Returns:
            Composed transform pipeline
        """
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=int(0.1 * size)),
            transforms.ToTensor()
        ])
        return data_transforms

    def get_dataset(self, n_views):
        """
        Get dataset with contrastive augmentation.
        
        Args:
            n_views: Number of views to generate per image
            
        Returns:
            PyTorch Dataset with multi-view augmentation
        """
        if self.dataset_name == 'cifar10':
            return datasets.CIFAR10(
                self.root_folder, train=True,
                transform=ContrastiveLearningViewGenerator(
                    self.get_simclr_pipeline_transform(32), n_views
                ),
                download=True
            )
        elif self.dataset_name == 'tinyimagenet':
            root = os.path.join(self.root_folder, 'train')
            return datasets.ImageFolder(
                root=root,
                transform=ContrastiveLearningViewGenerator(
                    self.get_simclr_pipeline_transform(64), n_views
                )
            )


class ContrastiveLearningDatasetSNN:
    """
    Dataset wrapper for contrastive learning with SNNs.
    
    Extends the standard contrastive learning dataset with temporal
    augmentation, generating different augmented views for each timestep.
    
    Args:
        root_folder: Path to dataset root
        dataset_name: Name of dataset ('cifar10' or 'tinyimagenet')
        timesteps: Number of simulation timesteps
    """
    def __init__(self, root_folder, dataset_name, timesteps=1):
        self.root_folder = root_folder
        self.dataset_name = dataset_name
        self.timesteps = timesteps

    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Create SimCLR augmentation pipeline (same as ANN version)."""
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=int(0.1 * size)),
            transforms.ToTensor()
        ])
        return data_transforms

    def get_dataset(self, n_views):
        """
        Get dataset with temporal contrastive augmentation.
        
        Args:
            n_views: Number of views per image
            
        Returns:
            PyTorch Dataset with temporal multi-view augmentation
        """
        if self.dataset_name == 'cifar10':
            return datasets.CIFAR10(
                self.root_folder, train=True,
                transform=ContrastiveLearningViewGeneratorSNN(
                    self.get_simclr_pipeline_transform(32),
                    n_views, self.timesteps
                ),
                download=True
            )
        elif self.dataset_name == 'tinyimagenet':
            root = os.path.join(self.root_folder, 'train')
            return datasets.ImageFolder(
                root=root,
                transform=ContrastiveLearningViewGeneratorSNN(
                    self.get_simclr_pipeline_transform(64),
                    n_views, self.timesteps
                )
            )
