import gymnasium
import mo_gymnasium
from gymnasium import Env
from gymnasium.core import ActType, ObsType, RenderFrame
import gymnasium as gym
import safety_gymnasium as safe_gym
from safety_gymnasium.wrappers import SafetyGymnasium2Gymnasium
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, SubprocVecEnv
from typing import List, Callable, SupportsFloat, Any
import numpy as np
from copy import deepcopy
from collections import deque


def reward_dim(env: Env) -> int:
    try:
        return env.reward_dim
    except AttributeError:
        return reward_dim(env.unwrapped)


class MODummyVecEnv(DummyVecEnv):
    def __init__(self, env_fns: List[Callable[[], gym.Env]]):
        super().__init__(env_fns)
        self.reward_dim = reward_dim(self.envs[0])
        del self.buf_rews
        self.buf_rews = np.zeros((self.num_envs, self.reward_dim), dtype=np.float32)


class MOSubprocVecEnv(SubprocVecEnv):
    pass


class Hopper2d(gymnasium.Env):
    def __init__(self, max_episode_steps=500):

        self._inner = gymnasium.make('Hopper-v5',max_episode_steps=max_episode_steps, frame_skip=5)
        self.reward_dim = 2

    @property
    def observation_space(self):
        return self._inner.observation_space

    @property
    def action_space(self):
        return self._inner.action_space

    def reset(
        self,
        *,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[ObsType, dict[str, Any]]:

        return self._inner.reset(seed=seed, options=options)

    def step(self, action: ActType ):

        obs, reward, done, timeout, info = self._inner.step(action)
        alive_bonus = 1
        reward_x_vel = info['x_velocity'] + alive_bonus
        height = 10 * info["z_distance_from_origin"] + alive_bonus
        return obs, np.asarray([reward_x_vel, height], dtype=np.float32), done, timeout, info



class SafetyGymnasiumMO(gym.Env):
    def __init__(self, env_name,
                 history_length: int = 5,
                 scale_reward: float = 1,
                 scale_cost: float = 1,
                 *args, **kwargs):
        self.env = safe_gym.make(env_name, *args, **kwargs)
        self.reward_dim = 2
        self.scale_reward = scale_reward
        self.scale_cost = scale_cost
        self.history_length = history_length
        self.obs_buffer = deque(maxlen=history_length)

    @property
    def observation_space(self):
        return gymnasium.spaces.Box(low=-np.inf, high=np.inf, shape=(self.env.observation_space.shape[0] * self.history_length, ))

    @property
    def action_space(self):
        return self.env.action_space

    def reset(self, *, seed=None, options=None):
        self.obs_buffer = deque(maxlen=self.history_length)
        for _ in range(self.history_length):
            self.obs_buffer.append(np.zeros(self.env.observation_space.shape))
        obs, info = self.env.reset(seed=seed, options=options)
        self.obs_buffer.append(obs)
        return np.asarray(self.obs_buffer).copy().flatten(), info

    def step(
            self, action: ActType
    ) -> tuple[ObsType, np.ndarray, bool, bool, dict[str, Any]]:
        obs, reward, cost, done, timeout, info = self.env.step(action)
        reward = np.asarray([-cost * self.scale_cost, reward * self.scale_reward,])
        self.obs_buffer.append(obs)
        return np.asarray(self.obs_buffer).copy().flatten(), reward, done, timeout, info

    def render(self) -> RenderFrame | list[RenderFrame] | None:
        return self.env.render()
