import matplotlib.pyplot as plt
import numpy as np
import os
import json
import traceback
import time
import pandas as pd

from .data import get_data, save_data

import openai
openai.api_key = os.getenv("OPENAI_API_KEY")


from .utils import (
    map_to_list,
    extract_list,
    extract_dict,
    list_of_lists_to_string,
    find_character_position,
    extract_slash
    )

from .gym_agent import CustomEnv, LLMAgent

from .fixers import pad_rows_to_max_length

from vllm import LLM, SamplingParams

def calculate_path_length(actions):
    # Initialize the agent's starting position
    x, y = 0, 0
    
    # Dictionary to map actions to their coordinate changes
    move_delta = {
        'move_up': (0, 1),
        'move_down': (0, -1),
        'move_left': (-1, 0),
        'move_right': (1, 0)
    }
    
    # Keep track of each position the agent moves to
    path_positions = [(x, y)]  # Start with the initial position
    path = []
    # Process each action
    for action in actions:
        if action in move_delta:
            dx, dy = move_delta[action]
            x += dx
            y += dy
            path_positions.append((x, y))  # Append new position after the move
            path.append({"x":dx,"y":dy})
    # Calculate the path length by summing the distances between consecutive positions
    path_length = 0
    for i in range(1, len(path_positions)):
        prev_x, prev_y = path_positions[i - 1]
        curr_x, curr_y = path_positions[i]
        path_length += abs(curr_x - prev_x) + abs(curr_y - prev_y)
    
    return path_length, path

