import gymnasium as gym
import numpy as np

from typing import SupportsFloat

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType

from loguru import logger

class ProximityPenaltyWrapper(gym.RewardWrapper):
    def __init__(self, env,
                 safe_clearance: float,
                 lambda_obs: float = 1.0,
                 use_mean_cones: bool = False):
        super().__init__(env)
        self.safe_clearance = safe_clearance
        self.lambda_obs = lambda_obs
        self.use_mean = use_mean_cones

    def reward(self, reward=0.0):
        sensors = self.get_wrapper_attr("_sensor_readings")
        safe_clearance = self.safe_clearance

        if self.use_mean:
            vals = []
            for s in sensors:
                vals.append(max(0.0, (safe_clearance - float(s)) / safe_clearance) ** 2)
            prox_pen = -self.lambda_obs * (sum(vals) / len(vals))
        else:
            dmin = float(np.min(sensors))
            prox_pen = -self.lambda_obs * (max(0.0, (safe_clearance - dmin) / safe_clearance) ** 2)
            logger.bind(task="stats", proximity_penalty=float(prox_pen)).trace(f"ProximityPenaltyWrapper: {prox_pen}")

        return reward + prox_pen
