import asyncio
import copy
import logging
import time
import uuid
from contextlib import nullcontext
from typing import AsyncGenerator

import openai
import tiktoken
from tenacity import (
    after_log,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_random_exponential,
)

from log import LoggingComponent, get_logger

tenacity_logger = get_logger("tenacity")
CL100K_ENCODER = tiktoken.get_encoding("cl100k_base")
P50K_ENCODER = tiktoken.get_encoding("p50k_base")

class AsyncAPIResourcePatcher:
    def __init__(self, resource: openai._resource.AsyncAPIResource):
        self.client = copy.copy(resource._client)
        self.old_process_response = self.client._process_response
        self.client._process_response = self._process_response_wrapper
        self.resource = resource.__class__(self.client)

    @classmethod
    def maybe_process_duration(cls, s):
        if s.endswith("ms"):
            return int(s[:-2])
        if s.endswith("s"):
            return int(1000 * float(s[:-1]))
        return int(s)

    @classmethod
    def insert_headers(cls, obj, headers):
        obj.rate_limits = {
            k[12:]: cls.maybe_process_duration(v)
            for k, v in headers.items()
            if k.startswith("x-ratelimit-")
        }
        obj.openai = {k[7:]: v for k, v in headers.items() if k.startswith("openai-")}
        obj.openai["processing-ms"] = int(obj.openai["processing-ms"])
        obj.headers = {k: v for k, v in headers.items() if not k.startswith(("x-ratelimit-", "openai-"))}

    async def _process_response_wrapper(self, *args, **kwargs):
        processed = await self.old_process_response(*args, **kwargs)
        self.insert_headers(processed, kwargs["response"].headers)
        return processed

    async def create(self, *args, **kwargs):
        return await self.resource.create(*args, **kwargs)
        

class OpenAIRateLimiter(LoggingComponent):
    def __init__(
        self,
        resource,
        name="oai-ratelimit",
        token_limit=None,
        request_limit=None,
        token_loadfactor=0.9,
        request_loadfactor=0.9,
        token_margin=0.2,
    ):
        LoggingComponent.__init__(self, name)
        self.resource = resource

        # Parameters
        self.token_loadfactor = token_loadfactor
        self.request_loadfactor = request_loadfactor
        self.token_margin = token_margin

        # Internal state
        self.pending_lock = asyncio.Lock()
        self.return_event = asyncio.Event()
        self.tokens_inflight = 0
        self.requests_inflight = 0

        # External state
        self.token_limit = token_limit
        self.request_limit = request_limit
        self.last_update = None
        self.tokens_used = 0
        self.requests_used = 0

    @classmethod
    def num_tokens_consumed_by_request(
        self, max_tokens, messages=None, n=1, margin=0.2, **_
    ):
        if messages is None:
            # TODO: Why is this true?
            return n * max_tokens + 1

        # TODO: There seems to be a bug in the way OpenAI counts tokens
        # compare with, e.g.
        # https://github.com/shobrook/openlimit/blob/master/openlimit/utilities/token_counters.py
        num_tokens = max_tokens
        for message in messages:
            num_tokens += (
                4  # Every message follows <im_start>{role/name}\n{content}<im_end>\n
            )
            for key, value in message.items():
                num_tokens += len(CL100K_ENCODER.encode(value))

                if key == "name":  # If there's a name, the role is omitted
                    num_tokens -= 1

        num_tokens += 2  # Every reply is primed with <im_start>assistant

        return n * num_tokens * (1 + margin)

    def estimate_deltas(self):
        if self.last_update is None:
            return None, None

        delta_ns = time.monotonic_ns() - self.last_update
        delta_tokens = delta_ns * self.token_limit / 6e10
        delta_requests = delta_ns * self.request_limit / 6e10
        return delta_tokens, delta_requests

    def estimate_current_load(self):
        delta_tokens, delta_requests = self.estimate_deltas()
        return (
            max(self.tokens_used - delta_tokens, 0) + self.tokens_inflight,
            max(self.requests_used - delta_requests, 0) + self.requests_inflight,
        )

    def estimate_delay(self, tokens):
        token_max = self.token_loadfactor * self.token_limit
        if tokens > token_max:
            self.log.error(f"Request is too large {tokens} > {token_max}")

        token_load, request_load = self.estimate_current_load()
        return max(
            (
                self.tokens_used
                + self.tokens_inflight
                + tokens
                - self.token_loadfactor * self.token_limit
            )
            * 60
            / self.token_limit,
            (
                self.requests_used
                + self.requests_inflight
                + 1
                - self.request_loadfactor * self.request_limit
            )
            * 60
            / self.request_limit,
        )

    def request_would_overload(self, tokens):
        token_load, request_load = self.estimate_current_load()
        return (
            token_load + tokens > self.token_loadfactor * self.token_limit
            or request_load + 1 > self.request_loadfactor * self.request_limit
        )

    def update(self, result):
        delta_tokens, delta_requests = self.estimate_deltas()
        rate_limits = result.rate_limits
        model = result.model
        self.log.info(
            f"{model}: tokens "
            f"{rate_limits['remaining-tokens']}/{rate_limits['limit-tokens']} "
            f"requests "
            f"{rate_limits['remaining-requests']}/{rate_limits['limit-requests']}"
        )

        self.token_limit = rate_limits["limit-tokens"]
        self.request_limit = rate_limits["limit-requests"]
        self.last_update = time.monotonic_ns()

        if delta_tokens is not None:
            # Take the worst of the current estimate and the current
            # result. Allow for up to 1% clock drift without falling
            # permanently behind.
            self.tokens_used = max(
                self.tokens_used - delta_tokens * 1.01,
                self.token_limit - rate_limits["remaining-tokens"],
            )
            self.requests_used = max(
                self.requests_used - delta_requests * 1.01,
                self.request_limit - rate_limits["remaining-requests"],
            )
        else:
            self.tokens_used = self.token_limit - rate_limits["remaining-tokens"]
            self.requests_used = self.request_limit - rate_limits["remaining-requests"]

        tokens_remaining = self.token_limit - rate_limits["remaining-tokens"]

    async def create(self, uid=None, **kwargs):
        model = kwargs.get("model")
        tokens = self.num_tokens_consumed_by_request(**kwargs, margin=self.token_margin)

        uid = uid if uid is not None else uuid.uuid4().hex[:6]
        self.log.debug(f"request {uid} ({model}) created {tokens:.2f}")

        async with self.pending_lock:
            self.log.debug(f"request {uid} ({model}) is first in line")
            while self.request_would_overload(tokens):
                self.return_event.clear()
                # Two possible cases to consider
                wait_task = asyncio.create_task(self.return_event.wait())
                wait_time = self.estimate_delay(tokens)
                self.log.debug(
                    f"request {uid} ({model}) waiting for at most {wait_time:.3f}s"
                )
                await asyncio.wait([wait_task], timeout=wait_time)

        self.tokens_inflight += tokens
        self.requests_inflight += 1

        result = await self.resource.create(**kwargs)

        self.tokens_inflight -= tokens
        self.requests_inflight -= 1
        self.update(result)
        self.return_event.set()

        return result


