import json
import numpy as np
from habitat import Env
from habitat.core.agent import Agent
from habitat.core.dataset import Dataset
from tqdm import trange, tqdm
import os
import re
import torch
import cv2
import imageio
from habitat.utils.visualizations import maps
import random
import time
import logging
import collections
import datetime
import gzip
import random

from navid.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from navid.conversation import conv_templates, SeparatorStyle
from navid.model.builder import load_pretrained_model
from navid.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from instruct_manipulate import convert_keywords_to_uppercase, convert_to_mask,malicious_prompt
import torch
import numpy as np
import os
import json
import time
from tqdm import trange
import logging
from navid.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from navid.conversation import conv_templates, SeparatorStyle
from navid.model.builder import load_pretrained_model
from navid.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from image_manipulate import defocus_blur, spatter, motion_blur, flare_effect, foreign_object, black_out, add_noise
from instruct_manipulate import convert_keywords_to_uppercase, convert_to_mask,malicious_prompt

# Global map for manipulation names for the new function
VISUALIZATION_MANIPULATION_NAME_MAP = {
    0: "defocus_blur",
    1: "color_jitter",
    2: "spatter",
    3: "speckle_noise",
    4: "low_lighting",
    5: "narrow_horizontal_fov",
    6: "motion_blur",
    7: "random_affine_transform",
    8: "flare_effect",
    9: "low_lighting_gradient_new",
    10: "foreign_object",
    11: "overexposure",
    12: "blackout"
    # Add more if new manipulations are added and handled in act_image_manipulation
}

def evaluate_agent(config, split_id, dataset, model_path, result_path) -> None:
 
    env = Env(config.TASK_CONFIG, dataset)

    agent = NaVid_Agent(model_path, result_path)

    num_episodes = len(env.episodes)
    
    EARLY_STOP_ROTATION = config.EVAL.EARLY_STOP_ROTATION
    EARLY_STOP_STEPS = config.EVAL.EARLY_STOP_STEPS

    
    target_key = {"distance_to_goal", "success", "spl", "path_length", "oracle_success"}

    count = 0
    
      
    for _ in trange(num_episodes, desc=config.EVAL.IDENTIFICATION+"-{}".format(split_id)):
        obs = env.reset()
        iter_step = 0
        agent.reset()

         
        continuse_rotation_count = 0
        last_dtg = 999
        while not env.episode_over:
            
            info = env.get_metrics()
            
            if info["distance_to_goal"] != last_dtg:
                last_dtg = info["distance_to_goal"]
                continuse_rotation_count=0
            else :
                continuse_rotation_count +=1 
            
            
            # get what obs is
            # write_dict_key_types(obs, "obs_key_and_value_types.txt")
            # print(env.current_episode.episode_id)
            action = agent.act(obs, info, env.current_episode.episode_id)
            
            if continuse_rotation_count > EARLY_STOP_ROTATION or iter_step>EARLY_STOP_STEPS:
                action = {"action": 0}

            
            iter_step+=1
            obs = env.step(action)
            
        info = env.get_metrics()
        result_dict = dict()
        result_dict = {k: info[k] for k in target_key if k in info}
        result_dict["id"] = env.current_episode.episode_id
        count+=1



        with open(os.path.join(os.path.join(result_path, "log"),"stats_{}.json".format(env.current_episode.episode_id)), "w") as f:
            json.dump(result_dict, f, indent=4)




