import os
import os.path as osp
import time
import random

import json
import torch
from transformers import DynamicCache
from openai import OpenAI

from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.agent import BaseAgent
from embodied_cd.common.mixin import FewShotMixIn
from embodied_cd.trl.models.core import (
    _Type_Decoding,
    generation,
    greedy_generation,
    beam_token_generation,
    generation_using_cache,
    beam_token_generation_using_cache,
)


class ReflexionAgent(BaseAgent, FewShotMixIn):
    name = "reflexion"

    def __init__(
        self,
        model=None,
        tokenizer=None,
        env_name: str = "virtualhome",
        decoding_strategy: _Type_Decoding = "beam-action",
        feedback_window_size: int = 8,
        use_cache=True,
    ):

        super().__init__()

        self.env_name = env_name
        self.env_prompt = self._load_env_prompt(env_name)
        self.action_dict = PromptTemplate.load_env_action_dict(env_name)

        if isinstance(model, str):
            self.model = self._prepare_model(model)
        else:
            self.model, self.tokenizer = model, tokenizer
        self.few_shot_prompt = None
        self.prompt_cache = DynamicCache()

        self.self_reflection = []
        self.oracle_reflection = []
        self.feedback_window_size = feedback_window_size
        self.use_cache = use_cache

        self.decoding_strategy = decoding_strategy

    def get_action(self, obs):
        self.prompt += obs + "\n> think: "

        start_time = time.time()
        print("Before Think: ", self.prompt_cache._seen_tokens)
        with torch.no_grad():
            if self.use_cache:
                generation_output = generation_using_cache(
                    self.model,
                    self.tokenizer,
                    self.prompt,
                    prompt_cache=self.prompt_cache,
                    **BaseAgent.default_gen_params,
                )
                self.prompt_cache = generation_output.prompt_cache
            else:
                generation_output = generation(
                    self.model,
                    self.tokenizer,
                    self.prompt,
                    **BaseAgent.default_gen_params,
                )
        think = generation_output.response.split("\n")[0].strip()

        self.prompt += think + "\nOK.\n> "
        print("After Think: ", self.prompt_cache._seen_tokens)
        print(f"Think: {think} ({time.time() - start_time:.2f}s)")

        start_time = time.time()
        with torch.no_grad():
            if self.decoding_strategy == "beam-action":
                object_list = PromptTemplate.get_object_list(obs)
                if self.use_cache:
                    generation_output = beam_token_generation_using_cache(
                        self.model,
                        self.tokenizer,
                        self.prompt,
                        self.action_dict,
                        object_list,
                        prompt_cache=self.prompt_cache,
                    )
                else:
                    generation_output = beam_token_generation(
                        self.model,
                        self.tokenizer,
                        self.prompt,
                        self.action_dict,
                        object_list,
                    )
            elif self.decoding_strategy == "greedy":
                generation_output = generation(
                    self.model,
                    self.tokenizer,
                    self.prompt,
                    **BaseAgent.default_gen_params,
                )
        action = generation_output.response
        action = action.strip().split("\n")[0]
        print(self.prompt_cache._seen_tokens)
        print(f"Action: {action} ({time.time() - start_time:.2f}s)")
        self.prompt += f"{action}\n"

        action = action.replace("[", "").replace("]", "")
        return action

    def reset(self, task, goal):
        assert self.few_shot_prompt, "Make sure to call load_few_shot_prompt() first."

        self.goal = goal
        self.prompt = (
            "Interact with a household to solve a task. "
            + self.env_prompt
            + "\n\nHere are some examples.\n\n"
            + self.few_shot_prompt
        )

        if len(self.self_reflection) > 0:
            self.prompt += (
                "\n\nHere are previous failed trajectories with reflection\n\n"
            )
            self.prompt += "\n\n".join(self.self_reflection)

        self.prompt += "\n\nHere is the task.\n\n" + f"Your task is to: {self.goal}.\n"
        self.prompt_cache = DynamicCache()

    def forward(
        self,
        instruction,
        state,
        history,
        few_shot_examples=None,
    ):
        return self.get_action(state)

    def load_self_reflection(self, dataset_dir):
        with open(osp.join(dataset_dir, f"self_reflection.json"), "r") as f:
            self.self_reflection = json.load(f)

    def load_oracle_reflection(self, dataset_dir, k=2):
        with open(osp.join(dataset_dir, f"oracle_reflection.json"), "r") as f:
            self.oracle_reflection = "\n\n".join(random.sample(json.load(f), k))

    def self_reflect(self):
        traj = self.prompt.split("Here is the task.")[-1].strip("\n")

        prompt = "Analyse the failed trajectory and reflect on what mistake you made and what you should have done.\n"
        prompt += "Look at the goal, observation, think and action carefully to identify the beginning of the mistake."
        prompt += "\n\nHere are some examples.\n\n"
        prompt += self.oracle_reflection
        prompt += "\n\nHere is the *failed* trajectory.\n\n"
        prompt += traj
        prompt += "\n\nReflection: "

        with torch.no_grad():
            generation_output = generation(
                self.model,
                self.tokenizer,
                prompt,
                **BaseAgent.default_gen_params,
            )
        reflection = generation_output.response.split("\n")[0].strip()
        reflection = traj + "\n\n" + "Reflection: " + reflection

        self.self_reflection.append(reflection)
        self.self_reflection = self.self_reflection[-self.feedback_window_size :]
