import numpy as np
import math

from typing import Optional
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium.core import ActType, ObsType, WrapperActType

from molecule_movement import VerticalAction, LateralAction
from molecule_movement.Simulator import Simulator
from molecule_movement.Molecule import find_nearest
from molecule_movement.wrapper.TaskSchedulingWrapper import Task

from shapely import Point
import shapely

from loguru import logger


class ManipulateCurrentMoleculeDiscreteWrapper(gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs):
    def __init__(self, env: gym.Env[ObsType, ActType]):
        gym.Wrapper.__init__(self, env)
        gym.utils.RecordConstructorArgs.__init__(self)

    def action(self, action: WrapperActType) -> ActType:
        assert False, "This wrapper will be deprecated."
        current_molecule = self.env.unwrapped.get_wrapper_attr("current_molecule")
        if isinstance(action, np.int64):
            action = Point(round(-2.1 + 0.3*(action%15),1),round(-2.1 + 0.3*(action//15),1))
        else:
            action = action.xy
        action = shapely.affinity.rotate(action, current_molecule.rotation, origin=(0,0))
        action = shapely.affinity.translate(current_molecule.center, action.x, action.y)
        return VerticalAction(action, z=9999, V=9999)

class ManipulateCurrentMoleculeWrapper(gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs):
    def __init__(self,
                 env: gym.Env[ObsType, ActType],
                 x_: tuple[float, float, float],
                 y_: tuple[float, float, float]):
        gym.Wrapper.__init__(self, env)
        gym.utils.RecordConstructorArgs.__init__(self, action_space_x=x_, actions_space_y=y_)
        self.action_space = gym.spaces.Box(np.full((2,), [x_[0], y_[0]]), np.full((2,), [x_[1], y_[1]]))
        self.x_step = x_[2]
        self.y_step = y_[2]
        self.x_space = np.arange(self.action_space.low[0], self.action_space.high[0], self.x_step)
        self.y_space = np.arange(self.action_space.low[1], self.action_space.high[1], self.y_step)

    def action(self, action: WrapperActType) -> ActType:
        assert isinstance(action, np.ndarray) and action.shape == (2,) or len(action) == 2
        action = self._translate_action(action)
        return np.array([action.x, action.y])

    def _translate_action(self, action):
        current_molecule = self.env.unwrapped.get_wrapper_attr("current_molecule")
        action = Point(round(action[0],1), round(action[1],1))
        action = shapely.affinity.rotate(action, current_molecule.orientation, origin=(0,0))
        action = shapely.affinity.translate(current_molecule.center, action.x, action.y)
        return action


class LateralManipulateCurrentMoleculeWrapper(gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs):
    """ Lateral Manipulate Current Molecule Wrapper is a wrapper that allows to manipulate the current molecule via start and end points, a voltage and a z coordinate."""

    def __init__(self,
                env: gym.Env[ObsType, ActType],
                x: tuple[float, float, float],
                y: tuple[float, float, float],
                x_dest: tuple[float, float, float],
                y_dest: tuple[float, float, float],
                z: tuple[float, float, float],
                V: tuple[float, float, float]
                ):
        gym.Wrapper.__init__(self, env)
        gym.utils.RecordConstructorArgs.__init__(self)
        self.action_space = gym.spaces.Box(
            np.full((6,), [x[0], y[0], x_dest[0], y_dest[0], z[0], V[0]]),
            np.full((6,), [x[1], y[1], x_dest[1], y_dest[1], z[1], V[1]])
            )

        self.x_step = x[2]
        self.y_step = y[2]
        self.x_dest_step = x_dest[2]
        self.y_dest_step = y_dest[2]
        self.z_step = z[2]
        self.V_step = V[2]
        self.action_space_x = np.linspace(self.action_space.low[0], self.action_space.high[0], int(np.round((self.action_space.high[0] - self.action_space.low[0])/self.x_step,0)) + 1)
        self.action_space_y = np.linspace(self.action_space.low[1], self.action_space.high[1], int(np.round((self.action_space.high[1] - self.action_space.low[1])/self.y_step,0)) + 1)
        try:
            self.action_space_dest_x = np.linspace(self.action_space.low[2], self.action_space.high[2], int(np.round((self.action_space.high[2] - self.action_space.low[2])/self.x_dest_step,0)) + 1)
            self.action_space_dest_y = np.linspace(self.action_space.low[3], self.action_space.high[3], int(np.round((self.action_space.high[3] - self.action_space.low[3])/self.y_dest_step,0)) + 1)
        except:
            self.action_space_dest_x = None
            self.action_space_dest_y = None
        assert self.action_space_dest_x is not None and self.action_space_dest_y is not None
        self.action_space_z = np.linspace(self.action_space.low[4], self.action_space.high[4], int(np.round((self.action_space.high[4] - self.action_space.low[4])/self.z_step,0)) + 1)
        self.action_space_V = np.linspace(self.action_space.low[5], self.action_space.high[5], int(np.round((self.action_space.high[5] - self.action_space.low[5])/self.V_step,0)) + 1)


    def action(self, action: WrapperActType) -> ActType:
        return self._absolute_action(action)

    def _absolute_action(self, action) -> VerticalAction | LateralAction:
        """Converts the action to an absolute action based on the current molecule's position."""
        current_molecule = self.env.unwrapped.get_wrapper_attr("current_molecule")
        current_position = current_molecule.center
        current_goal = self.env.unwrapped.get_wrapper_attr("goal_position")
        #current_task = self.get_wrapper_attr("current_task")
        assert isinstance(action, np.ndarray) and action.shape == (6,) or len(action) == 6
        goal_angle_rad = np.atan2(current_goal.y - current_position.y, current_goal.x - current_position.x)
        action = np.array([find_nearest(self.action_space_x, action[0]),
                           find_nearest(self.action_space_y, action[1]),
                           find_nearest(self.action_space_dest_x, action[2]),
                           find_nearest(self.action_space_dest_y, action[3]),
                           find_nearest(self.action_space_z, action[4]),
                           find_nearest(self.action_space_V, action[5])])
        #logger.bind(task="stats", sampled_action=LateralAction(Point(action[0], action[1]), Point(action[2], action[3]), z=action[4], V=action[5])).trace("")

        action_start = Point(action[0], action[1])
        action_start = shapely.affinity.rotate(action_start, current_molecule.orientation, origin=(0,0))
        action_start = shapely.affinity.translate(action_start, current_molecule.center.x, current_molecule.center.y)
        action_end = Point(action[2], action[3])
        action_end = shapely.affinity.rotate(action_end, current_molecule.orientation, origin=(0,0))
        action_end = shapely.affinity.translate(action_end, current_molecule.center.x, current_molecule.center.y)

        action_z = action[4]
        action_V = action[5]
        return np.array([action_start.x, action_start.y, action_end.x, action_end.y, action_z, action_V])

    def _compute_start_position(self, end_position, start_offset=1.25):
        """Computes the start position for lateral manipulation based on molecule position."""
        molecule_pos = self.env.unwrapped.get_wrapper_attr("current_molecule").center
        molecule_max_size = max(self.env.unwrapped.get_wrapper_attr("current_molecule").shape_size_x,
                                self.env.unwrapped.get_wrapper_attr("current_molecule").shape_size_y)
        goal_angle = np.atan2(end_position.y - molecule_pos.y, end_position.x - molecule_pos.x)
        start_angle = goal_angle - np.pi  # Opposite direction
        return Point(
            molecule_pos.x + start_offset * molecule_max_size * np.cos(start_angle),
            molecule_pos.y + start_offset * molecule_max_size * np.sin(start_angle)
        )
