from utils import *
import time
import tiktoken  # Import tiktoken library
from langchain.callbacks import get_openai_callback

class Planner:
    def __init__(
        self,
        # openai_key,
        memory,
        mode = 0,
        rules_dir=None, 
        traj_memory_dir=None,
        model_name="gpt-4o-mini",
        temperature=0, # default 0
        choice_num=1, # default 1
    ):
        # os.environ["OPENAI_API_KEY"] = openai_key
        

        # openai.api_base ="https://api.chatweb.plus/v1"

        self.llm = ChatOpenAI(
            model_name=model_name,
            temperature=temperature,
            n = choice_num #,
            # response_format = { "type": "json_object" }
        )

        self.memory = memory
        self.mode = mode
        self.tokenizer = tiktoken.get_encoding("cl100k_base")  # Initialize the tokenizer

        assert  self.memory is not None, "Please input memory"

        if rules_dir is None:
            self.rules = {}
        else:
            self.rules = load_json_file(rules_dir)

        if traj_memory_dir is None:
            self.traj_memory = {}
        else:
            self.traj_memory = load_json_file(traj_memory_dir)
    
    def _format_feedback_or_suggestion(self, item):
        if isinstance(item, list):
            return "; ".join(item)
        elif isinstance(item, str):
            return item
        else:
            return ""

    def _extract_between_brackets(self, s):
        start_index = s.find('[')
        end_index = s.rfind(']')
        if start_index != -1 and end_index != -1 and start_index < end_index:
            return s[start_index:end_index + 1]
        else:
            return ""

    def _extract_between_curly_brackets(self, s):
        start_index = s.find('{')
        end_index = s.rfind('}')
        if start_index != -1 and end_index != -1 and start_index < end_index:
            return s[start_index:end_index + 1]
        else:
            return ""

    def _count_tokens(self, text):
        """Counts the number of tokens in a given text.
        
        Args:
            text (str): The text to count tokens for.
        
        Returns:
            int: The number of tokens.
        """

        text = json.dumps(text)
        return len(self.tokenizer.encode(text))

    def get_workflow(self, task_information, check_result, initial_state, running_dataset, testModel, max_retries=5):
        # mode = 0 ['agent+world model']; 1 ['agent with rules']; 2 ['agent with rules and state']

        if max_retries == 0:
            log_info("************Failed to get workflow. Consider updating your prompt.************\n\n")
            return {}

        try:
            if self.mode == 0:
                #### CASE 1: without rules###########
                # structured_action_system = load_prompt("structured_action_system")
                structured_action_system = load_prompt("structured_action_system_w_rules_w_multisingle")
            elif self.mode == 1:
                #### CASE 2: with rules###########
                rules_act = self.rules
                rules_string = dict_to_prompt(rules_act)
                structured_action_system = load_prompt("structured_action_system_w_rules_w_multisingle").replace("<rules>", rules_string)
            elif self.mode == 2:
                #### CASE 2: with rules and state###########
                rules_act = self.rules
                rules_string = dict_to_prompt(rules_act)
                structured_action_system = load_prompt("structured_action_system_w_rulesandstate").replace("<rules>", rules_string)
            else:
                raise NotImplementedError

            task_information_string = dict_to_prompt(task_information)
            if 'withtrajmemory' in testModel:
                reference_plan_string = list_dict_to_prompt(self.memory.search_workflows(task_information["task"])) # task_information["task"]
            else:
                reference_plan_string = ''

            if self.mode == 0:
                structured_action_query = load_prompt("structured_action_query").format(
                    task_information=task_information_string, 
                    inventory=self.memory.inventory,
                    reference_plan=reference_plan_string
                )
            elif self.mode == 1 or self.mode == 2:
                initial_state_string = dict_to_prompt(initial_state)
                structured_action_query = load_prompt("structured_action_query_w_rulesandstate").format(
                    task_information=task_information_string, 
                    initial_state=initial_state_string,
                    reference_plan=reference_plan_string
                )
            else:
                raise NotImplementedError

            if len(check_result) == 0:
                structured_action_query += "\nPlan your workflow. Remember to follow the response format."
            else:
                # structured_action_query += f"""The previous workflow failed. 
                # The reason for the failure: {check_result["feedback"]}.
                # A suggested recommendations: {check_result["suggestion"]}. 
                # re-plan your workflow. Remember to follow the response format."""
                feedback = self._format_feedback_or_suggestion(check_result.get("feedback", ""))
                suggestion = self._format_feedback_or_suggestion(check_result.get("suggestion", ""))
                structured_action_query += f"""The previous workflow failed. 
                The reason for the failure: {feedback}.
                A suggested recommendations: {suggestion}. 
                Re-plan your workflow. Remember to follow the response format."""
                
            messages = [
                SystemMessage(content=structured_action_system),
                HumanMessage(content=structured_action_query)
            ]

            # TODO
            token_length = self._count_tokens(structured_action_system) + self._count_tokens(structured_action_query)
            print(f'total token length is {token_length}')
            # TODO

            with get_openai_callback() as cb:
                llm_response = self.llm(messages)
            tokens_used = {
                "stage": "agent planning",
                "prompt_tokens": cb.prompt_tokens,
                "completion_tokens": cb.completion_tokens,
                "total_tokens": cb.total_tokens,
                "Total Cost (USD)": cb.total_cost
            }
            time.sleep(0.2)
            with open(f'/home/**/Workspace/MP5/MP5_agent/agent/task_result/[tokenuse]_{running_dataset}_{testModel}.log', 'a') as f:
                f.write(json.dumps(tokens_used) + '\n')  

            # llm_response = self.llm(messages)
            # llm_response = self.llm.generate(messages = [messages])
            #### llm_response.generations[0][0].message.content
            #### llm_response.generations[0][1].message.content
            #### llm_response.generations[0][2].message.content
            # llm_response = self.llm.invoke(messages)

            workflow_dict = llm_response.content
            log_info(f"Create Workflow Result: {workflow_dict}")
            
            workflow_dict = self._extract_between_curly_brackets(workflow_dict) 
            workflow_dict = fix_and_parse_json(workflow_dict)
            workflow_dict = update_dict(workflow_dict)

            if "workflow on explanation" in workflow_dict:
                if (workflow_dict["workflow on explanation"] and 
                    workflow_dict["workflow on task"] and
                    len(workflow_dict["workflow on task"]) == 1 and
                    workflow_dict["workflow on explanation"][-1] == workflow_dict["workflow on task"][0]):
                    workflow = workflow_dict["workflow on explanation"]
                else:
                    workflow = workflow_dict["workflow on explanation"] + workflow_dict["workflow on task"]
            elif isinstance(workflow_dict['workflow'], dict):
            # mode 2
                workflow = extract_actions(workflow_dict)
            elif isinstance(workflow_dict['workflow'], list):
            # mode 0 & 1
                workflow = workflow_dict['workflow']
            else:
                raise NotImplementedError

            return workflow
        except Exception as e:
            log_info(f"Error arises in Plan Workflow part: {e} Trying again!\n\n")
            self.memory.reset_current_environment_information()

            return self.get_workflow(
                task_information, 
                check_result, 
                initial_state, 
                running_dataset, 
                testModel, 
                max_retries=max_retries - 1
            )
    
    def get_workflow_for_search(self, task_information, check_result, initial_state, running_dataset, testModel, planner_search_alg = 'MPC', max_retries=5):
        if max_retries == 0:
            log_info("************Failed to get workflow. Consider updating your prompt.************\n\n")
            return {}
        try:
            #### CASE 2: with rules###########
            rules_act = self.rules
            rules_string = dict_to_prompt(rules_act)
            if planner_search_alg == 'MPC':
                structured_action_system = load_prompt("structured_action_system_w_rules_w_multisingle").replace("<rules>", rules_string)
            elif planner_search_alg == 'MCTS':
                structured_action_system = load_prompt("structured_action_system_w_rules_w_multisingle_onestep").replace("<rules>", rules_string)
            else:
                NotImplementedError

            task_information_string = dict_to_prompt(task_information)
            reference_plan_string = list_dict_to_prompt(self.traj_memory) # TODO task_information["task"]
            initial_state_string = dict_to_prompt(initial_state)
            structured_action_query = load_prompt("structured_action_query_w_rulesandstate").format(
                task_information=task_information_string, 
                initial_state=initial_state_string,
                reference_plan=reference_plan_string
            )

            if len(check_result) == 0:
                structured_action_query += "\nPlan your workflow. Remember to follow the response format."
            else:
                # structured_action_query += f"""The previous workflow failed. 
                # The reason for the failure: {check_result["feedback"]}.
                # A suggested recommendations: {check_result["suggestion"]}. 
                # re-plan your workflow. Remember to follow the response format."""
                feedback = self._format_feedback_or_suggestion(check_result.get("feedback", ""))
                suggestion = self._format_feedback_or_suggestion(check_result.get("suggestion", ""))
                structured_action_query += f"""The previous workflow failed. 
                The reason for the failure: {feedback}.
                A suggested recommendations: {suggestion}. 
                Re-plan your workflow. Remember to follow the response format."""
                
            messages = [
                SystemMessage(content=structured_action_system),
                HumanMessage(content=structured_action_query)
            ]

            # TODO
            # llm_response = self.llm(messages)
            #### llm_response.generations[0][0].message.content
            #### llm_response.generations[0][1].message.content
            #### llm_response.generations[0][2].message.content
            # llm_response = self.llm.invoke(messages)

            ####################
            # Tokens 统计
            ####################
            # llm_response = self.llm.generate(messages = [messages])
            with get_openai_callback() as cb:
                llm_response = self.llm.generate(messages = [messages])
            # 获取 token 统计信息
            tokens_used = {
                "stage": "agent planning",
                "prompt_tokens": cb.prompt_tokens,
                "completion_tokens": cb.completion_tokens,
                "total_tokens": cb.total_tokens,
                "Total Cost (USD)": cb.total_cost
            }
            time.sleep(0.2)
            with open(f'/home/**/Workspace/MP5/MP5_agent/agent/task_result/[tokenuse]_{running_dataset}_{testModel}.log', 'a') as f:
                f.write(json.dumps(tokens_used) + '\n') 
            
            # process_llm_response(llm_response):
            workflow_list = []
            logprobs = []
            success_list = []
            for generation in llm_response.generations[0]:
                message_content = generation.message.content
                message_content = self._extract_between_curly_brackets(message_content) 

                workflow_dict = fix_and_parse_json(message_content)
                workflow_dict = update_dict(workflow_dict)
                if "workflow on explanation" in workflow_dict:
                    if (workflow_dict["workflow on explanation"] and 
                        workflow_dict["workflow on task"] and
                        len(workflow_dict["workflow on task"]) == 1 and
                        workflow_dict["workflow on explanation"][-1] == workflow_dict["workflow on task"][0]):
                        workflow_data = workflow_dict["workflow on explanation"]
                    else:
                        workflow_data = workflow_dict["workflow on explanation"] + workflow_dict["workflow on task"]
                elif isinstance(workflow_dict['workflow'], dict):
                # mode 2
                    workflow = extract_actions(workflow_dict)
                elif isinstance(workflow_dict['workflow'], list):
                # mode 0 & 1
                    workflow = workflow_dict['workflow']

                # parsed_data = parsed_data.get('workflow') or parsed_data.get('next action')
                elif planner_search_alg == 'MPC':
                    workflow_data = workflow_dict.get('workflow')
                elif planner_search_alg == 'MCTS':
                    workflow_data = workflow_dict.get('next action')
                    action_success = workflow_dict.get('success')
                    success_list.append(action_success)
                else:
                    NotImplementedError
                workflow_list.append(workflow_data)
                logprobs.append(generation.message.response_metadata['logprobs'])
            
            log_info(f"Create Workflow Result: {workflow_list}")
            log_info(f"Workflow logprobs Result: {logprobs}")

            return workflow_list, success_list
        except Exception as e:
            log_info(f"Error arises in Plan Workflow part: {e} Trying again!\n\n")
            self.memory.reset_current_environment_information()

            return self.get_workflow_for_search(
                task_information, 
                check_result, 
                initial_state,
                running_dataset, 
                testModel, 
                max_retries=max_retries - 1
            )

    def update_rules(self, rules_extra):
        self.rules = rules_extra

if __name__ == "__main__":
    model_name = 'gpt-4-turbo'
    # model_name= "gpt-3.5-turbo"


# equip(obj): equip the object from the inventory: The first item in 'state feature'-'equipment' represents the object that is currently being equipped, do not equip it again!
# - obj: a string, the object to equip, like "wooden pickaxe".