from typing import Any

import gymnasium
from gymnasium.spaces.box import Box
import numpy as np

from offline.utils.suppress_warnings import gym


def gym_box_to_gymnasium_box(space: gym.Space) -> Box:
    return Box(
        dtype=space.dtype, high=space.high, low=space.low  # type: ignore
    )


class GymEnvWrapper(gymnasium.Env):
    def __init__(
        self,
        gym_env_id: str = "",
        make_kwargs: dict[str, Any] | None = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if not gym_env_id:
            raise ValueError("Empty gym env ID.")
        if make_kwargs is None:
            make_kwargs = {}
        self.env: gym.Env[np.float64, np.float32]
        self.env = gym.make(gym_env_id, **make_kwargs)
        self.action_space = gym_box_to_gymnasium_box(self.env.action_space)
        self.observation_space = gym_box_to_gymnasium_box(
            self.env.observation_space
        )

    def render(self):
        assert self.render_mode is not None
        return self.env.render(mode=self.render_mode)

    def reset(self, *, seed=None, options=None):
        del options
        if seed is not None:
            self.env.seed(seed)
        return self.env.reset(), {}

    def step(self, action):
        observations, rewards, done, info = self.env.step(action)
        if done:
            truncated = info.get("TimeLimit.truncated", False)
            terminated = not truncated
        else:
            truncated = terminated = False
        return observations, rewards, terminated, truncated, info
