"""Tasks abstract the concept of "finding the parameters of a system
such that its 'Marble' reaches a target position".
"""
from dataclasses import dataclass
import json
import random
from trajectory_dynamics import State
from trajectory_dynamics import Segment
from trajectory_dynamics import Circle
from trajectory_dynamics import get_trajectory
from trajectory_dynamics import get_segments_hit
from maze import get_serpent
from maze import Maze
from maze import Direction
from maze import get_random_maze
import jsons
import torch
import math


@dataclass(frozen=True, eq=True)
class Goal:
    position: torch.Tensor
    radius: float


def rotate_radians(
        v: torch.Tensor,
        radians: torch.Tensor,
        ) -> torch.Tensor:
    """Returned the vector that results from the counter-clockwise rotation
    of `v` by the given number of `radians`."""
    return torch.stack([
        radians.cos()*v[0]-radians.sin()*v[1],
        radians.sin()*v[0]+radians.cos()*v[1],
    ])


@dataclass(frozen=True, eq=True)
class TaskImpulse:
    """A Task is defined by an initial marble position and velocity, a target
    circle, a list of segment segments and a number of time-steps.

    The task is to find an angle (parametrized in [-1, 1]) such that the
    marble reaches the target circle within the specified number of
    time-steps.

    It is assumed that `timesteps` is greater than 1.
    """
    goal_circle: Goal
    initial_state: State
    max_timesteps: int
    checkpoints: tuple[tuple[float, float], ...]

    def cost(self, s: State) -> torch.Tensor:
        """Return the cost of the given state with respect to the task."""
        # Quadratic cost on the distance of the marble to the center of the
        # goal, shifted so that the cost is zero when the marble is inside a
        # circle with the same center as the goal but half its radius
        return torch.relu(((
            s.marble.position-self.goal_circle.position
        ).norm() - s.marble.radius))

    @torch.inference_mode()
    def get_segments_hit(self, actions: torch.Tensor) -> list[int]:
        """Return the sequence of segments that were hit when running the
        given parameters."""
        return get_segments_hit(actions, self.initial_state)

    @torch.inference_mode()
    def get_trajectory(self, actions: torch.Tensor) -> list[State]:
        """Return the sequence of system states using the given parameters."""
        state = self.initial_state
        return get_trajectory(actions, state)

    @staticmethod
    def random_parameters(_: int, n: int) -> torch.Tensor:
        """Return random parameters for the task."""
        force = [
            [0.0, 0.0]
            for _ in range(n)
        ]
        return torch.tensor(force)

    @property
    def json_str(self) -> str:
        def serialize_tensor(t, *_, **__) -> bool:
            return t.detach().numpy().tolist()
        jsons.set_serializer(serialize_tensor, torch.Tensor)
        return jsons.dumps(
            self,
            jdkwargs=dict(indent=4),
            strip_properties=True
        )

    @staticmethod
    def from_json_str(json_str: str) -> "TaskImpulse":
        # Most objects in a task are of type torch.Tensor,
        # so de-serialize all floats as tensors and convert
        # the few that are not tensors.
        def deserialize_tensor(obj, *_, **__) -> torch.Tensor:
            return torch.tensor(obj)
        json_obj = json.loads(json_str)
        jsons.set_deserializer(deserialize_tensor, list)
        task = jsons.load(json_obj, TaskImpulse)
        return task


@dataclass(frozen=True, eq=True)
class RandomTaskImpulseParameters:
    """torch.Tensor for random `Task` generation."""
    goal_radius: float
    marble_radius: float
    segment_radius: float
    dt: float
    impulse_scale: float
    drag_constant: float
    coefficient_of_restitution: float
    corridor_width: float
    maze_len: int


