import torch

class TemporalContrastEncoder:
    """Encode pixel intensities using temporal contrast (latency coding)"""
    
    def __init__(self, num_steps=25, threshold=0.1):
        self.num_steps = num_steps
        self.threshold = threshold
        
    def __call__(self, data):
        """Higher intensities spike earlier"""
        spike_trains = []
        device = data.device
        
        # Normalize data
        data = (data - data.min()) / (data.max() - data.min() + 1e-8)
        
        # Calculate spike times based on intensity (inverse relationship)
        spike_times = ((1.0 - data) * self.num_steps).long()
        
        for t in range(self.num_steps):
            # Spikes occur when t >= spike_time
            spikes = (t >= spike_times) & (data > self.threshold)
            spike_trains.append(spikes.float())
        
        return torch.stack(spike_trains)