import copy
import os
import sys
import json
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from torch.nn import functional as F
import torch
import random
import numpy as np
import imageio
from MAPF_Envs.mapf_gym import MAPFEnv
from Dec_Transformer.decision_transformer import Embeddings_for_Atari, GPT_for_DT, CfgNode
from Pipeline.prompting import Prompter
from Pipeline.llm import LLM
from Pipeline.initial_prompt import initial_prompt
import argparse
import matplotlib.pyplot as plt


def make_name(n, s, d, id, extension, dirname, extra=""):
    """
    Creates & returns the path from num_agents, size, density
    """
    if extra == "":
        return dirname + '/' + "{}_agents_{}_size_{}_density_id_{}{}".format(n, s, d, id, extension)
    else:
        return dirname + '/' + "{}_agents_{}_size_{}_density_id_{}_{}{}".format(n, s, d, id, extra, extension)


class DT(object):
    '''
    This class encapsulates all functionalities for running inferences of 
    Decision Transformer on MAPF environments

    Input

    llm: Object of LLM() class
    grid_size: grid_size of current MAPF Gym env
    config: Decision transformer configuration
    model_path: DT weights save path
    device: GPU num or CPU
    USE_LLM: In case of inferencing DT+LLM
    use_llm_at: use DT until and start using LLM after __
    render_gif: Render and store GIFs
    '''

    def __init__(self, llm, grid_size, config, model_path, device, USE_LLM, use_llm_at, 
                    target_rtg=20, render_gif=False, stop_llm_at={}):
        self.grid_size = grid_size
        self.llm = llm
        
        self.emb = Embeddings_for_Atari(config).to(device)
        self.gpt = GPT_for_DT(config).to(device)

        self.emb.load_state_dict(torch.load(f"{model_path}/model_embedding"))
        self.emb.eval()
        self.gpt.load_state_dict(torch.load(f"{model_path}/model_gpt"))
        self.gpt.eval()

        self.embs = {}
        self.gpts = {}
        self.USE_LLM = USE_LLM

        self.device = device
        
        self.actions = {}
        self.rtgs = {}
        self.all_states = {}
        self.time_start = {}
        self.reward_sum = {}
        self.render_gif = render_gif
        self.target_rtg = target_rtg
        self.max_timesteps = config.max_timestep
        self.step_size = config.step_size
        self.stop_llm_at = stop_llm_at

        self.use_llm_at = use_llm_at
        self.goal_changes = {}

        # For making the same dynamic changes
        try:
            self.goal_changes_read = json.load(open('results/133k_dynamic_4_80/goal_changes.txt'))
        except:
            self.goal_changes_read = None

    def create_copies(self):
        """
        Creates multiple copies of Decision transformer for multiple agents.
        """
        for agent in range(1, self.num_agents + 1):
            self.gpts[agent] = copy.deepcopy(self.gpt)
            self.gpts[agent].eval()

        for agent in range(1, self.num_agents + 1):
            self.embs[agent] = copy.deepcopy(self.emb)
            self.embs[agent].eval()

    def set_env(self, gym):
        self.num_agents = gym.num_agents
        self.size = gym.SIZE
        self.env = gym
        self.agent_paths = {}
        self.agent_path_lengths = np.ones((self.num_agents))
        self.agents_reached = [False]*self.num_agents


    @torch.no_grad()
    def step_all_parallel(self, show_distance_inside_fov, prompter, is_it_first_step=None, last_step_dynamic=False, suggest_actions_with_llm=False):
        """
        Advances the state of the environment by a single step across all agents.
        
        Function has 2 parts
        1. Getting actions for all agents.
        2. Executing actions in the environment. These updates for agents happen sequentially, 
                            i.e., one by one, because action of one agent changes the environment.
        
        """
        max_timesteps = self.max_timesteps
        step_size = self.step_size
        action_dict = {}
        for agent in range(1, self.num_agents + 1):
            if is_it_first_step:
                new_o = self.env.get_goal_in_fov_format(agent, show_distance_inside_fov=show_distance_inside_fov)
                self.rtgs[agent] = [self.target_rtg]

                state = np.array(new_o[0])
                state_before_unsqueeze = copy.deepcopy(torch.tensor(np.array([state]), dtype=torch.float32).to(self.device))
                state = torch.tensor(np.array([state]), dtype=torch.float32).to(self.device).unsqueeze(0)

                # pick up next action
                rtgs_emb, states_emb, actions_emb = self.embs[agent](
                    rtgs=torch.tensor(self.rtgs[agent], dtype=torch.float32).to(self.device).unsqueeze(0).unsqueeze(-1).type(torch.float32),
                    states=state.unsqueeze(0),
                    actions=None)

                logits = self.gpts[agent](
                    rtgs_emb=rtgs_emb,
                    states_emb=states_emb,
                    actions_emb=actions_emb,
                    timesteps=torch.zeros((1, 1), dtype=torch.int64).to(self.device))

                logits = logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)
                sampled_action = torch.multinomial(probs, num_samples=1)
                
                action_temp2 = sampled_action.item()
                action_temp1 = sampled_action.cpu().numpy()[0,-1]

                if action_temp1 != action_temp2:
                    sys.exit("these actions are not equal!!!")

                action_dict[agent] = sampled_action.item()

                self.actions[agent] = []
                self.actions[agent] += [action_dict[agent]]
                self.time_start[agent] = 0
                self.all_states[agent] = state_before_unsqueeze
                self.reward_sum[agent] = 0

            else: 
                for t in range(self.time_start[agent], len(self.all_states[agent]), step_size):
                    batch_states = self.all_states[agent].unsqueeze(0)[:, t:t+step_size]
                    batch_actions = torch.tensor(self.actions[agent]+[0], dtype=torch.long).to(self.device).unsqueeze(0)[:, t:t+step_size]
                    batch_rtgs = torch.tensor(self.rtgs[agent], dtype=torch.long).to(self.device).unsqueeze(0).unsqueeze(-1)[:, t:t+step_size].type(torch.float32)
                self.time_start[agent] = t

                rtgs_emb, states_emb, actions_emb = self.embs[agent](
                    rtgs=batch_rtgs,
                    states=batch_states,
                    actions=batch_actions)
                
                logits = self.gpts[agent](
                    rtgs_emb=rtgs_emb,
                    states_emb=states_emb,
                    actions_emb=actions_emb,
                    timesteps=torch.arange(min(t, max_timesteps), min(t, max_timesteps) + len(states_emb)).unsqueeze(0).to(self.device))
                
                logits = logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)
                sampled_action = torch.multinomial(probs, num_samples=1)

                action_temp2 = sampled_action.item()
                action_temp1 = sampled_action.cpu().numpy()[0,-1]

                if action_temp1 != action_temp2:
                    sys.exit("these actions are not equal!!!")

                action_dict[agent] = sampled_action.item()
                self.actions[agent] += [action_dict[agent]]

        if self.USE_LLM and suggest_actions_with_llm:
            try:
                new_action_dict = prompter.suggest_action(self.env, self.llm, action_dict, dynamic=last_step_dynamic)
                # print(f"LLM actions: {new_action_dict}")
                if len(new_action_dict) == len(action_dict):
                    # print(f"Taking LLM's actions")
                    action_dict = new_action_dict
            except:
                print(f"Error while parsing GPT output")

        for agent in range(1, self.num_agents + 1):
            state, reward, done, on_goal, valid_action = self.env._step((agent, action_dict[agent]))
            if not on_goal:
                if not self.agents_reached[agent-1]:
                    self.agent_path_lengths[agent-1] += 1
            else:
                self.agents_reached[agent-1] = True

        for agent in range(1, self.num_agents + 1):
            
            new_o = self.env.get_goal_in_fov_format(agent, show_distance_inside_fov=show_distance_inside_fov)
            state = np.array(new_o[0])
            state = torch.tensor(np.array([state]), dtype=torch.float32).to(self.device)
            self.all_states[agent] = torch.cat([self.all_states[agent], state], dim=0)

            reward = self.env.individual_rewards[agent-1]
            self.rtgs[agent] += [self.rtgs[agent][-1] - reward]
            self.reward_sum[agent] += reward
                    
        return action_dict


    def find_path(self, show_distance_inside_fov, goal_change_agent_ratio=1/4, obs_change_ratio=1/4, 
                    gif_path='./trajectories/',id=1, max_step=192, grid_size=10, env_name=None, exp_name=None):
        """
        Run a full environment to completion, or until max_step steps
        """
        
        # Determines when to introduce dynamic change in the environment.
        size_to_step_dynamic = {10: 5, 20: 10, 40: 30, 80: 50}

        solution = []
        step = 0
        action_arrs = []
        prompter =  Prompter(self.env.num_agents, grid_size)
        last_step_dynamic = False
        suggest_actions_with_llm = False

        try:
            goals_to_change = self.goal_changes_read[env_name]
        except:
            goals_to_change = None
        
        while ((not self.env._complete()) and step < max_step):
            timestep = {}
            # print(f"step: {step}")
            if step > self.use_llm_at[grid_size] and step <= self.stop_llm_at[grid_size]:
                suggest_actions_with_llm = True
            else:
                suggest_actions_with_llm = False
            for agent in range(1, self.env.num_agents + 1):
                timestep[agent] = self.env.world.getPos(agent)
            solution.append(timestep)
            if step == 0:
                action_arr_one_episode = self.step_all_parallel(is_it_first_step=True, show_distance_inside_fov=show_distance_inside_fov, prompter=prompter)
            else:
                action_arr_one_episode = self.step_all_parallel(is_it_first_step=False, show_distance_inside_fov=show_distance_inside_fov, 
                                                                                prompter=prompter, last_step_dynamic=last_step_dynamic, 
                                                                                suggest_actions_with_llm=suggest_actions_with_llm)

            action_arrs.append(action_arr_one_episode)
            last_step_dynamic = False

            step += 1
            
            if step == size_to_step_dynamic[grid_size]:
                if goals_to_change is None:
                    change_goal_for = int(self.env.num_agents * goal_change_agent_ratio)
                    num_changed = 0
                    already_changed_for = []
                    travelling_agents = self.get_travelling_agents()
                    
                    change_goal_for = min(change_goal_for, len(travelling_agents))
                    random.shuffle(travelling_agents)
                    for change_goal_agent_id in travelling_agents[:change_goal_for]:
                        # print(f"Changing goal for agent_id: {change_goal_agent_id} ")
                        self.change_goal(change_goal_agent_id)
                
                else:
                    for agent_to_change, new_goal in goals_to_change.items():
                        agent_to_change = int(agent_to_change)
                        old_x, old_y = self.env.world.agent_goals[agent_to_change-1]  
                        self.env.world.goals[old_x][old_y] = 0
                        new_i, new_j = new_goal[0], new_goal[1]

                        self.env.world.goals[new_i][new_j] = agent_to_change
                        self.env.world.agent_goals[agent_to_change-1] = (new_i, new_j)
                    
                changed_obs = self.change_obstacles(obs_change_ratio)
                last_step_dynamic = True

            # episode_frames.append(episode_frame)

        if step == max_step:
            np.save(f"{gif_path}/bad/ep_{id}.npy", np.array(solution))
            np.save(f"{gif_path}/bad/ep_{id}_actions.npy", np.array(action_arrs))
            
            if self.render_gif:
                create_gifs(solution, gif_path+f"/bad/ep_{id}.gif", 
                                            env_name, self.env, grid_size=grid_size, 
                                            exp_name=exp_name, gif_path=gif_path)
            raise RuntimeError    

        for agent in range(1, self.env.num_agents):
            timestep[agent] = self.env.world.getPos(agent)
        return np.array(solution), np.array(action_arrs)

    def get_travelling_agents(self):
        """
        Travelling agents means those agents who have not found their goals yet.
        We want to change the goals for only those agents who have not found their goals yet.
        """
        travelling_agents = []
        for agent in range(1, self.env.num_agents+1):
            if self.env.world.getPos(agent) == self.env.world.getGoal(agent):
                pass
            else:
                travelling_agents.append(agent)
        
        return travelling_agents

    def change_goal(self, agent_id, debug=False):
        """
        Change goal for one agent with ID 'agent_id'.
        Places the goal randomly, on open spots where there are no agents/obstacles
        """
        for i in range(len(self.env.world.goals)):
            for j in range(len(self.env.world.goals[i])):
                if self.env.world.goals[i][j] == agent_id:
                    if debug:
                        print(f"Old goal for agent: {(i, j)}")
                    self.env.world.goals[i][j] = 0
            
        new_i = random.randint(0, len(self.env.world.goals)-1)
        new_j = random.randint(0, len(self.env.world.goals)-1)
        
        while self.env.world.state[new_i][new_j] != 0 or self.env.world.goals[new_i][new_j] != 0:
            if debug:
                print(f"Invalid random, changing")
            new_i = random.randint(0, len(self.env.world.goals)-1)
            new_j = random.randint(0, len(self.env.world.goals)-1)
        

        self.env.world.goals[new_i][new_j] = agent_id
        self.env.world.agent_goals[agent_id-1] = (new_i, new_j)
        if debug:
            print(f"new goal for {agent_id}: {(new_i, new_j)}")
            print(f"goal changed for one agent: {self.env.world.goals}")
        self.goal_changes[agent_id] = (new_i, new_j)
        # return env
    
    def change_obstacles(self, obs_change_ratio=1/4 , min_obs=3, max_obs=30, debug=False):
        """
        Changes x number of obstacles
        where x = obs_change_ratio * total number of obstacles.
        """        
        if obs_change_ratio == 0:
            return None

        obstacles = []
        for i in range(len(self.env.world.state)):
            for j in range(len(self.env.world.state[i])):
                if self.env.world.state[i][j] == -1:
                    obstacles.append((i,j))
        
        if debug:
            print(f"obstacles: {obstacles}")
        
        random.shuffle(obstacles)
        
        num_obs = int(len(obstacles) * obs_change_ratio)
        
        if debug:
            print(f"shuffled obstacles: {obstacles}")

        for num in range(num_obs):
            new_i = random.randint(0, len(self.env.world.state)-1)
            new_j = random.randint(0, len(self.env.world.state)-1)

            while self.env.world.state[new_i][new_j] != 0 or self.env.world.goals[new_i][new_j] != 0:
                if debug:
                    print(f"Invalid random, changing")
                new_i = random.randint(0, len(self.env.world.state)-1)
                new_j = random.randint(0, len(self.env.world.state)-1)

            old_i, old_j = obstacles[num]
            
            self.env.world.state[old_i][old_j] = 0
            self.env.world.state[new_i][new_j] = -1
            
            if debug:            
                print(f"Old: {(old_i, old_j)}")
                print(f"new: {(new_i, new_j)}")
                print(f"Obstacles changed: {self.env.world.state}")

        return num_obs


