import torch

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation,
    beam_action_generation,
)


class ZeroShotAgent(BaseAgent):
    name = "zeroshot"

    def __init__(
        self,
        model=None,
        tokenizer=None,
        env_name: str = "virtualhome",
        decoding_strategy: _Type_Decoding = "beam-action",
        perturb: bool = False,
    ):
        super().__init__()

        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)
        self.action_format = PromptTemplate.load_env_action_format(env_name)
        self.action_dict = PromptTemplate.load_env_action_dict(env_name)

        if isinstance(model, str):
            self.model = self._prepare_model(model)
        else:
            self.model, self.tokenizer = model, tokenizer

        self.decoding_strategy = decoding_strategy
        self.perturb = perturb

    def get_action(self, obs):
        self.prompt += obs + "\n> "

        if self.model in ["gpt-4o-mini", "gpt-4o"]:
            action = self._call_model(self.prompt)
        else:
            with torch.no_grad():
                if self.decoding_strategy == "beam-action":
                    object_list = PromptTemplate.get_object_list(obs)
                    generation_output = beam_action_generation(
                        self.model,
                        self.tokenizer,
                        self.prompt,
                        self.action_format,
                        object_list,
                        env_name=self.env_name,
                    )
                elif self.decoding_strategy == "greedy":
                    generation_output = generation(
                        self.model,
                        self.tokenizer,
                        self.prompt,
                        **BaseAgent.default_gen_params,
                    )
            action = generation_output.response
            action = action.strip().split("\n")[0]

        self.prompt += f"{action}\n"
        return action

    def reset(self, task, goal):
        self.goal = goal
        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere is the task.\n\n"
            + f"Your task is to: {self.goal}.\n"
        )

    def forward(self, instruction, state, history, success):
        if not success:
            transitions = self.prompt.strip("\n").split("\n")[:-2]
            self.prompt = "\n".join(transitions) + "\n"

        state = PromptTemplate.preprocess(state)
        if self.perturb:
            if "Turn" in instruction or "Open" in instruction:
                state = PromptTemplate.randomize(state)
            elif "Place" in instruction or "Put" in instruction:
                state = PromptTemplate.randomize(state, 0.5)
            else:
                raise NotImplementedError

        return self.get_action(state)
