import torch

class PoissonEncoder:
    """Efficient Poisson spike encoder with temporal contrast enhancement"""
    
    def __init__(self, num_steps=25, dt=1.0, normalize=True):
        self.num_steps = num_steps
        self.dt = dt
        self.normalize = normalize
        
    def __call__(self, data):
        """
        Convert input data to spike trains using improved Poisson encoding
        """
        if self.normalize:
            # Normalize to [0, 1] if not already
            data_min = data.min()
            data_max = data.max()
            if data_max > 1.0 or data_min < 0.0:
                data = (data - data_min) / (data_max - data_min + 1e-8)
        
        # Generate spike trains with temporal structure
        spike_trains = []
        
        # Add temporal contrast: higher values spike earlier and more frequently
        for t in range(self.num_steps):
            # Time-dependent probability with adaptation
            time_factor = 1.0 - (t / self.num_steps) * 0.3  # Slight decay
            
            # Generate spikes
            spike_prob = data * time_factor * self.dt
            spikes = torch.rand_like(data) < spike_prob
            spike_trains.append(spikes.float())
        
        return torch.stack(spike_trains)
