# following code is manually copied from stable-baseline3, as it depends on an older version of gym, but I don't want to solve this version conflict

from utils.common import AllowedStates, Info, Action, Reward, S, LazyFrames
import gym
from torchvision import transforms as T
import numpy as np
from gym import spaces
from collections import deque
from typing import (
    Optional,
    Any,
    TypeVar,
    cast,
    List,
    Union,
    Dict,
    Generic,
    Callable,
    Tuple,
)
import torch
import numpy as np
from gym.spaces import Box
from gym import ObservationWrapper
from typing import Callable
import os
import gym
from typing import Callable, Optional

from gym import logger
import json
import os
import os.path
import pkgutil
import subprocess
import tempfile
from io import StringIO
from copy import deepcopy
import distutils.spawn
import distutils.version
import numpy as np

from gym import error, logger
from utils.episode import SingleEpisode
from utils.step import NotNoneStep, Step
from utils.transition import Transition
from utils.reporter import Reporter, get_reporter, ReportTrait

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class FireResetEnv(gym.Wrapper):
    """
    Take action on reset for environments that are fixed until firing.

    :param env: the environment to wrap
    """

    def __init__(self, env: gym.Env):
        gym.Wrapper.__init__(self, env)
        assert env.unwrapped.get_action_meanings()[1] == "FIRE"
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs) -> np.ndarray:
        self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset(**kwargs)
        return obs


