"""
Multi-View Generation for Contrastive Learning

Generates multiple augmented views of each image for contrastive learning.
"""

import numpy as np

np.random.seed(0)


class ContrastiveLearningViewGenerator(object):
    """
    Generate multiple randomly augmented views of each image.
    
    For each input image, applies the same transformation pipeline
    multiple times with different random seeds to create diverse views.
    
    Args:
        base_transform: Torchvision transform pipeline
        n_views: Number of views to generate per image (default: 2)
    """

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        """
        Generate n_views augmented versions of input image.
        
        Args:
            x: Input PIL Image
            
        Returns:
            List of n_views transformed tensors
        """
        return [self.base_transform(x) for i in range(self.n_views)]


class ContrastiveLearningViewGeneratorSNN(object):
    """
    Generate multiple views with temporal augmentation for SNNs.
    
    For each view, generates different augmentations for each timestep,
    providing temporal diversity during SNN training.
    
    Args:
        base_transform: Torchvision transform pipeline
        n_views: Number of views per image (default: 2)
        timesteps: Number of timesteps per view (default: 1)
    """

    def __init__(self, base_transform, n_views=2, timesteps=1):
        self.base_transform = base_transform
        self.n_views = n_views
        self.timesteps = timesteps

    def __call__(self, x):
        """
        Generate temporally augmented views.
        
        Args:
            x: Input PIL Image
            
        Returns:
            List of n_views, each containing timesteps augmented tensors
        """
        views = []
        for i in range(self.n_views):
            # Generate different transformations for each timestep
            view_timesteps = [self.base_transform(x) for _ in range(self.timesteps)]
            views.append(view_timesteps)
        return views
