import os
import os.path as osp
import time
from collections import defaultdict

import gym
import numpy as np
import torch
from rlf.exp_mgr.viz_utils import save_mp4
from rlf.il.traj_mgr import TrajSaver
from rlf.policies.base_policy import get_empty_step_info
from rlf.rl import utils
from rlf.rl.envs import get_vec_normalize, make_vec_envs
from tqdm import tqdm

from iq_learn.utils.utils import gen_frame, logging, save_video


def eval_print(
    env_interface,
    args,
    alg_env_settings,
    policy,
    vec_norm,
    total_num_steps,
    mode,
    eval_envs,
    log,
):
    print("Evaluating " + mode)
    args.evaluation_mode = True
    eval_info, eval_envs = evaluate(
        args,
        alg_env_settings,
        policy,
        vec_norm,
        env_interface,
        total_num_steps,
        mode,
        eval_envs,
        log,
        None,
    )

    log.log_vals(
        {"eval_%s_%s" % (mode, k): np.mean(v) for k, v in eval_info.items()},
        total_num_steps,
    )
    args.evaluation_mode = False
    return eval_envs


def train_eval(
    envs,
    alg_env_settings,
    policy,
    args,
    log,
    total_num_steps,
    env_interface,
    train_eval_envs,
):

    vec_norm = get_vec_normalize(envs)

    train_eval_envs = eval_print(
        env_interface,
        args,
        alg_env_settings,
        policy,
        vec_norm,
        total_num_steps,
        "train",
        train_eval_envs,
        log,
    )

    return train_eval_envs


def full_eval(
    envs,
    policy,
    log,
    checkpointer,
    env_interface,
    args,
    alg_env_settings,
    create_traj_saver_fn,
    vec_norm,
    updater=None,
):
    args.evaluation_mode = True
    ret_info, envs = evaluate(
        args,
        alg_env_settings,
        policy,
        vec_norm,
        env_interface,
        0,
        "final",
        envs,
        log,
        create_traj_saver_fn,
        updater,
    )
    args.evaluation_mode = False
    envs.close()

    return ret_info


