import gymnasium as gym
import numpy as np

from typing import SupportsFloat

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType

from loguru import logger

class CrashPenalizingWrapper(gym.RewardWrapper[ObsType, ActType]):
    def __init__(self, env: gym.Env[ObsType, ActType], penalty = 1.0):
        gym.RewardWrapper.__init__(self, env)
        self.penalty = penalty

    def reward(self, reward: SupportsFloat) -> SupportsFloat:
        molecule = self.env.get_wrapper_attr("current_molecule")
        if molecule.crashed:
            logger.bind(task="stats", crash_penalty=-float(self.penalty)).trace(f"")
            return reward - self.penalty
        logger.bind(task="stats", crash_penalty=0.0).trace(f"")
        return reward

