"""
python3 -m unittest test_sagemaker_server.TestSageMakerServer.test_chat_completion
"""

import json
import unittest

import requests

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    popen_launch_server,
)


class TestSageMakerServer(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.api_key = "sk-123456"
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            api_key=cls.api_key,
        )
        cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def run_chat_completion(self, logprobs, parallel_sample_num):
        data = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": "You are a helpful AI assistant"},
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
            ],
            "temperature": 0,
            "logprobs": logprobs is not None and logprobs > 0,
            "top_logprobs": logprobs,
            "n": parallel_sample_num,
        }

        headers = {"Authorization": f"Bearer {self.api_key}"}

        response = requests.post(
            f"{self.base_url}/invocations", json=data, headers=headers
        ).json()

        if logprobs:
            assert isinstance(
                response["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0][
                    "token"
                ],
                str,
            )

            ret_num_top_logprobs = len(
                response["choices"][0]["logprobs"]["content"][0]["top_logprobs"]
            )
            assert (
                ret_num_top_logprobs == logprobs
            ), f"{ret_num_top_logprobs} vs {logprobs}"

        assert len(response["choices"]) == parallel_sample_num
        assert response["choices"][0]["message"]["role"] == "assistant"
        assert isinstance(response["choices"][0]["message"]["content"], str)
        assert response["id"]
        assert response["created"]
        assert response["usage"]["prompt_tokens"] > 0
        assert response["usage"]["completion_tokens"] > 0
        assert response["usage"]["total_tokens"] > 0

    def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
        data = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": "You are a helpful AI assistant"},
                {
                    "role": "user",
                    "content": "What is the capital of France? Answer in a few words.",
                },
            ],
            "temperature": 0,
            "logprobs": logprobs is not None and logprobs > 0,
            "top_logprobs": logprobs,
            "stream": True,
            "stream_options": {"include_usage": True},
            "n": parallel_sample_num,
        }

        headers = {"Authorization": f"Bearer {self.api_key}"}

        response = requests.post(
            f"{self.base_url}/invocations", json=data, stream=True, headers=headers
        )

        is_firsts = {}
        for line in response.iter_lines():
            line = line.decode("utf-8").replace("data: ", "")
            if len(line) < 1 or line == "[DONE]":
                continue
            print(f"value: {line}")
            line = json.loads(line)
            usage = line.get("usage")
            if usage is not None:
                assert usage["prompt_tokens"] > 0
                assert usage["completion_tokens"] > 0
                assert usage["total_tokens"] > 0
                continue

            index = line.get("choices")[0].get("index")
            data = line.get("choices")[0].get("delta")

            if is_firsts.get(index, True):
                assert data["role"] == "assistant"
                is_firsts[index] = False
                continue

            if logprobs:
                assert line.get("choices")[0].get("logprobs")
                assert isinstance(
                    line.get("choices")[0]
                    .get("logprobs")
                    .get("content")[0]
                    .get("top_logprobs")[0]
                    .get("token"),
                    str,
                )
                assert isinstance(
                    line.get("choices")[0]
                    .get("logprobs")
                    .get("content")[0]
                    .get("top_logprobs"),
                    list,
                )
                ret_num_top_logprobs = len(
                    line.get("choices")[0]
                    .get("logprobs")
                    .get("content")[0]
                    .get("top_logprobs")
                )
                assert (
                    ret_num_top_logprobs == logprobs
                ), f"{ret_num_top_logprobs} vs {logprobs}"

            assert isinstance(data["content"], str)
            assert line["id"]
            assert line["created"]

        for index in [i for i in range(parallel_sample_num)]:
            assert not is_firsts.get(
                index, True
            ), f"index {index} is not found in the response"

    def test_chat_completion(self):
        for logprobs in [None, 5]:
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion(logprobs, parallel_sample_num)

    def test_chat_completion_stream(self):
        for logprobs in [None, 5]:
            for parallel_sample_num in [1, 2]:
                self.run_chat_completion_stream(logprobs, parallel_sample_num)


if __name__ == "__main__":
    unittest.main()
