import re, ast
import base64
import requests
import json, os
import pandas as pd
from pathlib import Path
import sys
from typing import Optional, List, Dict, Any, Union

# set parent directory to address relative imports
directory = Path(os.getcwd()).absolute()
sys.path.append(
    str(directory)
)  # note: no ".parent" addition is needed for python (.py) files

from AI2Thor.env_new import AI2ThorEnv
from AI2Thor.base_env import convert_dict_to_string
from AI2Thor.object_actions import get_closest_feasible_action, get_closest_object_id

with open(os.path.expanduser("~") + "/openai_key.json") as json_file:
    key = json.load(json_file)
    api_key = key["my_openai_api_key"]
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}


class MAPLE:
    def __init__(
        self, env: AI2ThorEnv, planner_prompt, corrector_prompt, obs_summariser=None
    ):
        self.env = env
        self.planner_prompt = planner_prompt
        self.corrector_prompt = corrector_prompt
        self.obs_summariser = obs_summariser
        self.attempted_subtasks = []
        self.obs_pre_string = "The robots can see "
        self.obs_summariser = obs_summariser

    def get_planner_user_prompt(self):
        """
        The planner requires the observations of the agents, the task, and the set of completed subtasks.
        """
        user_prompt = f"Task: {self.env.input_dict['Task']}\n"

        user_prompt += f"Observations: {self._get_obs_string()} \n"
        user_prompt += f"Attempted Subtasks: {self._get_attempted_subtasks_string()}\n"

        return user_prompt

    def _get_obs_string(self):
        """
        Combine the list of all observations for all agents into a single string
        """
        obs = self.obs_pre_string
        obs_list_combined = []
        for agent_id in range(self.env.num_agents):
            agent_name = self.env.agent_names[agent_id]
            obs_list = self.env.input_dict[agent_name + "'s observation"]
            # convert the string to a list with ast.literal_eval
            # "['PaperTowelRoll_1', 'Fridge_1', 'CoffeeMachine_1', 'Mug_1', 'Cabinet_1', 'CounterTop_3', 'Window_1', 'Toaster_1']"
            obs_list_combined += ast.literal_eval(obs_list)
        # just get the unique observations
        obs_list_combined = list(set(obs_list_combined))
        if self.obs_summariser:
            obs_str = self.obs_summariser(obs_list_combined)
        else:
            obs_str = ", ".join(obs_list_combined)
        # lowercase the observations
        # obs_str = obs_str.lower()
        obs += obs_str
        return obs

    def _get_attempted_subtasks_string(self):
        """
        Get the list of attempted subtasks
        """
        if len(self.attempted_subtasks) == 0:
            return "{No subtasks attempted}\n"
        else:
            subtasks_str = "{"
            # attempted_subtasks = [{Subtask 1: {...}, Subtask 2: {...}}]
            for subtask_i, subtask in enumerate(self.attempted_subtasks):
                subtasks_str += f"Subtask {subtask_i + 1}: {json.dumps(subtask)}"
                if subtask_i != len(self.attempted_subtasks) - 1:
                    subtasks_str += ", "
            subtasks_str += "}\n"
            return subtasks_str

    def get_plan(self) -> Dict:
        subtasks = self.get_subtasks()
        executable_subtasks = self.get_executable_subtasks(subtasks)
        agent_assignments = self.random_subtask_assigner(subtasks)
        return agent_assignments

    def get_subtasks(self) -> Union[str, Dict[str, Any]]:
        user_prompt = self.get_planner_user_prompt()
        response = self.get_gpt_response(self.planner_prompt, user_prompt)
        subtasks = self._parse_subtasks(response)
        return subtasks

    def prepare_prompt(self, module_name: str, addendum: str):
        """module_name: str
        choose from planner, verifier, action
        """
        # Choose the appropriate prompt based on what module is being called
        # user_prompt = convert_dict_to_string(env.input_dict)
        if module_name == "planner":
            system_prompt = self.planner_prompt
            user_prompt = self.get_planner_user_prompt(self.env)
        elif module_name == "verifier":
            system_prompt = self.verifier_prompt
            user_prompt = convert_dict_to_string(self.env.get_verifier_llm_input())
        return system_prompt, user_prompt

    def map2executable_action(self, action):
        """
        Get closest valid action

        The action output from the model is in natural language.
        This function will find the env feasible action which has the closest embedding
        to the natural language action output from the model.
        Eg: "pick up the apple" -> "PickupObject(Apple_1)"
        """
        ex_act = get_closest_feasible_action(action)
        action_type = ex_act.split("(")[0]
        if action_type in [
            "PickupObject",
            "PutObject",
            "OpenObject",
            "CloseObject",
            "ToggleObjectOn",
            "ToggleObjectOff",
            "SliceObject",
            "NavigateTo",
        ]:
            ex_act = get_closest_object_id(ex_act, self.env.object_dict)
        return ex_act

    def get_executable_subtasks(self, subtasks: Dict) -> List[List[str]]:
        """
        Get the list of subtasks that are executable
        Convert the free-form subtasks to a list of environment executable subtasks
        Example: {'Subtask 1': {'Step 1': 'Navigate to the bread',
        'Step 2': 'Pick up the bread',
        'Step 3': 'Navigate to the fridge',
        'Step 4': 'Open the fridge',
        'Step 5': 'Put the bread in the fridge',
        'Step 6': 'Close the fridge',
        'Step 7': 'Complete'}}
        is converted to:
        [[NavigateTo(Bread), PickUp(Bread), NavigateTo(Fridge), Open(Fridge), PutInFridge(Bread), Close(Fridge)]]
        """
        all_executable_subtasks = []
        for subtask_key in subtasks.keys():
            subtask = subtasks[subtask_key]
            executable_subtask = []
            for step_key in subtask.keys():
                action_step = subtask[step_key]
                action = get_closest_feasible_action(action_step)
                executable_subtask.append(action)
            all_executable_subtasks.append(executable_subtask)
        return all_executable_subtasks

    def expected_num_steps(self, action_list):
        """Given the action list, calculate the expected number of steps"""
        pass

    def random_subtask_assigner(self, subtasks):
        """
        Random planner to assign which agent does which step
        Assign each subtask to an agent
        If number of subtasks > number of agents, then assign multiple subtasks to an agent
        if number of subtasks < number of agents, then assign the leftover agents to "Explore"
        """
        num_subtasks_to_assign = len(subtasks.keys())
        num_agents = self.env.num_agents
        # iterate through all subtasks and assign them to agents and rotate through the agents
        agent_assignments = {agent_id: [] for agent_id in range(num_agents)}
        agent_id = 0
        executable_subtasks = self.get_executable_subtasks(subtasks)
        for subtask_i, subtask in enumerate(executable_subtasks):
            agent_assignments[agent_id].append(subtask)
            agent_id = (agent_id + 1) % num_agents
        return agent_assignments

    def prepare_payload(self, system_prompt, user_prompt):
        """# payload consists of
        * the system prompt (which is constant)
        * the user prompt (which changes based on the state)
        This is then sent to the openai api to get the response (action or plan or verification of the plan)
        """
        payload = {
            "model": "gpt-4",
            "messages": [
                {
                    "role": "system",
                    "content": [
                        {"type": "text", "text": system_prompt},
                    ],
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": user_prompt},
                    ],
                },
            ],
            "max_tokens": 2000,
            "temperature": 0.7,
        }
        return payload

    def _parse_subtasks(self, response):
        """Convert the GPT response to a dictionary of subtasks"""
        response_dict = response.json()
        # convert the string to a dict
        # json_acceptable_string = response_dict["choices"][0]["message"]["content"].replace("'", "\"").replace("\n", "").replace("json", "").replace("`", "")
        output = response_dict["choices"][0]["message"]["content"]
        if "Task Completed" in output:
            return "Task Completed"
        json_match = re.search(r"```json(.*?)```", output, re.DOTALL)
        python_match = re.search(r"```python(.*?)```", output, re.DOTALL)
        tilde_match = re.search(r"```(.*?)```", output, re.DOTALL)
        if python_match:
            json_data = json_match.group(1)
        elif json_match:
            json_data = python_match.group(1)
        elif tilde_match:
            json_data = tilde_match.group(1)
        else:
            json_data = output
        # print(json_data)
        out_dict = json.loads(json_data)
        return out_dict

    def get_gpt_response(self, system_prompt, user_prompt):
        payload = self.prepare_payload(system_prompt, user_prompt)
        response = requests.post(
            "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
        )
        return response


def get_gpt_response(planner_prompt, user_prompt):
    payload = self.prepare_payload(planner_prompt, user_prompt)
    response = requests.post(
        "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
    )
    response = response.json()
    output = response["choices"][0]["message"]["content"]
    return output


user_prompt = """Task: Put all the groceries in the fridge.
Observations: The robots can see countertop, fridge, bread, apple, lettuce, microwave, table, tomato, spoon and bowl.
Attempted Subtasks:
{No subtasks attempted}"""
