import gymnasium as gym
import numpy as np

from typing import Generic, Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType, Wrapper
from gymnasium import Env

from molecule_movement.Objects import VerticalAction, LateralAction
from molecule_movement.Simulator import Simulator

from shapely import Point, LineString
import shapely

from loguru import logger

class PcAu111_ObservationWrapper(gym.ObservationWrapper[WrapperObsType, ObsType, ActType]):
    def __init__(self, env: gym.Env[ObsType, ActType]):
        gym.ObservationWrapper.__init__(self, env)
        self.env = env
        self.vocab = self._build_vocab([m.name for m in self.unwrapped.molecules])
        
        self.info = dict()
        self.do_positioning = True
        self.do_reorientation = True
        self.do_reorientation_first = False

        num_sensors = self.env.unwrapped.num_sensors
        

        low = np.array([0.0, 0.0, 0.0, 0.0] + [0.0] * num_sensors + [0.0], dtype=np.float32)
        high = np.array([360.0, 360.0, np.inf, np.inf] + [np.inf] * num_sensors + [1.0], dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)

    def _build_vocab(self, names):
        unique = sorted(set(names))
        return {name: i for i, name in enumerate(unique)}

    def _encode_names(self, names):
        return np.array([self.vocab[name] for name in names], dtype=np.int64)

    # def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        
    #     self.unwrapped.steps = 0
    #     # This matches the closest moiety to the first goal
    #     self.unwrapped.current_matching_index = 0
    #     self.unwrapped._set_goals()
    #     self.unwrapped._set_obstacles()
        
    #     # XXX This is not right now
    #     # if not self.unwrapped.continue_after_done(): self.unwrapped._create_initial_distribution()

    #     self.name_to_molecule_map = {molecule.name : molecule for molecule in self.env.unwrapped.molecules}
        
    #     self.info["done"] = self.env.unwrapped._done
        
    #     self.current_episode = self.unwrapped.get_lastest_episode()+1
    #     self.unwrapped.simulator = Simulator(self.unwrapped.molecules, self.unwrapped.goals)

    #     self.unwrapped._initialize_renderer()
    #     self.unwrapped._get_matching()

    #     # Scan at the center of the molecule
    #     self.unwrapped.center_of_scan_position = self.unwrapped.current_molecule.center
    #     self.unwrapped._get_current_stm_state()

    #     self.molecule = self.unwrapped.current_matching.molecule
    #     self.goal = self.unwrapped.current_matching.goal

    #     if self.unwrapped.renderer:
    #         self.unwrapped.render(None)
    #         self.unwrapped.renderer.update()

    #     self.info['agent_task'] = self.task_scheduler()
    #     return self.observation(), self.info
    
    # def step(self, action):
    #     self.unwrapped.steps += 1
    #     current_matching = self.env.get_wrapper_attr("current_matching")
    #     self.molecule = current_matching.molecule
    #     self.goal = current_matching.goal
        
    #     self.current_molecule_position_before = self.molecule.center
    #     self.current_molecule_rotation_before = self.molecule.rotation

    #     self.unwrapped.perform_action_on_stm(action, 'lateral')
    #     self.unwrapped._get_current_stm_state()

    #     self.unwrapped.render()

    #     if self.info['agent_task'] == 'positioning':
    #         reward = self.env.reward_positioning()
    #     elif self.info['agent_task'] == 'rotation':
    #         pass
    #     reward = self.env.env.env.reward_reorientation()

    #     terminated = self.unwrapped.terminated()
    #     truncated = self.unwrapped.truncated()

    #     print("Reward: ", reward)
    #     print("Terminated: ", terminated)
    #     print("Truncated: ", truncated)

    #     self.info['agent_task'] = self.task_scheduler()
    #     self.info["done"] = self.unwrapped.is_done()
    #     return self.observation(), reward, terminated, truncated, self.info

    # def task_scheduler(self):
    #     """ Sets the task to perform on the current molecule. This includes precise positioning and reorientation.

    #         Returns:
    #             str: The task to perform on the current molecule.
    #             'positioning' if the task is to position the molecule, 'reorientation' if the task is to reorient the molecule.
    #     """
    #     # do positioning first, then reorientation
    #     if not self.do_reorientation_first and self.do_positioning and self.do_reorientation:
    #         if self.unwrapped.reached_goal_position() and self.unwrapped.reached_goal_orientation():
    #             self.unwrapped.increment_matching()
    #             if self.current_matching_index == len(self.goals):
    #                 self.unwrapped.set_done()
    #         elif not self.unwrapped.reached_goal_position():
    #             info = 'positioning'
    #         elif self.unwrapped.reached_goal_position() and not self.unwrapped.reached_goal_orientation():
    #             info = 'rotation'
    #     # do reorientation first, then positioning
    #     elif self.do_reorientation_first and self.do_positioning and self.do_reorientation :
    #         if self.unwrapped.reached_goal_position() and self.unwrapped.reached_goal_orientation():
    #             self.unwrapped.increment_matching()
    #             if self.current_matching_index == len(self.goals):
    #                 self.unwrapped.set_done()
    #         elif not self.unwrapped.reached_goal_position() and self.unwrapped.reached_goal_orientation():
    #             info = 'positioning'
    #         elif not self.unwrapped.reached_goal_position() and not self.unwrapped.reached_goal_orientation():
    #             info = 'rotation'
    #     elif self.do_positioning and not self.do_reorientation:
    #         if self.unwrapped.reached_goal_position():
    #             self.unwrapped.increment_matching()
    #             info = 'positioning'
    #             if self.current_matching_index == len(self.goals):
    #                 self.unwrapped.set_done()
    #     elif not self.do_positioning and self.do_reorientation:
    #         if self.unwrapped.reached_goal_orientation():
    #             self.unwrapped.increment_matching()
    #             info = 'rotation'
    #             if self.current_matching_index == len(self.goals):
    #                 self.unwrapped.set_done()
    #     self.unwrapped.set_agent_task(agent_task=info)
    #     return info

    # def observation(self):

    #     (moiety_types,
    #      moiety_position_nm,
    #      moiety_orientation_rad,
    #      moiety_bbox_width_height_nm,
    #      confidence,
    #      moiety_target_contour_nm,
    #      moiety_matching_reference_contour_nm,
    #      moiety_types_before,
    #      moiety_type_changes
    #      ) = self.unwrapped._get_observation()

    #     current_molecule = self.unwrapped.current_matching.molecule
    #     current_molecule.name = moiety_types
    #     current_molecule.type = moiety_types
    #     current_molecule.type_before = moiety_types_before
    #     current_molecule.center = Point(moiety_position_nm)
    #     current_molecule.set_orientation(np.rad2deg(moiety_orientation_rad))
    #     # current_molecule._adsorption_geometries = self.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, "_adsorption_geometries")

    #     surface_orientation = 0 if not self.include_surface_orientation else current_molecule.rotation
    #     orientation_to_goal = self.__orientation_to_goal()
    #     direction_to_goal = self.__directional_vector_to_goal()
    #     angle_to_goal = self.__angle_to_goal()
    #     distance_to_goal = self.__distance_to_goal()
        

    #     if current_molecule.num_sensors > 0:
    #         sensor_readings = np.full((current_molecule.num_sensors,), np.inf)

    #         for j, sensor in enumerate(current_molecule.sensors):
    #             for other in self.unwrapped.molecules:
    #                 if current_molecule == other: # dont compare moiety with itself
    #                     continue
    #                 if sensor.intersects(other.polygon):
    #                     intersection = shapely.intersection(sensor, other.polygon)
    #                     distance = current_molecule.polygon.distance(intersection)
    #                     sensor_readings[j] = min(sensor_readings[j], distance)
    #     else: 
    #         sensor_readings = np.zeros(self.unwrapped.num_sensors)

    #     obs = np.array(
    #         [surface_orientation, orientation_to_goal, angle_to_goal, distance_to_goal, moiety_type_changes, sensor_readings],
    #         dtype=np.object_
    #         )

    #     return self.flatten(obs)
            
    
    # def flatten(self, obs):
    #     flat = []
    #     for part in obs:
    #         if isinstance(part, (np.ndarray, list)):
    #             flat.extend(np.asarray(part, dtype=np.float32).flatten())
    #         else:
    #             flat.append(np.float32(part))
    #     return np.array(flat, dtype=np.float32)

    # def __orientation_to_goal(self) -> float:
    #     """Calculates the orientation difference between the molecule and the goal in degrees.
    #     Returns:
    #         float: The difference in orientation between the molecule and the goal in degrees.
    #     """
    #     return self.goal.rotation - self.molecule.rotation

    # def __directional_vector_to_goal(self) -> NDArray:
    #     molecule_goal_vector = np.array([self.goal.position.x - self.molecule.center.x, self.goal.position.y - self.molecule.center.y])
    #     return molecule_goal_vector / np.linalg.norm(molecule_goal_vector)
    
    # def __angle_to_goal(self) -> float:
    #     molecule_goal_vector = np.array([self.goal.position.x - self.molecule.center.x, self.goal.position.y - self.molecule.center.y])
    #     return np.arctan2(molecule_goal_vector[1], molecule_goal_vector[0]) - np.deg2rad(self.molecule.rotation)
    
    # def __distance_to_goal(self) -> float:
    #     return np.linalg.norm(np.array([self.goal.position.x - self.molecule.center.x, self.goal.position.y - self.molecule.center.y]))