class OpenAIWrapper(LoggingComponent):
    def __init__(self, resource):
        LoggingComponent.__init__(self, "oai-wrapper")
        if not isinstance(resource, AsyncAPIResourcePatcher):
            resource = AsyncAPIResourcePatcher(resource)
        self.resource = resource
        self.rate_limiters = {}
        self.init_events = {}
        self.aliases = {}

    # @retry(
    #     retry=retry_if_exception_type(
    #         (
    #             openai.error.RateLimitError,
    #             openai.error.APIConnectionError,
    #             openai.error.Timeout,
    #             openai.error.ServiceUnavailableError,
    #             openai.error.APIError,
    #         )
    #     ),
    #     after=after_log(tenacity_logger, logging.ERROR),
    #     wait=wait_random_exponential(min=1, max=300),
    #     stop=stop_after_attempt(8),
    # )
    async def create(self, uid=None, **kwargs):
        uid = uid if uid is not None else uuid.uuid4().hex[:6]
        model = kwargs.get("model")

        if model in self.init_events:
            await self.init_events[model].wait()

        if model in self.rate_limiters:
            return await self.rate_limiters[model].create(uid=uid, **kwargs)

        self.init_events[model] = asyncio.Event()
        self.log.debug(f"{uid} running model {model} for the first time")
        result = await self.resource.create(**kwargs)

        actual_model = result.model
        if actual_model != model:
            self.log.debug(f"{uid} found {model} is an alias for {actual_model}")
            self.aliases[model] = actual_model
        if actual_model not in self.rate_limiters:
            limiter = OpenAIRateLimiter(self.resource)
            limiter.update(result)
            self.rate_limiters[actual_model] = limiter

        self.rate_limiters[model] = self.rate_limiters[actual_model]

        self.init_events[model].set()
        self.init_events.pop(model)

        return result

