import json
from google import genai

FORMAT = {
    "robot_navigation": "## **Previous Plan to collect ball and move to goal location**\n",
    "robot_arm": "## **Previous plan to achieve the goal state**\n",
    "mix_color": "## **Previous Plan to {goal}**\n",
    "block_stacking": "## **Previous plan to achieve the goal state**\n"
}

def load_text(path):
    with open(path, 'r', encoding='utf-8') as file:
        return file.read()

class InfoSeekerPlanner():
    def __init__(self, env_type):
        # LLM
        with open("./gemini_key.json", 'r', encoding='utf-8') as f:
            gemini_key = json.load(f)
        self.client = genai.Client(api_key=gemini_key['key'])
        self.model = "gemini-2.0-flash"
        self.llm_history = []
        self.action_history = ""
        self.previous_plan = False

        #Load prompt
        self.env_type = env_type
        self.planning_prompt = load_text("./prompt/planning.txt")
        self.seeking_prompt1 = load_text("./prompt/seeking1.txt")
        self.seeking_prompt2 = load_text("./prompt/seeking2.txt")
        self.extraction_prompt = load_text(f"./prompt/extraction/{self.env_type}.txt")

    def llm(self, prompt: str, force_json=True, key=None):
        plan_iter = 0
        output, plan = None, None
        while plan_iter < 5:
            plan_iter += 1
            try:
                response = self.client.models.generate_content(
                    model=self.model,
                    contents=prompt,
                    config={"response_mime_type": "application/json"}
                )
                output = response.text
                #print(response.usage_metadata)
                data = json.loads(output)
                self.llm_history.append({"prompt": prompt, "response": output})
                if key:
                    return data[key]
                else:
                    return data
            except Exception as e:
                print(f"Fail to generate plan or parse output, try again! {e}")

    def format_history(self, goal:str):
        if self.previous_plan:
            history_format = FORMAT[self.env_type]
            history_format = history_format.replace('{goal}', goal)
            self.action_history = history_format + self.action_history

    def format_extract_prompt(self, goal:str):
        prompt = self.extraction_prompt
        prompt = prompt.replace('{goal}', goal)
        return prompt

    def information_seeking(self, domain_desc:str):
        prompt = ""
        if not self.previous_plan:
            prompt = self.seeking_prompt1.replace('{domain_desc}', domain_desc)
        else:
            prompt = self.seeking_prompt2.replace('{domain_desc}', domain_desc)\
                .replace('{interaction_history}', self.action_history)
        return self.llm(prompt, key="Steps")

    def extract_info(self, goal:str):
        prompt = self.format_extract_prompt(goal)
        prompt = prompt.replace('{interaction_history}', self.action_history)
        info_list = self.llm(prompt, key="Information")
        return '\n'.join([f"{i+1}) {info_list[i]}" for i in range(len(info_list))])

    def task_oriented_plannig(self, domain_desc:str, goal:str):
        self.format_history(goal)
        information = self.extract_info(goal)
        prompt = self.planning_prompt.replace('{domain_desc}', domain_desc)
        prompt = prompt.replace('{interaction_history}', self.action_history).replace('{information}', information)
        self.previous_plan = True
        return self.llm(prompt, key="Solution Plan")