def benchmark(model,
            total_episodes,
            world_map_fixed,
            world_map_fixed_with_chars,
            tileset_used_dict,
            walkable_tiles_list,
            object_tiles_list,
            objective_tile_dict,
            astar_path_length,
            client,
            method="cot"):

    
    print("Generating Actions...")
    except_done = False
    whole_exception = 0
    
    frames = [] 
    episodes = 0
    
    if not hasattr(benchmark, "_reflexion_memory"):
        benchmark._reflexion_memory = {}
    try:
        while not except_done:

            folder_path = "/"


            env = CustomEnv(walkable_tiles_list, world_map_fixed, world_map_fixed_with_chars, object_tiles_list, "#")
            
            agent = LLMAgent()
            state = env.reset()

            reward_feedback = "This is your first objective"
            reward_design = {
                "Each action you take will deduct following reward so that you take minimum amount of actions to complete objective. For example you take 10 actions then you will recieve -10 rewards": -1,
                "You are 8 tiles away from objective thus objective is incomplete": -100,
                "You are 5 to 8 tiles away from objective thus objective is incomplete": -50,
                "You are 3 to 5 tiles away from objective": +25,
                "You are 1 to 3 tiles away from objective": +50,
                "You are 1 tile away or your are on the objective tile from objective": +100,
                "You have completed the objective": +200,
            }
            
            done = False
            orig_world_map_fixed_with_chars = world_map_fixed_with_chars
            while not done:
                total_actions = {}
                prev_reward = 0
                all_rewards = 0
                all_llm_path_length = 0
                all_wrong_action_generated = 0
                all_total_actions_taken = 0
                all_generation_errors = 0
                all_total_achieved_objectives = 0
                all_total_1tilewindow_achieved_objectives = 0
                all_total_5tilewindow_achieved_objectives = 0
                total_llm_paths = []
                protagonist_position = find_character_position(orig_world_map_fixed_with_chars, "@")
                reward_this_objective = {}
                for i in range(len(objective_tile_dict)):
                    total_actions[list(objective_tile_dict.keys())[i]] = []
                    reward_this_objective[list(objective_tile_dict.keys())[i]] = []
                    for j in range(total_episodes):
                        
                        print("\n")
                        print(f"OBJECTIVE: {list(objective_tile_dict.keys())[i]}")
                        print(f"EPISODE: {j+1}")
                        print("\n")
                        reward = 0
                        llm_path_length = 0
                        total_actions_taken = 0
                        wrong_action_generated = 0
                        generation_error = 0
                        total_achieved_objectives = 0
                        total_1tilewindow_achieved_objectives = 0
                        total_5tilewindow_achieved_objectives = 0
                        llm_paths = []
                    
                        action_system = f"You are a great planner in a 2D game. You plan actions for the protagonist of the game to achieve all objects. You are given objectives, tiles and the position of tiles to achieve the objectives. You have the following options as actions: 'move_up', 'move_down', 'move_right', and 'move_left'. Generate a sequence of actions that will achieve the objective. Only return the sequence of actions from the options."
                        
                        if j > 0:
                            if (distance_from_objective[0] == 0 and distance_from_objective[1] == 0):
                                reward = prev_reward
                                break
                            if i ==0:
                                action_prompt = f"Let's say you are given a 2D tile map of a 2D game:\n{world_map_fixed_with_chars}\n The tile map was created using the following tile to character mapping which has information about all the tiles:\n{tileset_used_dict}\n You are also provided with a set of objectives:\n{objective_tile_dict}\nwalkable tiles:\n{walkable_tiles_list}\n and interactive object tiles:\{object_tiles_list}\n. The character '@' is the protagonist of the story and you are controlling it. The current position of protagonist is {prev_protagonist_position}. The rewards will be given as follows:\n{reward_design}\n{reward_feedback}. You are also given information about your previous try for this objective. You generated the following sequence of actions:\n{total_actions[list(objective_tile_dict.keys())[i]]}\n These actions took protagonist from coordinates {prev_protagonist_position} to {protagonist_position} which was {distance_from_objective} distance away from objective (the objective is at the tile and the position {list(objective_tile_dict.values())[i]}). This previous try gave you {reward_this_objective[list(objective_tile_dict.keys())[i]]} rewards. Taking this information into your context, create a sequence of actions for the agent to complete the objective which is to reach the tile, at the tile and the position: {list(objective_tile_dict.values())[i]}. Strictly return a Python dictionary with the entry as 'action'. Only return Python dictionary with one entry like 'action': ['move_up', 'move_down'.. etc.]. Do not return it in a Python response."
                                protagonist_position = prev_protagonist_position
                            else:                                
                                action_prompt = f"Let's say you are given a 2D tile map of a 2D game:\n{world_map_fixed_with_chars}\n The tile map was created using the following tile to character mapping which has information about all the tiles:\n{tileset_used_dict}\n You are also provided with a set of objectives:\n{objective_tile_dict}\nwalkable tiles:\n{walkable_tiles_list}\n and interactive object tiles:\{object_tiles_list}\n. The character '@' is the protagonist of the story and you are controlling it. The current position of protagonist is {prev_protagonist_position}. The rewards will be given as follows:\n{reward_design}\n{reward_feedback}. You are also given information about your previous try for all objectives. You generated the following sequence of actions:\n{total_actions[list(objective_tile_dict.keys())[i]]}\n These actions took protagonist from coordinates {prev_protagonist_position} to {protagonist_position} which was {distance_from_objective} distance away from objective (the objective is at the tile and the position {list(objective_tile_dict.values())[i]}). This previous try gave you {reward_this_objective[list(objective_tile_dict.keys())[i]]} rewards. Taking this information into your context, create a sequence of actions for the agent to complete the objective which is to reach the tile, at the tile and the position: {list(objective_tile_dict.values())[i]}. Strictly return a Python dictionary with the entry as 'action'. Only return Python dictionary with one entry like 'action': ['move_up', 'move_down'.. etc.]. Do not return it in a Python response."
                                protagonist_position = prev_protagonist_position
                        else: 
                            if i ==0:
                                action_prompt = f"Let's say you are given a 2D tile map of a 2D game:\n{world_map_fixed_with_chars}\n The tile map was created using the following tile to character mapping which has information about all the tiles:\n{tileset_used_dict}\n You are also provided with a set of objectives:\n{objective_tile_dict}\nwalkable tiles:\n{walkable_tiles_list}\n and interactive object tiles:\{object_tiles_list}\n. The character '@' is the protagonist of the story and you are controlling it. The current position of protagonist is {protagonist_position}. The rewards will be given as follows:\n{reward_design}\n{reward_feedback}. Taking this information into your context, create a sequence of actions for the agent to complete the objective which is to reach the tile, at the tile and the position: {list(objective_tile_dict.values())[i]}. Strictly return a Python dictionary with the entry as 'action'. Only return Python dictionary with one entry like 'action': ['move_up', 'move_down'.. etc.]. Do not return it in a Python response."
                            else:
                                action_prompt = f"Let's say you are given a 2D tile map of a 2D game:\n{world_map_fixed_with_chars}\n The tile map was created using the following tile to character mapping which has information about all the tiles:\n{tileset_used_dict}\n You are also provided with a set of objectives:\n{objective_tile_dict}\nwalkable tiles:\n{walkable_tiles_list}\n and interactive object tiles:\{object_tiles_list}\n. The character '@' is the protagonist of the story and you are controlling it. The current position of protagonist is {protagonist_position}. The rewards will be given as follows:\n{reward_design}\n{reward_feedback}. Taking this information into your context, create a sequence of actions for the agent to complete the objective which is to reach the tile, at the tile and the position: {list(objective_tile_dict.values())[i]}. Strictly return a Python dictionary with the entry as 'action'. Only return Python dictionary with one entry like 'action': ['move_up', 'move_down'.. etc.]. Do not return it in a Python response."
                    
                        action_exception = 0
                        action_done = False
                        while not action_done:
                            try:
                                if method == "Hierachical":
                                    # 获取当前位置和目标位置的坐标
                                    current_pos = protagonist_position  # (row, col)
                                    objective_pos = extract_list(str(list(objective_tile_dict.values())[i]))  # [tile, x, y]
                                    
                                    # 获取地图尺寸
                                    map_lines = world_map_fixed_with_chars.strip().split('\n')
                                    map_height = len(map_lines)
                                    map_width = len(map_lines[0]) if map_lines else 0
                                    
                                    # 从地图中提取障碍物位置 - 基于walkable_tiles判断
                                    obstacle_positions = []
                                    for x, line in enumerate(map_lines):
                                        for y, char in enumerate(line):
                                            # 如果不在walkable_tiles中且不是主角，则为障碍物
                                            if char not in walkable_tiles_list and char != '@':
                                                obstacle_positions.append(f"({x},{y})")
                                    
                                    obstacles_str = " and ".join(obstacle_positions) if obstacle_positions else ""
                                    
                                    # 构建task_description - current_pos是(row, col)，需要转换为(x, y)
                                    if obstacles_str:
                                        task_description = f"You are in a {map_height} by {map_width} world. There are obstacles that you have to avoid at: {obstacles_str}. Go from ({current_pos[0]},{current_pos[1]}) to ({objective_pos[1]},{objective_pos[2]})"
                                    else:
                                        task_description = f"You are in a {map_height} by {map_width} world. Go from ({current_pos[0]},{current_pos[1]}) to ({objective_pos[1]},{objective_pos[2]})"
                                    
                                    # High level planning prompt - 严格按照txt格式
                                    high_level_plan_prompt = f""" 
                                    You are a path planner. Your task is to plan a feasible, obstacle-free path for a single agent in a given environment. The environment is a {map_height}x{map_width} grid from (0,0) to ({map_width-1},{map_height-1}). The path should be a series of key anchor points.

                                    Please find two feasible intermediate anchor points for Agent 1 and provide their coordinates.

                                    Identify exactly two key turning points along this path. These are the points where the path changes direction to navigate around obstacles or to head toward the goal.

                                    Anchor Point Selection Strategy: The path does not have to be the single shortest path. Instead, explore multiple valid paths and select a random but feasible one. The anchor points should be key turning points or located at important positions around obstacles.

                                    Please strictly follow the following format to output the list of anchor points:

                                    <trajectory for agent1> = [(start_x, start_y), (anchor_1_x, anchor_1_y), (anchor_2_x, anchor_2_y), (end_x, end_y)]

                                    Directly output the result.

                                    Here are some examples that show different valid paths:

                                    Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (2,1). Go from (0,1) to (3,4)
                                    <trajectory for agent1> = [(0,1), (1,2), (2,3), (3,4)]

                                    Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (1,5) and (1,2). Go from (5,4) to (0,5)
                                    <trajectory for agent1> = [(5,4), (3,3), (1,1), (0,5)]

                                    Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (0,3), (2,5) and (5,2). Go from (4,2) to (0,5)
                                    <trajectory for agent1> = [(4,2), (2,4), (1,5), (0,5)]

                                    Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (3,5), (4,2), (3,3) and (0,4). Go from (1,5) to (3,1)
                                    <trajectory for agent1> = [(1,5), (2,4), (4,1), (3,1)]

                                    Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (2,5), (5,2), (0,4), (1,4) and (0,1). Go from (4,2) to (1,2) 
                                    <trajectory for agent1> = [(4,2), (3,3), (2,3), (1,2)]

                                    Here is the question:

                                    Task: {task_description}
                                    <trajectory for agent1> ="""

                                    # 生成high-level trajectory
                                    sampling_params = SamplingParams(max_tokens=1024, temperature=0, n=1)
                                    waypoint_messages = [
                                        {"role": "user", "content": high_level_plan_prompt}
                                    ]
                                    waypoint_response = client.chat(messages=waypoint_messages, sampling_params=sampling_params)
                                    
                                    # 解析trajectory格式: 直接找[(x1,y1), (x2,y2), ...]结构
                                    response_text = waypoint_response[0].outputs[0].text.strip()
                                    trajectory_waypoints = []
                                    
                                    # 提取trajectory中的坐标点 - 直接找[]结构
                                    import re
                                    trajectory_match = re.search(r'\[(.*?)\]', response_text)
                                    if trajectory_match:
                                        coords_str = trajectory_match.group(1)
                                        # 解析坐标对
                                        coord_matches = re.findall(r'\((\d+),(\d+)\)', coords_str)
                                        trajectory_waypoints = [(int(x), int(y)) for x, y in coord_matches]
                                    
                                    # 构建路径序列
                                    if len(trajectory_waypoints) >= 4:  # 起点 + 2个中间点 + 终点
                                        path_sequence = trajectory_waypoints
                                    elif len(trajectory_waypoints) >= 2:  # 至少有起点和终点
                                        path_sequence = trajectory_waypoints
                                    else:
                                        # 如果解析失败，直接从当前位置到目标位置
                                        path_sequence = [
                                            (current_pos[0], current_pos[1]),  # 当前位置(x,y) - current_pos是(row,col)需要转换
                                            (objective_pos[1], objective_pos[2])  # 目标位置(x,y)
                                        ]
                                    
                                    # 为每个路径段生成最优动作序列
                                    all_actions = []
                                    
                                    for seg_idx in range(len(path_sequence) - 1):
                                        start_point = path_sequence[seg_idx]
                                        end_point = path_sequence[seg_idx + 1]
                                        
                                        # 构建该路径段的task_description
                                        if obstacles_str:
                                            segment_task_description = f"You are in a {map_height} by {map_width} world. There are obstacles that you have to avoid at: {obstacles_str}. Go from ({start_point[0]},{start_point[1]}) to ({end_point[0]},{end_point[1]})"
                                        else:
                                            segment_task_description = f"You are in a {map_height} by {map_width} world. Go from ({start_point[0]},{start_point[1]}) to ({end_point[0]},{end_point[1]})"
                                        
                                        # Low level prompt - 严格按照txt格式
                                        low_level_prompt = f"""Provide a sequence of actions to navigate a world to reach a goal similarly to the examples below. (0,0) is located in the upper-left corner and (M, N) lies in the M row and N column.
                                        Here are some examples:
                                        ###
                                        Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (2,1). Go from (0,1) to (3,4)
                                        Actions: right right right down down down 
                                        ###
                                        Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (1,5) and (1,2). Go from (5,4) to (0,5)
                                        Actions: up up up up up right 
                                        ###
                                        Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (0,3), (2,5) and (5,2). Go from (4,2) to (0,5)
                                        Actions: up up up right right up right 
                                        ###
                                        Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (3,5), (4,2), (3,3) and (0,4). Go from (1,5) to (3,1)
                                        Actions: left left left left down down 
                                        ###
                                        Task: You are in a 10 by 10 world. There are obstacles that you have to avoid at: (2,5), (5,2), (0,4), (1,4) and (0,1). Go from (4,2) to (1,2)
                                        Actions: up up up 

                                        Task: {segment_task_description}
                                        Actions: """

                                        # 生成该路径段的动作
                                        sampling_params = SamplingParams(max_tokens=1024, temperature=0, n=1)
                                        segment_messages = [
                                            {"role": "user", "content": low_level_prompt}
                                        ]
                                        segment_response = client.chat(messages=segment_messages, sampling_params=sampling_params)
                                        
                                        # 解析动作序列输出 - response直接是动作序列: right right down down
                                        response_text = segment_response[0].outputs[0].text.strip()
                                        
                                        # 直接分割response获取动作序列
                                        segment_actions = response_text.split()
                                        
                                        # 转换动作格式: right -> move_right, up -> move_up等
                                        action_mapping = {
                                            'right': 'move_right',
                                            'left': 'move_left', 
                                            'up': 'move_up',
                                            'down': 'move_down'
                                        }
                                        
                                        valid_segment_actions = []
                                        for action in segment_actions:
                                            if action in action_mapping:
                                                valid_segment_actions.append(action_mapping[action])
                                            elif action in ['move_up', 'move_down', 'move_left', 'move_right']:
                                                valid_segment_actions.append(action)
                                        
                                        all_actions.extend(valid_segment_actions)
                                    
                                    # 构建最终的动作字典
                                    if all_actions:
                                        action_dict = {"action": all_actions}
                                        print(f"Hierarchical planning: {len(path_sequence)} waypoints, {len(all_actions)} total actions")
                                    else:
                                        # 如果没有生成任何动作，使用简单的备用方案
                                        # action_dict = {"action": ["move_up"]}
                                        raise ValueError("Hierarchical: empty or invalid action list after filtering.")                                    
                                if method == "baseline":
                                    action_prompt = action_prompt + "\n\n Keep your answer short. You must output a formated list of actions to ensure that answer is in the format of ['move_up','move_down','move_left','move_right']."
                                    sampling_params = SamplingParams(max_tokens=2048,temperature=0,n=1)
                                    messages=[
                                        # Set an optional system message. This sets the behavior of the
                                        # assistant and can be used to provide specific instructions for
                                        # how it should behave throughout the conversation.
                                        {
                                            "role": "system",
                                            "content": action_system
                                        },
                                        # Set a user message for the assistant to respond to.
                                        {
                                            "role": "user",
                                            "content": action_prompt,
                                        }
                                    ]
                                    response = client.chat(messages=messages, sampling_params=sampling_params)
                                    action_dict = extract_dict(response[0].outputs[0].text)
                                
                                print("Action: \n")
                                print(action_dict["action"])
                                print("\n")
                                action_done = True

                            except Exception as e:
                                generation_error += 1
                                tb = traceback.format_exc()
                                print(f"Exception raised: {e}\n {tb}")
                                action_exception += 1
                                reward -= 1
                                reward_feedback = ""
                                reward_feedback = "Your previous objectives reward feedback is: "
                                reward_feedback += f"You are given a regret(negative reward) of -1 points an error that was a cause of wrong generation."
                                if action_exception >= 10:
                                    action_done = True
                                    reward -= astar_path_length
                                    reward -= 100
                                    reward_feedback += f"You were very far from the objective tile so you were also given a regret(negative reward) of -100 points and objective was INCOMPLETE"
                                    
                                continue
                        if action_exception >= 10:
                            action_done = True
                            continue
                        total_actions[list(objective_tile_dict.keys())[i]].append(action_dict["action"])
                        
                        _llm_path_length, llm_path = calculate_path_length(action_dict["action"])
                        llm_paths.append(llm_path)
                        llm_path_length += _llm_path_length
                        total_actions_taken += len(action_dict["action"])
                        
                        try:
                            for action_str in action_dict["action"]:
                                action = agent.action(action_str)
                                state, _r, done, _ = env.step(action)
                                
                                #frame = env.render(mode='rgb_array')  # Capture the frame
                                #frames.append(frame)  # Append the frame
                                time.sleep(0.01)
                        except Exception as e:
                            wrong_action_generated += 1
                            generation_error += 1
                            reward -= 1
                            reward_feedback = ""
                            reward_feedback = "Your previous objectives reward feedback is: "
                            reward_feedback += f"You are given a regret(negative reward) of -0.5 points an error that was a cause of wrong generation."
                            tb = traceback.format_exc()
                            print(f"Exception raised: {e}\n {tb}")
                    
                    
                        current_state = list_of_lists_to_string(state)
                        
                        print(current_state)
                        print("\n")

                        world_map_fixed_with_chars = current_state

                        
                        for k, value in enumerate(objective_tile_dict.values()):
                            if k == i:
                                objective_pos = extract_list(str(value))
                        prev_protagonist_position = protagonist_position
                        protagonist_position = find_character_position(world_map_fixed_with_chars, "@")
                        print("\n")
                        print(f"protagonist_position: {protagonist_position}")
                        print(f"objective_position: [{objective_pos[1]},{objective_pos[2]}]")
                        

                        distance_from_objective = (abs(objective_pos[1] - protagonist_position[0]), abs(objective_pos[2] - protagonist_position[1]))
                        print(f"distance from current objective: [{distance_from_objective[0]}, {distance_from_objective[1]}]") 
                        print("\n")

                        reward_feedback = ""
                        reward_feedback = "Your previous objectives reward feedback is: "
                        reward -= len(action_dict["action"])
                        reward_feedback += f"You took {len(action_dict['action'])} actions for the objective so you were given a regret(negative reward) of -{len(action_dict['action'])} points. "
                        if (distance_from_objective[0] > 8 or distance_from_objective[1] > 8):
                            reward -= 100
                            reward_feedback += f"You were very far from the objective tile so you were given a regret(negative reward) of -100 points and objective was INCOMPLETE"
                        if (distance_from_objective[0] > 5 and distance_from_objective[0] < 8) or (distance_from_objective[1] > 5 and distance_from_objective[1] < 8):
                            reward -= 50
                            reward_feedback += f"You were far from the objective tile so you were given a regret(negative reward) of -50 points and objective was INCOMPLETE"
                        if (distance_from_objective[0] <= 5 and distance_from_objective[0] > 3) and (distance_from_objective[1] <= 5 and distance_from_objective[1] > 3):
                            reward += 25
                            reward_feedback += f"You were close to the objective tile so you were given a reward of 25 points"
                        if (distance_from_objective[0] < 3 and distance_from_objective[0] > 1) and (distance_from_objective[1] < 3 and distance_from_objective[1] > 1):
                            reward += 50
                            reward_feedback += f"You were very close to the objective tile so you were given a reward of 50 points"

                        if (distance_from_objective[0] <= 1) and (distance_from_objective[1] > 1 and distance_from_objective[1] <= 5):
                            total_5tilewindow_achieved_objectives += 1
                            reward += 50
                            reward_feedback += f"You were very close to the objective tile so you were given a reward of 50 points"
                        if (distance_from_objective[1] <= 1) and (distance_from_objective[0] > 1 and distance_from_objective[0] <= 5):
                            total_5tilewindow_achieved_objectives += 1
                            reward += 50
                            reward_feedback += f"You were very close to the objective tile so you were given a reward of 50 points"

                        if (distance_from_objective[0] <= 1 and distance_from_objective[1] <= 1):# or check_discriptions['choices'][0]['message']['content'] == "Complete":
                            
                            if (distance_from_objective[0] == 0 and distance_from_objective[1] == 0):# and check_discriptions['choices'][0]['message']['content'] == "Complete":
                                reward += 200
                                total_achieved_objectives += 1
                                reward_feedback += f"You were by the objective tile and you COMPLETED the objective so you were given a reward of 200 points"
                            else:
                                total_1tilewindow_achieved_objectives += 1
                                reward += 100
                                reward_feedback += f"You were by the objective tile so you were given a reward of 100 points"
                        
                        reward_this_objective[list(objective_tile_dict.keys())[i]].append(reward)
                        
                        print("\n")
                        print(f"EPISODE REWARDS uptill now: {reward}")
                        print("\n")
                        prev_reward = reward
                    all_rewards += reward
                    all_llm_path_length += llm_path_length
                    all_wrong_action_generated += wrong_action_generated
                    all_generation_errors += generation_error
                    all_total_achieved_objectives += total_achieved_objectives
                    all_total_1tilewindow_achieved_objectives += total_1tilewindow_achieved_objectives
                    all_total_5tilewindow_achieved_objectives += total_5tilewindow_achieved_objectives
                    all_total_actions_taken += total_actions_taken
                    total_llm_paths.append(llm_paths)
                    print("\n")
                    print(f"All REWARDS uptill now: {all_rewards}")
                    print("\n")

                print("\n")
                print(f"TOTAL REWARD for EPISODE: {all_rewards}")
                episodes += 1
                
                done = True

            #with imageio.get_writer(f'./outputs/benchmark/{EXPERIMENT}_{generation}.mp4', fps=10) as video:
            #    for frame in frames:
            #        video.append_data(frame)

            except_done = True
        
    except Exception as e:
        tb = traceback.format_exc()
        print(f"Exception raised: {e}\n {tb}")
        whole_exception += 1
        except_done = True
        pass
    

    total_possible_rewards = (len(objective_tile_dict)*200) - astar_path_length
    total_normalised_rewards = all_rewards/(total_possible_rewards)
    
    return all_rewards, total_possible_rewards, total_normalised_rewards, all_llm_path_length, all_total_actions_taken, total_llm_paths, all_wrong_action_generated, all_generation_errors, all_total_achieved_objectives, all_total_1tilewindow_achieved_objectives, all_total_5tilewindow_achieved_objectives