def create_gifs(trajectories, gif_filename, env_name, env, grid_size, exp_name, gif_path):
    """
    Sample function for creating visualizations for 8 agents.
    """
    obstacle_map = env.getObstacleMap()
    goals = env.getGoals()

    obstacles = np.argwhere(obstacle_map == 1).tolist()

    if not os.path.exists(f"gif_images/{exp_name}/{env_name}"):
        os.makedirs(f"gif_images/{exp_name}/{env_name}")
    frame = 0

    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'orange']

    size = int(env_name.split('_')[2])
    if size == 10:
        change_frame = 5
    else:
        change_frame = 10

    with imageio.get_writer(gif_filename, mode='I', duration=2) as writer:
        # Iterate over each time step
        for step in trajectories:
            # Create a blank grids
            fig, ax = plt.subplots(figsize=(6, 6))
            ax.set_xlim(-0.5, grid_size-0.5)
            ax.set_ylim(-0.5, grid_size-0.5)
            ax.set_xticks(np.arange(0, grid_size, 1))
            ax.set_yticks(np.arange(0, grid_size, 1))
            ax.grid(True)
            
            if frame == change_frame:
                obstacle_map = env.getObstacleMap()
                obstacles = np.argwhere(obstacle_map == 1).tolist()
                goals = env.getGoals()
 
            # Draw grid cells
            for x in range(grid_size):
                for y in range(grid_size):
                    rect = plt.Rectangle((x - 0.5, y - 0.5), 1, 1, linewidth=1, edgecolor='gray', facecolor='white')
                    ax.add_patch(rect)

            # Plot the obstacles
            for obstacle in obstacles:
                obs_rect = plt.Rectangle((obstacle[1] - 0.5, obstacle[0] - 0.5), 1, 1, linewidth=1, edgecolor='gray', facecolor='gray')
                ax.add_patch(obs_rect)

            # Plot each agent's position inside the grid cell
            for agent_id, pos in step.items():
                ax.plot(pos[1], pos[0], 'o', color=colors[agent_id - 1], markersize=10, label=f'Agent {agent_id}')
                ax.text(pos[1], pos[0], f'{agent_id}', color=colors[agent_id - 1], fontsize=12, ha='center', va='center')

            # Plot the goal positions for each agent as stars
            for agent_id, goal in enumerate(goals, start=1):
                ax.plot(goal[1], goal[0], '*', color=colors[agent_id - 1], markersize=15)

            ax.legend()
            ax.set_title(f'Time Step {frame}')

            plt.gca().invert_yaxis()
            plt.savefig(f'gif_images/{exp_name}/{env_name}/{frame}.png')
            writer.append_data(imageio.v2.imread(f'gif_images/{exp_name}/{env_name}/{frame}.png'))
            plt.close(fig)
            frame += 1


