"""Defines discrete-time dynamics for the marble + line segment segments
system. These could easily be included in the simulation's main code, but are
instead written separately to minimize the amount of code in the simulation,
which is intended as an illustrative example of pylic's capabilities."""
import torch
from dataclasses import dataclass
from dataclasses import replace
from pylic.predicates import less_than
from pylic.predicates import greater_than
from pylic.predicates import conjunction
from pylic.tape import IfNode
from pylic.code_transformations import get_tape


@dataclass(frozen=True, eq=True)
class Circle:
    velocity: torch.Tensor
    position: torch.Tensor
    radius: float


@dataclass(frozen=True)
class Segment:
    p1: tuple[float, float]
    p2: tuple[float, float]
    radius: float

    @property
    def normal(self) -> torch.Tensor:
        """The segment normal at `self.p1`."""
        a = (torch.tensor(self.p1)-torch.tensor(self.p2))
        an = a/a.norm()
        return torch.stack([-an[1], an[0]])

    def __eq__(self, other) -> bool:
        if not isinstance(other, Segment):
            return False
        if self.radius != other.radius:
            return False
        if self.p1 == other.p1 and self.p2 == other.p2:
            return True
        if self.p2 == other.p1 and self.p1 == other.p2:
            return True
        return False


@dataclass(frozen=True, eq=True)
class State:
    marble: Circle
    dt: float
    impulse_scale: float
    drag_constant: float
    coefficient_of_restitution: float
    segments: tuple[Segment, ...]


def collision_dynamics(
        s: State,
        i: int,
        dt: float,
        coefficient_of_restitution: float
        ) -> State:
    """Return the next state assuming a collision with the i-th obstacle."""
    normal = s.segments[i].normal
    incident = s.marble.velocity/s.marble.velocity.norm()
    reflected_direction_ = (
        (normal*2.0*normal.dot(incident))-incident
    )
    reflected_direction = reflected_direction_/reflected_direction_.norm()

    # Compute time of impact. EDIT: compute minimum time-step size before
    # collision with line search. This could be done analytically, but
    # taking into account obstacle width is easier with this simple
    # approximation.
    new_dt = 0.0
    iter_n = 4
    min_dist = s.marble.radius + s.segments[i].radius
    seg_a = torch.tensor(s.segments[i].p1)
    seg_b = torch.tensor(s.segments[i].p2)
    seg_delta = seg_b - seg_a
    seg_delta_length = seg_delta.norm()**2  # TODO: why squared?
    if seg_delta_length > 1e-6:
        for i_dt in range(iter_n):
            new_dt_candidate = dt*i_dt/(iter_n-1)
            circle_next_position = s.marble.position + new_dt_candidate*s.marble.velocity
            closest_t = seg_delta.dot(circle_next_position-seg_a)/seg_delta_length
            closest = seg_a + seg_delta*closest_t
            delta = closest - circle_next_position
            dist = delta.norm() - min_dist
            if dist <= 0.0:
                break
            new_dt = new_dt_candidate

    circle_next_velocity = reflected_direction*-1.0*s.marble.velocity.norm()*coefficient_of_restitution
    circle_next_position = s.marble.position + new_dt*s.marble.velocity
    circle_next = replace(
        s.marble,
        position=circle_next_position,
        velocity=circle_next_velocity,
    )
    s_next = replace(
        s,
        marble=circle_next,
    )
    return s_next


