from .base import LMAgent
import openai
import logging
import traceback
from .datatypes import Action
import backoff
from .prompt_template.Template import QueryGeneratorTemplate
from string import Template
import json
import re
from .openai_api import OpenAIClient

LOGGER = logging.getLogger("Root")


class OpenAIMultiTurnLMAgent(LMAgent):
    def __init__(self, api_type, config):
        super().__init__(config)
        assert api_type in ["azure", "openai"]

        if api_type == "openai":
            self.api = OpenAIClient(config)

        self.usage_profiles = []
        self.max_try = 3

    @backoff.on_exception(
        backoff.fibo,
        # https://platform.openai.com/docs/guides/error-codes/python-library-error-types
        (
            openai.APIError,
            openai.Timeout,
            openai.RateLimitError,
            openai.APIConnectionError,
        ),
    )
    def call_lm(self, messages):

        # Prepend the prompt with the system message
        js = self.api.chat_sync(messages=messages)

        response = js.choices[0].message.content

        # usage contains input token, output token, times
        usage = {
            "promptTokens": js.usage.prompt_tokens,
            "completionTokens": js.usage.completion_tokens,
            "totalTokens": js.usage.total_tokens,
            "costTimeMillis": 0
        }

        final_response = response.replace('"', "'")
        self.usage_profiles.append(usage)

        return final_response, usage

    def parser_results(self, response):
        # assume the response can be splitted by \n, and every element is a dict
        retrieval_necessity = re.search(r'Retrieval Necessity:\s*(\w+)', response)
        queries = re.search(r'Query For Search Engine:\s*([\s\S]+)', response)

        # Output the results
        retrieval_necessity = retrieval_necessity.group(1) if retrieval_necessity else None
        queries = queries.group(1).split("\n") if queries else None

        if retrieval_necessity in ["yes", "Yes", "YES"]:
            retrieval_necessity = True

        if (retrieval_necessity is None) or (queries is None):
            return False, []

        return retrieval_necessity, queries

    def act(self, template: str, **kwargs):

        # TODO construct messages here

        llm_query = Template(template).substitute(kwargs)

        messages = [{
            "role": "user",  # user / assistant
            "content": f"{llm_query}",
        }]

        for _ in range(self.max_try):

            try:
                lm_output, usage = self.call_lm(messages)
                retrieval_necessity, queries = self.parser_results(lm_output)

                return retrieval_necessity, queries
            except KeyboardInterrupt:
                exit()
            except Exception as E:  # mostly due to model context window limit
                tb = traceback.format_exc()
                print(f"Some error happens when calling generator agent: \n{tb}")

        return False, []