def evaluate(
    args,
    alg_env_settings,
    policy,
    true_vec_norm,
    env_interface,
    num_steps,
    mode,
    eval_envs,
    log,
    create_traj_saver_fn,
    updater=None,
):
    if args.eval_num_processes is None:
        num_processes = args.num_processes
    else:
        num_processes = args.eval_num_processes

    if eval_envs is None:
        args.force_multi_proc = False
        eval_envs = make_vec_envs(
            args.env_name,
            args.seed + num_steps,
            num_processes,
            args.gamma,
            args.device,
            True,
            env_interface,
            args,
            alg_env_settings,
            set_eval=True,
        )

    assert get_vec_normalize(eval_envs) is None, "Norm is manually applied"

    if true_vec_norm is not None:
        obfilt = true_vec_norm._obfilt
    else:

        def obfilt(x, update):
            return x

    eval_episode_rewards = []
    eval_def_stats = defaultdict(list)
    ep_stats = defaultdict(list)

    obs = eval_envs.reset()

    hidden_states = {}
    for k, dim in policy.get_storage_hidden_states().items():
        hidden_states[k] = torch.zeros(num_processes, dim).to(args.device)
    eval_masks = torch.zeros(num_processes, 1, device=args.device)

    frames = []
    infos = None

    policy.eval()
    if args.eval_save and create_traj_saver_fn is not None:
        traj_saver = create_traj_saver_fn(
            osp.join(args.traj_dir, args.env_name, args.prefix)
        )
    else:
        assert not args.eval_save, (
            "Cannot save evaluation without ",
            "specifying the eval saver creator function",
        )

    total_num_eval = num_processes * args.num_eval

    # Measure the number of episodes completed
    pbar = tqdm(total=total_num_eval)
    evaluated_episode_count = 0
    n_succs = 0
    n_fails = 0
    succ_frames = []
    fail_frames = []
    if args.render_succ_fails and args.eval_num_processes > 1:
        raise ValueError(
            """
                Can only render successes and failures when the number of
                processes is 1.
                """
        )

    if args.num_render is None or args.num_render > 0:
        frames.extend(
            get_render_frames(
                eval_envs,
                env_interface,
                None,
                None,
                None,
                None,
                None,
                args,
                evaluated_episode_count,
            )
        )

    if updater is not None and args.alg == 'dpf':
        total_reward = 0
        total_reward_list = []
        ep_success_list = []
        returns = None
        dis_prox_func = updater.modules[0]
        dis_prox_func.prox_func.eval()
    
    if args.add_reaching_goal_speed:
        # count the number of steps taken to reach the goal
        step_count = 0
        step_count_list = []

    if args.widowx_save_torque:
        torque_buffer = []
        torque_tmp_buffer = []
        cube_pos_buffer = []
        cube_pos_tmp = None
        ep_found_goal_buffer = []
        ep_found_goal = None
    
    if args.add_ooc_action_estimation and not args.env_name.startswith("WidowXLiftCube"):
        total_action_taken = 0
        in_distribution_action_taken = 0
        # define function to check if the action is in distribution
        def func_in_dist_action_count(action):
            if args.env_name.startswith("MiniGrid"):
                # For MiniGrid, we assume the action is in distribution if it is one of the 4 actions
                return np.sum(action < 4)
            elif args.env_name.startswith("MBRLmaze2d"):
                action_constraint = gym.spaces.Box(low=-0.1, high=0.1, shape=eval_envs.action_space.shape)
                cnt = np.all((action >= action_constraint.low) & (action <= action_constraint.high), axis=1)
                return cnt.sum()
            elif args.env_name.startswith("FetchPickAndPlaceDiffHoldoutTS150"):
                action_constraint = gym.spaces.Box(low=-0.1, high=0.1, shape=eval_envs.action_space.shape)
                cnt = np.all((action >= action_constraint.low) & (action <= action_constraint.high), axis=1)
                return cnt.sum()
            elif args.env_name.startswith("FetchPushEnvCustomTS500"):
                action_constraint = gym.spaces.Box(low=-0.05, high=0.05, shape=eval_envs.action_space.shape)
                cnt = np.all((action >= action_constraint.low) & (action <= action_constraint.high), axis=1)
                return cnt.sum()
            else:
                raise NotImplementedError(
                    "Action distribution check not implemented for this environment"
                )


    while evaluated_episode_count < total_num_eval:
        step_info = get_empty_step_info()
        with torch.no_grad():
            act_obs = obfilt(utils.ob_to_np(obs), update=False)
            act_obs = utils.ob_to_tensor(act_obs, args.device)

            ac_info = policy.get_action(
                utils.get_def_obs(act_obs),
                utils.get_other_obs(obs),
                hidden_states,
                eval_masks,
                step_info,
            )

            hidden_states = ac_info.hxs
        

        if args.add_ooc_action_estimation and not args.env_name.startswith("WidowXLiftCube"):
            action = ac_info.take_action.cpu().numpy()
            total_action_taken += action.shape[0]
            in_distribution_action_taken += func_in_dist_action_count(action)

        # Observe reward and next obs
        if args.widowx_save_torque:
            # get the torque action from widowx env and save it
            torque_action = eval_envs.unwrapped.envs[0].action_transform(ac_info.take_action[0])
            torque_tmp_buffer.append(torque_action)
            # also save the cube position
            if cube_pos_tmp is None:
                cube_pos_tmp = obs[0, 0:3]
                cube_pos_buffer.append(cube_pos_tmp.cpu())

        next_obs, _, done, infos = eval_envs.step(ac_info.take_action)

        if args.widowx_save_torque and done.any():
            torque_buffer.append(torque_tmp_buffer)
            torque_tmp_buffer = []
            cube_pos_tmp = None

            ep_found_goal_buffer.append(infos[0]['ep_found_goal'])
        if args.add_reaching_goal_speed:
            step_count += 1
        if updater is not None and args.alg == 'dpf':
            import rlf.rl.utils as rutils
            from copy import deepcopy
            cur_state = obfilt(utils.ob_to_np(obs), update=False)
            cur_state = utils.ob_to_tensor(cur_state, args.device)
            cur_state = rutils.get_def_obs(cur_state)
            cur_state = cur_state.clone()
            if args.action_input:
                raise NotImplementedError("Action input not supported")
            else:
                cur_action = None
            next_state = obfilt(utils.ob_to_np(next_obs), update=False)
            next_state = utils.ob_to_tensor(next_state, args.device)
            next_state = rutils.get_def_obs(next_state)
            next_state = next_state.clone()
            if args.action_input:
                raise NotImplementedError("Action input not supported")
            else:
                next_action = None

            cur_prox = dis_prox_func._get_prox(cur_state, cur_action, args.pf_clip)
            next_prox = dis_prox_func._get_prox(next_state, next_action, args.pf_clip)
            # prox_action = ac_info.take_action
            # prox_action = None
            # prox_value = updater.modules[0]._get_prox(act_obs, prox_action, args.pf_clip)
            # infos[0]['prox_value'] = prox_value.item()
            if args.pf_reward_type == 'nofinal':
                diff_prox_reward = (next_prox - cur_prox)
                final_prox_reward = torch.zeros(diff_prox_reward.shape).to(args.device)
                
                constant_pen = 0
                uncert_pen = 0
                if args.pf_uncert and args.pf_n_nets > 1:
                    next_uncert = dis_prox_func._get_prox_uncert(next_state, next_action)
                    uncert_pen = dis_prox_func.args.pf_uncert_scale * next_uncert
                reward = (diff_prox_reward + final_prox_reward - uncert_pen + constant_pen) * args.pf_reward_scale
                total_reward += reward.item()
                if done:
                    total_reward_list.append(total_reward)
                    total_reward = 0
                # Note: originally, there was a reward normalization step here, but I skipped it
                # if args.pf_reward_norm:
                #     # Normalize reward
                #     if returns is None:
                #         returns = reward.clone()
                #     self.returns = self.returns * storage.masks[step] * self.args.gamma + reward
                #     self.ret_rms.update(self.returns.cpu().numpy())
                #     reward = reward / np.sqrt(self.ret_rms.var[0] + 1e-8)
                if done:
                    returns = None
            else:
                diff_prox_reward = None
                final_prox_reward = None
        else:
            diff_prox_reward = None
            final_prox_reward = None
        
        if args.add_reaching_goal_speed:
            if done:
                if infos[0]['ep_found_goal']:
                    step_count_list.append(step_count)
                step_count = 0

        if args.eval_save:
            finished_count = traj_saver.collect(
                obs, next_obs, done, ac_info.take_action, infos
            )
        else:
            finished_count = sum([int(d) for d in done])

        pbar.update(finished_count)
        evaluated_episode_count += finished_count

        cur_frame = None

        eval_masks = torch.tensor(
            [[0.0] if done_ else [1.0] for done_ in done],
            dtype=torch.float32,
            device=args.device,
        )

        should_render = (args.num_render) is None or (
            evaluated_episode_count < args.num_render
        )
        if args.render_succ_fails:
            should_render = n_succs < args.num_render or n_fails < args.num_render

        if should_render:
            cur_frame = get_render_frames(
                    eval_envs,
                    env_interface,
                    obs,
                    next_obs,
                    ac_info.take_action,
                    eval_masks,
                    infos,
                    args,
                    evaluated_episode_count,
                )
            if args.add_reaching_goal_speed:
                import cv2
                cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"step_count: {step_count}",
                                    org=(10, 100),
                                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=0.5,
                                    color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
                
            if args.add_widowx_cube_height:
                import cv2
                cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"cube height: {next_obs[0][2]}",
                                    org=(10, 100),
                                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=0.5,
                                    color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
            if updater is not None and args.alg == 'dpf':
                import cv2
                # cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"prox_value:      {prox_value.item():.8f}",
                #                     org=(10, 160),
                #                     fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1,
                #                     color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
                cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"prox_value: {cur_prox.item():.10f}",
                                    org=(10, 120),
                                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=0.5,
                                    color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
                cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"diff_prox_reward: {diff_prox_reward.item():.10f}",
                                    org=(10, 140),
                                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=0.5,
                                    color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
                cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"uncertainty_penalty: {uncert_pen.item():.10f}",
                                    org=(10, 160),
                                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=0.5,
                                    color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
                cur_frame[0] = cv2.putText(img=cur_frame[0], text=f"reward=diff_prox_reward-uncertainty: {reward.item():.10f}",
                                    org=(10, 180),
                                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=0.5,
                                    color=(128, 0, 0), thickness=1, lineType=cv2.LINE_AA)
            frames.extend(cur_frame)
        obs = next_obs

        step_log_vals = utils.agg_ep_log_stats(infos, ac_info.extra)
        for k, v in step_log_vals.items():
            ep_stats[k].extend(v)
        

        if "ep_success" in step_log_vals and args.render_succ_fails:
            is_succ = step_log_vals["ep_success"][0]
            if is_succ == 1.0:
                if n_succs < args.num_render:
                    succ_frames.extend(frames)
                n_succs += 1
            else:
                if n_fails < args.num_render:
                    fail_frames.extend(frames)
                n_fails += 1
            frames = []

    pbar.close()

    if args.widowx_save_torque:
        # transform buffer to be tensor and save
        for i in range(len(torque_buffer)):
            torque_buffer[i] = torch.stack(torque_buffer[i])

        if args.alg in ['bc', 'dpf-deep']:
            # Filter successful indices
            success_indices = [i for i, success in enumerate(ep_found_goal_buffer) if success]
        else:
            success_indices = [i for i in range(len(ep_found_goal_buffer))]

        # Keep only successful trajs
        torque_buffer = [torque_buffer[i] for i in success_indices]
        cube_pos_buffer = [cube_pos_buffer[i] for i in success_indices]
        ep_found_goal_buffer = torch.tensor([ep_found_goal_buffer[i] for i in success_indices])

        # Save
        saved_buffer = {
            'torque': torque_buffer,
            'cube_pos': cube_pos_buffer,
            'ep_found_goal': ep_found_goal_buffer
        }



        # torque_buffer = torch.stack(torque_buffer)
        # save the torque buffer
        torch.save(saved_buffer, osp.join(args.log_dir, args.env_name, args.prefix, "torque.pt"))
        logging("Torque buffer saved to %s" % osp.join(args.log_dir, args.env_name, args.prefix, "torque.pt"))


    info = {}
    if args.eval_save:
        traj_saver.save()

    ret_info = {}

    if args.add_ooc_action_estimation and not args.env_name.startswith("WidowXLiftCube"):
            
        ret_info['total_action_taken'] = total_action_taken
        ret_info['in_distribution_action_taken'] = in_distribution_action_taken
        if total_action_taken > 0:
            ret_info['in_dist_action_ratio'] = (
                in_distribution_action_taken / total_action_taken
            )
        else:
            ret_info['in_dist_action_ratio'] = 0.0


    print(" Evaluation using %i episodes:" % len(ep_stats["r"]))
    for k, v in ep_stats.items():
        print(" - %s: %.5f" % (k, np.mean(v)))
        ret_info[k] = np.mean(v)

    if updater is not None and args.alg == 'dpf':
        # get correlation between success and reward
        ep_success_list = ep_stats["ep_found_goal"]
        corr = np.corrcoef(ep_success_list, total_reward_list)
        ret_info['corr'] = corr[0, 1]
        print(f"Correlation between success and reward: {corr[0, 1]}")
    
    if args.add_reaching_goal_speed:
        print(f"Average number of steps taken to reach the goal: {np.mean(step_count_list)}")
        ret_info['avg_steps_to_goal'] = np.mean(step_count_list)

    if args.render_succ_fails:
        # Render the success and failures to two separate files.
        save_frames(succ_frames, "succ_" + mode, num_steps, args)
        save_frames(fail_frames, "fail_" + mode, num_steps, args)
    else:
        # save_file = save_frames(frames, mode, num_steps, args)
        if args.num_render > 0:
            save_dir = osp.join(args.vid_dir, args.env_name, args.prefix)
            if not osp.exists(save_dir):
                os.makedirs(save_dir)
            save_file = save_video(save_dir, np.array(frames), episode_id=num_steps)
            if save_file is not None:
                log.log_video(save_file, num_steps, args.vid_fps)

    # Switch policy back to train mode
    policy.train()

    return ret_info, eval_envs


