import math
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from tqdm import tqdm


@dataclass
class Completion:
    prompt: str
    text: str
    meta: Optional[Dict] = None
    logprobs: Optional[List[Dict[str, float]]] = None
    token_logprobs: Optional[List[Union[None, float]]] = None
    tokens: Optional[List[str]] = None

    def __repr__(self):
        return f'{self.__class__.__name__}: Prompt: {self.prompt}, Completion Text: {self.text}'

    @property
    def mc_answer_probs(self):
        if self.logprobs is None:
            return None
        return {token.strip(): math.exp(logprob) for token, logprob in self.logprobs[0].items()}

    @property
    def total_answer_prob(self):
        if not self.meta:
            return None
        return self.meta.get('total_answer_prob')

    def to_dict(self, include_logprobs: bool = False) -> Dict[str, Any]:
        ret: Dict[str, Any] = dict(prompt=self.prompt, completion=self.text, meta=self.meta)
        if self.logprobs and len(self.logprobs) == 1:
            ret['answer_probs'] = self.mc_answer_probs
        if include_logprobs:
            ret['logprobs'] = self.logprobs
        return ret


class Completer(ABC):
    @abstractmethod
    def __call__(self, prompts: List[str], *args, **kwargs) -> List[Completion]:
        ...
    
    @property
    @abstractmethod
    def name(self) -> str:
        ...

    @property
    @abstractmethod
    def info(self) -> Dict[str, Any]:
        ...

    @property
    @abstractmethod
    def multiple_choice(self) -> bool:
        ...


class CompleterWrapper(Completer):
    def __init__(self, completer: Completer) -> None:
        self.completer = completer

    def __repr__(self) -> str:
        return str(self.completer)

    @property
    def name(self) -> str:
        return self.completer.name

    @property
    def info(self) -> Dict[str, Any]:
        return self.completer.info

    @property
    def multiple_choice(self) -> bool:
        return self.completer.multiple_choice


class ThrottledCompleter(CompleterWrapper):
    def __init__(self, completer: Completer, pause_seconds: int = 0, completion_batch_size: int = 1, retry_pause_seconds: int = 0) -> None:
        super().__init__(completer)
        self.iteration_pause_seconds = pause_seconds
        self.retry_pause_seconds = retry_pause_seconds
        self.max_tries = 5 if retry_pause_seconds else 1
        self.last_called_time = 0
        self.completion_batch_size = completion_batch_size

    def throttled_complete(self, prompts: List[str], *args, **kwargs) -> List[Completion]:
        tries = 0
        completions = []
        wait_for_seconds = self.last_called_time + self.iteration_pause_seconds - time.time()
        while not completions and tries < self.max_tries:
            if wait_for_seconds > 0:
                time.sleep(wait_for_seconds)
            try:
                completions = self.completer(prompts, *args, **kwargs)
            except Exception as e:
                print(e)
                print(f'Waiting for {self.retry_pause_seconds} seconds')
                tries += 1
                wait_for_seconds = self.retry_pause_seconds
                completions = []
        self.last_called_time = time.time()
        return completions

    def __call__(self, prompts: List[str], *args, progress_bar: bool = False, **kwargs) -> List[Completion]:
        n_batches = math.ceil(len(prompts) / self.completion_batch_size)
        completions = []
        for i in tqdm(list(range(n_batches)), disable=not progress_bar, desc=f'Completing: {self.name}', unit_scale=self.completion_batch_size):
            prompt_batch = prompts[i*self.completion_batch_size:(i+1)*self.completion_batch_size]
            completion_batch = self.throttled_complete(prompt_batch, *args, **kwargs)
            completions.extend(completion_batch)
        return completions


class CompletionLogger(CompleterWrapper):
    def __call__(self, prompts: List[str], *args, **kwargs) -> List[Completion]:
        completions = self.completer(prompts, *args, **kwargs)
        print('---COMPLETIONS---')
        for p, c in zip(prompts, completions):
            print(c)
            print('---')
        return completions

