import base64

import requests
import re
import json
from Safety_Module.helper import Safety_Helper


class SafeMind:
    def __init__(self):
        self.encoded_image = None
        self.task_description = None
        self.plan = None
        self.low_level_action = None
        self.scene = None
        self.image_path = None
        self.encoded_image = None
        self.task_description = None
        self.objects = None
        self.skill_set = None
        self.view = None
        self.safety_helper = Safety_Helper()
        self.reason = None
        self.safe_prompt = None

    def encode_image(self, image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    def update_state(self, scene, image_path, instruction, objects, skill_set_string):
        self.scene = scene
        self.image_path = image_path
        self.encoded_image = self.encode_image(self.image_path)
        self.task_description = instruction
        self.objects = objects
        self.skill_set = skill_set_string
        self.safe_prompt = self.safety_helper.safety_check(self.task_description)
        print("self.safe_prompt:", self.safe_prompt)

    def planner(self):
        success = False
        outdict = None
        count = 0
        while not success:
            try:
                response = self.get_gpt_response(action_or_planner="planner")
                outdict = self.get_action(response)
                print("Planner Output:\n", outdict)
                success = True
            except Exception as e:
                count += 1
                print("failure reason (in try-except loop):", e)
                if count > 3:
                    self.view = None
                    self.plan = ["warning"]
                    return self.plan
                else:
                    pass
        self.view = outdict["view"]
        self.reason = None
        self.plan = outdict["plan"]
        return self.plan

    def replanner(self):
        success = False
        outdict = None
        count = 0
        while not success:
            try:
                response = self.get_gpt_response(action_or_planner="replanner")
                outdict = self.get_action(response)
                print("RePlanner Output:\n", outdict)
                success = True
            except Exception as e:
                count += 1
                print("failure reason (in try-except loop):", e)
                if count > 3:
                    return self.plan
                else:
                    pass
        self.reason = None
        self.plan = outdict["plan"]
        return self.plan

    def actor(self, plan):
        success = False
        self.plan = plan
        count = 0
        while not success:
            try:
                response = self.get_gpt_response(action_or_planner="action")
                outdict = self.get_action(response)
                success = True
                print("*" * 10, "Actor outdict!", "*" * 10)
                print(outdict)
            except:
                count += 1
                if count > 3:
                    self.view = None
                    self.low_level_action = ["warning"]
                    return self.low_level_action
                else:
                    pass

        actions = outdict["Action"]
        self.low_level_action = actions
        self.reason = None
        return actions

    def get_action(self, response):
        response_dict = response.json()
        prompt_tokens = response_dict["usage"]["prompt_tokens"]
        completion_tokens = response_dict["usage"]["completion_tokens"]
        total_tokens = response_dict["usage"]["total_tokens"]
        print("tokens:", prompt_tokens, completion_tokens, total_tokens)

        with open("result/token.txt", "a") as f:
            f.write(str(total_tokens))
            f.write('\n')

        output = response_dict["choices"][0]["message"]["content"]
        json_match = re.search(r"```json(.*?)```", output, re.DOTALL)
        python_match = re.search(r"```python(.*?)```", output, re.DOTALL)
        tilde_match = re.search(r"```(.*?)```", output, re.DOTALL)
        if json_match:
            json_data = json_match.group(1)
        elif python_match:
            json_data = python_match.group(1)
        elif tilde_match:
            json_data = tilde_match.group(1)
        else:
            json_data = output
        out_dict = json.loads(json_data)
        return out_dict

    def get_gpt_response(self, action_or_planner: str):
        payload = self.prepare_payload(action_or_planner)
        api_key = ""
        headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
        response = requests.post(
            "", headers=headers, json=payload
        )
        return response

    def prepare_payload(self, module_name: str):
        system_prompt, user_prompt = self.prepare_prompt(module_name)
        if self.encoded_image == None:
            payload = {
                "model": "gpt-4o",
                "messages": [
                    {
                        "role": "system",
                        "content": [
                            {"type": "text", "text": system_prompt},
                        ],
                    },
                    {
                        "role": "user",
                        "content": [{"type": "text", "text": user_prompt},
                                    ],
                    },
                ],
                "max_tokens": 1000,
                "temperature": 0,
            }
        else:
            image_urls = {"url": f"data:image/jpeg;base64,{self.encoded_image}"}
            # image_urls = [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,123"}} for image in base64_image]
            payload = {
                "model": "gpt-4o",
                "messages": [
                    {
                        "role": "system",
                        "content": [
                            {"type": "text", "text": system_prompt},
                        ],
                    },
                    {
                        "role": "user",
                        "content": [{"type": "text", "text": user_prompt}, {
                            "type": "image_url",
                            "image_url": image_urls}],
                    },
                ],
                "max_tokens": 1000,
                "temperature": 0,
            }
        return payload

    def prepare_prompt(self, module_name: str):
        with open("result/token.txt","a") as f:
            f.write(module_name)
            f.write(':')
        self.PLANNER_PROMPT = f"""You are Planner, an excellent high-level planner who is tasked with helping embodied robot carry out a task in the {self.scene}.
        You will get a description of the task robot is supposed to do. You will get an image of the environment as the observation input. 
        To help you with planning, you will also get a list of objects agent is able to see in the environment. 
        You need to understand the task instruction, analyze the objects relationships in the scene image, and then make a sequence of high-level plans for the robot to complete the task.
        So, along with the image inputs you will get the following information:
        ### INPUT FORMAT ###
        {{Task: description of the task the robots are supposed to do,
        Objects_list: All the objects involved in the task.  Do not use any objects not in the scene.
        Safety_Guidelines: Safety guidelines are something you should consider when planning tasks.  You can continue to perform the task if you think is safe. If you think it has risks, You can arrange your actions to avoid risks or choose to refuse to perform the task if you think the task is unsafe(Then your "plan" should be ["warning"]).
        History: Mistakes made in the last high-level plan.
        }}
        Reason over the robot's task, image inputs, safety requirements, history and then output the following:
        * Reason: The reason for why new subtasks need to be added.
        * Subtasks: A list of subtasks the robots are supposed to take to complete the task. 
        * View: You need to describe everything you see in your field of vision without leaving out any objects.
        The "Plan" should be in a list format where the actions are listed sequentially. Your plan should be clear and specific, don't use ambiguous terms such as "object"
        For example:
        The task instruction is: "Put the bowl in the microwave and heat it up for 5 minutes."
        The suitable plans are:["1. Find the bowl", "2. Pick up the bowl", "3. Find the microwave", "4. Open the microwave door and Place the bowl in the microwave", "5. Close the microwave door", "6. Set the microwave timer for 5 minutes", "7. Start the microwave" ]
        Your output should be in the form of a python dictionary as shown below.
        Example output: {{"reason": "Since the robot need to Open the Cabinet，robot needs to find cabinet and open it.", 
        "plan": ["1.find cabinet", "2.open cabinet"],"view":"There is a cabinet in view with a statue on it."}}
        Ensure that the subtasks are not generic statements like "explore the environment" or "do the task". They should be specific to the task at hand.
        * NOTE: DO NOT OUTPUT ANYTHING EXTRA OTHER THAN WHAT HAS BEEN SPECIFIED
        Let's work this out in a step by step way to be sure we have the right answer."""
        self.ACTION_LIST_PROMPT = f"""You are Actor, an excellent low level action planner and robot controller who is tasked with helping 1 embodied robot carry out a task in the {self.scene}.
        You will get a description of the task robot is supposed to do. You will get an image of the environment as the observation input. 
        You need to understand the task instruction, analyze the objects relationships in the scene image, and then make a sequence of low-level actions for the robot to complete the task.
        So, along with the image inputs you will get the following information:
        ### INPUT FORMAT ###
        {{
        Plan: list of high-level plans supposed to carry out to finish the task. Map plans into actionable steps. You need to follow the plan strictly. Don't add unnecessary steps.
        skill_set: a set of predefined skills that defines the atomic actions that the robot can perform. You can only use the actions from skill_set!        
        History: Mistakes made in the last action plan.
        {{
        First of all you are supposed to reason over the image inputs, high-level plans and available skill_set the robots can perform, history,and think step by step and then output the following things:
        * Reason: Reasons why the robot should perform these actions in this order.
        * Action: The low level actions list the robots need to execute in a specific sequence to efficiently complete the task.
        Your output should just be in the form of a python dictionary as shown below.
        Examples:
        Output: {{"reason":"xxx",Action":[
        "1. Move to the kitchen counter_1.",
        "2. Find the microwave_1.",
        "3. Open the microwave_1.",
        "4. Move to the dining table_1.",
        "5. Find the metal bowl_1",
        "6. Pick up the metal bowl_1.",
        "7. Put the metal bowl_1 in the microwave_1.",
        "8. Close the microwave_1.",
        "9. Turn on the microwave_1.",
        "10. Wait for 5 minutes.",
        "11. Done."
        ]}}
        Note that the output should just be a dictionary similar to the example outputs. Don't use "Locate the table in the room using the locate skill" rather than "Locate the table in the room"
        The last plan should be "Done." If high-level plans contains "warning", you should output "Action":["warning"]
        * NOTE: DO NOT OUTPUT ANYTHING EXTRA OTHER THAN WHAT HAS BEEN SPECIFIED
        Let's work this out in a step by step way to be sure we have the right answer.
        """
        if module_name == "action":
            system_prompt = self.ACTION_LIST_PROMPT
            user_prompt = (f"Plan:{self.plan}\n"
                           f"skill_set:{self.skill_set}\n"
                           f"History:{self.reason}")
            print(user_prompt)
        elif module_name == "planner":
            system_prompt = self.PLANNER_PROMPT
            user_prompt = (f"Task:{self.task_description}\n"
                           f"Object_list:{self.objects}\n"
                           f"Safety_guidelines:{self.safe_prompt}\n"
                           f"History:{self.reason}")
            print(user_prompt)

        return system_prompt, user_prompt

    def check_plan_damage(self):
        #self.view
        self.safe_prompt = self.safety_helper.check_plan_damage(self.view, self.plan)
        return self.safe_prompt

    def judge_safe_requirement(self):
        if self.safe_prompt in [None, "None", "null","NULL",""]:
            return "none"
        output,self.reason = self.safety_helper.check_action(self.task_description+"requirement:"+str(self.safe_prompt),self.plan,self.low_level_action)
        if output.lower() == "none":
            self.reason = None
        return output