class SkipFrames(gym.Wrapper):
    """Skip timesteps: repeat action, accumulate reward, take last obs."""

    def __init__(self, env, skip: int = 4):
        super(SkipFrames, self).__init__(env)
        self.skip = skip

    def step(self, action):
        total_reward = 0
        for i in range(self.skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            info["steps"] = i + 1
            if done:
                break
        return obs, total_reward, done, info


class NoopResetEnv(gym.Wrapper):
    """
    Sample initial states by taking random number of no-ops on reset.
    No-op is assumed to be action 0.

    :param env: the environment to wrap
    :param noop_max: the maximum value of no-ops to run
    """

    def __init__(self, env: gym.Env, noop_max: int = 30):
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == "NOOP"

    def reset(self, **kwargs) -> np.ndarray:
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
        assert noops > 0
        obs = np.zeros(0)
        for _ in range(noops):
            obs, _, done, _ = self.env.step(self.noop_action)
            if done:
                obs = self.env.reset(**kwargs)
        return obs


class EpisodicLifeEnv(gym.Wrapper):
    """
    Make end-of-life == end-of-episode, but only reset on true game over.
    Done by DeepMind for the DQN and co. since it helps value estimation.

    :param env: the environment to wrap
    """

    def __init__(self, env: gym.Env):
        gym.Wrapper.__init__(self, env)
        self.lives = 0
        self.was_real_done = True

    def step(self, action: int):
        obs, reward, done, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if 0 < lives < self.lives:
            # for Qbert sometimes we stay in lives == 0 condtion for a few frames
            # so its important to keep lives > 0, so that we only reset once
            # the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, info

    def reset(self, **kwargs):
        """
        Calls the Gym environment reset, only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.

        :param kwargs: Extra keywords passed to env.reset() call
        :return: the first observation of the environment
        """
        if self.was_real_done:
            obs = self.env.reset(**kwargs)
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _ = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs


class FrameStackVec(gym.Wrapper):
    def __init__(self, env, num_stack):
        super().__init__(env)
        self.num_stack = num_stack
        assert env.is_vector

        self.frames: List[List[np.ndarray]] = cast(
            List[List[np.ndarray]],
            [[None for _ in range(self.num_stack)] for _ in range(env.vec_nums)],
        )
        self.p: List[int] = [0 for _ in range(env.vec_nums)]
        self.empty = True

        low = np.repeat(
            self.env.observation_space.low[:, np.newaxis, ...], num_stack, axis=1
        )
        high = np.repeat(
            self.env.observation_space.high[:, np.newaxis, ...], num_stack, axis=1
        )
        self.observation_space = Box(
            low=low, high=high, dtype=self.observation_space.dtype
        )

    def transform_observation(self):
        assert len(self.frames) == self.env.vec_nums
        assert len(self.frames[0]) == self.num_stack
        return np.array(
            [
                LazyFrames(
                    self.frames[i][self.p[i] :] + self.frames[i][: self.p[i]], self.num_stack
                )
                for i in range(self.env.vec_nums)
            ],
            dtype=object,
        )

    def step(self, action):
        rlt = self.env.step(action)
        obs = rlt[0]
        list_info = rlt[-1]
        dones = rlt[2]
        if any(dones):
            for i, d in enumerate(dones):
                if d:
                    last_obs = list_info[i]["final_observation"]
                    self.frames[i][self.p[i]] = last_obs
                    _np = self.p[i] + 1 if self.p[i] + 1 < self.num_stack else 0
                    list_info[i]["final_observation"] = LazyFrames(
                        self.frames[i][_np :] + self.frames[i][: _np],
                        self.num_stack,
                    )

                    for j in range(self.num_stack):
                        self.frames[i][j] = obs[i]
                    self.p[i] = 0

        for i in range(self.env.vec_nums):
            self.frames[i][self.p[i]] = rlt[0][i]
        self.p = [ p+1 if p+1 < self.num_stack else 0 for p in self.p]

        return (self.transform_observation(), *rlt[1:])

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for i in range(self.env.vec_nums):
            for j in range(self.num_stack):
                self.frames[i][j] = obs[i]
        self.empty = False
        self.p = [0 for _ in range(self.env.vec_nums)]

        return self.transform_observation(), info


class PreprocessObservation(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

        self.observation_space = Box(
            low=0,
            high=1,
            shape=self.env.observation_space.shape[:-3] + (84, 84),
            dtype=np.float32,
        )

        self.transform = T.Compose(
            [
                # T.ToPILImage(),
                T.Resize((84, 84)),
                T.Grayscale(),
                # T.ToTensor(),
                T.Lambda(lambda x: x.squeeze(0)),
            ]
        )

    def reset(self, **kwargs):
        """Resets the environment, returning a modified observation using :meth:`self.observation`."""
        obs, info = self.env.reset(**kwargs)
        return self.transform_observation(obs), info

    def step(self, action):
        """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
        rlt = self.env.step(action)
        return (self.transform_observation(rlt[0]), *rlt[1:])

    def transform_observation(self, observation):
        observation = self.transform(
            (torch.as_tensor(observation, dtype=torch.float32)).permute(2, 1, 0)
        )
        assert observation.shape == self.observation_space.shape
        return observation.numpy()


def touch(path):
    open(path, "a").close()


class VideoRecorder:
    """VideoRecorder renders a nice mov>ie of a rollout, frame by frame. It
    comes with an `enabled` option so you can still use the same code
    on episodes where you don't want to record video.

    Note:
        You are responsible for calling `close` on a created
        VideoRecorder, or else you may leak an encoder process.

    Args:
        env (Env): Environment to take video of.
        path (Optional[str]): Path to the video file; will be randomly chosen if omitted.
        base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
        metadata (Optional[dict]): Contents to save to the metadata file.
        enabled (bool): Whether to actually record video, or just no-op (for convenience)
    """

    def __init__(self, env, path=None, metadata=None, enabled=True, base_path=None):
        modes = ["human", "rgb_array"]

        self._async = env.metadata.get("semantics.async")
        self.enabled = enabled
        self._closed = False

        # Don't bother setting anything else if not enabled
        if not self.enabled:
            return

        self.ansi_mode = False
        if "rgb_array" not in modes:
            if "ansi" in modes:
                self.ansi_mode = True
            else:
                logger.info(
                    f'Disabling video recorder because {env} neither supports video mode "rgb_array" nor "ansi".'
                )
                # Whoops, turns out we shouldn't be enabled after all
                self.enabled = False
                return

        if path is not None and base_path is not None:
            raise error.Error("You can pass at most one of `path` or `base_path`.")

        self.last_frame = None
        self.env = env

        required_ext = ".json" if self.ansi_mode else ".mp4"
        if path is None:
            if base_path is not None:
                # Base path given, append ext
                path = base_path + required_ext
            else:
                # Otherwise, just generate a unique filename
                with tempfile.NamedTemporaryFile(
                    suffix=required_ext, delete=False
                ) as f:
                    path = f.name
        self.path = path

        path_base, actual_ext = os.path.splitext(self.path)

        if actual_ext != required_ext:
            hint = (
                " HINT: The environment is text-only, therefore we're recording its text output in a structured JSON format."
                if self.ansi_mode
                else ""
            )
            raise error.Error(
                f"Invalid path given: {self.path} -- must have file extension {required_ext}.{hint}"
            )
        # Touch the file in any case, so we know it's present. (This
        # corrects for platform platform differences. Using ffmpeg on
        # OS X, the file is precreated, but not on Linux.
        touch(path)

        self.frames_per_sec = env.metadata.get("render_fps", 30)
        self.output_frames_per_sec = env.metadata.get("render_fps", self.frames_per_sec)

        # backward-compatibility mode:
        self.backward_compatible_frames_per_sec = env.metadata.get(
            "video.frames_per_second", 30
        )
        self.backward_compatible_output_frames_per_sec = env.metadata.get(
            "video.output_frames_per_second", self.frames_per_sec
        )
        if self.frames_per_sec != self.backward_compatible_frames_per_sec:
            logger.deprecation(
                '`env.metadata["video.frames_per_second"] is marked as deprecated and will be replaced with `env.metadata["render_fps"]` '
                "see https://github.com/openai/gym/pull/2654 for more details"
            )
            self.frames_per_sec = self.backward_compatible_frames_per_sec
        if self.output_frames_per_sec != self.backward_compatible_output_frames_per_sec:
            logger.deprecation(
                '`env.metadata["video.output_frames_per_second"] is marked as deprecated and will be replaced with `env.metadata["render_fps"]` '
                "see https://github.com/openai/gym/pull/2654 for more details"
            )
            self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec

        self.encoder = None  # lazily start the process
        self.broken = False

        # Dump metadata
        self.metadata = metadata or {}
        self.metadata["content_type"] = (
            "video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
        )
        self.metadata_path = f"{path_base}.meta.json"
        self.write_metadata()

        logger.info("Starting new video recorder writing to %s", self.path)
        self.empty = True

    @property
    def functional(self):
        return self.enabled and not self.broken

    def capture_frame(self):
        """Render the given `env` and add the resulting frame to the video."""
        if not self.functional:
            return
        if self._closed:
            logger.warn(
                "The video recorder has been closed and no frames will be captured anymore."
            )
            return
        logger.debug("Capturing video frame: path=%s", self.path)

        render_mode = "ansi" if self.ansi_mode else "rgb_array"
        frame = self.env.render(mode=render_mode)

        if frame is None:
            if self._async:
                return
            else:
                # Indicates a bug in the environment: don't want to raise
                # an error here.
                logger.warn(
                    "Env returned None on render(). Disabling further rendering for video recorder by marking as disabled: path=%s metadata_path=%s",
                    self.path,
                    self.metadata_path,
                )
                self.broken = True
        else:
            self.last_frame = frame
            if self.ansi_mode:
                self._encode_ansi_frame(frame)
            else:
                self._encode_image_frame(frame)

    def close(self):
        """Flush all data to disk and close any open frame encoders."""
        if not self.enabled or self._closed:
            return

        if self.encoder:
            logger.debug("Closing video encoder: path=%s", self.path)
            self.encoder.close()
            self.encoder = None
        else:
            # No frames captured. Set metadata, and remove the empty output file.
            os.remove(self.path)

            if self.metadata is None:
                self.metadata = {}
            self.metadata["empty"] = True

        # If broken, get rid of the output file, otherwise we'd leak it.
        if self.broken:
            logger.info(
                "Cleaning up paths for broken video recorder: path=%s metadata_path=%s",
                self.path,
                self.metadata_path,
            )

            # Might have crashed before even starting the output file, don't try to remove in that case.
            if os.path.exists(self.path):
                os.remove(self.path)

            if self.metadata is None:
                self.metadata = {}
            self.metadata["broken"] = True

        self.write_metadata()

        # Stop tracking this for autoclose
        self._closed = True

    def write_metadata(self):
        with open(self.metadata_path, "w") as f:
            json.dump(self.metadata, f)

    def __del__(self):
        # Make sure we've closed up shop when garbage collecting
        self.close()

    def _encode_ansi_frame(self, frame):
        if not self.encoder:
            self.encoder = TextEncoder(self.path, self.frames_per_sec)
            self.metadata["encoder_version"] = self.encoder.version_info
        self.encoder.capture_frame(frame)
        self.empty = False

    def _encode_image_frame(self, frame):
        if not self.encoder:
            self.encoder = ImageEncoder(
                self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
            )
            self.metadata["encoder_version"] = self.encoder.version_info

        try:
            self.encoder.capture_frame(frame)
        except error.InvalidFrame as e:
            logger.warn("Tried to pass invalid video frame, marking as broken: %s", e)
            self.broken = True
        else:
            self.empty = False


class TextEncoder:
    """Store a moving picture made out of ANSI frames. Format adapted from
    https://github.com/asciinema/asciinema/blob/master/doc/asciicast-v1.md"""

    def __init__(self, output_path, frames_per_sec):
        self.output_path = output_path
        self.frames_per_sec = frames_per_sec
        self.frames = []

    def capture_frame(self, frame):
        string = None
        if isinstance(frame, str):
            string = frame
        elif isinstance(frame, StringIO):
            string = frame.getvalue()
        else:
            raise error.InvalidFrame(
                f"Wrong type {type(frame)} for {frame}: text frame must be a string or StringIO"
            )

        frame_bytes = string.encode("utf-8")

        if frame_bytes[-1:] != b"\n":
            raise error.InvalidFrame(f'Frame must end with a newline: """{string}"""')

        if b"\r" in frame_bytes:
            raise error.InvalidFrame(
                f'Frame contains carriage returns (only newlines are allowed: """{string}"""'
            )

        self.frames.append(frame_bytes)

    def close(self):
        # frame_duration = float(1) / self.frames_per_sec
        frame_duration = 0.5

        # Turn frames into events: clear screen beforehand
        # https://rosettacode.org/wiki/Terminal_control/Clear_the_screen#Python
        # https://rosettacode.org/wiki/Terminal_control/Cursor_positioning#Python
        clear_code = b"%c[2J\033[1;1H" % (27)
        # Decode the bytes as UTF-8 since JSON may only contain UTF-8
        events = [
            (
                frame_duration,
                (clear_code + frame.replace(b"\n", b"\r\n")).decode("utf-8"),
            )
            for frame in self.frames
        ]

        # Calculate frame size from the largest frames.
        # Add some padding since we'll get cut off otherwise.
        height = max(frame.count(b"\n") for frame in self.frames) + 1
        width = (
            max(max(len(line) for line in frame.split(b"\n")) for frame in self.frames)
            + 2
        )

        data = {
            "version": 1,
            "width": width,
            "height": height,
            "duration": len(self.frames) * frame_duration,
            "command": "-",
            "title": "gym VideoRecorder episode",
            "env": {},  # could add some env metadata here
            "stdout": events,
        }

        with open(self.output_path, "w") as f:
            json.dump(data, f)

    @property
    def version_info(self):
        return {"backend": "TextEncoder", "version": 1}


class ImageEncoder:
    def __init__(self, output_path, frame_shape, frames_per_sec, output_frames_per_sec):
        self.proc = None
        self.output_path = output_path
        # Frame shape should be lines-first, so w and h are swapped
        h, w, pixfmt = frame_shape
        if pixfmt != 3 and pixfmt != 4:
            raise error.InvalidFrame(
                "Your frame has shape {}, but we require (w,h,3) or (w,h,4), i.e., RGB values for a w-by-h image, with an optional alpha channel.".format(
                    frame_shape
                )
            )
        self.wh = (w, h)
        self.includes_alpha = pixfmt == 4
        self.frame_shape = frame_shape
        self.frames_per_sec = frames_per_sec
        self.output_frames_per_sec = output_frames_per_sec

        if distutils.spawn.find_executable("avconv") is not None:
            self.backend = "avconv"
        elif distutils.spawn.find_executable("ffmpeg") is not None:
            self.backend = "ffmpeg"
        elif pkgutil.find_loader("imageio_ffmpeg"):
            import imageio_ffmpeg

            self.backend = imageio_ffmpeg.get_ffmpeg_exe()
        else:
            raise error.DependencyNotInstalled(
                """Found neither the ffmpeg nor avconv executables. On OS X, you can install ffmpeg via `brew install ffmpeg`. On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`. Alternatively, please install imageio-ffmpeg with `pip install imageio-ffmpeg`"""
            )

        self.start()

    @property
    def version_info(self):
        return {
            "backend": self.backend,
            "version": str(
                subprocess.check_output(
                    [self.backend, "-version"], stderr=subprocess.STDOUT
                )
            ),
            "cmdline": self.cmdline,
        }

    def start(self):
        self.cmdline = (
            self.backend,
            "-nostats",
            "-loglevel",
            "error",  # suppress warnings
            "-y",
            # input
            "-f",
            "rawvideo",
            "-s:v",
            "{}x{}".format(*self.wh),
            "-pix_fmt",
            ("rgb32" if self.includes_alpha else "rgb24"),
            "-framerate",
            "%d" % self.frames_per_sec,
            "-i",
            "-",  # this used to be /dev/stdin, which is not Windows-friendly
            # output
            "-vf",
            "scale=trunc(iw/2)*2:trunc(ih/2)*2",
            "-vcodec",
            "libx264",
            "-pix_fmt",
            "yuv420p",
            "-r",
            "%d" % self.output_frames_per_sec,
            self.output_path,
        )

        logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
        if hasattr(os, "setsid"):  # setsid not present on Windows
            self.proc = subprocess.Popen(
                self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
            )
        else:
            self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)

    def capture_frame(self, frame):
        if not isinstance(frame, (np.ndarray, np.generic)):
            raise error.InvalidFrame(
                f"Wrong type {type(frame)} for {frame} (must be np.ndarray or np.generic)"
            )
        if frame.shape != self.frame_shape:
            raise error.InvalidFrame(
                f"Your frame has shape {frame.shape}, but the VideoRecorder is configured for shape {self.frame_shape}."
            )
        if frame.dtype != np.uint8:
            raise error.InvalidFrame(
                f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
            )

        try:
            if distutils.version.LooseVersion(
                np.__version__
            ) >= distutils.version.LooseVersion("1.9.0"):
                self.proc.stdin.write(frame.tobytes())
            else:
                self.proc.stdin.write(frame.tostring())
        except Exception as e:
            stdout, stderr = self.proc.communicate()
            logger.error("VideoRecorder encoder failed: %s", stderr)

    def close(self):
        self.proc.stdin.close()
        ret = self.proc.wait()
        if ret != 0:
            logger.error(f"VideoRecorder encoder exited with status {ret}")


def capped_cubic_video_schedule(episode_id):
    if episode_id < 1000:
        return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
    else:
        return episode_id % 1000 == 0


class RecordVideo(gym.Wrapper):
    def __init__(
        self,
        env,
        video_folder: str,
        episode_trigger: Callable[[int], bool] = None,
        step_trigger: Callable[[int], bool] = None,
        video_length: int = 0,
        name_prefix: str = "rl-video",
    ):
        super().__init__(env)

        if episode_trigger is None and step_trigger is None:
            episode_trigger = capped_cubic_video_schedule

        trigger_count = sum(x is not None for x in [episode_trigger, step_trigger])
        assert trigger_count == 1, "Must specify exactly one trigger"

        self.episode_trigger = episode_trigger
        self.step_trigger = step_trigger
        self.video_recorder = None

        self.video_folder = os.path.abspath(video_folder)
        # Create output folder if needed
        if os.path.isdir(self.video_folder):
            logger.warn(
                f"Overwriting existing videos at {self.video_folder} folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
            )
        os.makedirs(self.video_folder, exist_ok=True)

        self.name_prefix = name_prefix
        self.step_id = 0
        self.video_length = video_length

        self.recording = False
        self.recorded_frames = 0
        self.is_vector_env = getattr(env, "is_vector_env", False)
        self.episode_id = 0

    def reset(self, **kwargs):
        observations = super().reset(**kwargs)
        if not self.recording and self._video_enabled():
            self.start_video_recorder()
        return observations

    def start_video_recorder(self):
        self.close_video_recorder()

        video_name = f"{self.name_prefix}-step-{self.step_id}"
        if self.episode_trigger:
            video_name = f"{self.name_prefix}-episode-{self.episode_id}"

        base_path = os.path.join(self.video_folder, video_name)
        self.video_recorder = VideoRecorder(
            env=self.env,
            base_path=base_path,
            metadata={"step_id": self.step_id, "episode_id": self.episode_id},
        )

        self.video_recorder.capture_frame()
        self.recorded_frames = 1
        self.recording = True

    def _video_enabled(self):
        if self.step_trigger:
            return self.step_trigger(self.step_id)
        else:
            return self.episode_trigger(self.episode_id)

    def step(self, action):
        observations, rewards, dones, infos = super().step(action)

        # increment steps and episodes
        self.step_id += 1
        if not self.is_vector_env:
            if dones:
                self.episode_id += 1
        elif dones[0]:
            self.episode_id += 1

        if self.recording:
            self.video_recorder.capture_frame()
            self.recorded_frames += 1
            if self.video_length > 0:
                if self.recorded_frames > self.video_length:
                    self.close_video_recorder()
            else:
                if not self.is_vector_env:
                    if dones:
                        self.close_video_recorder()
                elif dones[0]:
                    self.close_video_recorder()

        elif self._video_enabled():
            self.start_video_recorder()

        return observations, rewards, dones, infos

    def close_video_recorder(self) -> None:
        if self.recording:
            self.video_recorder.close()
        self.recording = False
        self.recorded_frames = 1

    def close(self):
        self.close_video_recorder()

    def __del__(self):
        self.close_video_recorder()


class RescaleAction(gym.ActionWrapper):
    r"""Rescales the continuous action space of the environment to a range [min_action, max_action].

    Example::

        >>> RescaleAction(env, min_action, max_action).action_space == Box(min_action, max_action)
        True

    """

    def __init__(self, env, scale):
        assert scale > 0
        min_action = -scale
        max_action = scale
        assert isinstance(
            env.action_space, spaces.Box
        ), f"expected Box action space, got {type(env.action_space)}"
        assert np.less_equal(min_action, max_action).all(), (min_action, max_action)

        super().__init__(env)
        self.scale = scale
        self.min_action = (
            np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
        )
        self.max_action = (
            np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
        )
        self.action_space = spaces.Box(
            low=min_action,
            high=max_action,
            shape=env.action_space.shape,
            dtype=env.action_space.dtype,
        )

    def action(self, action):
        assert np.all(np.greater_equal(action, self.min_action)), (
            action,
            self.min_action,
        )
        assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action)
        action = self.scale * action
        action = np.clip(action, -self.scale, self.scale)
        return action


