from utilsextra import *
import tiktoken  # Import tiktoken library


class Planner:
    def __init__(
        self,
        # openai_key,
        rules_dir=None, 
        traj_memory_dir=None,
        model_name="gpt-4o-mini",
        temperature=0, # default 0
        choice_num=1, # default 1
    ):

        self.llm = ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            n = choice_num,
            response_format = { "type": "json_object" }
        )

        self.tokenizer = tiktoken.get_encoding("cl100k_base")  # Initialize the tokenizer


        if rules_dir is None:
            self.rules = {}
        else:
            self.rules = load_json_file(rules_dir)

        if traj_memory_dir is None:
            self.traj_memory = {}
        else:
            self.traj_memory = load_json_file(traj_memory_dir)
    
    def _format_feedback_or_suggestion(self, item):
        if isinstance(item, list):
            return "; ".join(item)
        elif isinstance(item, str):
            return item
        else:
            return ""

    def _extract_between_brackets(self, s):
        start_index = s.find('[')
        end_index = s.rfind(']')
        if start_index != -1 and end_index != -1 and start_index < end_index:
            return s[start_index:end_index + 1]
        else:
            return ""

    def _extract_between_curly_brackets(self, s):
        start_index = s.find('{')
        end_index = s.rfind('}')
        if start_index != -1 and end_index != -1 and start_index < end_index:
            return s[start_index:end_index + 1]
        else:
            return ""

    def _count_tokens(self, text):
        """Counts the number of tokens in a given text.
        
        Args:
            text (str): The text to count tokens for.
        
        Returns:
            int: The number of tokens.
        """

        text = json.dumps(text)
        return len(self.tokenizer.encode(text))

    def get_workflow(self, inital_state, check_result, max_retries=5):
        # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']

        if max_retries == 0:
            log_info("************Failed to get workflow. Consider updating your prompt.************\n\n")
            return {}

        try:
            rules_act = self.rules
            rules_string = dict_to_prompt(rules_act)
            structured_action_system = load_prompt("structured_action_system_w_rules_w_multisingle").replace("<rules>", rules_string)

            structured_action_query = load_prompt("structured_action_query").format(
                inital_state=inital_state
            )

            if len(check_result) == 0:
                structured_action_query += "\nGive you next step. Remember to follow the response format."
            else:
                # structured_action_query += f"""The previous workflow failed. 
                # The reason for the failure: {check_result["feedback"]}.
                # A suggested recommendations: {check_result["suggestion"]}. 
                # re-plan your workflow. Remember to follow the response format."""
                feedback = self._format_feedback_or_suggestion(check_result.get("feedback", ""))
                suggestion = self._format_feedback_or_suggestion(check_result.get("suggestion", ""))
                structured_action_query += f"""The previous action failed. 
                The reason for the failure: {feedback}.
                A suggested recommendations: {suggestion}. 
                Re-generate your next action. Remember to follow the response format."""
                
            messages = [
                SystemMessage(content=structured_action_system),
                HumanMessage(content=structured_action_query)
            ]

            # TODO
            token_length = self._count_tokens(structured_action_system) + self._count_tokens(structured_action_query)
            print(f'total token length is {token_length}')
            # TODO

            llm_response = self.llm(messages)
            # llm_response = self.llm.generate(messages = [messages])
            #### llm_response.generations[0][0].message.content
            #### llm_response.generations[0][1].message.content
            #### llm_response.generations[0][2].message.content
            # llm_response = self.llm.invoke(messages)

            workflow_dict = llm_response.content
            prediction_json = fix_and_parse_json(workflow_dict)

            return prediction_json["next action"]
        except Exception as e:
            log_info(f"Error arises in Decision Making part: {e} Trying again!\n\n")

            return self.get_workflow(
                inital_state, 
                check_result, 
                max_retries=max_retries - 1
            )
    

    def update_rules(self, rules_extra):
        self.rules = rules_extra



if __name__ == "__main__":
    model_name = 'gpt-3.5-turbo-instruct'
    # model_name= "gpt-3.5-turbo"


# equip(obj): equip the object from the inventory: The first item in 'state feature'-'equipment' represents the object that is currently being equipped, do not equip it again!
# - obj: a string, the object to equip, like "wooden pickaxe".