

# import all we need
from __future__ import annotations

import random
import omnisafe
import csv
from typing import Any, ClassVar
import numpy as np

import torch
from gymnasium import spaces

from omnisafe.envs.core import CMDP, env_register, env_unregister
from benchmarks.bipedalwalker import  BipedalWalkerEnv
from benchmarks.acc import AccEnv
from benchmarks.car_racing import CarRacingEnv
from benchmarks.pendulum import PendulumEnv
from benchmarks.cheetah import CheetahEnv
from benchmarks.hopper import HopperEnv
from benchmarks.ant import AntEnv
from benchmarks.humanoid import HumanoidEnv

class GymnasiumWrapper(CMDP):
    _support_envs: ClassVar[list[str]] = ['BipedalWalker-v1', "AccEnv-v1", "CarRacing-v1", "Pendulum-v1", "Cheetah-v1", "Hopper-v1", "Ant-v1", "HumanoidEnv-v1"]  # Supported task names

    need_auto_reset_wrapper = True  # Whether `AutoReset` Wrapper is needed
    need_time_limit_wrapper = True  # Whether `TimeLimit` Wrapper is needed

    def __init__(self, env_id: str, *, csv_path:str, **kwargs) -> None:
        
        super().__init__(env_id=env_id, **kwargs)

        #passing no reduced dims in env creation
        if env_id == 'BipedalWalker-v1':
            self.env = BipedalWalkerEnv()
        elif env_id == 'AccEnv-v1':
            self.env = AccEnv()         
        elif env_id == 'CarRacing-v1':
            self.env = CarRacingEnv()  
        elif env_id == 'Pendulum-v1':
            self.env = PendulumEnv()
        elif env_id == 'Cheetah-v1':
            self.env = CheetahEnv()
        elif env_id == 'Hopper-v1':
            self.env = HopperEnv()
        elif env_id == 'Ant-v1':
            print("here")
            self.env = AntEnv()
        elif env_id == 'HumanoidEnv-v1':
            self.env = HumanoidEnv()
        else:
            raise NotImplementedError

        self._observation_space = self.env.observation_space
        self._action_space = self.env.action_space
        self._count = 0
        self._num_envs = 1
        self._epcost = 0
        self._num_episodes = 0
  
     
        # print(f"CSV path: {csv_path}")
        self._csv_file = open(csv_path, 'w', newline='')
        self._csv_writer = csv.writer(self._csv_file)
        self._csv_writer.writerow(['episode', 'steps', 'cost'])
        self._csv_file.flush()

    def set_seed(self, seed: int) -> None:
        self.env.seed(seed)
        random.seed(seed)

    def reset(
            self,
            seed: int | None = None,
            options: dict[str, Any] | None = None,
    ) -> tuple[torch.Tensor, dict]:
        if seed is not None:
            self.set_seed(seed)  
        if self._num_episodes != 0:
            print(f"For episode num {self._num_episodes}  Steps count? : {self._count}, Cost: {self._epcost}")
            # write to CSV
            self._csv_writer.writerow([
                self._num_episodes,
                self._count,
                float(self._epcost),
            ])
            self._csv_file.flush()


        self._num_episodes += 1
        self._count = 0
        state, info = self.env.reset()
        return  torch.as_tensor(state, dtype=torch.float32), info

    def render(self) -> Any:
        return self.env.render()


    @property
    def max_episode_steps(self) -> None:
        """The max steps per episode."""
        return self.env._max_episode_steps

    def close(self) -> None:
        return self.env.close()

    def step( self, action: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: 
        self._count += 1 
        state, reward, cost, done, truncation, info = self.env.step(action) 
        obs = torch.as_tensor(state, dtype=torch.float32) 
        reward = torch.as_tensor(reward, dtype=torch.float32) 
        cost = torch.as_tensor(cost, dtype=torch.float32) 
        self._epcost += cost 
        done = torch.as_tensor(done, dtype=torch.bool) # Convert boolean to tensor 
        truncation = torch.as_tensor(truncation, dtype=torch.bool) # Convert boolean to tensor 
        final_info = { 'state_original': torch.as_tensor(info['state_original'], dtype=torch.float32) } 
        return obs, reward, cost, done, truncation, final_info

    def get_cost_from_obs_tensor(self, obs: torch.Tensor, is_binary: bool = True) -> torch.Tensor:
        return self.env.get_cost_from_obs_tensor(obs, is_binary)

# env = GymnasiumWrapper(env_id='BipedalWalker-v1')
# env.reset(seed=0)
# n = 10
# while n > 0:
#     action = env.action_space.sample()
#     obs, reward, cost, terminated, truncated, info = env.step(action)
#     print('-' * 20)
#     print(f'obs: {obs}')
#     print(f'reward: {reward}')
#     print(f'cost: {cost}')
#     print(f'terminated: {terminated}')
#     print(f'truncated: {truncated}')
#     print('*' * 20)
#     if terminated or truncated:
#         break
#     n=n-1
# env.close()