import gymnasium as gym
import numpy as np

from typing import SupportsFloat

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType

from loguru import logger

def angle_to_index(angle_deg, symmetry_angle: int, num_bins: int) -> int:
    """Map angle difference to an index, with 0 indicating goal alignment."""
    bin_size = symmetry_angle / num_bins
    index = int(angle_deg // bin_size)
    return index

class RotationRewardWrapper(
        gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
        ):
    """
    This reward function rewards the agent for rotating towards the goal orientation and penalizes the agent for rotating away from the goal orientation.
    """
    def __init__(self, env: gym.Env[ObsType, ActType], weight: float = 1.0, rotation_reward: float = 1.0, rotation_penalty: float = 0.1, precision: float = 1.0):
        gym.RewardWrapper.__init__(self, env)
        gym.utils.RecordConstructorArgs.__init__(self, rotation_reward=rotation_reward, rotation_penalty=rotation_penalty)
        self.rotation_penalty = rotation_penalty
        self.rotation_reward = rotation_reward
        self.weight = weight
        self.precision = precision

    def reward(self, reward: SupportsFloat) -> SupportsFloat:
        current_matching = self.env.get_wrapper_attr("current_matching")
        molecule = current_matching.molecule
        goal = current_matching.goal
        if np.abs(molecule.orientation - goal.rotation) <= self.precision:
            rotation_reward = self.rotation_reward
        else:
            rotation_reward = -self.rotation_penalty

        logger.bind(task="stats", rotation_reward=float(self.weight * rotation_reward)).trace("")
        return reward + self.weight * rotation_reward

class ReorientationRewardWrapper(gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs):
    """
    This reward function rewards the agent for rotating towards the goal orientation and penalizes the agent for rotating away from the goal orientation.
    """
    def __init__(self, env: gym.Env[ObsType, ActType], weight=1.0, reward_at_goal_orientation: float = 2):
        gym.RewardWrapper.__init__(self, env)
        self.reward_at_goal_orientation = reward_at_goal_orientation
        self.reward_no_rotation = -1.5
        self.reward_reorientation = 0.5
        self.cost_reorientation = -1.5
        self.reward_stayed_on_goal_orientation = 2
        self.reward_moved_from_goal_orientation = -2
        gym.utils.RecordConstructorArgs.__init__(self,
                                                 reward_at_goal_orientation=self.reward_at_goal_orientation,
                                                 reward_no_rotation=self.reward_no_rotation,
                                                 reward_reorientation=self.reward_reorientation,
                                                 cost_reorientation=self.cost_reorientation,
                                                 reward_stayed_on_goal_orientation=self.reward_stayed_on_goal_orientation,
                                                 reward_moved_from_goal_orientation=self.reward_moved_from_goal_orientation)
        self.weight = weight

    def reward(self, reward: SupportsFloat = 0) -> SupportsFloat:
        current_matching = self.env.get_wrapper_attr("current_matching")
        molecule = current_matching.molecule
        goal = current_matching.goal
        goal_orientation = goal.rotation
        current_molecule_orientation = molecule.orientation
        previous_molecule_orientation = molecule.previous_orientation
        if self.env.get_wrapper_attr("current_distance") > self.env.get_wrapper_attr("SUCCESS_DISTANCE"):
            logger.bind(task="stats", reorientation_reward=0).trace("")
            return reward

        orientation_to_goal_before_action = goal_orientation - previous_molecule_orientation
        orientation_to_goal_after_action  = goal_orientation - current_molecule_orientation

        angle_before_index = np.array(angle_to_index(orientation_to_goal_before_action, molecule.angle_symmetry._symmetry_angle_moiety, molecule.angle_symmetry._num_adsorption_angles))
        angle_after_index  = np.array(angle_to_index(orientation_to_goal_after_action,  molecule.angle_symmetry._symmetry_angle_moiety, molecule.angle_symmetry._num_adsorption_angles))

        # Determine the type of rotation
        rotated_towards_goal        = abs(angle_before_index) > abs(angle_after_index)
        rotated_away_from_goal      = abs(angle_before_index) < abs(angle_after_index)

        rotated                     = angle_before_index != angle_after_index
        equivalent_orientation      = abs(angle_before_index) == abs(angle_after_index)
        stayed_on_goal_orientation  = (angle_before_index == 0) & (angle_after_index == 0)
        rotate_at_goal_orientation  = (angle_before_index != 0) & (angle_after_index == 0)
        moved_from_goal_orientation = (angle_before_index == 0) & (angle_after_index != 0)

        # Initialize rewards
        reward_at_goal_orientation = reward_no_rotation = reward_reorientation = reward_correctness = reward_incorrectness = 0

        # reward rotating towards goal orientation
        if rotate_at_goal_orientation:
            reward_at_goal_orientation = self.reward_at_goal_orientation
        else:
            # penaltize no rotation towards the goal and rotation towards equivalent orientations
            if (not stayed_on_goal_orientation and
               (not rotated or equivalent_orientation)):
                reward_no_rotation = self.reward_no_rotation

            # reward and penaltize reorientation: rotated towards and away from goal
            if rotated_towards_goal:     reward_reorientation = self.reward_reorientation
            elif rotated_away_from_goal: reward_reorientation = self.cost_reorientation

            # reward correctness: at goal angle and stay at goal angle, write as indices
            if stayed_on_goal_orientation: reward_correctness = self.reward_stayed_on_goal_orientation

            # penaltize incorrectness: at goal angle and rotated away from goal angle
            if moved_from_goal_orientation: reward_incorrectness = self.reward_moved_from_goal_orientation

        complete_reward = self.weight * (reward_at_goal_orientation + reward_reorientation + reward_no_rotation)
        logger.bind(task="stats", reorientation_reward=complete_reward).trace("")
        return reward + complete_reward
