# import all we need
from __future__ import annotations

import random
import omnisafe
from typing import Any, ClassVar

import csv
import torch
from gymnasium import spaces

from omnisafe.envs.core import CMDP, env_register, env_unregister
from benchmarks.bipedalwalker import  BipedalWalkerEnv
from benchmarks.lunar_lander import LunarLanderEnv
from benchmarks.acc import AccEnv

class GymnasiumWrapper(CMDP):
    _support_envs: ClassVar[list[str]] = ['CustomLunarLander-v2', 'CustomBipedalWalker-v1', "AccEnv-v1"]  # Supported task names

    need_auto_reset_wrapper = True  # Whether `AutoReset` Wrapper is needed
    need_time_limit_wrapper = True  # Whether `TimeLimit` Wrapper is needed

    def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None:
        if env_id == 'CustomBipedalWalker-v1':
            print(f"Creating env - CustomBipedalWalker-v1")
            self.env = BipedalWalkerEnv()
        elif env_id == 'CustomLunarLander-v2':
            print(f"Creating env - CustomLunarLander-v2")
            self.env = LunarLanderEnv()
        elif env_id == 'AccEnv-v1':
            print(f"Creating env - AccEnv")       
            self.env = AccEnv()
        self._count = 0 
        self._num_envs = num_envs
        self._observation_space = self.env.observation_space
        self._action_space = self.env.action_space
        self._epcost = 0
        self._resetingepcost = 0
        self._num_episodes = 0
        # #passing no reduced dims in env creation
        # if env_id == 'BipedalWalker-v1':
        #
        # else:
        #     raise NotImplementedError

    def set_seed(self, seed: int) -> None:
        self.env.seed(seed)

    def reset(
            self,
            seed: int | None = None,
            options: dict[str, Any] | None = None,
    ) -> tuple[torch.Tensor, dict]:
        
        if self._num_episodes != 0:
            print(f"For episode num {self._num_episodes}  Steps count? : {self._count}, Cost: {self._epcost}, Reset Cost:{self._resetingepcost}")
            
        self._num_episodes += 1

        if seed is not None:
            self.set_seed(seed)
        self._count = 0
        self._resetingepcost = 0
        state, info = self.env.reset()
        return ( torch.as_tensor(state), info )
    
    def render(self) -> Any:
        self.env.render()

    @property
    def max_episode_steps(self) -> None:
        """The max steps per episode."""
        return self.env._max_episode_steps

    def close(self) -> None:
        return self.env.close()

    def step(
            self,
            action: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:
        self._count += 1
        action_np = action.detach().cpu().numpy()
        state, reward, cost, done, truncation, info = self.env.step(action_np)
          
        self._epcost += cost
        self._resetingepcost += cost
        obs = torch.tensor(state, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)
        cost = torch.tensor(cost, dtype=torch.float32)
        done = torch.tensor(done, dtype=torch.float32)  # Convert boolean to tensor
        truncation = torch.tensor(truncation, dtype=torch.float32)  # Convert boolean to tensor
        # final_info = {
        #     'final_observation': obs,
        #     'state_original': torch.tensor(info['state_original'], dtype=torch.float32)
        # }
        if 'final_observation' in info:
            info['final_observation'] = np.array(
                [
                    array if array is not None else np.zeros(obs.shape[-1])
                    for array in info['final_observation']
                ],
            )
            # Convert the last observation recorded in info into a torch tensor.
            info['final_observation'] = torch.as_tensor(
                info['final_observation'],
                dtype=torch.float32,
                device=self._device,
            )

        return obs, reward, cost, done, truncation, info