def run(model, total_episodes: int = 1, experiment_name: str = "exp_001", save_dir: str = "./outputs/" , method="cot"):

    if "gpt" or "o1" in model:

        import openai
        openai.api_key = ""#os.getenv("OPENAI_API_KEY")
        client = ""

    elif "claude" in model:

        import anthropic

        client = anthropic.Anthropic(
        os.environ.get("ANTHROPIC_API_KEY")
        
        )

    elif ("llama3" or "gemma" or "mixtral") in model:
        
        from groq import Groq

        client = Groq(
            api_key=os.environ.get("GROQ_API_KEY")
        )
        
    print("Using vllm")
    client = LLM(model=model, trust_remote_code=True, tensor_parallel_size=4, gpu_memory_utilization=0.95)



    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    bench_data = get_data()
    model_name = model.split('/')[-1]

    result_file_name = f"{model_name}_{experiment_name}"

    for i, data in enumerate(bench_data[2:]):
        try:
            print(f"EVALUATING ROW {i}")
            print(f"WITH EXPERIMENT ID {data['experiment_id']}")

            str_world = data["environment"]
            print(f"World:\n{str_world}\n")
            char_tile_mapping = data["tile_mapping"]
            walkables = data["walkable_tiles"]
            interactive_object_tiles = data["interactive_object_tiles"]
            
            objective_tile_dict = data["objectives"]
            astar_path_length = data["path_length"]
            str_world = pad_rows_to_max_length(str_world)
            grid_world = map_to_list(str_world)
            world_width = max(len(row) for row in grid_world)
            world_height = len(grid_world)
            print(f"Game dimensions: {world_width} x {world_height}")
            rewards, total_possible_rewards, normalised_rewards, llm_path_length, total_actions_taken, llm_path, wrong_action_generated, \
            generation_errors, total_achieved_objectives, \
            total_1tilewindow_achieved_objectives, total_5tilewindow_achieved_objectives = benchmark(model,total_episodes,str_world,str_world,char_tile_mapping,
                                                                                                    walkables,interactive_object_tiles,objective_tile_dict,astar_path_length, client, method)

            model_results = {}

            model_results = {"experiment_id": data["experiment_id"],
                            "agent_rewards": rewards,
                            "total_possible_rewards": total_possible_rewards,
                            "normalised_agent_rewards":normalised_rewards,
                            "llm_path_length": llm_path_length,
                            "total_actions_taken": total_actions_taken,
                            "wrong_action_generated": wrong_action_generated,
                            "generation_errors": generation_errors,
                            "total_achieved_objectives": total_achieved_objectives,
                            "total_1tilewindow_achieved_objectives": total_1tilewindow_achieved_objectives,
                            "total_5tilewindow_achieved_objectives": total_5tilewindow_achieved_objectives,
                            "Path": llm_path
                            }
            
            print(f"RESULT FOR EXPERIMENT ID {data['experiment_id']}:")
            for key, val in model_results.items():
                print(f"{key} : {val}")
                    
            
            save_data(benchmark_data=model_results,file_name=result_file_name,save_dir=save_dir,save_json=True)
            
            
        except Exception as e:
            tb = traceback.format_exc()
            print(f"Exception raised: {e}\n {tb}")
    
    return result_file_name