import torch

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation,
    greedy_generation,
    beam_action_generation,
)
from embodied_cd.common.print_utils import *


class FewShotAgent(BaseAgent, FewShotMixIn):
    name = "fewshot"

    def __init__(
        self,
        model=None,
        tokenizer=None,
        env_name: str = "virtualhome",
        decoding_strategy: _Type_Decoding = "beam-action",
        num_few_shot: int = 2,
        context_window: int = 10,
        perturb: bool = False,
    ):
        super().__init__()

        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)
        self.action_format = PromptTemplate.load_env_action_format(env_name)
        self.action_dict = PromptTemplate.load_env_action_dict(env_name)

        if isinstance(model, str):
            self.model = self._prepare_model(model)
        else:
            self.model, self.tokenizer = model, tokenizer

        self.decoding_strategy = decoding_strategy

        self.few_shot_prompt = None
        self.few_shot_pool = None
        self.num_few_shot = num_few_shot

        self.context_window = context_window
        self.perturb = perturb

    def get_action(self, obs):
        self.prompt += obs + "\n> "

        if self.model in ["gpt-4o-mini", "gpt-4o"]:
            action = self._call_model(self.prompt)
            action = action.strip(" >")
        else:
            with torch.no_grad():
                if self.decoding_strategy == "beam-action":
                    object_list = PromptTemplate.get_object_list(obs)
                    generation_output = beam_action_generation(
                        self.model,
                        self.tokenizer,
                        self.prompt,
                        self.action_format,
                        object_list,
                        env_name=self.env_name,
                    )
                elif self.decoding_strategy == "greedy":
                    generation_output = generation(
                        self.model,
                        self.tokenizer,
                        self.prompt,
                        **BaseAgent.default_gen_params,
                    )
            action = generation_output.response
            action = action.strip().split("\n")[0]

        self.prompt += f"{action}\n"

        return action

    def reset(self, task, goal):
        assert self.few_shot_pool, "Make sure to call load_few_shot_pool() first."

        self.few_shot_prompt = self.sample_few_shot_prompt(goal, k=self.num_few_shot)

        self.goal = goal
        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere are some examples.\n\n"
            + self.few_shot_prompt
            + "\n\nHere is the task.\n\n"
            + f"Your task is to: {self.goal}.\n"
        )

    def forward(
        self,
        instruction,
        state,
        history,
        success,
    ):
        if not success:
            transitions = self.prompt.strip("\n").split("\n")[:-2]
            self.prompt = "\n".join(transitions) + "\n"

        state = PromptTemplate.preprocess(state)
        if self.perturb:
            if "Turn" in instruction or "Open" in instruction:
                state = PromptTemplate.randomize(state)
            elif "Place" in instruction or "Put" in instruction:
                state = PromptTemplate.randomize(state, 0.5)
            else:
                raise NotImplementedError

        context_count = self.prompt.split("Here is the task.")[-1].count(">")
        if context_count > self.context_window:
            temp = self.prompt.strip("\n").split("\n")
            prompt, transitions = (
                temp[: -2 * context_count],
                temp[-2 * context_count + 2 :],
            )
            self.prompt = "\n".join(prompt) + "\n".join(transitions) + "\n"

        return self.get_action(state)