def get_maze_segments(
        maze: Maze,
        corridor_width: float,
        segment_radius: float,
        ) -> tuple[list[Segment], torch.Tensor, torch.Tensor, float, list[tuple[float, float]]]:
    """Return the line segments that are described by the
    maze, the start position and the goal position, and the total
    travel distance."""
    if len(maze.directions) == 0:
        raise ValueError("Cannot get segments for maze with empty directions!")
    start_position = torch.tensor([0.0, 0.0])
    corridor_start = start_position
    segments = list[Segment]()

    def add_segment(p1: torch.Tensor, p2: torch.Tensor):
        if not torch.is_nonzero((p1-p2).norm()):
            return
        segment = Segment(
            p1=tuple(p1.tolist()),
            p2=tuple(p2.tolist()),
            radius=segment_radius,
        )
        if segment not in segments:
            segments.append(segment)

    # Direction conventions that match the maze definition
    up = torch.tensor([0.0, 1.0])
    right = torch.tensor([1.0, 0.0])
    down = -up
    left = -right

    # Decide the start of the first corridor walls
    if maze.directions[0] == Direction.LEFT:
        left_wall_start = corridor_start+(right+down)*corridor_width/2
        right_wall_start = corridor_start+(right+up)*corridor_width/2
    elif maze.directions[0] == Direction.RIGHT:
        left_wall_start = corridor_start+(left+up)*corridor_width/2
        right_wall_start = corridor_start+(left+down)*corridor_width/2
    elif maze.directions[0] == Direction.UP:
        left_wall_start = corridor_start+(left+down)*corridor_width/2
        right_wall_start = corridor_start+(right+down)*corridor_width/2
    else: # maze.directions[0] == Direction.DOWN:
        left_wall_start = corridor_start+(right+up)*corridor_width/2
        right_wall_start = corridor_start+(left+up)*corridor_width/2

    # Place segment at start of first corridor
    add_segment(left_wall_start, right_wall_start)

    corridor_end = corridor_start

    # Track checkpoints (nominal trajectory of sorts)
    checkpoints: list[tuple[float, float]] = [tuple(corridor_end.tolist())]

    # Compute wall positions
    for i in range(0, len(maze.directions)):
        #  +---------+
        #  |         |
        #  |         |
        #  +----+    |
        #       |    |
        #       |    |
        #       +----+
        #  (UP, LEFT)
        if i > 0 and maze.directions[i] != maze.directions[i-1]:
            # Decide corridor wall end positions
            # There is actually 4 cases only,
            # notice invariance to ordering of (dir[i-1], dir[i])
            directions = [maze.directions[i], maze.directions[i-1]]
            if Direction.LEFT in directions and Direction.UP in directions:
                left_wall_end = corridor_end+(left+down)*corridor_width/2
                right_wall_end = corridor_end+(right+up)*corridor_width/2
            elif Direction.LEFT in directions and Direction.DOWN in directions:
                left_wall_end = corridor_end+(right+down)*corridor_width/2
                right_wall_end = corridor_end+(left+up)*corridor_width/2
            elif Direction.RIGHT in directions and Direction.UP in directions:
                left_wall_end = corridor_end+(left+up)*corridor_width/2
                right_wall_end = corridor_end+(right+down)*corridor_width/2
            elif Direction.RIGHT in directions and Direction.DOWN in directions:
                left_wall_end = corridor_end+(right+up)*corridor_width/2
                right_wall_end = corridor_end+(left+down)*corridor_width/2
            else:
                raise ValueError(f"Invalid maze: maze cannot move from {maze.directions[i-1]} to {maze.directions[i]}!")
            # Add segments
            add_segment(left_wall_start, left_wall_end)
            add_segment(right_wall_start, right_wall_end)

            # Update start positions
            left_wall_start = left_wall_end
            right_wall_start = right_wall_end

        # Update corridor end
        if maze.directions[i] == Direction.LEFT:
            corridor_end = corridor_end + left*corridor_width
        elif maze.directions[i] == Direction.RIGHT:
            corridor_end = corridor_end + right*corridor_width
        elif maze.directions[i] == Direction.UP:
            corridor_end = corridor_end + up*corridor_width
        elif maze.directions[i] == Direction.DOWN:
            corridor_end = corridor_end + down*corridor_width

        # Use corridor end as checkpoint
        checkpoints.append(tuple(corridor_end.tolist()))

    # Close the last corridor
    if maze.directions[-1] == Direction.LEFT:
        left_wall_end = corridor_end+(left+down)*corridor_width/2
        right_wall_end = corridor_end+(left+up)*corridor_width/2
    elif maze.directions[-1] == Direction.RIGHT:
        left_wall_end = corridor_end+(right+up)*corridor_width/2
        right_wall_end = corridor_end+(right+down)*corridor_width/2
    elif maze.directions[-1] == Direction.UP:
        left_wall_end = corridor_end+(left+up)*corridor_width/2
        right_wall_end = corridor_end+(right+up)*corridor_width/2
    else: # maze.directions[-1] == Direction.DOWN:
        left_wall_end = corridor_end+(right+down)*corridor_width/2
        right_wall_end = corridor_end+(left+down)*corridor_width/2
    add_segment(left_wall_start, left_wall_end)
    add_segment(right_wall_start, right_wall_end)

    # Place segment between end of last corridor
    add_segment(left_wall_end, right_wall_end)

    # Add final checkpoint at goal position
    goal_position = corridor_end
    checkpoints.append(tuple(goal_position.tolist()))

    # Total distance is simply number of directions
    total_distance = len(maze.directions)*corridor_width
    return segments, start_position, goal_position, total_distance, checkpoints