def save_frames(frames, mode, num_steps, args):
    if not osp.exists(args.vid_dir):
        os.makedirs(args.vid_dir)

    add = ""
    if args.load_file != "":
        add = args.load_file.split("/")[-2]
        add += "_"

    save_name = "%s%s_%s" % (add, utils.human_format_int(num_steps), mode)

    save_dir = osp.join(args.vid_dir, args.env_name, args.prefix)

    fps = args.vid_fps

    if len(frames) > 0:
        save_mp4(frames, save_dir, save_name, fps=args.vid_fps, no_frame_drop=True)
        return osp.join(save_dir, save_name)
    return None


def get_render_frames(
    eval_envs,
    env_interface,
    obs,
    next_obs,
    action,
    masks,
    infos,
    args,
    evaluated_episode_count,
):
    add_kwargs = {}
    if args.render_metric:
        add_kwargs = {}
        if obs is not None:
            add_kwargs = {
                "obs": utils.ob_to_cpu(obs),
                "action": action.cpu(),
                "next_obs": utils.ob_to_cpu(next_obs),
                "info": infos,
                "next_mask": masks.cpu(),
            }

    try:
        cur_frame = eval_envs.render(**env_interface.get_render_args(), **add_kwargs)
    except EOFError as e:
        print("This problem can likely be fixed by setting --eval-num-processes 1")
        raise e
    
    if isinstance(cur_frame, torch.Tensor):
        cur_frame = cur_frame.cpu().numpy()

    if not isinstance(cur_frame, list):
        cur_frame = [cur_frame]
    return cur_frame







