# gymnasium_wrapper.py

from __future__ import annotations
import csv
import numpy as np
import gymnasium as gym
from typing import Any, ClassVar
from omnisafe.envs.core import CMDP, env_register, env_unregister

@env_register
@env_unregister
class GymnasiumWrapper(CMDP):
    _support_envs = ['BipedalWalker-v1', 'AccEnv-v1', 'CarRacing-v1', 'Pendulum-v1', 'Cheetah-v1']
    need_auto_reset_wrapper = True
    need_time_limit_wrapper = True

    def __init__(self, env_id: str, *, csv_path: str, **kwargs):
        super().__init__(env_id=env_id, **kwargs)

        # select your benchmark
        from benchmarks.bipedalwalker import  BipedalWalkerEnv
        from benchmarks.acc         import  AccEnv
        from benchmarks.car_racing  import  CarRacingEnv
        from benchmarks.pendulum    import  PendulumEnv
        from benchmarks.cheetah     import  CheetahEnv

        if env_id == 'BipedalWalker-v1':
            self.env = BipedalWalkerEnv()
        elif env_id == 'AccEnv-v1':
            self.env = AccEnv()
        elif env_id == 'CarRacing-v1':
            self.env = CarRacingEnv()
        elif env_id == 'Pendulum-v1':
            self.env = PendulumEnv()
        elif env_id == 'Cheetah-v1':
            self.env = CheetahEnv()
        else:
            raise NotImplementedError(f"Unknown env {env_id!r}")

        self._observation_space = self.env.observation_space
        self._action_space      = self.env.action_space

        # CSV logging
        self._csv_file = open(csv_path, 'w', newline='')
        self._csv_writer = csv.writer(self._csv_file)
        self._csv_writer.writerow(['episode', 'steps', 'cost'])
        self._csv_file.flush()
        self._epcost     = 0
        self._count      = 0
        self._num_episodes = 0

    def set_seed(self, seed: int) -> None:
        self.env.seed(seed)

    def reset(self, seed: int | None = None, options: dict[str,Any]|None=None
             ) -> tuple[np.ndarray, dict]:
        if seed is not None:
            self.set_seed(seed)
        if self._num_episodes>0:
            # flush previous episode
            self._csv_writer.writerow([self._num_episodes, self._count, float(self._epcost)])
            self._csv_file.flush()

        self._num_episodes += 1
        self._count = 0
        self._epcost = 0

        state, info = self.env.reset(seed=seed, options=options)
        # return pure NumPy + Python
        return np.asarray(state, dtype=np.float32), info

    def step(self, action: Any
            ) -> tuple[np.ndarray,float,float,bool,bool,dict]:
        self._count += 1

        # action can be torch.Tensor or np.ndarray
        if hasattr(action, 'detach'):
            action = action.detach().cpu().numpy()

        state, reward, cost, done, trunc, info = self.env.step(action)
        self._epcost += float(cost)

        # ensure pure Python types
        state     = np.asarray(state, dtype=np.float32)
        reward    = float(reward)
        cost      = float(cost)
        done      = bool(done)
        trunc     = bool(trunc)
        info['state_original'] = np.asarray(info.get('state_original', state), dtype=np.float32)

        return state, reward, cost, done, trunc, info

    def render(self, **kwargs):
        return self.env.render(**kwargs)

    def close(self):
        self._csv_file.close()
        return self.env.close()

    @property
    def max_episode_steps(self):
        return getattr(self.env, '_max_episode_steps', None)
