import os
import cv2
import atexit
import numpy as np

video_writers = {}

def init_video_writers(obs, base_path="INPUT_YOUR_PATH", fps=15):
    global video_writers
    if video_writers:
        return
    os.makedirs(base_path, exist_ok=True)
    cam_attrs = [attr for attr in dir(obs) if attr.endswith("_rgb")]
    for cam in cam_attrs:
        img = getattr(obs, cam)
        if img.dtype != np.uint8:
            img = (img * 255).astype(np.uint8)
        img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        height, width, _ = img_bgr.shape
        fourcc = cv2.VideoWriter_fourcc(*'MJPG')
        video_file = os.path.join(base_path, f"{cam}.avi")
        vw = cv2.VideoWriter(video_file, fourcc, fps, (width, height))
        video_writers[cam] = vw

def write_cam_frames(obs):
    init_video_writers(obs)
    cam_attrs = [attr for attr in dir(obs) if attr.endswith("_rgb")]
    for cam in cam_attrs:
        img = getattr(obs, cam)
        if img.dtype != np.uint8:
            img = (img * 255).astype(np.uint8)
        img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        video_writers[cam].write(img_bgr)

def release_video_writers():
    global video_writers
    for writer in video_writers.values():
        writer.release()
    video_writers = {}

atexit.register(release_video_writers)

def recording_step(original_step):
    def new_step(action):
        obs, reward, done = original_step(action)
        write_cam_frames(obs)
        return obs, reward, done
    return new_step

def recording_get_observation(original_get_obs):
    def new_get_obs():
        obs = original_get_obs()
        write_cam_frames(obs)
        return obs
    return new_get_obs