import gymnasium as gym
import numpy as np

from typing import SupportsFloat

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType

from loguru import logger

class NormalizeDistanceRewardWrapper(gym.RewardWrapper[ObsType, ActType]):
    def __init__(self, env: gym.Env[ObsType, ActType]):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward: SupportsFloat) -> SupportsFloat:
        normalized_reward = self.env.get_wrapper_attr("current_distance") / self.env.get_wrapper_attr("goal_distance")
        return reward + normalized_reward

class DistanceRewardBuffWrapper(gym.RewardWrapper[ObsType, ActType]):
    def __init__(self, env: gym.Env[ObsType, ActType]):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward: SupportsFloat) -> SupportsFloat:
        current_matching = self.env.get_wrapper_attr("current_matching")
        initial_distance = current_matching.goal.position.distance(current_matching.molecule.starting_position)
        current_distance = current_matching.goal.position.distance(current_matching.molecule.center)
        reward_buff = 0.1 * round(1 - current_distance / initial_distance,1)
        return reward + reward_buff

class MovementRewardWrapper(gym.RewardWrapper[ObsType, ActType]):
    def __init__(self, env: gym.Env[ObsType, ActType], weight: float = 1.0):
        gym.RewardWrapper.__init__(self, env)
        self.weight = weight

    def reward(self, reward: SupportsFloat) -> SupportsFloat:
        distance_before_movement = self.unwrapped.distance_before_action
        current_distance = self.unwrapped.current_distance
        travelled = (distance_before_movement - current_distance) - 0.15
        action_max = self.unwrapped.get_wrapper_attr("current_molecule").stochastic_updates.maximum_movement
        movement_reward = np.tanh(travelled / (0.5*action_max) )
        logger.bind(task="stats", movement_reward=float(self.weight * movement_reward)).trace(f"MovementRewardWrapper: {self.weight * movement_reward}")
        return reward + self.weight * movement_reward

class PerStepCostWrapper(
        gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
        ):
    def __init__(self, env: gym.Env[ObsType, ActType], cost: float = 0.01):
        gym.utils.RecordConstructorArgs.__init__(self, per_step_cost=cost)
        gym.RewardWrapper.__init__(self, env)
        self.per_step_cost = cost
    def reward(self, reward: SupportsFloat) -> SupportsFloat:
        return reward - self.per_step_cost

class PositionRewardWrapper(
        gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
        ):
    def __init__(self, env: gym.Env[ObsType, ActType],
                 linear_slope: float = -0.25,
                 linear_offset: float = 0.4,
                 exp_slope: float = 6,
                 exp_offset: float = 0.001,
                 exp_magnitude: float = 2,
                 weight: float = 1.0
                 ):
        gym.utils.RecordConstructorArgs.__init__(self)
        gym.RewardWrapper.__init__(self, env)
        self.weight = weight
        self.linear_slope = linear_slope
        self.linear_offset = linear_offset

        self.exp_slope = exp_slope
        self.exp_offset = exp_offset
        self.exp_magnitude = exp_magnitude

    def reward(self, reward = 0.0) -> SupportsFloat:
        current_distance = self.env.get_wrapper_attr("current_distance")
        # If the target_action_space is used action_max is given by the target action space

        if self.unwrapped.get_wrapper_attr("current_molecule").action_space_translation_dest_x is None:
            action_max = self.unwrapped.get_wrapper_attr("current_molecule").action_space_translation_x.max()
        else:
            action_max = self.unwrapped.get_wrapper_attr("current_molecule").action_space_translation_dest_x.max()
        # if current_distance / action_max > 1.0:
        #     return reward

        #linear_reward_term = self.linear_slope * current_distance + self.linear_offset
        #exp_reward_term = self.exp_magnitude * 2.718282 ** (-self.exp_slope * np.abs(current_distance)) - self.exp_offset
        #logger.bind(task="stats", position_reward=float(self.weight * (linear_reward_term + exp_reward_term))).trace(f"PositionRewardWrapper: {self.weight * (linear_reward_term + exp_reward_term)}")
        #return reward + self.weight * (linear_reward_term + exp_reward_term)

        distance_before_movement = self.unwrapped.distance_before_action
        current_distance = self.unwrapped.current_distance
        if distance_before_movement > current_distance:
            linear_reward_term = self.linear_slope * current_distance + self.linear_offset
            exp_reward_term = self.exp_magnitude * 2.718282 ** (-self.exp_slope * np.abs(current_distance)) - self.exp_offset
            logger.bind(task="stats", position_reward=float(self.weight * (linear_reward_term + exp_reward_term))).trace(f"PositionRewardWrapper: {self.weight * (linear_reward_term + exp_reward_term)}")
            return reward + self.weight * (linear_reward_term + exp_reward_term)
        else:
            no_progress_penalty = 1.0
            logger.bind(task="stats", position_reward=float(-no_progress_penalty)).trace(f"PositionRewardWrapper: {self.weight * (-no_progress_penalty)}")
            return reward - no_progress_penalty


