from .agent import LLMAgent
from pathlib import Path
import dspy

INTERACTION_BELIEF_CONTEXT = """
You are a Minecraft agent.
You just had a conversation with another agent based on a task you are trying to solve.
Based on the contents of the conversation and the previous beliefs, you have to create a set of beliefs that that can help you complete the task.
"""

INTERACTION_BELIEF_PREFIX = """
The new interaction beliefs should encapsulate useful information from the conversation that can help you complete the task given the conversation and your previous beliefs.
Aim to create a maximum of 5 beliefs. Beliefs should be concise and relevant to the task.
"""

PARTNER_BELIEF_CONTEXT = """
You are a Minecraft agent.
You just had a conversation with another agent based on a task you are trying to solve.
Based on the contents of the conversation and the previous beliefs, you have to create a set of beliefs that represent your perception of the other agent.
"""

PARTNER_BELIEF_PREFIX = """
The new partner beliefs should contain your perception of the other agent based on the conversation and your previous beliefs. Beliefs should be informative of the other's agent state.
"""

class PartnerBeliefs(dspy.Signature):
    context: str = dspy.InputField(desc="The context for generating the partner beliefs.")
    previous_partner_beliefs: dict = dspy.InputField(desc="The previous beliefs about the partner.")
    latest_conversation: str = dspy.InputField(desc="The latest conversation with the partner.")

    partner_beliefs = dspy.OutputField(desc="New partner beliefs.", prefix=PARTNER_BELIEF_PREFIX)

class InteractionBeliefs(dspy.Signature):
    task: str = dspy.InputField(desc="The task to be accomplished.")
    context: str = dspy.InputField(desc="The context for generating the interaction beliefs.")
    previous_interaction_beliefs: dict = dspy.InputField(desc="Previous interaction beliefs from past conversations.")
    latest_conversation: str = dspy.InputField(desc="The latest conversation with the partner.")

    interaction_beliefs = dspy.OutputField(desc="New interaction beliefs.", prefix=INTERACTION_BELIEF_PREFIX)

class SocialAgent(LLMAgent):

    def __init__(
        self,
        name: str,
        llm: str,
        temperature: float,
        request_timeout: int,
        resume: bool,
        ckpt_dir: str,
        chat_log: bool,
        execution_error: bool,
        logger
    ):
        super().__init__(name, llm, temperature, logger)
        self.request_timeout = request_timeout
        self.path = Path(f"{ckpt_dir}/{name}")
        self.chat_log = chat_log
        self.execution_error = execution_error


    def generate_partner_beliefs(self, previous_partner_beliefs: str, latest_conversation: str):
        generate_partner_beliefs = dspy.Predict(PartnerBeliefs)
        partner_beliefs = generate_partner_beliefs(
            context=PARTNER_BELIEF_CONTEXT,
            previous_partner_beliefs=previous_partner_beliefs,
            latest_conversation=latest_conversation
        )
        return partner_beliefs

    def generate_interaction_beliefs(self, task:str, previous_interaction_beliefs: str, latest_conversation: str):
        generate_interaction_beliefs = dspy.Predict(InteractionBeliefs)
        interaction_beliefs = generate_interaction_beliefs(
            task=task,
            context=INTERACTION_BELIEF_CONTEXT,
            previous_interaction_beliefs=previous_interaction_beliefs,
            latest_conversation=latest_conversation
        )
        return interaction_beliefs

    def __call__(self, context: dict): # pyright: ignore
        partner_beliefs = self.generate_partner_beliefs(
            previous_partner_beliefs=context["previous_partner_beliefs"],
            latest_conversation=context["latest_conversation"]
        )
        interaction_beliefs = self.generate_interaction_beliefs(
            task=context["task"],
            previous_interaction_beliefs=context["previous_interaction_beliefs"],
            latest_conversation=context["latest_conversation"]
        )

        return partner_beliefs, interaction_beliefs

    def restore(self):
        # TODO: implement this
        pass
