from BACKEND import cp, sp

def random_spike_train_event(num_samples: int, num_trains: int,
                             min_spikes: int = 1, max_spikes: int = 10,
                             t_max: float = 1, seed: int = 42, round_digits=None):
    rng = cp.random.default_rng(seed)

    n_spikes = rng.integers(size=(num_samples, num_trains), low=min_spikes, high=max_spikes + 1)
    s = rng.random(size=(num_samples, num_trains, max_spikes), dtype=cp.float32) * t_max

    mask = cp.arange(max_spikes)[None, None, :] >= n_spikes[:, :, None]

    s[mask] = cp.inf
    s = cp.sort(s, axis=-1)

    if round_digits:
        s = cp.round(s, round_digits)
    return s

def random_sins(num_samples: int, num_trains: int,
                ts, seed: int = 42):
    rng = cp.random.default_rng(seed)
    freqs = (rng.random(size=(num_samples, num_trains), dtype=cp.float32) + 0.5) * 8 * cp.pi
    phases = rng.random(size=(num_samples, num_trains), dtype=cp.float32) * 2 * cp.pi
    s = cp.sin(ts[None, None, :] * freqs[:, :, None] + phases[:, :, None])
    return s