import gymnasium as gym
import numpy as np

from dataclasses import dataclass

from typing import Generic, Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType, Wrapper
from gymnasium import Env

from molecule_movement import VerticalAction, LateralAction, Matching

import shapely
from shapely import Point

from loguru import logger

MAX_DISTANCE = 6.0

class CurrentMoleculeObservationWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    def __init__(self,
                 env: Env[ObsType, ActType]):
        super().__init__(env)
        gym.utils.RecordConstructorArgs.__init__(self)
        gym.Wrapper.__init__(self, env)
        self.env = env
        assert isinstance(env, Env)

        self.observation_space = gym.spaces.Box(np.full((5,),[0.0, 0.0, -np.inf, -np.inf,    0.0]),
                                                np.full((5,),[360, 360,  np.inf,  np.inf, np.inf]), shape=(5,))

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        _, info = self.env.reset(seed=seed,options=options)
        self.current_matching = self.env.get_wrapper_attr("current_matching")
        self.molecule = self.current_matching.molecule
        self.goal = self.current_matching.goal
        self.render()
        return self.observation(), info

    def step(self, action: VerticalAction | LateralAction):
        self.current_matching = self.env.get_wrapper_attr("current_matching")
        self.molecule = self.current_matching.molecule
        self.goal = self.current_matching.goal
        _, reward, terminated, truncated, info = self.env.step(action)
        return self.observation(), reward, terminated, truncated, info

    def observation(self) -> NDArray:
        direction_to_goal = self.__directional_vector_to_goal()
        clamped_distance = min(MAX_DISTANCE, self.env.get_wrapper_attr("current_matching").distance_to_waypoint(self.molecule))
        return np.array([self.molecule.orientation,
                         self.__relative_orientation_to_goal(),
                         direction_to_goal[0], direction_to_goal[1],
                         clamped_distance])
                         #self.env.get_wrapper_attr("current_matching").normalized_distance_to_waypoint(self.molecule)])

    def __directional_vector_to_goal(self) -> NDArray:
        if self.current_matching.has_waypoints:
            next_goal = self.current_matching.next_waypoint(self.molecule)
        else:
            next_goal = self.goal.position
        molecule_goal_vector = np.array([next_goal.x - self.molecule.center.x, next_goal.y - self.molecule.center.y])
        if np.linalg.norm(molecule_goal_vector) <= 0.0:
            return np.array([0.0,0.0])
        return molecule_goal_vector / np.linalg.norm(molecule_goal_vector)

    def __relative_angle_to_goal_position(self) -> float:
        molecule_rotation_vector = shapely.affinity.rotate(Point(1,0), self.molecule.rotation, origin=(0,0))
        molecule_rotation_vector = np.array([molecule_rotation_vector.x, molecule_rotation_vector.y])
        molecule_goal_vector = self.__directional_vector_to_goal()
        return np.degrees(np.arccos(np.clip(np.dot(molecule_rotation_vector, molecule_goal_vector), -1.0, 1.0)))

    def __relative_orientation_to_goal(self) -> float:
        return self.molecule.orientation - self.goal.rotation

    def increment_matching(self):
        _, info = self.env.get_wrapper_attr("increment_matching")()
        self.current_matching = self.env.get_wrapper_attr("current_matching")
        self.molecule = self.current_matching.molecule
        self.goal = self.current_matching.goal
        return self.observation(), info