def flat_to_episode(
    states: List[S],
    actions: List[Action],
    rewards: List[Reward],
    dones: List[bool],
    infos: List[Info],
    has_next_state: bool = False,
) -> List[SingleEpisode]:
    assert len(states) == len(actions) == len(rewards) == len(dones) == len(infos)

    l = len(states)

    assert dones[-1]

    s = []
    a = []
    r = []
    info = []

    episodes = []

    for i in range(l):
        s.append(torch.from_numpy(states[i]).type(torch.float32))
        a.append(torch.from_numpy(actions[i]).type(torch.float32))
        r.append(rewards[i])

        _i = dict(infos[i])
        assert "end" not in _i
        _i["end"] = False
        info.append(_i)

        if has_next_state and dones[i]:
            s.append(
                torch.from_numpy(infos[i]["next_state"]).type(torch.float32).to(DEVICE)
            )
            info.append(dict(end=True))

            episodes.append(SingleEpisode.from_list((s, a, r, info)))
            s = []
            a = []
            r = []
            info = []
            continue

        if not has_next_state and dones[i]:
            assert not info[-1]["end"]
            a.pop(-1)
            r.pop(-1)
            info[-1]["end"] = True

            episodes.append(SingleEpisode.from_list((s, a, r, info)))
            s = []
            a = []
            r = []
            info = []

    return episodes


