import torch

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation,
    greedy_generation,
    beam_action_generation,
)
from embodied_cd.common.print_utils import *


class GPTAgent(BaseAgent, FewShotMixIn):
    name = "gpt"

    def __init__(
        self,
        model=None,
        env_name: str = "virtualhome",
        num_few_shot: int = 2,
        perturb: bool = False,
    ):
        super().__init__()

        self.model = model
        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)
        self.action_format = PromptTemplate.load_env_action_format(env_name)
        self.action_dict = PromptTemplate.load_env_action_dict(env_name)

        self.few_shot_prompt = None
        self.few_shot_pool = None
        self.num_few_shot = num_few_shot

        self.perturb = perturb

    def get_action(self, obs):
        self.prompt += obs + "\n> "

        action = self._call_model(self.prompt)
        action = action.strip(" >")

        self.prompt += f"{action}\n"

        return action

    def reset(self, task, goal):
        assert self.few_shot_pool, "Make sure to call load_few_shot_pool() first."

        self.few_shot_prompt = self.sample_few_shot_prompt(goal, k=self.num_few_shot)

        self.goal = goal
        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere are some examples.\n\n"
            + self.few_shot_prompt
            + "\n\nHere is the task.\n\n"
            + f"Your task is to: {self.goal}.\n"
        )

    def forward(
        self,
        instruction,
        state,
        history,
        success,
    ):
        state = PromptTemplate.preprocess(state)
        if self.perturb:
            if "Turn" in instruction or "Open" in instruction:
                state = PromptTemplate.randomize(state)
            elif "Place" in instruction or "Put" in instruction:
                state = PromptTemplate.randomize(state, 0.5)
            else:
                raise NotImplementedError

        # if not success:
        #     state = "The last action was invalid. Try another action.\n" + state
        # if not success:
        #     transitions = self.prompt.strip("\n").split("\n")[:-2]
        #     self.prompt = "\n".join(transitions) + "\n"


        return self.get_action(state)