def collision_soft_predicate(s: State, dt: float, i: int) -> torch.Tensor:
    """Return a value greater than zero if the marble will collide
    with the i-th obstacle in the system in the next time-step, else, return a
    value not greater than zero."""
    # Note: this code does not check the case where the marble is not
    # impacting the obstacle at the next time step, but should have
    # (i.e. does not check if it "jumped over" the obstacle due to time
    # discretization error). If this happens, consider lowering the value of
    # dt or modify this code to check for that.
    circle_next_position = s.marble.position + dt*s.marble.velocity
    seg_a = torch.tensor(s.segments[i].p1)
    seg_b = torch.tensor(s.segments[i].p2)
    seg_delta = seg_b - seg_a
    #norm_square = sum(abs(a)**2 for a in seg_delta)
    #closest_t = seg_delta.dot(circle_next_position-seg_a)/(1e-6+norm_square)
    closest_t = seg_delta.dot(circle_next_position-seg_a)/(1e-6+seg_delta.norm()**2)
    closest = seg_a + seg_delta*closest_t
    min_dist = s.marble.radius + s.segments[i].radius
    delta = closest - circle_next_position
    dist = delta.norm()

    # Build reflecting conditions.  Only reflect when the shapes are
    # colliding, moving towards each other, and not colliding at the
    # segment end-caps.
    colliding = less_than(dist, min_dist)
    moving_away_negated = greater_than(
        s.marble.velocity.dot(delta),
        torch.tensor(0.0)
    )
    in_t_0 = less_than(torch.tensor(0.0), closest_t)
    in_t_1 = less_than(closest_t, torch.tensor(1.0))
    reflecting_conditions = (
        colliding,
        in_t_0,
        in_t_1,
        moving_away_negated,
    )
    should_reflect = conjunction(
        *reflecting_conditions,
    )
    #if should_reflect <= 0.0:
    #    segment_middle = (seg_b+seg_a)/2.0
    #    new_robustness = -(s.marble.position-segment_middle).norm()
    #    #print("old robustness, custom robustness", should_reflect, new_robustness)
    #    return new_robustness
    return should_reflect


def normal_dynamics_force(
        s: State,
        force: torch.Tensor,
        drag_constant: float,
        dt: float,
        ) -> State:
    """Return the next state assuming no collisions occur and applying the
    impulse in the parameters."""
    # The circle is of mass one, so a = F
    acc = force
    # We assume the marble is moving at low speed with no turbulence (i.e.
    # Stoke's drag model).
    drag = torch.stack([
        -s.marble.velocity[0]*drag_constant,
        -s.marble.velocity[1]*drag_constant,
    ])
    circle_next_velocity = s.marble.velocity + acc*dt + drag*dt
    circle_next_position = s.marble.position + s.marble.velocity*dt
    circle_next = replace(
        s.marble,
        position=circle_next_position,
        velocity=circle_next_velocity,
    )
    s_next = replace(
        s,
        marble=circle_next,
    )
    return s_next


def simulation(
        actions: torch.Tensor,
        state: State,
        ) -> State:
    """Run the simulation with the given closed-loop control
    signal and return the final state."""
    for t in range(len(actions)):
        # Compute next state with no collisions
        action = actions[t].clip(min=-1, max=1)*state.impulse_scale
        next_state = normal_dynamics_force(
            state,
            action,
            state.drag_constant,
            state.dt
        )

        # Check for collisions
        for i in range(len(state.segments)):
            v = collision_soft_predicate(state, state.dt, i)
            if v > 0:  # ID: collision_check
                # If collided, adjust next state
                next_state = collision_dynamics(
                    state,
                    i,
                    state.dt,
                    state.coefficient_of_restitution
                )

        # Update state
        state = next_state
    return state


def get_trajectory(
        actions: torch.Tensor,
        state: State,
        ) -> list[State]:
    """Return the state sequence induced by the given actions. The output
    list includes the given initial state."""
    states = [state]
    for a in actions:
        states.append(simulation(a.unsqueeze(0), states[-1]))
    return states


def get_segments_hit(
        actions: torch.Tensor,
        state: State,
        ) -> list[int]:
    """Return the list of indices of the segments that the marble hit."""
    tape = get_tape(simulation, None, actions=actions, state=state)
    if_nodes = [
        node
        for node in tape
        if isinstance(node, IfNode)
    ]
    collided_segments = [
        dict(n.variables_in_scope)["i"]
        for n in if_nodes
        if float(n.value) > 0
    ]
    return collided_segments