def flat_to_transitions(
    states: List[S],
    actions: List[Action],
    rewards: List[Reward],
    dones: List[bool],
    infos: Optional[List[Info]] = None,
    has_next_state: bool = False,
) -> List[Transition]:
    assert len(states) == len(actions) == len(rewards) == len(dones)
    if infos is not None:
        assert len(infos) == len(states)

    if has_next_state:
        assert infos is not None

    l = len(states)
    transitions: List[Transition] = []

    _states = torch.as_tensor(states, dtype=torch.float32, device=DEVICE)
    _actions = torch.as_tensor(actions, dtype=torch.float32, device=DEVICE)
    _dones = torch.as_tensor(dones, dtype=torch.float32, device=DEVICE)
    _rewards = torch.as_tensor(rewards, dtype=torch.float32, device=DEVICE)

    for i in range(l):
        s = _states[i]
        a = _actions[i]
        r = _rewards[i]
        d = _dones[i]

        if has_next_state:
            # _info = infos[i]
            sn: torch.Tensor = torch.as_tensor(infos[i]["next_state"], dtype=torch.float32, device=DEVICE)

            transitions.append(
                Transition(
                    (
                        NotNoneStep(s, a, r),
                        Step(sn, None, None, dict(end=d.item() == 1)),
                    )
                )
            )
            continue

        if i == l - 1:
            continue

        if not d:
            transitions.append(
                Transition(
                    (
                        NotNoneStep(s, a, r),
                        Step(_states[i + 1], None, None, dict(end=False)),
                    )
                )
            )

    return transitions


