import asyncio
import logging
import time
from collections import deque
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import openai

openai.util.logger.setLevel(logging.WARNING)  # type: ignore

import tiktoken

from lm_understanding.prompting.completion import Completer, Completion

CHATGPT_MODELS = ['gpt-4', 'gpt-3.5-turbo']

ChatPrompt = List[Dict[str, str]]


def message(role: str, content: str):
    return {"role": role, "content": content}


def to_messages(prompt: str) -> ChatPrompt:
    default_instructions = "You are a helpful assistant."
    return [
        message("system", default_instructions),
        message("user", prompt),
    ]


class GPT3(Completer):
    dollars_per_token = {
        'text-davinci-001': .02 / 1000,
        'text-davinci-002': .02 / 1000,
        'text-davinci-003': .02 / 1000,
        'gpt-3.5-turbo': .002 / 1000,
        'text-curie-001': .002 / 1000,
        'text-babbage-001': .0005 / 1000,
        'text-ada-001': .02 / 1000,
    }

    def __init__(self, model_name='text-davinci-003', multiple_choice: bool = False, temperature: float = 0.0, max_tokens: int = 1024) -> None:
        self.completion_params: Dict[str, Any] = dict(
            model=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            logprobs=1,
        )
        self.tokens_used = 0
        self._multiple_choice = multiple_choice
        self.enc = tiktoken.encoding_for_model(model_name)

    
    @property
    def multiple_choice(self) -> bool:
        return self._multiple_choice

    @classmethod
    def model_names(cls):
        return list(cls.dollars_per_token.keys())

    @property
    def name(self) -> str:
        return self.completion_params['model']

    @property
    def info(self) -> Dict[str, Any]:
        return dict(
            **self.completion_params,
            multiple_choice=self.multiple_choice,
            tokens_used=self.tokens_used,
            dollars_spent=self.dollars_per_token[self.completion_params['model']] * self.tokens_used
        )

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}: ({str(self.completion_params)})'

    def __call__(self, prompts: List[str]) -> List[Completion]:
        options = ['Yes', 'No'] if self.multiple_choice else None
        completion_params = deepcopy(self.completion_params)
        if options is not None:
            completion_params['logit_bias'] = {self.enc.encode(o)[0]: 100 for o in options}
            completion_params['logprobs'] = len(options)
            completion_params['max_tokens'] = 1
        response = openai.Completion.create(
            prompt=prompts,
            **completion_params
        )
        self.tokens_used += response['usage']['total_tokens'] # type: ignore            
        return [
            Completion(
                prompt,
                choice['text'],
                dict(
                    completion_params=completion_params,
                    finish_reason=choice['finish_reason']
                ),
                logprobs=[dict(lp) for lp in choice['logprobs']['top_logprobs'] if lp is not None],
                token_logprobs=choice['logprobs']['token_logprobs'],
                tokens=choice['logprobs']['tokens'],
            ) for prompt, choice in zip(prompts, response['choices'])  # type: ignore
        ]
   

