import logging
import numpy as np
from typing import Sequence, Optional, List, Tuple
from itertools import repeat, groupby


_logger = logging.getLogger(__name__)


def gen(p: int, n: int, start_pos, vel):
    """
    Generate a sequence of "crossing zero" events for a cyclic group.

    Slow, but simple. 

    The cyclic group is a product of 1-dimensional cyclic groups, all with
    the same number of elements, `p`. Elements are placed in `start_pos` and
    incremented by `vel`. If an element value exceeds `p` in any dimension,
    it wraps around and produces a "crossing zero" event.

    Args:
        p: the number of elements in each 1-dimensional cyclic group
        n: number of events to generate
        start_pos: initial position of each element
        vel: velocity vector of each element
    Returns:
        (ts, ids): where `ts` is the event times and `ids` is the list of
            element IDs that crossed zero at each event time (possibly multiple
            IDs per event time).
    """
    ts = []
    ids = []
    prev = np.array(start_pos)
    assert np.all(prev < p)
    vel = np.array(vel)
    i = 0
    while len(ts) < n:
        i += 1  # No collisions at t=0
        before_mod = prev + vel
        after_mod = before_mod % p
        obj_ids = np.nonzero(before_mod != after_mod)[0]
        obj_ids = obj_ids.tolist()
        if len(obj_ids) > 0:
            ts.append(i)
            ids.append(set(obj_ids))
        prev = after_mod
    return (ts, ids)


def gen_continuous(p: float, n: int, start_pos, vel):
    """
    Generate a sequence of "crossing zero" events for a cyclic group.

    The cyclic group is a product of 1-dimensional cyclic groups, all with
    the same number of elements, `p`. Elements are placed in `start_pos` and
    incremented by `vel`. If an element value exceeds `p` in any dimension,
    it wraps around and produces a "crossing zero" event.

    Args:
        p: the length of the dimensions (all the same)
        n: number of events to generate
        start_pos: initial position of each element
        vel: velocity vector of each element
    Returns:
        (ts, ids): where `ts` is the event times and `ids` is the list of
            element IDs that crossed zero at each event time (possibly multiple
            IDs per event time).
    """
    if start_pos.ndim != 2:
        raise ValueError("start_pos must be a 2D array: (element, pos).")
    if start_pos.shape != vel.shape:
        raise ValueError("start_pos and vel must have the same shape.")
    n_obj, n_dim = start_pos.shape
    vel = np.array(vel)
    # ts and ids (IDs are the dimension numbers)
    events: List[Tuple[float, int]] = []
    for obj_id in range(n_obj):
        for d in range(n_dim):
            dist_to_face = np.arange(1, n+1) * p - start_pos[obj_id][d]
            time_to_face = dist_to_face / vel[obj_id][d]
            # events.append((time_to_face, d))
            events.extend(zip(time_to_face, repeat(obj_id)))
    # Sort by time
    sort_key = lambda x: x[0]
    events = sorted(events, key=sort_key)
    # Deduplicate objs and concat the ids to a list.
    events_joined: List[Tuple[float, Tuple[int]]] = [
        # turning something like [[5, 0], [5, 0], [5, 1]] into [5, (0, 1)]
        (k, tuple(set(tuple(zip(*g))[1])))
        for k, g in groupby(events, key=sort_key)
    ]
    # Take the first n events
    first_n = events_joined[:n]
    ts, ids = zip(*first_n)
    assert np.all(np.diff(np.array(ts)) > 0), "Events are not in order."
    assert np.all(np.array(ts) >= 0), "Some events are negative."
    # An assumption is that there are no simultaneous events.
    assert len(set(ts)) == len(ts), (
        "Simultaneous events occurred. Code should be modified to join object "
        "ids when this happens."
    )
    return (ts, ids)


if __name__ == "__main__":
    rng = np.random.default_rng(123)
    n_samples = int(1e5)
    p = 1021
    n_objs = 1
    max_step = 5
    n_dim = 12
    start_pos = rng.integers(0, p, size=(n_objs, n_dim))
    vel = np.zeros_like(start_pos)
    max_tries = 4
    t = 0
    while np.any(vel.sum(axis=1) == 0):
        # Continue until no velocity is zero in all dimensions.
        vel = rng.binomial(n=max_step, p=0.3, size=(n_objs, n_dim))
        t += 1
        if t > max_tries:
            raise ValueError("Could not generate non-zero velocity vectors.")
    ts, ids = gen(p, n_samples, start_pos, vel)
