from typing import Any, SupportsFloat

import traceback

import gymnasium as gym
from gymnasium.core import ActType, ObsType

from molecule_movement.statistics import enable_statistics_logger

from loguru import logger
from molecule_movement.logging import log_and_raise


class StatisticsInfoWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    """
    """

    def __init__(
        self,
        env: gym.Env[ObsType, ActType],
        columns: list[str],
        sliding_window_size: int = 1000,
        check: bool = False
    ):
        """ This wrapper will keep track of cumulative rewards and episode lengths.
        """
        gym.utils.RecordConstructorArgs.__init__(self)
        gym.Wrapper.__init__(self, env)

        self.statistics = enable_statistics_logger(maxlen=sliding_window_size)
        self.sliding_window_size = sliding_window_size

        # We will use these for logging in the future
        self.columns = columns
        self.__check = check

        self.episode_count = 0
        self.episode_start_time: float = -1
        self.episode_returns: float = 0.0
        self.episode_lengths: int = 0

    def step(
        self, action: ActType
    ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        """Steps through the environment, recording the episode statistics."""
        obs, reward, terminated, truncated, info = super().step(action)

        try:
            info["mean_movement_travelled"] = self.statistics.last_N_mean("movement_travelled", self.sliding_window_size)
        except ValueError as e:
            self.handle(e)
        except KeyError as e:
            self.handle(e)
        try:
            info["inside_corridor_mean"] = self.statistics.last_N_mean("inside_corridor", self.sliding_window_size)
        except ValueError as e:
            self.handle(e)
        except KeyError as e:
            self.handle(e)
        try:
            info["mean_movement_reward"] = self.statistics.last_N_mean("movement_reward", self.sliding_window_size)
        except ValueError as e:
            self.handle(e)
        except KeyError as e:
            self.handle(e)
        #try:
        #    info["mean_corridor_penalty"] = self.statistics.last_N_mean("corridor_penalty", self.sliding_window_size)
        #except ValueError as e:
        #    self.handle(e)
        #except KeyError as e:
        #    self.handle(e)
        try:
            info["mean_reorientation_reward"] = self.statistics.last_N_mean("reorientation_reward", self.sliding_window_size)
        except ValueError as e:
            self.handle(e)
        except KeyError as e:
            self.handle(e)
        try:
            info["mean_position_reward"] = self.statistics.last_N_mean("position_reward", self.sliding_window_size)
        except ValueError as e:
            self.handle(e)
        except KeyError as e:
            self.handle(e)
        try:
            info["mean_proximity_penalty"] = self.statistics.last_N_mean("proximity_penalty", self.sliding_window_size)
        except ValueError as e:
            self.handle(e)
        except KeyError as e:
            self.handle(e)
        return obs, reward, terminated, truncated, info

    def handle(self, e: Exception):
        #print(traceback.format_exc())
        if self.__check: log_and_raise(e, msg="")


    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[ObsType, dict[str, Any]]:
        """Resets the environment using seed and options and resets the episode rewards and lengths."""
        obs, info = super().reset(seed=seed, options=options)
        return obs, info
