import os
import os.path as osp
import copy

import torch
from transformers import DynamicCache
from saycanpay.lm_planner_base import CanModel, PayModel

from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation,
    greedy_generation,
    beam_action_generation,
)
from embodied_cd.common.print_utils import *


class ActCanPayAgent(BaseAgent, FewShotMixIn):
    name = "actcanpay"

    def __init__(
        self,
        model=None,
        tokenizer=None,
        env_name: str = "virtualhome",
        decoding_strategy: _Type_Decoding = "beam-action",
        num_few_shot: int = 2,
        context_window: int = 10,
        perturb: bool = False,
    ):
        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)
        self.action_format = PromptTemplate.load_env_action_format(env_name)
        self.action_dict = PromptTemplate.load_env_action_dict(env_name)

        if isinstance(model, str):
            self.model = self._prepare_model(model)
        else:
            self.model, self.tokenizer = model, tokenizer

        self.decoding_strategy = decoding_strategy

        self.admissible_actions = self._load_admissible_actions(env_name)
        self.prompt = ""
        self.goal = None

        self.few_shot_prompt = None
        self.few_shot_pool = None
        self.num_few_shot = num_few_shot

        self.context_window = context_window
        self.perturb = perturb

        self.decoding_score = None
        self.can_model = None
        self.pay_model = None

    def get_action(self, obs):
        self.prompt += obs + "\n> "
        action = self._call_model(self.prompt)
        self.prompt += f" {action}\n"
        return action

    def reset(self, task, goal):
        assert self.few_shot_prompt, "Make sure to call load_few_shot_prompt() first."

        self.goal = goal
        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere are some examples.\n\n"
            + self.few_shot_prompt
            + "\n\nHere is the task.\n\n"
            + f"Your task is to: {self.goal}.\n"
        )

    def forward(self, instruction, state, history, few_shot_examples=None):
        return self.get_action(state)

    def _call_model(self, prompt):
        tokenizer = self.model.tokenizer
        root_output = self._generate(
            prompt,
            past_key_values=DynamicCache(),
            max_new_tokens=1,
        )

        root_output.past_key_values.crop(-1)
        action_log_probs = torch.zeros(len(self.admissible_actions))
        can_scores = torch.zeros_like(action_log_probs)
        pay_scores = torch.zeros_like(action_log_probs)

        for action_idx, action in enumerate(self.admissible_actions):
            action_tokens = tokenizer.encode(action, add_special_tokens=False)
            action_log_prob = 0
            output = root_output

            action_prompt = prompt
            for i, token in enumerate(action_tokens):
                position_log_probs = torch.log_softmax(output.scores[0][0], dim=-1)
                action_log_prob += position_log_probs[token].item()
                if i < len(action_tokens) - 1:
                    action_prompt += tokenizer.decode(token)
                    with torch.no_grad():
                        output = self._generate(
                            action_prompt,
                            past_key_values=copy.deepcopy(root_output.past_key_values),
                            max_new_tokens=1,
                        )

            if "can" in self.decoding_score:
                can_scores[action_idx] = self._compute_can_score(prompt, action)
            if "pay" in self.decoding_score:
                pay_scores[action_idx] = self._compute_pay_score(prompt, action)

            action_log_probs[action_idx] = action_log_prob

        # action_log_probs = torch.log_softmax(action_log_probs, dim=-1)
        action_log_probs = action_log_probs + can_scores + pay_scores
        best_action = self.admissible_actions[action_log_probs.argmax().item()]
        return best_action

    def _generate(self, prompt, past_key_values, max_new_tokens):
        model, tokenizer = self.model.model, self.model.tokenizer

        inputs = tokenizer(prompt, return_tensors="pt")
        inputs["input_ids"] = inputs["input_ids"].to(model.device)
        inputs["attention_mask"] = inputs["attention_mask"].to(model.device)
        with torch.no_grad():
            output = model.generate(
                **inputs,
                do_sample=False,
                top_k=None,
                top_p=None,
                temperature=None,
                max_new_tokens=max_new_tokens,
                past_key_values=past_key_values,
                pad_token_id=tokenizer.eos_token_id,
                num_beams=1,
                return_dict_in_generate=True,
                output_scores=True,
                repetition_penalty=1.0,
            )
        return output

    def _load_admissible_actions(self, env_name):
        admissible_action_path = (
            f"externals/saycanpay/{env_name}_admissible_actions.txt"
        )
        with open(admissible_action_path, "r") as f:
            admissible_actions = f.read().splitlines()
        return admissible_actions

    def _compute_can_score(self, prompt, action):
        can_prompt = self._build_can_pay_prompt(prompt, action)
        with torch.no_grad():
            can_score = self.can_model(can_prompt, action)
            can_score = torch.log(can_score + 1e-3).item()
        return can_score

    def _compute_pay_score(self, prompt, action):
        pay_prompt = self._build_can_pay_prompt(prompt, action)
        with torch.no_grad():
            pay_score = self.pay_model(pay_prompt, action)
            pay_score = torch.log(pay_score + 1e-3).item()
        return pay_score

    def _build_can_pay_prompt(self, prompt, action):
        prompts = prompt.split("Here is the task.")[-1].strip("\n").split("\n")[:-1]
        observations = [p for p in prompts if p.startswith("(")]
        actions = [p.strip("> ") for p in prompts if p.startswith(">")]

        context = ["[Goal]", self.goal, "[Initial State]", observations[0]]
        for i in range(len(actions)):
            context.append(f"[Step {i + 1}] {actions[i]}")
        can_pay_prompt = f"{' '.join(context)} [NXT] {action}"
        return can_pay_prompt

    def load_can_pay_model(self, checkpoint_dir, decoding_score):
        self.decoding_score = decoding_score

        if "can" in decoding_score:
            print("Loading Can Model")
            can_ckpt = os.path.join(checkpoint_dir, "can.ckpt")
            self.can_model = CanModel.load_from_checkpoint(can_ckpt)
            self.can_model.eval().to(self.model.device)

        if "pay" in decoding_score:
            print("Loading Pay Model")
            pay_ckpt = os.path.join(checkpoint_dir, "pay.ckpt")
            self.pay_model = PayModel.load_from_checkpoint(pay_ckpt)
            self.pay_model.eval().to(self.model.device)
