from BACKEND import cp

def bin_spike_times(
        spikes: cp.ndarray,
        n_bins: int,
        tmax: float = 1,
        tmin: float = 0,
        mode: str = 'count',  # 'binary' or 'count'
        reshape: bool = True
) -> cp.ndarray:
    """
    Vectorized binning of event times S into occupancy/counts B.

    Parameters
    ----------
    spikes : cp.ndarray
        Array of shape (N, N_trains, T) containing event times.
    tmin : float
        Lower edge of first bin.
    tmax : float
        Upper edge of last bin (exclusive).
    dt : float
        Bin width.
    mode : {'binary', 'count'}
        If 'binary', B[n,j,b] = True if ≥1 event fell into bin b.
        If 'count', B[n,j,b] = number of events in bin b.

    Returns
    -------
    B : cp.ndarray
        Array of shape (N, N_trains, n_bins), dtype=bool if mode='binary'
        else dtype=int if mode='count'.
    """
    # Compute number of bins
    N, N_trains, T = spikes.shape
    dt = (tmax - tmin) / n_bins

    # Bin edges for cp.digitize
    edges = tmin + dt * cp.arange(n_bins + 1)
    # Assign each event time to a bin index in [0, n_bins)
    bin_idx = cp.digitize(spikes, edges, right=False) - 1  # :contentReference[oaicite:0]{index=0}

    # Mask events outside [tmin, tmax)
    valid = (bin_idx >= 0) & (bin_idx < n_bins)

    # Flatten (N, N_trains) → M for vectorized scatter
    M = N * N_trains
    flat_bins = bin_idx.reshape(M, T)
    flat_valid = valid.reshape(M, T)

    # Prepare output buffer
    B = cp.zeros((M, n_bins), dtype=cp.float32)

    # Find all valid event coordinates
    row_idx, time_idx = cp.nonzero(flat_valid)  # :contentReference[oaicite:1]{index=1}
    b_idx = flat_bins[row_idx, time_idx]

    if mode == 'binary':
        # Set bin to True if any event falls there
        B[row_idx, b_idx] = 1  # advanced indexing scatter
    else:
        # Count multiple events per bin
        cp.add.at(B, (row_idx, b_idx), 1)  # :contentReference[oaicite:2]{index=2}

    # Same shape as design matrix
    # TODO: probably less reshaping possible
    if reshape:
        B = B.reshape(N, N_trains, n_bins).transpose((-1, 0, 1)).reshape(-1, N_trains)
    else:
        B = B.reshape(N, N_trains, n_bins)
    return B
