import numpy as np
import torch

def gen_labits(xs, ys, ts, ps=None, framesize=None, num_bins=5, t_range=None, t_start=0, norm=True, events_extended_backward=True):
    """
    Generate the Layered BIdirectional Time Surface using the given events (xs, ys, ts).
    For each pixel in the entire CMOS, find the timestamp of the event that is closest to the current event on the time axis in the past.
    If there is no event in the past at this pixel, find the timestamp of the event that is closest to the current event on the time axis in the future.
    Args:
        xs: x coordinates of events, numpy array, shape: (N,)
        ys: y coordinates of events, numpy array, shape: (N,)
        ts: timestamps of events, a sorted numpy array, shape: (N,)
        framesize: the size of the CMOS, tuple, shape: (H, W)
        t_range: the search range of the time surface, float
        num_bins: the number of bins for the time surface, int
        norm: whether to normalize the time surface, bool
        events_extended_backward: Boolean. Whether the input events (xs, ys, ts) has already included an additional t_range events before the targeted t_start. 
            If true, the events time span will be divided into (`num_bins+1`) parts and use the labits layers at the intermediate `num_bins` time points. Otherwise,
            the time span will be divided into `num_bins` parts, with the t=0 included as a probe time.
    Returns:
        labits: the local bidirectional time surface for the current event, numpy array, shape: (2*r+1, 2*r+1)
    """
    assert framesize is not None, "size of the frame must be provided"
    assert len(framesize) == 2, "framesize must be a tuple of 2 integers"
    
    H, W = framesize
    B = num_bins
    
    # Array dimensions for the time surface
    labits = np.full((num_bins, H, W), -np.inf)  # Use np.inf as a placeholder for unset values
    
    # Calculate the relative positions of ts
    t_cur = t_range+t_start if events_extended_backward else t_start

    # Calculate t_raneg if not given
    if t_range is None:
        t_range = (ts[-1]-ts[0])/(B+1) if events_extended_backward else (ts[-1]-ts[0])/B
    
    # Get indices before and after the current event
    cur_idx = np.searchsorted(ts, t_range, side='right') if events_extended_backward else 0
    past_indices = np.arange(0, cur_idx) if events_extended_backward else np.array([], dtype=int)
    probe_time_idxs = range(1, B+1) if events_extended_backward else range(B)
    
    for bidx in probe_time_idxs:
        # Array dimensions for the time surface of the current bin
        labits_bin = np.full((H, W), -np.inf)  # Use np.inf as a placeholder for unset values
        t_cur = t_cur.item() if isinstance(t_cur, torch.Tensor) else t_cur
        # Get indices before and after the current event
        after_idx = np.searchsorted(ts[cur_idx:], t_cur + t_range, side='right') + cur_idx 
        future_indices = np.arange(cur_idx, after_idx)

        # Get normalized time for events in past and future time windows
        t_norm_prev = ts[past_indices] - t_cur
        t_norm_future = ts[future_indices] - t_cur
    
        # Update labits for past events (choose minimum time difference for each pixel)
        if len(past_indices) > 0:
            np.maximum.at(labits_bin, (ys[past_indices], xs[past_indices]), t_norm_prev)

        # Temporary array to store future time differences, keeping only those cells that are still inf in labits
        if len(future_indices) > 0:
            future_labits_bin = np.full_like(labits_bin, np.inf)
            np.minimum.at(future_labits_bin, (ys[future_indices], xs[future_indices]), t_norm_future)

        # Combine past and future times, only filling future times where past times were not updated
        mask = np.isinf(labits_bin)  # Find where past updates have not occurred
        if len(future_indices) > 0:
            labits_bin[mask] = future_labits_bin[mask]

        # Replace any remaining np.inf with -t_range (indicating no events found in either direction)
        labits_bin[np.isinf(labits_bin)] = -t_range

        if events_extended_backward:
            labits[bidx-1] = labits_bin
        else:
            labits[bidx] = labits_bin
            
        t_cur += t_range
        cur_idx = after_idx
        past_indices = future_indices
    
    # Normalize to [-1, 1], while keep the empty cells as -1
    if norm:
        labits = labits / t_range
    
    return labits

