import asyncio
import unittest

import openai
import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    popen_launch_server,
)


class TestCacheReport(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.min_cached = 5
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=300,
            other_args=[
                "--chunked-prefill-size=40",
                "--enable-cache-report",
            ],
        )
        cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
        cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")

        usage = cls.run_openai(cls, "1").usage
        # we can assume that our request is of size 1, plus the total template size
        # ideally we would like to know the begin size / end size of the template to be more precise
        total_template_size = usage.prompt_tokens - 1
        print(f"template size: {total_template_size}")
        usage2 = cls.run_openai(cls, "2").usage
        assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
        cls.min_cached = max(
            usage2.prompt_tokens_details.cached_tokens,
            total_template_size - usage2.prompt_tokens_details.cached_tokens,
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
        response = requests.post(
            self.base_url + "/generate",
            # we use an uncommon start to minimise the chance that the cache is hit by chance
            json={
                "text": "_ The capital of France is",
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
                    "max_new_tokens": 128,
                    "n": n,
                    "stop_token_ids": [119690],
                },
                "stream": False,
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "logprob_start_len": 0,
            },
        )
        return response

    def run_openai(self, message):
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                # {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": message},
            ],
            temperature=0,
            max_tokens=100,
        )
        return response

    async def run_openai_async(self, message):
        response = await self.aclient.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "user", "content": message},
            ],
            temperature=0,
            max_tokens=100,
        )
        return response

    def cache_report_openai(self, message):
        response = self.run_openai(message)
        print(
            f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
        )
        first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
        # assert int(response.usage.cached_tokens) == 0
        assert first_cached_tokens < self.min_cached
        response = self.run_openai(message)
        cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
        print(f"openai second request cached_tokens: {cached_tokens}")
        assert cached_tokens > 0
        assert cached_tokens == int(response.usage.prompt_tokens) - 1
        return first_cached_tokens

    async def cache_report_openai_async(self, message):
        response = await self.run_openai_async(message)
        cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
        prompt_tokens = int(response.usage.prompt_tokens)
        return cached_tokens, prompt_tokens

    def test_generate(self):
        print("=" * 100)
        response = self.run_decode()
        # print(response.json())
        cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
        print(f"sglang first request cached_tokens: {cached_tokens}")
        print(
            f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
        )
        # can't assure to be 0: depends on the initialisation request / if a template is used with the model
        assert cached_tokens < self.min_cached
        response = self.run_decode()
        cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
        print(f"sglang second request cached_tokens: {cached_tokens}")
        print(
            f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
        )
        assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1

    def test_cache_split_prefill_openai(self):
        print("=" * 100)
        self.cache_report_openai(
            "€ This is a very long and unique text that should not be already cached, the twist is"
            " that it should be longer than the chunked-prefill-size, so it should be split among"
            " several prefill requests. Still, it shouldn't be cached"
        )

    def test_cache_report_openai(self):
        print("=" * 100)
        # warm up the cache, for the template
        self.run_openai("Introduce the capital of France.")

        first_cached_tokens_1 = self.run_openai(
            "How many sparrow do you need to lift a coconut?"
        ).usage.prompt_tokens_details.cached_tokens

        usage_2 = self.run_openai("* sing something about cats").usage
        first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
        # first request may not have 0 cached tokens, but if they only have the template in common they
        # should be the same once the cache is warmed up
        assert first_cached_tokens_1 == first_cached_tokens_2

        resp = self.run_openai("* sing something about cats and dogs")
        print(resp.usage)

        resp = self.run_openai("* sing something about cats, please")
        print(resp.usage)
        assert (
            resp.usage.prompt_tokens_details.cached_tokens
            >= usage_2.prompt_tokens - self.min_cached
        )

    def test_cache_report_openai_async(self):
        print("=" * 100)

        async def run_test():
            task0 = asyncio.create_task(
                self.cache_report_openai_async(
                    "first request, to start the inference and let the next two request be started in the same batch"
                )
            )
            await asyncio.sleep(0.05)  # to force the first request to be started first
            task1 = asyncio.create_task(
                self.cache_report_openai_async(
                    "> can the same batch parallel request use the cache?"
                )
            )
            task2 = asyncio.create_task(
                self.cache_report_openai_async(
                    "> can the same batch parallel request use the cache?"
                )
            )
            result0, result1, result2 = await asyncio.gather(task0, task1, task2)

            cached_tokens0, prompt_tokens0 = result0
            cached_tokens1, prompt_tokens1 = result1
            cached_tokens2, prompt_tokens2 = result2

            print(
                f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
            )
            print(
                f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
            )
            print(
                f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
            )

            # Assert that no requests used the cache (becausefirst is alone, and the next two are in the same batch)
            # If a new optimisation limiting starting request with same prefix at the same time was added
            # to maximise the cache hit, this would not be true
            assert cached_tokens1 == cached_tokens2 == cached_tokens0

        asyncio.run(run_test())


if __name__ == "__main__":
    unittest.main()