class HistoryRecorder(gym.Wrapper):
    def __init__(self, env: gym.Env, max_timesteps: int, vec_nums: int = 1):
        super().__init__(env)
        self.vec_nums = vec_nums
        self.max_timesteps = max_timesteps
        self._reset_self()

        low = np.repeat(
            self.env.observation_space.low[:, np.newaxis, ...], max_timesteps, axis=1
        )
        high = np.repeat(
            self.env.observation_space.high[:, np.newaxis, ...], max_timesteps, axis=1
        )
        self.observation_space = Box(
            low=low, high=high, dtype=self.observation_space.dtype
        )

    def _reset_self(self):
        self.episode_info: Dict[str, Any] = {}
        self.observation_episode = [
            [None for _ in range(self.max_timesteps)] for _ in range(self.vec_nums)
        ]

        self.action_episode = [
            [None for _ in range(self.max_timesteps)] for _ in range(self.vec_nums)
        ]
        self.reward_episode = [
            [None for _ in range(self.max_timesteps)] for _ in range(self.vec_nums)
        ]
        self.step_info = [
            [None for _ in range(self.max_timesteps)] for _ in range(self.vec_nums)
        ]

        self._last_episode: List[
            Optional[
                Tuple[
                    List,  # observation
                    List,  # action
                    List,  # reward
                    Dict[str, Any],  # info
                ]
            ]
        ] = [None for _ in range(self.vec_nums)]

        self.timestep_tracker = [0 for _ in range(self.vec_nums)]
        self.last_end = [None for _ in range(self.vec_nums)]
        self.env_statistics: List[Dict] = [
            {"finished_episode": 0} for _ in range(self.vec_nums)
        ]

    def last_episode_reward(self, idx: int):
        return [r for r in self._last_episode[idx][2][: self.last_end[idx]]]

    def last_observation(self, idx: int):
        return self._last_episode[idx][0][self.last_end[idx]]

    @property
    def latest_observation(self):
        # return self.observation_episode[range(self.vec_nums), self.timestep_tracker]
        return np.array(
            [
                self.observation_episode[i][self.timestep_tracker[i]]
                for i in range(self.vec_nums)
            ],
            dtype=object,
        )

    def reset(self, *args, **kwargs) -> np.ndarray:
        self._reset_self()

        obs = self.env.reset(*args, **kwargs)

        if isinstance(obs, tuple):
            # self.observation_episode[:, 0] = obs[0]
            for i in range(self.vec_nums):
                self.observation_episode[i][0] = obs[0][i]

            self.episode_info = {
                "args": args,
                "kwargs": kwargs,
                "reset_returned_info": obs[1],
            }

            return np.array(obs[0], dtype=object)

        # self.observation_episode[:, 0] = obs
        for i in range(self.vec_nums):
            self.observation_episode[i][0] = obs[i]

        return np.array(obs, dtype=object)

    def step(self, action: Tuple[np.ndarray, List[Dict[str, Any]]]):
        assert isinstance(action, tuple)
        assert isinstance(action[1], list)

        act = action[0]

        observation, reward, done, env_info = self.env.step(act)

        for i in range(self.vec_nums):
            t = self.timestep_tracker[i]
            self.action_episode[i][t] = act[i]
            self.reward_episode[i][t] = reward[i]
            self.step_info[i][t] = {
                "step_action_info": action[1][i],
                "step_returned_info": env_info[i],
            }

        self.timestep_tracker = [t + 1 for t in self.timestep_tracker]

        self.reset_episodes(done)

        for i in range(self.vec_nums):
            t = self.timestep_tracker[i]
            self.observation_episode[i][t] = observation[i]

        return observation, reward, done, env_info


    def reset_episodes(self, dones: List[bool]):
        assert len(dones) == self.vec_nums
        for i, d in enumerate(dones):
            if d:
                self.observation_episode[i][self.timestep_tracker[i]] = self.step_info[
                    i
                ][self.timestep_tracker[i] - 1]["step_returned_info"][
                    "final_observation"
                ]

                self._last_episode[i] = (
                    deepcopy(self.observation_episode[i]),
                    deepcopy(self.action_episode[i]),
                    deepcopy(self.reward_episode[i]),
                    deepcopy(self.step_info[i]),
                )
                self.last_end[i] = self.timestep_tracker[i]

                self.timestep_tracker[i] = 0

                self.env_statistics[i]["finished_episode"] = (
                    self.env_statistics[i]["finished_episode"] + 1
                )


