import gymnasium as gym
import numpy as np

from dataclasses import dataclass

from typing import Generic, Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType
from gymnasium import Env

from molecule_movement import VerticalAction, LateralAction, Matching

from shapely import Point, LineString
import shapely
import math
from molecule_movement.shapes import compute_corridor
from ..colour_utils import Highlight

from loguru import logger

MAX_DISTANCE = 6.0


def signed_lateral_offset(line: LineString, point: Point, eps: float = 1e-3) -> float:
    """
    Signed distance from `point` to `line`.
    Positive on one side, negative on the other.
    """
    # 1) projection parameter along the line
    s = line.project(point)

    # 2) closest point on the line to the point
    C = line.interpolate(s)

    # 3) approximate local tangent around s
    L = line.length
    if L == 0.0:
        # degenerate; no direction, just return unsigned distance
        return point.distance(C)

    s0 = max(0.0, s - eps)
    s1 = min(L, s + eps)
    if s1 == s0:  # very short line
        s0 = max(0.0, s - eps * 0.5)
        s1 = min(L, s + eps * 0.5)

    P0 = line.interpolate(s0)
    P1 = line.interpolate(s1)

    tx, ty = P1.x - P0.x, P1.y - P0.y  # tangent
    vx, vy = point.x - C.x, point.y - C.y  # offset vector

    # 4) signed distance = sign(cross2d(t, v)) * |v|
    cross_z = tx * vy - ty * vx
    dist = math.hypot(vx, vy)
    if dist == 0.0:
        return 0.0
    sign = 1.0 if cross_z >= 0.0 else -1.0
    return sign * dist

@dataclass(frozen=True)
class CorridorConfiguration:
    width: float
    parking_distance: float = 1.5
    parking_buffer: float = 1.0


class CorridorWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    def __init__(self,
                 env: Env[ObsType, ActType],
                 corridor_width: float,
                 parking_distance: float = 1.5,
                 parking_buffer: float = 1.0,
                 corridor_violation_penalty: float = -0.1):
        super().__init__(env)
        gym.utils.RecordConstructorArgs.__init__(self, corridor_width=corridor_width, corridor_violation_penalty=corridor_violation_penalty)
        gym.Wrapper.__init__(self, env)
        self.env = env
        self.corridor_width = corridor_width
        self.corridor_violation_penalty = corridor_violation_penalty
        self.corridor_config = CorridorConfiguration(corridor_width, parking_distance, parking_buffer)
        assert isinstance(env, Env)

        self.observation_space = gym.spaces.Box(np.full((7,),[0.0, 0.0, -np.inf, -np.inf, -np.inf, 0.0,    0.0]),
                                                np.full((7,),[360, 360,  np.inf,  np.inf,  np.inf, 1.0, np.inf]), shape=(7,))

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        if options is None:
            options = {"corridor_config": self.corridor_config}
        else:
            options["corridor_config"] = self.corridor_config
        _, info = self.env.reset(seed=seed,options=options)
        self.current_matching = self.env.get_wrapper_attr("current_matching")
        self.molecule = self.current_matching.molecule
        self.goal = self.current_matching.goal

        self.render()
        self.corridor = compute_corridor(self.current_matching, corridor_width=self.corridor_config.width, parking_distance=self.corridor_config.parking_distance, parking_buffer=self.corridor_config.parking_buffer)
        info["inside_corridor"] = self.__check_inside()
        return self.observation(), info

    def step(self, action: VerticalAction | LateralAction):
        self.current_matching = self.env.get_wrapper_attr("current_matching")
        self.molecule = self.current_matching.molecule
        self.goal = self.current_matching.goal
        self.corridor = compute_corridor(self.current_matching, corridor_width=self.corridor_config.width, parking_distance=self.corridor_config.parking_distance, parking_buffer=self.corridor_config.parking_buffer)
        self.render()
        _, reward, terminated, truncated, info = self.env.step(action)
        inside_corridor = self.__check_inside()
        info["inside_corridor"] = inside_corridor
        logger.bind(task="stats", inside_corridor=int(info["inside_corridor"])).trace("")

        if not inside_corridor:
            reward = reward + self.corridor_violation_penalty
            #terminated = True
        return self.observation(), reward, terminated, truncated, info

    def render(self) -> None:
        if self.env.render_mode == "human" or self.env.render_mode == "rgb_array":
            renderer = self.env.get_wrapper_attr("renderer")
            renderer.clear()
            renderer.render_matching(self.current_matching, highlight=Highlight(enabled=True, colour="", draw_corridor=True), corridor_config=self.corridor_config)

    def increment_matching(self) -> NDArray:
        obs, info = self.env.get_wrapper_attr("increment_matching")()
        self.current_matching = self.env.get_wrapper_attr("current_matching")
        self.molecule = self.current_matching.molecule
        self.goal = self.current_matching.goal
        self.corridor = compute_corridor(self.current_matching, corridor_width=self.corridor_config.width, parking_distance=self.corridor_config.parking_distance, parking_buffer=self.corridor_config.parking_buffer)
        return self.observation(), info

    def observation(self) -> NDArray:
        matching = self.env.get_wrapper_attr("current_matching")
        matching_line = matching.matching_line
        direction_to_goal = self.__directional_vector_to_goal()
        true_distance = matching.distance_to_waypoint(self.molecule)
        clamped_distance = min(
            MAX_DISTANCE, true_distance
        )

        molecule_center = self.molecule.center

        raw_offset = signed_lateral_offset(matching_line, molecule_center)

        d = self.molecule.get_size()[0]
        r = d / 2.0
        corridor_half_width = 1.5 * d
        center_clearance = corridor_half_width - r

        offset_norm = np.clip(raw_offset / center_clearance, -3.0, 3.0)

        return np.array([
            self.molecule.orientation,
            self.__relative_orientation_to_goal(),
            direction_to_goal[0],
            direction_to_goal[1],
            offset_norm,
            int(self.__check_inside()),
            clamped_distance,
        ], dtype=float)

    def get_corridor_geometry(self) -> shapely.Polygon:
        return self.corridor

    def __directional_vector_to_goal(self) -> NDArray:
        if self.current_matching.has_waypoints:
            next_goal = self.current_matching.next_waypoint(self.molecule)
        else:
            next_goal = self.goal.position
        molecule_goal_vector = np.array([next_goal.x - self.molecule.center.x, next_goal.y - self.molecule.center.y])
        if np.linalg.norm(molecule_goal_vector) <= 0.0:
            return np.array([0.0,0.0])
        return molecule_goal_vector / np.linalg.norm(molecule_goal_vector)

    def __relative_angle_to_goal_position(self) -> float:
        molecule_orientation_vector = shapely.affinity.rotate(Point(1,0), self.molecule.orientation, origin=(0,0))
        molecule_orientation_vector = np.array([molecule_orientation_vector.x, molecule_orientation_vector.y])
        molecule_goal_vector = self.__directional_vector_to_goal()
        return np.degrees(np.arccos(np.clip(np.dot(molecule_orientation_vector, molecule_goal_vector), -1.0, 1.0)))

    def __relative_orientation_to_goal(self) -> float:
        return self.molecule.orientation - self.goal.rotation

    def _corridor_reward(self, inside_corridor: bool) -> float:
        reward = 0.0 if inside_corridor else self.corridor_violation_penalty
        logger.bind(task="stats", corridor_penalty=reward).trace("")
        return reward


    def __check_inside(self) -> bool:
        molecule_within_corridor = shapely.within(self.molecule.polygon, self.corridor)
        return molecule_within_corridor