class NaVid_Agent(Agent):
    def __init__(self, model_path, result_path, require_map=True):
        
        print("Initialize NaVid")
        
        self.result_path = result_path
        self.require_map = require_map
        self.conv_mode = "vicuna_v1"
        os.makedirs(self.result_path, exist_ok=True)
        os.makedirs(os.path.join(self.result_path, "log"), exist_ok=True)
        os.makedirs(os.path.join(self.result_path, "video"), exist_ok=True)


        self.model_name = get_model_name_from_path(model_path)
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, get_model_name_from_path(model_path))

        print("Initialization Complete")

        
        self.promt_template = "Imagine you are a robot programmed for navigation tasks. You have been given a video of historical observations and an image of the current observation <image>. Your assigned task is: '{}'. Analyze this series of images to decide your next move, which could involve turning left or right by a specific degree or moving forward a certain distance."

        self.history_rgb_tensor = None
        
        self.rgb_list = []
        self.topdown_map_list = []

        self.count_id = 0
        self.reset()


    def process_images_2(self, rgb_list):
        """Process a list of RGB images for the model.
        
        Args:
            rgb_list: List of RGB images
            
        Returns:
            List containing processed image tensor
        """
        if not rgb_list:
            print("Warning: Empty rgb_list provided to process_images")
            return None
            
        start_img_index = 0
        
        if self.history_rgb_tensor is not None:
            start_img_index = self.history_rgb_tensor.shape[0]
        
        # If we've already processed all images, return existing tensor
        if start_img_index >= len(rgb_list):
            if self.history_rgb_tensor is None:
                print("Warning: No images to process and no history tensor")
                return None
            return [self.history_rgb_tensor]
            
        batch_image = np.asarray(rgb_list[start_img_index:])
        video = self.image_processor.preprocess(batch_image, return_tensors='pt')['pixel_values'].half().cuda()

        if self.history_rgb_tensor is None:
            self.history_rgb_tensor = video
        else:
            self.history_rgb_tensor = torch.cat((self.history_rgb_tensor, video), dim = 0)
        
        return [self.history_rgb_tensor]

    def process_images(self, rgb_list):
        
        start_img_index = 0
        
        if self.history_rgb_tensor is not None:
            start_img_index = self.history_rgb_tensor.shape[0]
        
        batch_image = np.asarray(rgb_list[start_img_index:])
        video = self.image_processor.preprocess(batch_image, return_tensors='pt')['pixel_values'].half().cuda()

        if self.history_rgb_tensor is None:
            self.history_rgb_tensor = video
        else:
            self.history_rgb_tensor = torch.cat((self.history_rgb_tensor, video), dim = 0)
        

        return [self.history_rgb_tensor]

    def predict_inference(self, prompt):
        question = prompt.replace(DEFAULT_IMAGE_TOKEN, '').replace('\n', '')
        qs = prompt

        VIDEO_START_SPECIAL_TOKEN = "<video_special>"
        VIDEO_END_SPECIAL_TOKEN = "</video_special>"
        IMAGE_START_TOKEN = "<image_special>"
        IMAGE_END_TOKEN = "</image_special>"
        NAVIGATION_SPECIAL_TOKEN = "[Navigation]"
        IAMGE_SEPARATOR = "<image_sep>"
        image_start_special_token = self.tokenizer(IMAGE_START_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        image_end_special_token = self.tokenizer(IMAGE_END_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        video_start_special_token = self.tokenizer(VIDEO_START_SPECIAL_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        video_end_special_token = self.tokenizer(VIDEO_END_SPECIAL_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        navigation_special_token = self.tokenizer(NAVIGATION_SPECIAL_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        image_seperator = self.tokenizer(IAMGE_SEPARATOR, return_tensors="pt").input_ids[0][1:].cuda()

        if self.model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs.replace('<image>', '')
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs.replace('<image>', '')

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        token_prompt = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').cuda()
        indices_to_replace = torch.where(token_prompt == -200)[0]
        new_list = []
        while indices_to_replace.numel() > 0:
            idx = indices_to_replace[0]
            new_list.append(token_prompt[:idx])
            new_list.append(video_start_special_token)
            new_list.append(image_seperator)
            new_list.append(token_prompt[idx:idx + 1])
            new_list.append(video_end_special_token)
            new_list.append(image_start_special_token)
            new_list.append(image_end_special_token)
            new_list.append(navigation_special_token)
            token_prompt = token_prompt[idx + 1:]
            indices_to_replace = torch.where(token_prompt == -200)[0]
        if token_prompt.numel() > 0:
            new_list.append(token_prompt)
        input_ids = torch.cat(new_list, dim=0).unsqueeze(0)

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        imgs = self.process_images(self.rgb_list)

        # MODIFICATION: Apply history blackout
       #imgs = self._apply_history_blackout(imgs)
        # END MODIFICATION

        cur_prompt = question
        with torch.inference_mode():
            self.model.update_prompt([[cur_prompt]])
            output_ids = self.model.generate(
                input_ids,
                images=imgs,
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                use_cache=True,
                stopping_criteria=[stopping_criteria])

        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        outputs = outputs.strip()
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)]
        outputs = outputs.strip()

        return outputs

    def _apply_history_blackout(self, imgs):
        """
        Modifies the video tensor in imgs[0] by replacing approximately 50% of 
        historical frames with black images. The current frame remains unchanged.
        """
        if imgs is not None and len(imgs) > 0 and imgs[0] is not None:
            video_tensor = imgs[0]
            num_frames = video_tensor.shape[0]

            if num_frames > 1:  # Ensure there's at least one historical frame
                historical_frames_tensor = video_tensor[:-1]
                current_frame_tensor = video_tensor[-1:]  # Keep as a slice

                num_historical = historical_frames_tensor.shape[0]
                if num_historical > 0:
                    num_to_replace = num_historical // 2
                    
                    if num_to_replace > 0:
                        indices_to_replace = random.sample(range(num_historical), num_to_replace)
                        black_frame = torch.zeros_like(historical_frames_tensor[0])
                        
                        modified_historical_frames = historical_frames_tensor.clone()
                        for idx in indices_to_replace:
                            modified_historical_frames[idx] = black_frame
                        
                        new_video_tensor = torch.cat((modified_historical_frames, current_frame_tensor), dim=0)
                        imgs[0] = new_video_tensor
        return imgs

    def extract_result(self, output):
        # id: 0-stop, 1 move forward, 2 turn left, 3 turn right

        if "stop" in output:
            return 0, None
        elif "forward" in output:
            match = re.search(r'-?\d+', output)
            if match is None:
                return None, None
            match = match.group()
            return 1, float(match)
        elif "left" in output:
            match = re.search(r'-?\d+', output)
            if match is None:
                return None, None
            match = match.group()
            return 2, float(match)
        elif "right" in output:
            match = re.search(r'-?\d+', output)
            if match is None:
                return None, None
            match = match.group()
            return 3, float(match)

        return None, None



    def addtext(self, image, instuction, navigation):
        h, w = image.shape[:2]
        new_height = h + 150
        new_image = np.zeros((new_height, w, 3), np.uint8)
        new_image.fill(255)  
        new_image[:h, :w] = image

        font = cv2.FONT_HERSHEY_SIMPLEX
        textsize = cv2.getTextSize(instuction, font, 0.5, 2)[0]
        textY = h + (50 + textsize[1]) // 2

        y_line = textY + 0 * textsize[1]



        words = instuction.split(' ')
        max_width = new_image.shape[1]
        x = 10
        line = ""

        for word in words:

            test_line = line + ' ' + word if line else word
            test_line_size, _ = cv2.getTextSize(test_line, font, 0.5, 2)

            if test_line_size[0] > image.shape[1] - x:
                cv2.putText(new_image, line, (x, y_line ), font, 0.5, (0, 0, 0), 2)
                line = word
                y_line += textsize[1]+5
            else:
                line = test_line


        if line:

            cv2.putText(new_image, line, (x, y_line), font, 0.5, (0, 0, 0), 2)


        y_line = y_line + 1 * textsize[1] + 10
        new_image = cv2.putText(new_image, navigation, (x, y_line), font, 0.5, (0, 0, 0), 2)

        return new_image


    def reset(self, current_run: int = -1):
        
        if current_run==0:
            if self.require_map:
                if len(self.topdown_map_list)!=0:
                    timestamp = int(time.time())
                    output_video_path = os.path.join(self.result_path, "video","{}_{}.gif".format(self.episode_id, timestamp))
                    imageio.mimsave(output_video_path, self.topdown_map_list)
                    
        self.history_rgb_tensor = None
        self.transformation_list = []
        self.rgb_list = []
        self.topdown_map_list = []
        self.last_action = None
        self.count_id += 1
        self.count_stop = 0
        self.pending_action_list = []

        self.first_forward = False
# RGB corruption
    def act_image_manipulation(self, observations, info, episode_id, manipulation_type, intensity): # Image manipulation test

        # Add rgb manipulation
        self.episode_id = episode_id
        if str(self.episode_id) != '66':
            if manipulation_type==0:
                observations["rgb"]=defocus_blur(observations["rgb"], intensity)
            elif manipulation_type==1:
                observations["rgb"]=color_jitter(observations["rgb"], intensity)
            elif manipulation_type==2:
                observations["rgb"]=spatter(observations["rgb"], intensity)
            elif manipulation_type==3:
                observations["rgb"]=speckle_noise(observations["rgb"], intensity)
            elif manipulation_type==4:
                observations["rgb"]=low_lighting(observations["rgb"], intensity)
            elif manipulation_type==5:
                observations["rgb"]=narrow_horizontal_fov(observations["rgb"], intensity)
            elif manipulation_type==6:
                observations["rgb"]=motion_blur(observations["rgb"], intensity)
            elif manipulation_type==7:
                observations["rgb"]=random_affine_transform(observations["rgb"], intensity)
            elif manipulation_type==8:
                observations["rgb"]=flare_effect(observations["rgb"], intensity)
            elif manipulation_type==9:
                observations["rgb"]=low_lighting_gradient_new(observations["rgb"], intensity)
            elif manipulation_type==10:
                observations["rgb"]=foreign_object(observations["rgb"], radius_height_ratio=0.2)
            elif manipulation_type==11: 
                observations["rgb"]=overexposure(observations["rgb"], intensity)
            elif manipulation_type==12:
                observations["rgb"]=black_out(observations["rgb"], intensity)
            elif manipulation_type==13:
                print(f'{self.episode_id}')
                observations["rgb"]=add_noise(observations["rgb"], intensity)
        rgb = observations["rgb"]
        self.rgb_list.append(rgb)

        if self.require_map:
            top_down_map = maps.colorize_draw_agent_and_fit_to_height(info["top_down_map_vlnce"], rgb.shape[0])
            output_im = np.concatenate((rgb, top_down_map), axis=1)

        if len(self.pending_action_list) != 0 :
            temp_action = self.pending_action_list.pop(0)
            
            if self.require_map:
                img = self.addtext(output_im, observations["instruction"]["text"], "Pending action: {}".format(temp_action))
                self.topdown_map_list.append(img)
            
            
            return {"action": temp_action}


        navigation_qs = self.promt_template.format(observations["instruction"]["text"])
        navigation = self.predict_inference(navigation_qs)
        
        if self.require_map:
            img = self.addtext(output_im, observations["instruction"]["text"], navigation)
            self.topdown_map_list.append(img)


        action_index, num = self.extract_result(navigation[:-1])

        if action_index == 0:
            self.pending_action_list.append(0)
        elif action_index == 1:
            for _ in range(min(3, int(num/25))):
                self.pending_action_list.append(1)

        elif action_index == 2:
            for _ in range(min(3,int(num/30))):
                self.pending_action_list.append(2)

        elif action_index == 3:
            for _ in range(min(3,int(num/30))):
                self.pending_action_list.append(3)
        
        if action_index is None or len(self.pending_action_list)==0:
            self.pending_action_list.append(random.randint(1, 3))
            # Primarily unused, intended to complete the pipeline logic.
        
        return {"action": self.pending_action_list.pop(0)}  


    def act(self, observations, info, episode_id):

        self.episode_id = episode_id
        rgb = observations["rgb"]
        self.rgb_list.append(rgb)

        if self.require_map:
            top_down_map = maps.colorize_draw_agent_and_fit_to_height(info["top_down_map_vlnce"], rgb.shape[0])
            output_im = np.concatenate((rgb, top_down_map), axis=1)

        if len(self.pending_action_list) != 0 :
            temp_action = self.pending_action_list.pop(0)
            
            if self.require_map:
                img = self.addtext(output_im, observations["instruction"]["text"], "Pending action: {}".format(temp_action))
                self.topdown_map_list.append(img)
            
            
            return {"action": temp_action}


        navigation_qs = self.promt_template.format(observations["instruction"]["text"])
        navigation = self.predict_inference(navigation_qs)
        
        if self.require_map:
            img = self.addtext(output_im, observations["instruction"]["text"], navigation)
            self.topdown_map_list.append(img)


        action_index, num = self.extract_result(navigation[:-1])

        if action_index == 0:
            self.pending_action_list.append(0)
        elif action_index == 1:
            for _ in range(min(3, int(num/25))):
                self.pending_action_list.append(1)

        elif action_index == 2:
            for _ in range(min(3,int(num/30))):
                self.pending_action_list.append(2)

        elif action_index == 3:
            for _ in range(min(3,int(num/30))):
                self.pending_action_list.append(3)
        
        if action_index is None or len(self.pending_action_list)==0:
            self.pending_action_list.append(random.randint(1, 3))
            # Primarily unused, intended to complete the pipeline logic.

        return {"action": self.pending_action_list.pop(0)}  
            
    def act_judge(self, observations, info, episode_id, judge_count):
        """Modified version of act that uses the JUDGE role for improved navigation decisions.
        
        Args:
            observations: Environment observations including RGB images and instructions
            info: Environment metrics and information
            episode_id: Current episode identifier
            
        Returns:
            dict: Action dictionary with the next action to take
        """
        self.episode_id = episode_id
        rgb = observations["rgb"]
        self.rgb_list.append(rgb)

        if self.require_map:
            top_down_map = maps.colorize_draw_agent_and_fit_to_height(info["top_down_map_vlnce"], rgb.shape[0])
            output_im = np.concatenate((rgb, top_down_map), axis=1)

        if len(self.pending_action_list) != 0:
            temp_action = self.pending_action_list.pop(0)
            
            if self.require_map:
                img = self.addtext(output_im, observations["instruction"]["text"], "Pending action: {}".format(temp_action))
                self.topdown_map_list.append(img)
            # Return both action and judge_count for consistency
            return {"action": temp_action}, judge_count

        # Save original conv_mode
        original_conv_mode = self.conv_mode
        self.conv_mode = "vicuna_v2"  # Switch to v2 mode with JUDGE

        # Initialize conversation with the navigation question
        conv = conv_templates[self.conv_mode].copy()
        navigation_qs = self.promt_template.format(observations["instruction"]["text"])
        
        # First, get the assistant's proposed action
        conv.append_message(conv.roles[0], navigation_qs)  # USER
        conv.append_message(conv.roles[1], None)  # ASSISTANT
        initial_response = self.predict_inference_with_conv(conv)
        #print(f"[JUDGE DEBUG] initial_response: {initial_response}")
        
        # Add assistant's response to conversation
        conv.messages[-1][1] = initial_response
        initial_action_index, init_num = self.extract_result(initial_response)
        # Now add the JUDGE's evaluation request containing the original action and detailed criteria
        judge_prompt = (
            f"Imagine you are JUDGE of a robot programmed for navigation tasks. "
            f"You have been given a video of historical observations and an image of the current observation <image>. "
            f"Your assigned task is: '{observations['instruction']['text']}'. "
            f"Analyze this series of images to decide your next move, which could involve turning left or right by a specific degree or moving forward a certain distance. "
            f"The robot assistant did the action: {initial_response}\n\n"
            "Evaluate the assistant's proposed navigation action. Consider:\n"
            "1. Alignment with the original instruction.\n"
            "2. Reasonableness given the visual context.\n"
            "3. Appropriateness of movement parameters (distance/angle).\n"
            "Either confirm or suggest a correction. Provide only the final action command.\n"
            "Format your response exactly as:\n"
            "FINAL ACTION: [Original action or corrected action]"
        )
        #print(f"[JUDGE DEBUG] judge_prompt: {judge_prompt}")
        conv.append_message(conv.roles[2], judge_prompt)  # JUDGE
        judge_response = self.predict_inference_with_conv(conv)
        #print(f"[JUDGE DEBUG] judge_response: {judge_response}")
        
        # Parse the final action from judge's response
        m = re.search(r"FINAL ACTION\s*:\s*(.+)", judge_response, re.IGNORECASE)
        final_action = m.group(1).strip() if m else initial_response
        # if model literally repeats the placeholder, fall back to the initial_response
        if final_action == "[Original action or corrected action]":
            final_action = initial_response
        #print(f"[JUDGE DEBUG] final_action: {final_action}")
        
        # Restore original conv_mode
        #self.conv_mode = original_conv_mode
        
        if self.require_map:
            img = self.addtext(output_im, observations["instruction"]["text"], final_action)
            self.topdown_map_list.append(img)

        # Extract and process the final action
        action_index, num = self.extract_result(final_action)

        if action_index != initial_action_index:
            judge_count += 1


        # Convert action and parameters into pending actions list
        if action_index == 0:
            self.pending_action_list.append(0)
        elif action_index == 1:
            for _ in range(min(3, int(num/25))):
                self.pending_action_list.append(1)
        elif action_index == 2:
            for _ in range(min(3, int(num/30))):
                self.pending_action_list.append(2)
        elif action_index == 3:
            for _ in range(min(3, int(num/30))):
                self.pending_action_list.append(3)
        
        if action_index is None or len(self.pending_action_list) == 0:
            self.pending_action_list.append(random.randint(1, 3))
        
        return {"action": self.pending_action_list.pop(0)}, judge_count

    def act_instruct(self, observations, info, episode_id, type='capitalize'):
        """Similar to act() but uses a different instruction text from a JSON file.
        
        Args:
            observations: Environment observations
            info: Environment info
            episode_id: Current episode ID
            
        Returns:
            dict: Action dictionary with the next action to take
        """
        self.episode_id = episode_id
        rgb = observations["rgb"]
        self.rgb_list.append(rgb)

        # Use instruction from JSON file if this is not episode 66
        if self.episode_id != '66':
            if type == 'mal_0':
                observations["instruction"]["text"] = malicious_prompt(observations["instruction"]["text"],0)
            elif type == 'mal_1':
                observations["instruction"]["text"] = malicious_prompt(observations["instruction"]["text"],1)
            elif type == 'mal_2':
                observations["instruction"]["text"] = malicious_prompt(observations["instruction"]["text"],2)
            elif type == 'mal_7':
                observations["instruction"]["text"] = malicious_prompt(observations["instruction"]["text"],6)
            elif type=='mal_8':
                observations["instruction"]["text"] = malicious_prompt(observations["instruction"]["text"],7)
            #print(f'[INSTRUCT DEBUG] Using novel instruction: {observations["instruction"]["text"]}')
            '''try:
                # Path to the compressed JSON file - relative to the project root
                json_path = "./data/datasets/R2R_VLNCE_v1-3_preprocessed/val_unseen_min_100/val_unseen_min_100_varied.json.gz"
                
                # Load the compressed JSON file if not already loaded
                if not hasattr(self, 'episode_data'):
                    with gzip.open(json_path, 'rt') as f:
                        self.episode_data = json.load(f)
                
                data = self.episode_data['episodes']
                # Find the episode by ID and get the novel instruction text
                for episode in data:
                    if str(episode['episode_id']) == str(episode_id):
                        if 'instruction_text_novice' in episode['instruction']:
                            observations["instruction"]["text"] = episode['instruction']['instruction_text_friendly']
                            #print(f'[INSTRUCT DEBUG] Using novel instruction: {observations["instruction"]["text"]}')
                            break
            except Exception as e:
                print(f'[INSTRUCT DEBUG] Error loading novel instruction: {str(e)}')'''

        if self.require_map:
            top_down_map = maps.colorize_draw_agent_and_fit_to_height(info["top_down_map_vlnce"], rgb.shape[0])
            output_im = np.concatenate((rgb, top_down_map), axis=1)

        if len(self.pending_action_list) != 0:
            temp_action = self.pending_action_list.pop(0)
            
            if self.require_map:
                img = self.addtext(output_im, observations["instruction"]["text"], "Pending action: {}".format(temp_action))
                self.topdown_map_list.append(img)
            
            return {"action": temp_action}

        navigation_qs = self.promt_template.format(observations["instruction"]["text"])
        navigation = self.predict_inference(navigation_qs)
        
        if self.require_map:
            img = self.addtext(output_im, observations["instruction"]["text"], navigation)
            self.topdown_map_list.append(img)

        action_index, num = self.extract_result(navigation[:-1])

        if action_index == 0:
            self.pending_action_list.append(0)
        elif action_index == 1:
            for _ in range(min(3, int(num/25))):
                self.pending_action_list.append(1)
        elif action_index == 2:
            for _ in range(min(3,int(num/30))):
                self.pending_action_list.append(2)
        elif action_index == 3:
            for _ in range(min(3,int(num/30))):
                self.pending_action_list.append(3)
        
        if action_index is None or len(self.pending_action_list)==0:
            self.pending_action_list.append(random.randint(1, 3))

        return {"action": self.pending_action_list.pop(0)}  

    def predict_inference_with_conv(self, conv):
        """Generate response using the provided conversation object with support for JUDGE role."""
        # Ensure we're using the right conversation mode
        if self.conv_mode == "vicuna_v2" and len(conv.roles) < 3:
            raise ValueError("vicuna_v2 mode requires USER, ASSISTANT, and JUDGE roles")

        # Format the prompt with image tokens
        if self.model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + conv.get_prompt().replace('<image>', '')
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + conv.get_prompt().replace('<image>', '')

        # Define special tokens
        VIDEO_START_SPECIAL_TOKEN = "<video_special>"
        VIDEO_END_SPECIAL_TOKEN = "</video_special>"
        IMAGE_START_TOKEN = "<image_special>"
        IMAGE_END_TOKEN = "</image_special>"
        NAVIGATION_SPECIAL_TOKEN = "[Navigation]"
        IAMGE_SEPARATOR = "<image_sep>"

        # Prepare special tokens
        image_start_special_token = self.tokenizer(IMAGE_START_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        image_end_special_token = self.tokenizer(IMAGE_END_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        video_start_special_token = self.tokenizer(VIDEO_START_SPECIAL_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        video_end_special_token = self.tokenizer(VIDEO_END_SPECIAL_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        navigation_special_token = self.tokenizer(NAVIGATION_SPECIAL_TOKEN, return_tensors="pt").input_ids[0][1:].cuda()
        image_seperator = self.tokenizer(IAMGE_SEPARATOR, return_tensors="pt").input_ids[0][1:].cuda()

        # Tokenize the prompt
        token_prompt = tokenizer_image_token(qs, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').cuda()
        
        # Process special tokens
        indices_to_replace = torch.where(token_prompt == -200)[0]
        new_list = []
        while indices_to_replace.numel() > 0:
            idx = indices_to_replace[0]
            new_list.append(token_prompt[:idx])
            new_list.append(video_start_special_token)
            new_list.append(image_seperator)
            new_list.append(token_prompt[idx:idx + 1])
            new_list.append(video_end_special_token)
            new_list.append(image_start_special_token)
            new_list.append(image_end_special_token)
            new_list.append(navigation_special_token)
            token_prompt = token_prompt[idx + 1:]
            indices_to_replace = torch.where(token_prompt == -200)[0]
        if token_prompt.numel() > 0:
            new_list.append(token_prompt)
        
        input_ids = torch.cat(new_list, dim=0).unsqueeze(0)

        # Set up stopping criteria
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        # Process images
        imgs = self.process_images_2(self.rgb_list)
        if imgs is None or not imgs:
            raise ValueError("No images available for processing")

        # MODIFICATION: Apply history blackout
        imgs = self._apply_history_blackout(imgs)
        # END MODIFICATION

        # Generate response
        cur_prompt = qs.replace(DEFAULT_IMAGE_TOKEN, '').replace('\n', '')
        with torch.inference_mode():
            try:
                self.model.update_prompt([[cur_prompt]])
                output_ids = self.model.generate(
                    input_ids,
                    images=imgs,
                    do_sample=True,
                    temperature=0.2,
                    max_new_tokens=1024,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria])
            except Exception as e:
                print(f"Error during generation: {str(e)}")
                raise

        # Process output
        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
        
        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        outputs = outputs.strip()
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)]
        outputs = outputs.strip()

        return outputs
        
# RGB corruption
def agg_eval_agent(config, split_id, dataset, model_path, result_path, manipulation_type, intensity, num_runs: int = 3) -> None:
    """Run evaluation multiple times and aggregate metrics per episode.
    
    Args:
        config: Configuration object
        split_id: Dataset split identifier
        dataset: Dataset to evaluate
        model_path: Path to model weights
        result_path: Path to save results
        num_runs: Number of evaluation runs (default: 3)
    """
    # Create aggregation directory
    agg_path = os.path.join(result_path, "aggregated")
    os.makedirs(agg_path, exist_ok=True)
    
    # Dictionary to store all metrics for each episode
    all_metrics = collections.defaultdict(lambda: collections.defaultdict(list))
    target_key = {"distance_to_goal", "success", "spl", "path_length", "oracle_success"}
    
    print(f"\nRunning {num_runs} evaluations...")
    for run in range(num_runs):
        print(f"\nStarting run {run + 1}/{num_runs}")
        
        # Initialize counter for total judge changes this run
        total_judge_changes = 0
        # Create run-specific result path and required directories
        run_path = os.path.join(result_path, f"run_{run + 1}")
        os.makedirs(os.path.join(run_path, "log"), exist_ok=True)
        
        # Initialize environment and agent for this run
        env = Env(config.TASK_CONFIG, dataset)
        agent = NaVid_Agent(model_path, run_path)
        # Run evaluation for this iteration
        num_episodes = len(dataset.episodes)
        for _ in trange(num_episodes, desc=f"{config.EVAL.IDENTIFICATION}-{split_id}-run{run + 1}"):
            obs = env.reset()
            episode_id = env.current_episode.episode_id
            iter_step = 0
            agent.reset(current_run=run)
            
            continuse_rotation_count = 0
            last_dtg = 999
            judge_count = 0
            while not env.episode_over:
                info = env.get_metrics()
                if info["distance_to_goal"] != last_dtg:
                    last_dtg = info["distance_to_goal"]
                    continuse_rotation_count = 0
                else:
                    continuse_rotation_count += 1
                
                #action, judge_count = agent.act_judge(obs, info, episode_id, judge_count)
                action=agent.act_image_manipulation(obs, info, episode_id, manipulation_type, intensity)
                if continuse_rotation_count > config.EVAL.EARLY_STOP_ROTATION or iter_step > config.EVAL.EARLY_STOP_STEPS:
                    action = {"action": 0}
                
                iter_step += 1
                obs = env.step(action)
            
            # Collect metrics for this episode
            info = env.get_metrics()
            metrics = {k: info[k] for k in target_key if k in info}
            metrics["id"] = episode_id
            metrics["run_number"] = run + 1
            
            # Store metrics for aggregation
            for key in target_key:
                if key in metrics:
                    all_metrics[episode_id][key].append(metrics[key])
            # Print episode-level judge-change summary
            #print(f"Episode {episode_id}: total_steps={iter_step}, judge_changes={judge_count}")
            total_judge_changes += judge_count
            # Save individual run results
            with open(os.path.join(run_path, "log", f"stats_{episode_id}.json"), "w") as f:
                json.dump(metrics, f, indent=4)
        
        # After all episodes in this run, print run-level judge-change summary
        #print(f"[JUDGE SUMMARY] Run {run + 1}: total judge changes = {total_judge_changes}")
        # Clean up resources after each run
        env.close()
        del agent
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Calculate aggregated metrics
    aggregated_results = {}
    overall_metrics = {k: [] for k in target_key}
    
    for episode_id, metrics in all_metrics.items():
        episode_means = {}
        episode_stds = {}
        
        for key in target_key:
            if key in metrics and metrics[key]:
                values = metrics[key]
                mean_value = float(np.mean(values))
                std_value = float(np.std(values))
                episode_means[key] = mean_value
                episode_stds[key] = std_value
                overall_metrics[key].append(mean_value)
        
        aggregated_results[episode_id] = {
            "mean": episode_means,
            "std": episode_stds,
            "num_runs": len(next(iter(metrics.values())))
        }
    
    # Calculate overall metrics
    overall_summary = {
        key: {
            "mean": float(np.mean(values)),
            "std": float(np.std(values))
        }
        for key, values in overall_metrics.items()
        if values
    }
    
    # Save aggregated results
    timestamp = int(time.time())
    results = {
        "config": {
            "num_runs": num_runs,
            "split_id": split_id,
            "timestamp": timestamp
        },
        "episode_metrics": aggregated_results,
        "overall_summary": overall_summary
    }

    manipulation_file_suffix=""

    if manipulation_type==0:
        manipulation_file_suffix="defocus_blur"
    elif manipulation_type==1:
        manipulation_file_suffix="color_jitter"
    elif manipulation_type==2:
        manipulation_file_suffix="spatter"
    elif manipulation_type==3:
        manipulation_file_suffix="speckle_noise"
    elif manipulation_type==4:
        manipulation_file_suffix="low_lighting"
    elif manipulation_type==5:
        manipulation_file_suffix="narrow_horizontal_fov"
    elif manipulation_type==6:
        manipulation_file_suffix="motion_blur"
    elif manipulation_type==7:
        manipulation_file_suffix="random_affine_transform"
    elif manipulation_type==8:
        manipulation_file_suffix="flare_effect"
    elif manipulation_type==9:
        manipulation_file_suffix="low_lighting_gradient"
    elif manipulation_type==10:
        manipulation_file_suffix="foreign_object"
    elif manipulation_type==11:
        manipulation_file_suffix="overexposure"
    elif manipulation_type==12:
        manipulation_file_suffix="blackout"

    results_file = os.path.join(agg_path, f"aggregated_metrics_{split_id}_{num_runs}_runs_{manipulation_file_suffix}_{timestamp}.json")
    with open(results_file, "w") as f:
        json.dump(results, f, indent=4)
    
    # Print summary
    print(f"\nAggregated results saved to: {results_file}")
    print("\nOverall metrics summary:")
    for key, stats in overall_summary.items():
        print(f"{key:15s}: {stats['mean']:.4f} ± {stats['std']:.4f}")         

# No corruption
def agg_eval_agent_normal(config, split_id, dataset, model_path, result_path, num_runs: int = 3,suffix="normal") -> None:
    """Run evaluation multiple times and aggregate metrics per episode.
    
    Args:
        config: Configuration object
        split_id: Dataset split identifier
        dataset: Dataset to evaluate
        model_path: Path to model weights
        result_path: Path to save results
        num_runs: Number of evaluation runs (default: 3)
    """
    # Create aggregation directory
    agg_path = os.path.join(result_path, "aggregated")
    os.makedirs(agg_path, exist_ok=True)
    
    # Dictionary to store all metrics for each episode
    all_metrics = collections.defaultdict(lambda: collections.defaultdict(list))
    target_key = {"distance_to_goal", "success", "spl", "path_length", "oracle_success"}
    
    print(f"\nRunning {num_runs} evaluations...")
    for run in range(num_runs):
        print(f"\nStarting run {run + 1}/{num_runs}")
        
        # Initialize counter for total judge changes this run
        total_judge_changes = 0
        # Create run-specific result path and required directories
        run_path = os.path.join(result_path, f"run_{run + 1}")
        os.makedirs(os.path.join(run_path, "log"), exist_ok=True)
        
        # Initialize environment and agent for this run
        env = Env(config.TASK_CONFIG, dataset)
        agent = NaVid_Agent(model_path, run_path)
        print(agent.conv_mode)
        # Run evaluation for this iteration
        num_episodes = len(dataset.episodes)
        for _ in trange(num_episodes, desc=f"{config.EVAL.IDENTIFICATION}-{split_id}-run{run + 1}"):
            obs = env.reset()
            episode_id = env.current_episode.episode_id
            iter_step = 0
            agent.reset(current_run=run)
            
            continuse_rotation_count = 0
            last_dtg = 999
            judge_count = 0
            while not env.episode_over:
                info = env.get_metrics()
                if info["distance_to_goal"] != last_dtg:
                    last_dtg = info["distance_to_goal"]
                    continuse_rotation_count = 0
                else:
                    continuse_rotation_count += 1
                
                #action, judge_count = agent.act_judge(obs, info, episode_id, judge_count)
                #if suffix == 'normal':
                #    action=agent.act(obs, info, episode_id)
                #elif suffix == 'capitalize':
                #    action=agent.act_instruct(obs, info, episode_id,type='capitalize')
                #elif suffix == 'mask':
                #if suffix == 'mal_0':
                #    action=agent.act_instruct(obs, info, episode_id,type='mal_0')
                #elif suffix == 'mal_1':
                #    action=agent.act_instruct(obs, info, episode_id,type='mal_1')
                #elif suffix == 'mal_2':
                #    action=agent.act_instruct(obs, info, episode_id,type='mal_2')
                #elif suffix == 'mal_7':
                #    action=agent.act_instruct(obs, info, episode_id,type='mal_7')
                #elif suffix == 'mal_8':
                #    action=agent.act_instruct(obs, info, episode_id,type='mal_8')
                action=agent.act(obs, info, episode_id)
                if continuse_rotation_count > config.EVAL.EARLY_STOP_ROTATION or iter_step > config.EVAL.EARLY_STOP_STEPS:
                    action = {"action": 0}
                
                iter_step += 1
                obs = env.step(action)
            
            # Collect metrics for this episode
            info = env.get_metrics()
            metrics = {k: info[k] for k in target_key if k in info}
            metrics["id"] = episode_id
            metrics["run_number"] = run + 1
            
            # Store metrics for aggregation
            for key in target_key:
                if key in metrics:
                    all_metrics[episode_id][key].append(metrics[key])
            # Print episode-level judge-change summary
            #print(f"Episode {episode_id}: total_steps={iter_step}, judge_changes={judge_count}")
            total_judge_changes += judge_count
            # Save individual run results
            with open(os.path.join(run_path, "log", f"stats_{episode_id}.json"), "w") as f:
                json.dump(metrics, f, indent=4)
        
        # After all episodes in this run, print run-level judge-change summary
        #print(f"[JUDGE SUMMARY] Run {run + 1}: total judge changes = {total_judge_changes}")
        # Clean up resources after each run
        env.close()
        del agent
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Calculate aggregated metrics
    aggregated_results = {}
    overall_metrics = {k: [] for k in target_key}
    
    for episode_id, metrics in all_metrics.items():
        episode_means = {}
        episode_stds = {}
        
        for key in target_key:
            if key in metrics and metrics[key]:
                values = metrics[key]
                mean_value = float(np.mean(values))
                std_value = float(np.std(values))
                episode_means[key] = mean_value
                episode_stds[key] = std_value
                overall_metrics[key].append(mean_value)
        
        aggregated_results[episode_id] = {
            "mean": episode_means,
            "std": episode_stds,
            "num_runs": len(next(iter(metrics.values())))
        }
    
    # Calculate overall metrics
    overall_summary = {
        key: {
            "mean": float(np.mean(values)),
            "std": float(np.std(values))
        }
        for key, values in overall_metrics.items()
        if values
    }
    
    # Save aggregated results
    timestamp = int(time.time())
    results = {
        "config": {
            "num_runs": num_runs,
            "split_id": split_id,
            "timestamp": timestamp
        },
        "episode_metrics": aggregated_results,
        "overall_summary": overall_summary
    }

    results_file = os.path.join(agg_path, f"aggregated_metrics_{split_id}_{num_runs}_runs_{suffix}_{timestamp}.json")
    with open(results_file, "w") as f:
        json.dump(results, f, indent=4)
    
    # Print summary
    print(f"\nAggregated results saved to: {results_file}")
    print("\nOverall metrics summary:")
    for key, stats in overall_summary.items():
        print(f"{key:15s}: {stats['mean']:.4f} ± {stats['std']:.4f}")        

  

def run_episode_55_visualizations(config, original_dataset, model_path, result_path):
    """
    Runs episode "55" with each image manipulation applied and saves a GIF of RGB views.
    """
    # 1. Find episode "55"
    episode_55_obj = None
    for ep in original_dataset.episodes:
        if ep.episode_id == "55":
            episode_55_obj = ep
            break
    
    if episode_55_obj is None:
        print("Error: Episode '55' not found in the provided dataset.")
        return

    # 2. Create a dataset with only episode "55"
    episode_55_dataset = Dataset() 
    episode_55_dataset.episodes = [episode_55_obj]
    if hasattr(original_dataset, 'config') and original_dataset.config is not None:
         episode_55_dataset.config = original_dataset.config

    # 3. Define video output directory
    video_output_dir = os.path.join(result_path, "video")
    os.makedirs(video_output_dir, exist_ok=True)

    # 4. Loop through manipulations
    for manip_type, manip_name in VISUALIZATION_MANIPULATION_NAME_MAP.items():
        print(f"\nRunning episode 55 with manipulation: {manip_name} (type {manip_type}) - RGB only")

        env = None
        agent_instance = None
        try:
            # 4.1 Initialize Env & Agent
            task_config = config.TASK_CONFIG.clone() if hasattr(config.TASK_CONFIG, 'clone') else config.TASK_CONFIG
            env = Env(config=task_config, dataset=episode_55_dataset)
            
            agent_instance = NaVid_Agent(model_path=model_path, result_path=result_path) 
            # We will directly use agent_instance.rgb_list, so require_map can be conceptually false for GIF generation
            # The agent.topdown_map_list will not be used for these GIFs.

            # 4.2 Run Episode
            obs = env.reset()
            current_episode_id = env.current_episode.episode_id 
            
            agent_instance.reset() # Resets rgb_list among other things

            iter_step = 0
            continuse_rotation_count = 0
            initial_info = env.get_metrics()
            last_dtg = initial_info.get("distance_to_goal", float('inf'))

            while not env.episode_over:
                info = env.get_metrics()
                
                current_dtg = info.get("distance_to_goal", float('inf'))
                if current_dtg != last_dtg:
                    last_dtg = current_dtg
                    continuse_rotation_count = 0
                else:
                    continuse_rotation_count +=1
                
                # act_image_manipulation will append the (potentially manipulated) obs["rgb"] to agent_instance.rgb_list
                action = agent_instance.act(
                    observations=obs, 
                    info=info, 
                    episode_id=current_episode_id, 
                    #manipulation_type=manip_type, 
                    #intensity=1.0
                )
                
                if continuse_rotation_count > config.EVAL.EARLY_STOP_ROTATION or iter_step > config.EVAL.EARLY_STOP_STEPS:
                    action = {"action": 0}

                iter_step += 1
                obs = env.step(action)
            
            # 4.3 Save Video using agent_instance.rgb_list
            if agent_instance.rgb_list:
                gif_filename = f"{current_episode_id}_normal_rgb_only.gif" # Added suffix for clarity
                output_gif_path = os.path.join(video_output_dir, gif_filename)
                imageio.mimsave(output_gif_path, agent_instance.rgb_list) # Use rgb_list directly
                print(f"  SUCCESS: Saved RGB video to {output_gif_path}")
            else:
                print(f"  WARNING: No RGB frames recorded for episode {current_episode_id}, manipulation {manip_name}.")

        except Exception as e:
            print(f"  ERROR running episode {current_episode_id} with manipulation {manip_name}: {e}")
            import traceback
            traceback.print_exc()
        finally:
            if env is not None:
                env.close()
            if agent_instance is not None:
                del agent_instance 
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    print("\nFinished generating RGB-only visualizations for episode 55.")

def run_episode_55_topdown_map(config, original_dataset, model_path, result_path):
    """
    Runs episode "55" and saves a GIF of top-down map views.
    """
    # 1. Find episode "55"
    episode_55_obj = None
    for ep in original_dataset.episodes:
        if ep.episode_id == "55":
            episode_55_obj = ep
            break
    
    if episode_55_obj is None:
        print("Error: Episode '55' not found in the provided dataset.")
        return

    # 2. Create a dataset with only episode "55"
    episode_55_dataset = Dataset() 
    episode_55_dataset.episodes = [episode_55_obj]
    if hasattr(original_dataset, 'config') and original_dataset.config is not None:
         episode_55_dataset.config = original_dataset.config

    # 3. Define video output directory
    video_output_dir = os.path.join(result_path, "video")
    os.makedirs(video_output_dir, exist_ok=True)

    print("\nRunning episode 55 with top-down map visualization")

    env = None
    agent_instance = None
    try:
        # 4.1 Initialize Env & Agent
        task_config = config.TASK_CONFIG.clone() if hasattr(config.TASK_CONFIG, 'clone') else config.TASK_CONFIG
        env = Env(config=task_config, dataset=episode_55_dataset)
        
        # Set require_map to True to ensure we get top-down map visualization
        agent_instance = NaVid_Agent(model_path=model_path, result_path=result_path, require_map=True) 

        # 4.2 Run Episode
        obs = env.reset()
        current_episode_id = env.current_episode.episode_id 
        
        agent_instance.reset() # Resets topdown_map_list among other things

        iter_step = 0
        continuse_rotation_count = 0
        initial_info = env.get_metrics()
        last_dtg = initial_info.get("distance_to_goal", float('inf'))

        while not env.episode_over:
            info = env.get_metrics()
            
            current_dtg = info.get("distance_to_goal", float('inf'))
            if current_dtg != last_dtg:
                last_dtg = current_dtg
                continuse_rotation_count = 0
            else:
                continuse_rotation_count +=1
            
            action = agent_instance.act(
                observations=obs, 
                info=info, 
                episode_id=current_episode_id
            )
            
            if continuse_rotation_count > config.EVAL.EARLY_STOP_ROTATION or iter_step > config.EVAL.EARLY_STOP_STEPS:
                action = {"action": 0}

            iter_step += 1
            obs = env.step(action)
        
        # 4.3 Save Video using agent_instance.topdown_map_list
        if agent_instance.topdown_map_list:
            gif_filename = f"{current_episode_id}_topdown_map.gif"
            output_gif_path = os.path.join(video_output_dir, gif_filename)
            imageio.mimsave(output_gif_path, agent_instance.topdown_map_list)
            print(f"  SUCCESS: Saved top-down map video to {output_gif_path}")
        else:
            print(f"  WARNING: No top-down map frames recorded for episode {current_episode_id}.")

    except Exception as e:
        print(f"  ERROR running episode {current_episode_id} with top-down map visualization: {e}")
        import traceback
        traceback.print_exc()
    finally:
        if env is not None:
            env.close()
        if agent_instance is not None:
            del agent_instance 
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print("\nFinished generating top-down map visualization for episode 55.")