import re
import json
from Pipeline.initial_prompt import initial_prompt

class Prompter:
    def __init__(self, num_agents, size):
        self.size = size
        self.num_agents = num_agents
        self.prompt_file_path = "./Pipeline/prompt.txt"
        self.output_file_path = "./Pipeline/GPT_output.txt"

    def suggest_action(self, env, llm, prev_action=None, dynamic=False):
        """
        Takes in env, llm and returns a dictionary of actions which the LLM is suggesting
        
        This is a high level function called in the testing_dt file.
        
        initial_prompt: 
        env: instance of the MAPF gym environment
        llm: instance of the LLM used for assistance
        prev_action: dictionary of actions taken by the backbone Decision Transformer
        dynamic: flag to generate different prompts reflecting the changes in the env
        """

        # Generates a prompt based on the current state
        prompt = self.generate_prompt(initial_prompt, env)
        
        # Gets LLM output for the prompt, a description with an action dictionary suggested by the LLM
        gpt_output = llm.get_gpt_output(input_prompt=prompt)

        # Find and Get the action dictionary suggested by the LLM
        agent_actions = self.parse_actions_v4(gpt_output)

        return self.make_action_dict(agent_actions)


    def generate_prompt(self, initial_prompt, env):
        initial_prompt += f"""

Consider the above question-answer examples and give me the next actions which would lead agents towards their goal positions. Display the actions at the end of the response. Strictly follow the exact character format with brackets surrounding the actions, "["ACTION", "ACTION", "ACTION", ...]".
   	 
We have the same problem, but with {self.num_agents} unique agents positioned on a two-dimensional {self.size}*{self.size} grid environment now.
These are the coordinates for static obstacles: {self.get_coord_grid(env)}
These are the coordinates for agents' current positions in order: {self.dt_to_gpt_coords(tuple(env.getPositions()))}
These are the coordinates for agents' goals in order: {self.dt_to_gpt_coords(tuple(env.getGoals()))}

Considering the above specifications and restrictions, what should be the agents' actions for the next time-step?

Take into consideration that we have to move towards the goal, thus we shouldn't be waiting if we can take an action that gets an agent closer towards its goal. If an action would collide an agent into a static obstacle, prefer actions that aim to move around that obstacle while also making sure those directions wouldn't collide into another obstacle. Give me the set of actions in order for all {self.num_agents} agents.

    	"""
        return initial_prompt

    def convert_action(self, action_dict):
        """
        Given DT actions, convert to GPT actions

        Since the notations for our pipeline and GPT are different, 
        we have to make this conversion
        """
        action_list = []
        for action in action_dict:
            if action == 4:
                action_list.append('UP')
            elif action == 1:
                action_list.append('RIGHT')
            elif action == 2:
                action_list.append('DOWN')
            elif action == 3:
                action_list.append('LEFT')
            else:
                action_list.append('WAIT')
        return action_list
    
    def get_coord_grid(self, env):
        """
        Gets the coordinate grid from environment (DT Frame) and returns a list of coordinates (GPT frame) to be passed as input to GPT
        """
        map = env.getObstacleMap()
        rotated_coordinates = []
        for row in range(self.size):
            for col in range(self.size):
                if map[row][col] == 1:
                    new_row = col
                    new_col = self.size - 1 - row
                    rotated_coordinates.append((new_row, new_col))
        return rotated_coordinates

    def dt_to_gpt_coords(self, coords):
        """
        Converting from DT's coordinate frame to GPT's coordinate frame
        """
        new_coords = []
        for coord in coords:
            x = coord[0]
            y = coord[1]

            new_x = y
            new_y = self.size - 1 - x
            new_coords.append((new_x, new_y))
        return new_coords


    def make_action_dict(self, agent_actions):
        """
        Assuming we are moving in GPT frame returns dictionary of action values corresponding to each agent
        """
        next_pos = {}
        for agent_num, action in enumerate(agent_actions):
            if action == "LEFT":
                next_pos[agent_num+1] = 3
            elif action == "UP":
                next_pos[agent_num+1] = 4
            elif action == "RIGHT":
                next_pos[agent_num+1] = 1
            elif action == "DOWN":
                next_pos[agent_num+1] = 2
            else: # WAIT
                next_pos[agent_num+1] = 0
        return next_pos
    
    
    def parse_actions_v4(self, gpt_response_str):
        """
        Extracts actions from GPT output and returns an array with the actions of each agent in order as strings
        """
        last_optimal_actions = None
        actions_list = []
        for line in gpt_response_str.splitlines():
            # match = re.search(r'\[.*?\]', line)
            pattern = re.compile(r'\[(.*?)\]')
            match = pattern.search(line)

            if match:
                actions_str = match.group(0)
                new_actions_list = [action.strip("""'" []'""") for action in actions_str.split(',')]
                if len(new_actions_list) == self.num_agents:
                        actions_list = new_actions_list
        return actions_list