class ChatCompleter(Completer):
    dollars_per_token_prompt = {
        'gpt-3.5-turbo': .002 / 1000,
        'gpt-4': .03 / 1000,
    }
    dollars_per_token_completion = {
        'gpt-3.5-turbo': .002 / 1000,
        'gpt-4': .06 / 1000,
    }
    def __init__(
                self,
                model_name: str = 'gpt-3.5-turbo',
                multiple_choice: bool = False,
                temperature: float = 0,
                max_tokens: int = 1024,
                n: int = 1,
                log_fn: Optional[Callable[[List[str]], None]] = None,
                verbose: bool = False,
                max_concurrent_tasks: int = 10,
                timeout_seconds: int = 600,
                max_tries: int = 4
            ) -> None:
        self.model_name = model_name
        self._multiple_choice = multiple_choice
        self.verbose = verbose
        self.completion_params: Dict[str, Any] = dict(
            model=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            n=n
        )
        self.enc = tiktoken.encoding_for_model(model_name)
        self.retry_pause_seconds = 3
        self.max_tries = max_tries
        self.log_fn = log_fn
        self.rpm = 100 if model_name == 'gpt-4' else 3500
        self.tpm = 40000 if model_name == 'gpt-4' else 90000
        self.timeout_seconds = timeout_seconds
        self.max_concurrent_tasks = max_concurrent_tasks
        self._reset()

    def _reset(self):
        self.prompt_tokens_used = 0
        self.completion_tokens_used = 0
        self.last_called_time = 0
        self.completions = []
        self.request_time_queue = deque(maxlen=self.rpm)
        self.token_queue = []

    @property
    def name(self) -> str:
        return self.model_name

    @property
    def multiple_choice(self) -> bool:
        return self._multiple_choice

    @property
    def next_ready_time(self) -> float:
        if len(self.request_time_queue) < self.rpm:
            request_ready = time.time()
        else:
            request_ready = self.request_time_queue[0] + 60
        self.token_queue = [(added_time, tokens) for added_time, tokens in self.token_queue if added_time > time.time() - 60]
        tokens_in_the_last_minute = sum([t for _, t in self.token_queue])
        mean_tokens = np.array([t for _, t in self.token_queue]).mean() if self.token_queue else 0
        if (tokens_in_the_last_minute + mean_tokens) > self.tpm:
            token_ready = self.token_queue[0][0] + 60
        else:
            token_ready = request_ready
        return max(request_ready, token_ready)

    async def rate_limiter(self):
        wait_time = max(self.next_ready_time - time.time(), 0)
        await asyncio.sleep(wait_time)

    def tokens(self, prompt: ChatPrompt) -> int:
        tokenized = self.enc.encode(' '.join(str(m) for m in prompt))
        return (len(tokenized) + self.completion_params['max_tokens']) * self.completion_params['n']

    @property
    def info(self) -> Dict[str, Any]:
        return dict(
            **self.completion_params,
            multiple_choice=self.multiple_choice,
            prompt_tokens_used=self.prompt_tokens_used,
            completion_tokens_used=self.completion_tokens_used,
            dollars_spent=self.dollars_per_token_prompt[self.model_name] * self.prompt_tokens_used + self.dollars_per_token_completion[self.model_name] * self.completion_tokens_used
        )

    def parse_response(self, response) -> List[str]:
        self.completed += 1
        if self.verbose:
            print('Number of Completions:', self.completed)
        self.prompt_tokens_used += response['usage']['prompt_tokens'] # type: ignore   
        self.completion_tokens_used += response['usage']['completion_tokens'] # type: ignore   
        choices = response['choices'] # type: ignore
        completions = [choice['message']['content'] for choice in choices]
        if self.verbose:
            for c in completions:
                print(c)
        self.completions.extend(completions)
        if self.log_fn is not None:
            self.log_fn(self.completions)
        return completions

    async def async_complete(self, messages: ChatPrompt, options: Optional[List[str]] = None) -> List[Completion]:
        completion_params = deepcopy(self.completion_params)
        if options is not None:
            completion_params['logit_bias'] = {self.enc.encode(o)[0]: 100 for o in options}
            completion_params['max_tokens'] = 1

        request_no = self.requests
        self.requests += 1
        if self.verbose:
            print('Requests Made:', self.requests)
        for try_no in range(self.max_tries):
            if try_no > 0:
                self.request_time_queue.append(time.time())
                self.token_queue.append((time.time(), self.tokens(messages)))
            try:
                response = await asyncio.wait_for(openai.ChatCompletion.acreate(
                    messages=messages,
                    **completion_params
                ), timeout=self.timeout_seconds)
                if self.verbose:
                    print(f'Request {request_no} completed / {self.total_prompts} Total Prompts')
                texts = self.parse_response(response)
                completions = [Completion(str(messages), t) for t in texts]
                if options:
                    for completion in completions:
                        completion.logprobs = [{o: 0 if o == completion.text else -100 for o in options}]
                return completions
            except asyncio.TimeoutError:
                print(f"Request {request_no} timed out (try {try_no + 1}).")
            except Exception as e:
                print(e)
            retry_time = self.retry_pause_seconds * (4 ** try_no)
            print(f'Request {request_no} retrying in {retry_time} seconds (try {try_no + 1} / {self.max_tries})')
            await asyncio.sleep(retry_time)
            await self.rate_limiter()
        print(f'Max retries attempted, giving up on request {request_no}')
        return [Completion(str(messages), 'invalid_completion') for _ in range(completion_params['n'])]

    async def complete_multiple(self, prompts: List[ChatPrompt], options: Optional[List[str]] = None) -> List[List[Completion]]:
        self.completed = 0
        self.requests = 0
        self.total_prompts = len(prompts)
        semaphore = asyncio.Semaphore(self.max_concurrent_tasks)

        async def complete_and_release_semaphore(prompt, options):
            async with semaphore:
                return await self.async_complete(prompt, options)

        tasks = []
        for prompt in prompts:
            await self.rate_limiter()
            self.request_time_queue.append(time.time())
            self.token_queue.append((time.time(), self.tokens(prompt)))
            tasks.append(asyncio.create_task(complete_and_release_semaphore(prompt, options)))
        return await asyncio.gather(*tasks)
    
    def __call__(self, prompts: List[Union[ChatPrompt, str]], **kwargs) -> List[Completion]:
        options = None if not self.multiple_choice else ['Yes', 'No']
        messages: List[ChatPrompt] = [to_messages(p) if isinstance(p, str) else p for p in prompts]
        completions: List[List[Completion]] = asyncio.run(self.complete_multiple(messages, options))
        return [c for cl in completions for c in cl]