from langchain.schema import AgentAction, AgentFinish

from toolemu.utils.my_typing import SimulatedObservation


def construct_trajec(
        traject_dict,
        include_input=True,
        include_agent_log=True,
        include_simulator_log=True,
        include_final_answer=True,
        is_dumped_dict=False,
):
    trajec = ""
    if include_input:
        trajec += "User Input: {}\n".format(traject_dict["input"])

    for i, step in enumerate(traject_dict["intermediate_steps"]):
        if not is_dumped_dict:
            if isinstance(step[0], AgentAction) or include_final_answer:
                # step will be a tuple of (AgentAction/AgentFinish object, SimulatedObservation object)
                if include_agent_log:
                    # inlcude agent log
                    trajec += f"Thought: {step[0].log}\n"
                else:
                    if isinstance(step[0], AgentAction):
                        # include only the action
                        trajec += f"Action: {step[0].tool}\nAction Input: {step[0].tool_input}\n"
                    else:
                        # include final answer
                        assert isinstance(step[0], AgentFinish)
                        trajec += (
                            f"Final Answer:{step[0].log.split('Final Answer:')[1]}\n"
                        )

                if isinstance(step[0], AgentAction):
                    if (
                            isinstance(step[1], SimulatedObservation)
                            and include_simulator_log
                    ):
                        trajec += f"Simulator Thought:\n{step[1].log}\n"
                    else:
                        trajec += f"Observation:\n{step[1].observation}\n"
        else:
            if i != len(traject_dict["intermediate_steps"]) - 1 or include_final_answer:
                # step will be a tuple of (dumped list of actions, dumped list of observation)
                if include_agent_log:
                    # inlcude agent log
                    trajec += f"Thought: {step[0][-1]}\n"
                else:
                    if i != len(traject_dict["intermediate_steps"]) - 1:
                        # include only the action
                        trajec += f"Action: {step[0][0]}\nAction Input: {step[0][1]}\n"
                    else:
                        # include final answer
                        final_answer = step[0][-1].split("Final Answer:")
                        trajec += f"Final Answer:{final_answer[-1]}\n"

            if i != len(traject_dict["intermediate_steps"]) - 1:
                if include_simulator_log:
                    trajec += f"Simulator Thought:\n{step[1][-1]}\n"
                else:
                    trajec += f"Observation: {step[1][0]}\n"
    return trajec


def construct_simple_trajec(traject_dict, **kwargs):
    return construct_trajec(
        traject_dict, include_agent_log=False, include_simulator_log=False, **kwargs
    )


def get_used_tools_in_trajec(
        traject_dict,
        is_dumped_dict=True,
):
    used_tools = set()
    if not is_dumped_dict:
        for step in traject_dict["intermediate_steps"]:
            if isinstance(step[0], AgentAction):
                used_tools.add(step[0].tool)
    else:
        for i, step in enumerate(traject_dict["intermediate_steps"]):
            if i != len(traject_dict["intermediate_steps"]) - 1:
                used_tools.add(step[0][0])
    return used_tools