def get_marble_speed(
        marble_radius: float,
        dt: float,
        impulse_scale: float,
        drag_constant: float,
        coefficient_of_restitution: float,
        ) -> float:
    """Return the number of distance units traveled for dt units of
    time."""
    timestep_n = 100
    initial_state = State(
        segments=tuple(),
        marble=Circle(
            position=torch.tensor([0.0, 0.0]),
            velocity=torch.tensor([0.0, 0.0]),
            radius=marble_radius,
        ),
        dt=dt,
        impulse_scale=impulse_scale,
        drag_constant=drag_constant,
        coefficient_of_restitution=coefficient_of_restitution,
    )
    task = TaskImpulse(
        initial_state=initial_state,
        # Placeholder goal
        goal_circle=Goal(position=torch.tensor([0.0, 10000.0]), radius=1),
        max_timesteps=timestep_n+1,
        # Placeholder checkpoints
        checkpoints=tuple(),
    )
    action = torch.tensor([1.0, 0.0]).expand(timestep_n+1, -1)
    systems = task.get_trajectory(action)
    final_position = systems[timestep_n].marble.position
    return final_position.norm()/(timestep_n)


def random_task_impulse(
        parameters: RandomTaskImpulseParameters,
        seed: int
        ) -> TaskImpulse:
    """Return a random task sampled using the given `parameters`."""
    # Generate a maze
    _random = random.Random(seed)
    # Choose between a serpent or a random maze
    #if _random.choice([True, False]):
    if False:
        serpent_width = _random.randint(1, 4)
        serpent_height = _random.randint(1, 4)
        maze = get_serpent(
            maze_len=parameters.maze_len,
            serpent_width=serpent_width,
            serpent_height=serpent_height,
        )
    else:
        maze_seed = str(_random.random())
        maze = get_random_maze(
            maze_len=parameters.maze_len,
            seed=maze_seed,
        )

    # Get maze as list of line segments
    scale = parameters.corridor_width
    maze_segments = get_maze_segments(
        maze=maze,
        corridor_width=parameters.corridor_width,
        segment_radius=parameters.segment_radius,
    )
    segments, start_position, end_position, travel_distance, checkpoints = maze_segments

    # Generate task
    initial_state = State(
        segments=tuple(segments),
        marble=Circle(
            position=start_position,
            velocity=torch.tensor([0.0, 0.0]),
            radius=parameters.marble_radius,
        ),
        dt=parameters.dt,
        impulse_scale=parameters.impulse_scale,
        drag_constant=parameters.drag_constant,
        coefficient_of_restitution=parameters.coefficient_of_restitution,
    )
    goal = Goal(
        position=end_position,
        radius=parameters.goal_radius,
    )
    canonical_speed = get_marble_speed(
        marble_radius=parameters.marble_radius,
        dt=parameters.dt,
        impulse_scale=parameters.impulse_scale,
        drag_constant=parameters.drag_constant,
        coefficient_of_restitution=parameters.coefficient_of_restitution,
    )
    canonical_distance = travel_distance*scale
    canonical_timesteps = math.ceil(canonical_distance/canonical_speed)
    return TaskImpulse(
        initial_state=initial_state,
        goal_circle=goal,
        max_timesteps=canonical_timesteps*8,
        checkpoints=tuple(checkpoints),
    )


def get_tasks(
        n: int,
        task_parameters: RandomTaskImpulseParameters,
        seed:  str,
        ) -> tuple[TaskImpulse, ...]:
    """Helper function to generate random tasks and their solutions in
    parallel."""
    tasks = list()
    _random = random.Random(seed)
    while len(tasks) < n:
        try:
            task = random_task_impulse(
                task_parameters,
                _random.randint(0, 100000000)
            )
            tasks.append(task)
        except ValueError:
            pass
    return tuple(tasks)


def get_quality(parameters: torch.Tensor, task: TaskImpulse) -> float:
    """Return the normalized "closeness" to the goal, with 0
    on initial state and 1 when the goal is reached."""
    trajectory = get_trajectory(parameters, task.initial_state)
    goal = task.goal_circle

    def get_distance_to_goal(state):
        marble = state.marble
        distance = (marble.position-goal.position).norm()-goal.radius
        return float(distance)

    final_distance = get_distance_to_goal(trajectory[-1])
    starting_distance = get_distance_to_goal(trajectory[0])
    return 1 - final_distance/starting_distance