class WithReporter(gym.Wrapper, ReportTrait):
    def __init__(self, env: gym.Env, reporter: Optional[Reporter] = None):
        gym.Wrapper.__init__(self, env)
        ReportTrait.__init__(self, with_reporter=reporter)


class NormalizeScore(gym.Wrapper):
    def __init__(self, env: gym.Env, normalize_fn: Callable[[float], float]):
        super().__init__(env)
        self._normalize_fn = normalize_fn

    def normalize_score(self, raw_score: float) -> float:
        return self._normalize_fn(raw_score)


class VecEnvLabeler(gym.Wrapper):
    def __init__(self, env: gym.Env, nums: int):
        super().__init__(env)
        self.vec_nums = nums
        self.is_vector = True


class ObsToTensor(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def reset(self, *args, **kwargs):
        obs = self.env.reset(*args, **kwargs)
        if isinstance(obs, tuple):
            assert isinstance(obs[1], dict)
            return torch.from_numpy(obs[0]).float(), obs[1]

        return torch.from_numpy(obs)

    def step(self, *args, **kwargs):
        result = self.env.step(*args, **kwargs)
        return (torch.from_numpy(result[0]).float(),) + result[1:]


class ActUnwrapper(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def step(self, act):
        return self.env.step((act[0] if act.shape == (1,) else act))


class ActConverter(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def step(self, action: Tuple[torch.Tensor, Dict]):
        assert isinstance(action, tuple)
        assert torch.is_tensor(action[0])
        assert isinstance(action[1], list)

        act = action[0].detach().cpu().numpy()

        obs, rwd, done, info = self.env.step((act, action[1]))
        return obs, rwd, done, info


class TerminalOrTruncate(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def step(self, *args, **kwargs):
        obs, rwd, termination, truncation, info = self.env.step(*args, **kwargs)
        assert isinstance(termination, bool) or isinstance(termination, np.ndarray)
        assert isinstance(truncation, bool) or isinstance(truncation, np.ndarray)
        if isinstance(termination, bool) and isinstance(truncation, bool):
            return obs, rwd, termination or truncation, info

        return obs, rwd, np.logical_or(termination, truncation), info


class AddTruncate(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def step(self, *args, **kwargs):
        rlt = self.env.step(*args, **kwargs)
        if len(rlt) == 4:
            obs, rwd, done, info = rlt

            return obs, rwd, done, done, info
        return rlt


class RenderCollectionVec(gym.Wrapper):
    """Save collection of render frames."""

    def __init__(
        self,
        env: gym.Env,
        periods: int,
        pop_frames: bool = True,
        reset_clean: bool = True,
    ):
        """Initialize a :class:`RenderCollection` instance.

        Args:
            env: The environment that is being wrapped
            pop_frames (bool): If true, clear the collection frames after .render() is called.
            Default value is True.
            reset_clean (bool): If true, clear the collection frames when .reset() is called.
            Default value is True.
        """
        super().__init__(env)
        assert env.is_vector
        assert env.render_mode is not None
        assert not env.render_mode.endswith("_list")
        self.frame_list = [[] for _ in range(env.vec_nums)]
        self.reset_clean = reset_clean
        self._period = periods
        self.pop_frames = pop_frames

    @property
    def render_mode(self):
        """Returns the collection render_mode name."""
        return f"{self.env.render_mode}_list"

    def step(self, *args, **kwargs):
        """Perform a step in the base environment and collect a frame."""
        output = self.env.step(*args, **kwargs)
        self.frame_list.append(self.env.render())
        return output

    def reset(self, *args, **kwargs):
        """Reset the base environment, eventually clear the frame_list, and collect a frame."""
        result = self.env.reset(*args, **kwargs)

        if self.reset_clean:
            self.frame_list = []
        self.frame_list.append(self.env.render())

        return result

    def render(self):
        """Returns the collection of frames and, if pop_frames = True, clears it."""
        frames = self.frame_list
        if self.pop_frames:
            self.frame_list = []

        return frames


class StateStackVec(gym.Wrapper):
    def __init__(self, env, num_stack):
        super().__init__(env)
        self.num_stack = num_stack
        assert env.is_vector

        self.stack_memory = np.zeros((env.vec_nums, self.num_stack))
        self.start = 0
        self.empty = True
        self.end = 0

        low = np.repeat(
            np.expand_dims(self.observation_space.low, axis=1), self.num_stack, axis=1
        )
        high = np.repeat(
            np.expand_dims(self.observation_space.high, axis=1), self.num_stack, axis=1
        )
        self.observation_space = Box(
            low=low, high=high, dtype=self.observation_space.dtype
        )

    def get_observation(self):
        return np.concatenate(
            (self.stack_memory[self.start :], self.stack_memory[: self.start]), axis=0
        )

    def step(self, action):
        rlt = self.env.step(action)
        self.stack_memory[:, self.end] = rlt[0]
        self.end += 1
        self.start += 1
        if self.end >= self.num_stack:
            self.end = 0
        if self.start >= self.num_stack:
            self.start = 0
        return (self.get_observation(), *rlt[1:])

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.stack_memory = np.repeat(obs[:, np.newaxis, ...], self.num_stack, axis=1)
        self.empty = False

        return self.get_observation(), info

class NoopResetEnv(gym.Wrapper):
    """
    Sample initial states by taking random number of no-ops on reset.
    No-op is assumed to be action 0.
    :param env: Environment to wrap
    :param noop_max: Maximum value of no-ops to run
    """

    def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
        super().__init__(env)
        self.noop_max = noop_max
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == "NOOP"

    def reset(self, **kwargs) -> np.ndarray:
        obs, info = self.env.reset(**kwargs)
        noops = np.random.randint(1, self.noop_max + 1)
        assert noops > 0
        # obs = np.zeros(0)

        for ni in range(noops):
            obs, _, done, _ = self.env.step(self.noop_action)
            info['noops_num'] = ni
            if done:
                obs, info = self.env.reset(**kwargs)
        return obs, info