def run_pipeline(args):
    """
    The main function for inference.
    Takes in arguments indicating the environments we want to run DT on.
    Saves the results in the "results/{exp_name}" directory.

    Loads DT model weights
    Iterates through & runs DT on all the settings of each configuration 
        num_agents, grid_size, density.

    Arguments are passed through command line for better control.
    """

    USE_LLM = args.USE_LLM
    exp_name = args.exp_name

    results_path = './results/'

    device = torch.device("cuda")
    environment_path = args.environment_path

    group_reward_at_end = args.group_reward_at_end
    model_name = args.model_name
    show_distance_inside_fov = args.show_distance_inside_fov 

    RENDER_GIF = args.RENDER_GIF

    if USE_LLM:
        llm_key = args.gpt_key
        llm = LLM(key=llm_key)
        # Give initial prompt to the LLM to explain the MAPF problem and solution.
        # result = llm.give_initial_prompt(query=initial_prompt)
    else:
        llm = None
    
    success_rates = {}
    agent_setting = args.agent_setting
    size_setting = args.size_setting
    density_setting = args.density_setting
    iteration_setting = args.iteration_setting
    target_rtg = args.target_rtg

    goal_change_agent_ratio = args.goal_change_agent_ratio
    obs_change_ratio = args.obs_change_ratio
    max_steps_dt_config = args.max_steps_dt_config

    config = CfgNode(max_timesteps=max_steps_dt_config, step_size=args.step_size)    
    
    max_episode_length_dict = {10:255, 20:255, 40:255, 80:256}
    
    run_llm_for = 5
    use_llm_at = {10:6, 20:11, 40:31, 80:51}

    stop_llm_at = {10:6+run_llm_for, 20:11+run_llm_for, 40:31+run_llm_for, 80:51+run_llm_for}

    success_rates = {}
    collision_rates = {}
    makespan = {}
    average_steps = {}
    all_env_goal_changes = {}

    for num_agents in agent_setting:
        print(f"Num_agents: {num_agents}")
        for size in size_setting:
            # print(f"size: {size}")
            if size == 10 and num_agents > 16: continue
            if size == 20 and num_agents > 128: continue
            if size == 40 and num_agents > 512: continue
            max_episode_length = max_episode_length_dict[size]
            for density in density_setting:
                
                setting = f"{num_agents}_{size}_{density}"
                total_collisions_in_setting = 0
                total_collision_rates_in_setting = 0
                makespan_in_setting = 0
                sum_avg_path_lengths = 0
                per_setting_SR = 0
                bad_episodes = 0
                successful_episodes = 0
                total_episodes = 0
                for iteration in range(iteration_setting[0], iteration_setting[1]):
                    gif_path = f"./results/{model_name}_{exp_name}/{num_agents}_{size}_{density}"
                    env_name = f"{num_agents}_agents_{size}_size_{density}_density_id_{iteration}"

                    print(f"env_name: {env_name}")

                    if not os.path.exists(gif_path):
                        os.makedirs(gif_path)
                        os.makedirs(f"{gif_path}/bad")
                        os.makedirs(f"{gif_path}/good")
                    
                    # Loading the environment of a particular setting and making it a GYM env.
                    environment_data_filename = make_name(num_agents, size, density, iteration, ".npy",
                                                              environment_path, extra="environment")
                    world = np.load(environment_data_filename, allow_pickle=True)
                    env = MAPFEnv(num_agents=num_agents, world0=world[0], goals0=world[1],
                                      group_reward_at_end=group_reward_at_end, render_gif=RENDER_GIF)

                    # Initializes object of class DT which encapsulates all the helper functions for inference
                    dt = DT(llm=llm, grid_size=size, config=config, model_path=f"./Pipeline/DT_model/{model_name}", 
                                device=device, USE_LLM=USE_LLM, target_rtg=target_rtg, render_gif=RENDER_GIF,
                                use_llm_at=use_llm_at, stop_llm_at=stop_llm_at)
                    dt.set_env(env)
                    dt.create_copies()
                    try:
                        # Tries to find the paths for agents using DT and LLM
                        path, actions = dt.find_path(show_distance_inside_fov=show_distance_inside_fov, 
                                                                            goal_change_agent_ratio=goal_change_agent_ratio, 
                                                                            obs_change_ratio=obs_change_ratio,
                                                                            gif_path=gif_path, id=iteration, 
                                                                            max_step=max_episode_length, grid_size=size, 
                                                                            env_name=env_name, exp_name=exp_name)
                        all_env_goal_changes[env_name] = dt.goal_changes
                        average_path_length = sum(dt.agent_path_lengths)/num_agents
                        np.save(gif_path + f'/good/ep_{iteration}.npy', path)
                        np.save(gif_path + f'/good/ep_{iteration}_actions.npy', actions)
                        # successful_episodes += 1
                        per_setting_SR += 1
                        if RENDER_GIF:
                            create_gifs(path.tolist(), gif_path+f"/good/ep_{iteration}.gif", 
                                        env_name, env, grid_size=size, exp_name=exp_name, gif_path=gif_path)

                        del dt

                        successful_episodes += 1
                        total_collision_rates_in_setting += (env.collision_agent / len(path) )
                        makespan_in_setting += len(path)
                        sum_avg_path_lengths += average_path_length

                    except RuntimeError:
                        bad_episodes += 1
                        print(f"Failed Runtime error on {setting}_{iteration}")

                    total_episodes += 1
                
                try:    
                    success_rates[setting] = successful_episodes/total_episodes
                    makespan[setting] = makespan_in_setting/successful_episodes
                    average_steps[setting] = sum_avg_path_lengths/successful_episodes
                    collision_rates[setting] = total_collision_rates_in_setting/successful_episodes
                except:
                    continue
                
                f = open(f"{results_path}/{model_name}_{exp_name}/success_rates.txt",'w')
                f.write(json.dumps(success_rates))
                f.close()

                f = open(f"{results_path}/{model_name}_{exp_name}/collision_rates.txt",'w')
                f.write(json.dumps(collision_rates))
                f.close()

                f = open(f"{results_path}/{model_name}_{exp_name}/makespan.txt",'w')
                f.write(json.dumps(makespan))
                f.close()

                f = open(f"{results_path}/{model_name}_{exp_name}/avg_steps.txt",'w')
                f.write(json.dumps(average_steps))
                f.close()
                
                if goal_change_agent_ratio > 0:
                    f = open(f"{results_path}/{model_name}_{exp_name}/goal_changes.txt",'w')
                    f.write(json.dumps(all_env_goal_changes))
                    f.close()


