import os
import jax
import numpy as np
from PIL import Image

from src.envs.ogc.ogc import make_level_generator, OGC
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer

if __name__ == "__main__":
    sample_level = make_level_generator(5, 5, 0)
    rng = jax.random.PRNGKey(0)
    vis = OvercookedVisualizer()

    env = OGC(5, 5)

    output_dir = "saved_frames"
    os.makedirs(output_dir, exist_ok=True)

    for i in range(30):
        rng, _rng = jax.random.split(rng)
        level = sample_level(_rng)

        padding = 0
        obs, state = env.reset_env_to_level(_rng, level, env.default_params)
        grid = np.asarray(state.maze_map)

        frame = OvercookedVisualizer._render_grid(
            grid,
            tile_size=32,
            highlight_mask=None,
            agent_dir_idx=state.agent_dir_idx,
            agent_inv=state.agent_inv
        )

        # Convert to uint8 image and save
        img = Image.fromarray(frame.astype(np.uint8))
        img.save(os.path.join(output_dir, f"frame_{i:03d}.png"))
