import torch

def bin_spike_times(
    spikes: torch.Tensor,
    n_bins: int,
    tmax: float = 1.0,
    tmin: float = 0.0,
    mode: str = 'count',  # 'binary' or 'count'
    reshape: bool = True
) -> torch.Tensor:
    """
    Vectorized binning of event times into occupancy/counts.

    Parameters
    ----------
    spikes : torch.Tensor
        Tensor of shape (N, N_trains, T) containing event times.
    tmin : float
        Lower edge of first bin.
    tmax : float
        Upper edge of last bin (exclusive).
    mode : {'binary', 'count'}
        If 'binary', output[n,j,b] = True if ≥1 event fell into bin b.
        If 'count', output[n,j,b] = number of events in bin b.
    reshape : bool
        If True, reshapes the result as in design matrix format.

    Returns
    -------
    output : torch.Tensor
        Tensor of shape (N, N_trains, n_bins), dtype=bool if mode='binary'
        else dtype=int if mode='count'.
    """
    N, N_trains, T = spikes.shape
    dt = (tmax - tmin) / n_bins

    # Compute bin edges
    edges = torch.linspace(tmin, tmax, n_bins + 1, device=spikes.device)

    # Digitize: find indices of the bins to which each value belongs
    bin_idx = torch.bucketize(spikes, edges, right=False) - 1  # bin index ∈ [0, n_bins)

    # Mask invalid bin indices (outside range)
    valid = (bin_idx >= 0) & (bin_idx < n_bins)

    M = N * N_trains
    flat_bins = bin_idx.reshape(M, T)
    flat_valid = valid.reshape(M, T)

    output = torch.zeros((M, n_bins), device=spikes.device, dtype=torch.float32)

    row_idx, time_idx = flat_valid.nonzero(as_tuple=True)
    b_idx = flat_bins[row_idx, time_idx]

    if mode == 'binary':
        output[row_idx, b_idx] = 1.0
    elif mode == 'count':
        output.index_put_((row_idx, b_idx), torch.ones_like(b_idx, dtype=output.dtype), accumulate=True)
    else:
        raise ValueError("Invalid mode. Use 'binary' or 'count'.")

    if reshape:
        output = output.view(N, N_trains, n_bins).permute(2, 0, 1).reshape(-1, N_trains)
    else:
        output = output.view(N, N_trains, n_bins)

    return output
