import torch

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation,
    greedy_generation,
    beam_action_generation,
)


class ReActSayCanAgent(BaseAgent, FewShotMixIn):
    name = "react_saycan"

    def __init__(
        self,
        model=None,
        tokenizer=None,
        env_name: str = "virtualhome",
        decoding_strategy: _Type_Decoding = "beam-action",
        num_few_shot: int = 2,
        context_window: int = 8,
        perturb: bool = False,
    ):

        super().__init__()

        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)
        self.action_format = PromptTemplate.load_env_action_format(env_name)
        self.action_dict = PromptTemplate.load_env_action_dict(env_name)

        if isinstance(model, str):
            self.model = self._prepare_model(model)
        else:
            self.model, self.tokenizer = model, tokenizer

        self.decoding_strategy = decoding_strategy

        self.few_shot_prompt = None
        self.few_shot_pool = None
        self.num_few_shot = num_few_shot

        self.context_window = context_window
        self.perturb = perturb

    def get_action(self, obs, can_list):
        self.prompt += obs + "\n> think: "
        with torch.no_grad():
            generation_output = generation(
                self.model,
                self.tokenizer,
                self.prompt,
                max_length=200,
                **BaseAgent.default_gen_params,
            )
        think = generation_output.response.split("\n")[0].strip()
        self.prompt += think + "\nOK.\n> "

        with torch.no_grad():
            if self.decoding_strategy == "beam-action":
                object_list = PromptTemplate.get_object_list(obs)
                generation_output = beam_action_generation(
                    self.model,
                    self.tokenizer,
                    self.prompt,
                    self.action_format,
                    object_list,
                    env_name=self.env_name,
                    can_list=can_list,
                )
            elif self.decoding_strategy == "greedy":
                generation_output = generation(
                    self.model,
                    self.tokenizer,
                    self.prompt,
                    **BaseAgent.default_gen_params,
                )
        action = generation_output.response
        action = action.strip().split("\n")[0]
        self.prompt += f"{action}\n"

        return action

    def reset(self, task, goal):
        assert self.few_shot_pool, "Make sure to call load_few_shot_pool() first."

        self.few_shot_prompt = self.sample_few_shot_prompt(goal, k=self.num_few_shot)

        self.goal = goal
        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere are some examples.\n\n"
            + self.few_shot_prompt
            + "\n\nHere is the task.\n\n"
            + f"Your task is to: {self.goal}.\n"
        )

    def forward(
        self,
        instruction,
        state,
        history,
        info,
    ):
        success, can_list = info["success"], info["can_list"]

        if not success:
            transitions = self.prompt.strip("\n").split("\n")[:-4]
            self.prompt = "\n".join(transitions) + "\n"

        state = PromptTemplate.preprocess(state)
        if self.perturb:
            if self.env_name == "virtualhome":
                state = PromptTemplate.randomize(state, 0.5)
            elif self.env_name == "alfred":
                state = PromptTemplate.randomize(state, 0.3)
            else:
                raise NotImplementedError

        context_count = self.prompt.split("Here is the task.")[-1].count(">") // 2
        if context_count > self.context_window:
            temp = self.prompt.strip("\n").split("\n")
            prompt, transitions = (
                temp[: -4 * context_count],
                temp[-4 * context_count + 4 :],
            )
            self.prompt = "\n".join(prompt) + "\n".join(transitions) + "\n"

        return self.get_action(state, can_list)
