import math
import os
import time
from pathlib import Path
from typing import Callable

import gym
import imageio


class EpisodeRecorder(gym.Wrapper):
    def __init__(
        self,
        env: gym.Env,
        video_folder: str,
        record_video_trigger: Callable[[int], bool],
        video_length: int = 200,
        fps: int = 4,
        name_prefix: str = "rl-video",
        remove_older_seconds: int = 12 * 60 * 60,
        remove_job_interval_seconds: float = math.inf
    ):

        super().__init__(env)

        self.env = env

        self.record_video_trigger = record_video_trigger
        self.video_recorder = None

        self.video_folder = os.path.abspath(video_folder)
        # Create output folder if needed
        os.makedirs(self.video_folder, exist_ok=True)

        self.name_prefix = name_prefix
        self.episode_id = 0
        self.video_length = video_length
        self.fps = fps

        self.recording = False
        self.recorded_frames = 0

        self.current_video_frames = None
        self.current_video_path = None
        self.render_mode = 'rgb_array'

        # housekeeping
        self.remove_older_seconds = remove_older_seconds
        self.remove_job_interval_seconds = remove_job_interval_seconds
        self.remove_job_last_time_seconds = time.time()

    def _remove_old_videos(self):
        current_time_seconds = time.time()
        if current_time_seconds - self.remove_job_last_time_seconds < self.remove_job_interval_seconds:
            return

        self.remove_job_last_time_seconds = current_time_seconds
        video_dir = Path(self.video_folder)
        for path in video_dir.glob('*'):
            if current_time_seconds - os.path.getmtime(path) > self.remove_older_seconds:
                path.unlink(missing_ok=True)

    def reset(self):
        self._remove_old_videos()
        obs = self.env.reset()
        if self._video_enabled():
            self.start_video_recorder()
        return obs

    def start_video_recorder(self) -> None:
        self.current_video_frames = []

        video_name = f"{self.name_prefix}-ep-{self.episode_id}.mp4"
        self.current_video_path = os.path.join(self.video_folder, video_name)
        self.current_video_frames = [self.env.render(mode=self.render_mode)]
        self.recorded_frames = 1
        self.recording = True

    def _video_enabled(self) -> bool:
        return self.record_video_trigger(self.episode_id)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)

        if self.recording:
            self.current_video_frames.append(self.env.render(mode=self.render_mode))
            self.recorded_frames += 1
            if done or self.recorded_frames > self.video_length:
                print(f"Saving video to {self.current_video_path}")
                self.finish_recording()

        if done:
            self.episode_id += 1

        return obs, rew, done, info

    def finish_recording(self) -> None:
        if self.recording:
            imageio.mimsave(self.current_video_path, self.current_video_frames, fps=self.fps, macro_block_size=2)
        self.recording = False
        self.current_video_path = None
        self.current_video_frames = None
        self.recorded_frames = 1

    def close(self) -> None:
        super().close()
        self.finish_recording()

    def __del__(self):
        self.close()
