import os
import openai
from utils.num_tokens_from_messages import num_tokens_from_messages
from utils.interact_with_local_llm import send_message_to_vicuna, send_message_to_llama
from arguments import *
import time
import random
import requests
from utils.scorer.score_functions import get_bert_score as get_score
from typing import List
from numpy.typing import NDArray

cwd = os.getcwd()
openai_keys_file = os.path.join(cwd, "openai_keys.txt")


TOKEN_LIMIT_TABLE = {
    "gpt-4": 8192,
    "gpt-4-0314": 8192,
    "gpt-4-32k": 32768,
    "gpt-4-32k-0314": 32768,
    "gpt-3.5-turbo-0301": 4096,
    "gpt-3.5-turbo": 4096,
    "gpt-35-turbo": 4096,
    "text-davinci-003": 4080,
    "code-davinci-002": 8001,
    "text-davinci-002": 2048,
    "vicuna-33B": 2048,
    "Llama2-70B-chat": 2048,
    "gpt-35-turbo-16k": 16384,
    "Llama2-13B-chat": 2048,
    "vicuna-13B": 2048,
    'Llama2-33B-chat': 2048,
    'Llama2-7B-chat': 2048
}
=

class LLMAgent:
    """
    This agent uses LLM to generate actions.
    """
    def __init__(self, model = "gpt-3.5-turbo", scorer=get_score):
        self.model = model
        self.dialogue = []
        self.agent_index = None
        self.message = ''

        self.openai_api_keys = self.load_openai_keys()

        self.scorer = scorer

    def set_scorer(self, scorer):
        self.scorer = scorer

    @staticmethod
    def load_openai_keys():
        with open(openai_keys_file, "r") as f:
            context = f.read()
        return context.split('\n')


    def update_key(self):
        curr_key = self.openai_api_keys[0]
        openai.api_key = curr_key
        self.openai_api_keys.pop(0)
        self.openai_api_keys.append(curr_key)

    def query(self, stop=None, temperature=0.0, top_p=0.95):
        self.restrict_dialogue()
        # TODO add retreat mech to cope with rate limit
        self.update_key()

        if self.model in ['gpt-3.5-turbo-0301', 'gpt-3.5-turbo', 'gpt-4']:
            response = openai.ChatCompletion.create(
                model=self.model,
                messages=self.dialogue,
                temperature=temperature,
                top_p=top_p
            )
        elif self.model in ['vicuna-33B', 'vicuna-13B']:
            local_config = {'temperature': temperature, 'top_p': top_p, 'repetition_penalty': 1.1}
            response = send_message_to_vicuna(self.dialogue, local_config)

        elif self.model in ['Llama2-70B-chat', 'Llama2-13B-chat', 'Llama2-33B-chat', 'Llama2-7B-chat']:
            local_config = {'temperature': temperature, 'top_p': top_p, 'repetition_penalty': 1.1}
            response = send_message_to_llama(self.dialogue, local_config)

        else:
            response = openai.Completion.create(
                model=self.model,
                prompt=str(self.dialogue),
                max_tokens=1024,
                stop=stop,
                temperature=temperature,
                n=1,
                top_p=top_p
            )

        return response

    # @staticmethod
    def parse_response(self, response):
        if self.model in ['gpt-3.5-turbo-0301', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-0314', "gpt-35-turbo", "gpt-35-turbo-16k"]:
            return dict(response["choices"][0]["message"])

        elif self.model in ['vicuna-33B', 'Llama2-70B-chat', 'Llama2-13B-chat', 'vicuna-13B', 'Llama2-33B-chat', 'Llama2-7B-chat']:
            # return {'role': 'assistant', 'content': extract_json(response)}
            return {'role': 'assistant', 'content': response}

        else:
            # self.model in ['text-davinci-003', 'code-davinci-002']

            return {'role': 'assistant', 'content': response["choices"][0]["text"][2:]}

    def restrict_dialogue(self):
        limit = TOKEN_LIMIT_TABLE[self.model]
        # ipdb.set_trace()
        """
        The limit on token length for gpt-3.5-turbo-0301 is 4096.
        If token length exceeds the limit, we will remove the oldest messages.
        """
        # TODO validate that the messages removed are obs and actions
        while num_tokens_from_messages(self.dialogue) >= limit:
            for _ in range(2):
                self.dialogue.pop(-1)

    def save_dialogues(self, save_path):
        with open(save_path, "a+") as f:
            f.write(str(self.dialogue) + '\n')

    def communicate(self, content, parse_choice_tag=False):
        self.dialogue.append({"role": "user", "content": content})
        while True:
            try:
                raw_response = self.query()
                self.message = self.parse_response(raw_response)
                self.dialogue.append(self.message)
                response = self.message["content"]
                break

            except Exception as e:
                print(e)
                print("retrying...")
                continue

        return response

    def reset(self):
        # super().reset()
        self.dialogue = []
        self.agent_index = None
        self.message = ''
        # self.gpt_extractor.reset()

        self.openai_api_keys = self.load_openai_keys()
        
    def get_Semantic_Similarity(self, message: List[str], ref: List[str]) -> List[NDArray]:
        return self.scorer(message, ref)