def list_of_ints(arg):
    return list(map(int, arg.split(',')))

def list_of_floats(arg):
    return list(map(float, arg.split(',')))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Runs different kinds of pipelines for MAPF using DT and LLM.')
    
    parser.add_argument('--model_name', default='dt', help="Name of model weights to use")
    parser.add_argument('--exp_name', default='',
                            help="Experiment name for running multiple runs with same settings and models")
    parser.add_argument('--USE_LLM', action="store_true", help="Use LLM if passed")
    parser.add_argument('--gpt_key', help="Pass the GPT API Key")
    parser.add_argument('--environment_path', default="./MAPF_Envs/saved_environments/",
                            help="Saved environments are env descriptions in npy files.")
    parser.add_argument('--group_reward_at_end', action="store_true",
                            help="Used for experimenting different reward functions. This flag gives and additional same +ve reward to all agents if all reach goal")
    parser.add_argument('--show_distance_inside_fov', action="store_true",
                            help="Use for experimenting different ways of state representation. Specifically distance from goal")
    parser.add_argument('--RENDER_GIF', action="store_true", help="Renders & stores GIFs")
    parser.add_argument('--version_1', action="store_true", help="If true, will not pass DT actions to LLM")
    parser.add_argument('--target_rtg', default=20, help="Target Return-to-go for each agent")

    parser.add_argument('--agent_setting', type=list_of_ints, default=[4, 8], help="Run experiments on how many agents. e.g. [4, 8] represents running experiment with 4 agents and 8 agents")
    parser.add_argument('--size_setting', type=list_of_ints, default=[10], help="Denotes the different grid_sizes we want to run experiments on. e.g., [10, 40] represents running experiments on 10x10 and 40x40 grids")
    parser.add_argument('--density_setting', type=list_of_ints, default=[0.1], help="Denotes the different obstacle densities we want to run experiments on. e.g., [0.1, 0.2] represents running experiments on 0.1 & 0.2 densities")
    parser.add_argument('--iteration_setting', type=list_of_ints, default=[91, 93], help="Range of iterations - list of 2 numbers start and end. Each agent_size_density setting has 100 envs, 80 are used for training the DT and 20 (numbered 80 to 100) can be used for testing.")

    # parser.add_argument('--agent_setting', type=list_of_ints, default=[8, 16, 32, 64])
    # parser.add_argument('--size_setting', type=list_of_ints, default=[20, 40, 80])
    # parser.add_argument('--density_setting', type=list_of_floats, default=[0, 0.1, 0.2])
    # parser.add_argument('--iteration_setting', type=list_of_ints, default=[80, 100])
    
    parser.add_argument('--goal_change_agent_ratio', type=float, default=0, help="Dynamic change which changes the goals of #{ratio*num_agents} agents at specified timesteps. e.g., If passed 1/2 and num_agents=8, goals of 4 agents will be change at a specified timestep (Check size_to_step_dynamic variable)")
    parser.add_argument('--obs_change_ratio', type=float, default=0, help="Dynamic change to the obstacles in the environment. Locations of #{ratio*total number of obstacles} obstacles will  be changed at a specified timestep")
    parser.add_argument('--max_steps_dt_config', type=int, default=256, help="Max steps from the Decision transformer training configurations.")
    parser.add_argument('--step_size', default=50, help="Step size for DT config")

    args = parser.parse_args()
    run_pipeline(args)

