
import logging
import os
import time

import requests as _requests
from efficiency_benchmark.dependencies.lm_eval.base import BaseLM
from tqdm import tqdm

logger = logging.getLogger(__name__)


def textsynth_completion(**kwargs):
    
    backoff_time = 3
    while True:
        try:
            return _requests.post(**kwargs)
        except _requests.exceptions.RequestException:
            import traceback

            traceback.print_exc()
            time.sleep(backoff_time)
            backoff_time *= 1.5


class TextSynthLM(BaseLM):
    def __init__(self, engine, truncate=False):
        
        super().__init__()

        self.engine = engine
        self.truncate = truncate
        self.api_url = "https://api.textsynth.com"
        
        self.api_key = os.environ["TEXTSYNTH_API_SECRET_KEY"]

    @property
    def eot_token_id(self):
        
        raise NotImplementedError()

    @property
    def max_length(self):
        
        return 2048

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        
        raise NotImplementedError()

    @property
    def device(self):
        
        raise NotImplementedError()

    def tok_encode(self, string: str):
        
        raise NotImplementedError()

    def tok_decode(self, tokens):
        
        raise NotImplementedError()

    def loglikelihood(self, requests):
        res = []
        for context, continuation in tqdm(requests):
            response = textsynth_completion(
                url=self.api_url + "/v1/engines/" + self.engine + "/logprob",
                headers={"Authorization": "Bearer " + self.api_key},
                json={"context": context, "continuation": continuation},
            )
            resp = response.json()
            if "logprob" in resp:
                logprob = resp["logprob"]
                is_greedy = resp["is_greedy"]
                res.append((logprob, is_greedy))
            else:
                logger.error(f"The following response does not contain `logprobs`. Got:\n{resp}")
                assert False
        return res

    def loglikelihood_rolling(self, requests):
        
        
        
        
        raise NotImplementedError(
            "`loglikelihood_rolling` is currently not supported due to lack of "
            "input tokenization support from TextSynth."
        )

    def greedy_until(self, requests):
        if not requests:
            return []

        res = []
        for request in tqdm(requests):
            inp = request[0]
            until = request[1]
            response = textsynth_completion(
                url=self.api_url + "/v1/engines/" + self.engine + "/completions",
                headers={"Authorization": "Bearer " + self.api_key},
                json={
                    "prompt": inp,
                    "max_tokens": self.max_gen_toks,
                    "top_k": 1,
                    "stop": until,
                },
            )
            resp = response.json()
            if "text" in resp:
                s = resp["text"]
                res.append(s)
            else:
                logger.error(f"The following response does not contain generated `text`. " "Got:\n{resp}")
                assert False
        return res

    def _model_call(self, inps):
        
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        
        raise NotImplementedError()
