import os
from typing import Callable

import gym
from gym.wrappers.monitoring import video_recorder


class EpisodeRecorder(gym.Wrapper):
    def __init__(
        self,
        env: gym.Env,
        video_folder: str,
        record_video_trigger: Callable[[int], bool],
        video_length: int = 200,
        name_prefix: str = "rl-video",
    ):

        super().__init__(env)

        self.env = env

        self.record_video_trigger = record_video_trigger
        self.video_recorder = None

        self.video_folder = os.path.abspath(video_folder)
        # Create output folder if needed
        os.makedirs(self.video_folder, exist_ok=True)

        self.name_prefix = name_prefix
        self.episode_id = 0
        self.video_length = video_length

        self.recording = False
        self.recorded_frames = 0

    def reset(self):
        obs = self.env.reset()
        if self._video_enabled():
            self.start_video_recorder()
        return obs

    def start_video_recorder(self) -> None:
        self.close_video_recorder()

        video_name = f"{self.name_prefix}-ep-{self.episode_id}"
        base_path = os.path.join(self.video_folder, video_name)
        self.video_recorder = video_recorder.VideoRecorder(
            env=self.env, base_path=base_path, metadata={"episode_id": self.episode_id}
        )

        self.video_recorder.capture_frame()
        self.recorded_frames = 1
        self.recording = True

    def _video_enabled(self) -> bool:
        return self.record_video_trigger(self.episode_id)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)

        if self.recording:
            self.video_recorder.capture_frame()
            self.recorded_frames += 1
            if done or self.recorded_frames > self.video_length:
                print(f"Saving video to {self.video_recorder.path}")
                self.close_video_recorder()

        if done:
            self.episode_id += 1

        return obs, rew, done, info

    def close_video_recorder(self) -> None:
        if self.recording:
            self.video_recorder.close()
        self.recording = False
        self.recorded_frames = 1

    def close(self) -> None:
        super().close()
        self.close_video_recorder()

    def __del__(self):
        self.close()
