import os
import os.path as osp
from abc import ABC, abstractmethod
import json
import random
from typing import Union, List

import torch
from openai import OpenAI
from transformers import pipeline
 
from embodied_cd.common.dataset_utils import VirtualHomeDataset


class BaseAgent(ABC):
    default_gen_params = {
        "do_sample": False,
        "top_k": None,
        "top_p": None,
        "temperature": None,
        "num_beams": 1,
        "repetition_penalty": 1.0,
    }
    
    @abstractmethod
    def reset(self, task, goal):
        pass

    @abstractmethod
    def forward(self, obs):
        pass

    def _load_few_shot_prompt(self, env_name, agent_name, task, k=2):
        if "continual" in env_name:
            env_name = env_name.replace("continual_", "")

        filepath = osp.join(os.getcwd(), f"prompts/few_shot/{env_name}.json")
        prefix = f"{agent_name}_{task}"

        with open(filepath, "r") as f:
            all_prompts = json.load(f)
            prompts = [all_prompts[p] for p in all_prompts if p.startswith(prefix)]

        few_shot_prompts = random.sample(prompts, k)
        return few_shot_prompts

    def _load_env_prompt(self, env_name):
        """load environment information (e.g. action list) as prompt"""
        if "continual" in env_name:
            env_name = env_name.replace("continual_", "")

        filepath = osp.join(f"configs/{env_name}/env_prompt.json")
        with open(filepath, "r") as f:
            all_prompts = json.load(f)
            return all_prompts["env_prompt"]

        # filepath = osp.join(os.getcwd(), f"prompts/instruction.json")
        # with open(filepath, "r") as f:
        #     all_prompts = json.load(f)
        #     return all_prompts[env_name]

    def _convert_to_chat(self, user: List[str], assistant: List[str]):
        chat_query = []
        for i in range(len(assistant)):
            chat_query.append({"role": "user", "content": user[i]})
            chat_query.append({"role": "assistant", "content": assistant[i]})
        chat_query.append({"role": "user", "content": user[-1]})
        return chat_query
    
    def _convert_to_completion(self, user: str):
        return f"{VirtualHomeDataset.additional_special_tokens['query']} {user}\n{VirtualHomeDataset.additional_special_tokens['response']}"

    def _call_model(self, prompt, instruct: bool = False):
        if isinstance(self.model, str):
            client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
            response = client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": prompt,
                    }
                ],
                model=self.model,
                temperature=0,
                max_tokens=80,
                top_p=1,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                stop="\n",
            )
            response = response.choices[0].message.content
        else:
            prompt = self._prepare_query(prompt, instruct=instruct)
            response = self.model(
                prompt,
                do_sample=False,
                top_k=None,
                top_p=None,
                temperature=None,
                max_new_tokens=80,
                pad_token_id=self.model.tokenizer.eos_token_id,
                num_beams=1,
                repetition_penalty=1.0,
            )[0]["generated_text"]

            if not instruct:
                response = response[len(prompt) :]
                response = response.strip().split("\n")[0]
            else:
                response = response[1]["content"]
        return response

    def _prepare_query(self, query, instruct: bool = False):
        if instruct:
            return [{"role": "user", "content": query}]
        return query

    def _prepare_model(self, model, tokenizer=None):
        if model in ["gpt-4o-mini", "gpt-4o"]:
            return model

        model = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            model_kwargs={
                # "torch_dtype": torch.bfloat16,
                "quantization_config": {"load_in_8bit": True},
            },
            device_map="auto",
        )
        return model
