from typing import Union, List
from embodied_cd.common.agent import BaseAgent


class MotivAgent(BaseAgent):
    """ Agent for motivation experiments (comparing with & without thinking annotations)"""
    name = "motiv"

    def __init__(self, model: str, env_name: str = 'virtualhome'): 
        super().__init__()

        self.model = self._prepare_model(model)
        self.env_prompt = self._load_env_prompt(env_name)

    def reset(self, task, goal):
        return

    def get_action(
        self, 
        instruction: str, 
        state: str, 
        think: str = None, 
        few_shot_examples: List[str] = None,
    ):
        """
            insturction : instruction of the task
            state       : current state
            think       : think annotations
            few_shot    : few-shot examples 
        """
        query = self.env_prompt
        if isinstance(few_shot_examples, str):
            query = f"Following are the examples:\n{few_shot_examples}"

        if not isinstance(think, str):
            query += f"\nInstruction: {instruction}\nState: {state}\nAction:"
        else:
            query += f"\nInstruction: {instruction}\nState: {state}\nThink: {think}\nAction:"
        return self._call_model(query, instruct=False)

    def forward(
        self,
        instruction: str,
        state: str,
        history: str,
        few_shot_examples: Union[str, List[str]] = None,
    ):
        return self.get_action(instruction, state, history, few_shot_examples)