def visualize_minigrid_path(
    envs,
    policy,
    log,
    checkpointer,
    env_interface,
    args,
    alg_env_settings,
    create_traj_saver_fn,
    vec_norm,
    updater=None,
):
    args.evaluation_mode = True
    ret_info, envs = evaluate_minigrid_agent_path(
        args,
        alg_env_settings,
        policy,
        vec_norm,
        env_interface,
        0,
        "final",
        envs,
        log,
        create_traj_saver_fn,
        updater,
    )
    args.evaluation_mode = False
    envs.close()

    return ret_info


def evaluate_minigrid_agent_path(
    args,
    alg_env_settings,
    policy,
    true_vec_norm,
    env_interface,
    num_steps,
    mode,
    eval_envs,
    log,
    create_traj_saver_fn,
    updater=None,
):
    num_envs = 1

    if eval_envs is None:
        args.force_multi_proc = False
        eval_envs = make_vec_envs(
            args.env_name,
            args.seed + num_steps,
            num_envs,
            args.gamma,
            args.device,
            True,
            env_interface,
            args,
            alg_env_settings,
            set_eval=True,
        )

    from gym_minigrid.minigrid import Grid
    from PIL import Image
    assert get_vec_normalize(eval_envs) is None, "Norm is manually applied"

    if true_vec_norm is not None:
        obfilt = true_vec_norm._obfilt
    else:

        def obfilt(x, update):
            return x

    ep_stats = defaultdict(list)

    obs = eval_envs.reset()

    hidden_states = {}
    for k, dim in policy.get_storage_hidden_states().items():
        hidden_states[k] = torch.zeros(num_envs, dim).to(args.device)
    eval_masks = torch.zeros(num_envs, 1, device=args.device)

    frames = []
    infos = None

    policy.eval()

    total_num_eval = 1

    # Measure the number of episodes completed
    pbar = tqdm(total=total_num_eval)
    evaluated_episode_count = 0
    n_succs = 0
    n_fails = 0
    succ_frames = []
    fail_frames = []
    if args.render_succ_fails and args.eval_num_processes > 1:
        raise ValueError(
            """
                Can only render successes and failures when the number of
                processes is 1.
                """
        )
    
    agent_position_list = []
    agent_position_list.append(eval_envs.unwrapped.envs[0].env.env.env.env.agent_pos)


    expert_path = []
    for i in range(1, 17):
        expert_path.append((i, 1))
    for i in range(1, 17):
        expert_path.append((17, i))

    while evaluated_episode_count < total_num_eval:
        step_info = get_empty_step_info()
        with torch.no_grad():
            act_obs = obfilt(utils.ob_to_np(obs), update=False)
            act_obs = utils.ob_to_tensor(act_obs, args.device)

            ac_info = policy.get_action(
                utils.get_def_obs(act_obs),
                utils.get_other_obs(obs),
                hidden_states,
                eval_masks,
                step_info,
            )

            hidden_states = ac_info.hxs

        next_obs, _, done, infos = eval_envs.step(ac_info.take_action)
        agent_position_list.append(eval_envs.unwrapped.envs[0].env.env.env.env.agent_pos)

        finished_count = sum([int(d) for d in done])

        pbar.update(finished_count)
        evaluated_episode_count += finished_count

        eval_masks = torch.tensor(
            [[0.0] if done_ else [1.0] for done_ in done],
            dtype=torch.float32,
            device=args.device,
        )

        obs = next_obs

        step_log_vals = utils.agg_ep_log_stats(infos, ac_info.extra)
        for k, v in step_log_vals.items():
            ep_stats[k].extend(v)
        

        if "ep_success" in step_log_vals and args.render_succ_fails:
            is_succ = step_log_vals["ep_success"][0]
            if is_succ == 1.0:
                if n_succs < args.num_render:
                    succ_frames.extend(frames)
                n_succs += 1
            else:
                if n_fails < args.num_render:
                    fail_frames.extend(frames)
                n_fails += 1
            frames = []

    pbar.close()

    info = {}

    ret_info = {}

    logging("start rendering")
    def render_grid(
        grid,
        tile_size=64,
        agent_position_list=None,
        agent_dir=0,
        highlight_mask=None,
        expert_path=None,
    ):
        """
        Render this grid at a given scale
        :param r: target renderer object
        :param tile_size: tile size in pixels
        """

        if highlight_mask is None:
            highlight_mask = np.zeros(shape=(grid.width, grid.height), dtype=bool)

        # Compute the total grid size
        width_px = grid.width * tile_size
        height_px = grid.height * tile_size

        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)


        # Render the grid
        for j in range(0, grid.height):
            for i in range(0, grid.width):
                cell = grid.get(i, j)

                # agent_here = np.array_equal(agent_pos, (i, j))
                agent_here = False
                if agent_position_list is not None:
                    for agent_pos in agent_position_list:
                        agent_here |= np.array_equal(agent_pos, (i, j))
                expert_here = False
                if expert_path is not None:
                    for expert_pos in expert_path:
                        expert_here |= np.array_equal(expert_pos, (i, j))
                # tile_img = Grid.render_tile(
                #     cell,
                #     agent_dir=agent_dir if agent_here else None,
                #     highlight=highlight_mask[i, j],
                #     tile_size=tile_size
                # )
                tile_img = render_tile(
                    Grid,
                    cell,
                    agent_dir=agent_dir if agent_here else None,
                    highlight=highlight_mask[i, j],
                    tile_size=tile_size,
                    expert_here=expert_here,
                )

                ymin = j * tile_size
                ymax = (j+1) * tile_size
                xmin = i * tile_size
                xmax = (i+1) * tile_size
                img[ymin:ymax, xmin:xmax, :] = tile_img

        return img
    
    from gym_minigrid.rendering import fill_coords, point_in_rect, point_in_triangle, rotate_fn, highlight_img, downsample
    import math
    from gym_minigrid.minigrid import Goal, WorldObj


    class ExpertGrid(WorldObj):
        def __init__(self):
            super().__init__('goal', 'grey')

        def can_overlap(self):
            return True

        def render(self, img):
            fill_coords(img, point_in_rect(0, 1, 0, 1), np.array([0, 102, 204]))
    def render_tile(
        cls,
        obj,
        agent_dir=None,
        highlight=False,
        tile_size=32,
        subdivs=3,
        agent_color=(255, 0, 0),
        expert_here=False,
    ):
        """
        Render a tile and cache the result
        """

        # Hash map lookup key for the cache
        key = (agent_dir, highlight, tile_size)
        key = obj.encode() + key if obj else key

        # if key in cls.tile_cache:
        #     return cls.tile_cache[key]

        img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8)

        # Draw the grid lines (top and left edges)
        fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
        fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))

        if obj != None:
            obj.render(img)
        
        if expert_here:
            obj = ExpertGrid()
            obj.render(img)

        # Overlay the agent on top
        if agent_dir is not None:
            tri_fn = point_in_triangle(
                (0.12, 0.19),
                (0.87, 0.50),
                (0.12, 0.81),
            )

            # Rotate the agent based on its direction
            tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir)
            fill_coords(img, tri_fn, agent_color)

        # Highlight the cell if needed
        if highlight:
            highlight_img(img)

        # Downsample the image to perform supersampling/anti-aliasing
        img = downsample(img, subdivs)

        # Cache the rendered tile
        cls.tile_cache[key] = img

        return img
    try:
        frame = render_grid(eval_envs.unwrapped.envs[0].env.env.env.env.grid, agent_position_list=agent_position_list, expert_path=expert_path)
    except EOFError as e:
        print("This problem can likely be fixed by setting --eval-num-processes 1")
        raise e

    save_dir = osp.join(args.vid_dir, args.env_name, args.prefix)
    if not osp.exists(save_dir):
        os.makedirs(save_dir)
    save_image_path = osp.join(save_dir, "minigrid_traj.png")


    Image.fromarray(frame).save(save_image_path)
    logging(f"image saved to {save_image_path}")

    # Switch policy back to train mode
    policy.train()

    return ret_info, eval_envs





