import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1,
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2,
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1,
    DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)


def parse_models(model_string):
    return [model.strip() for model in model_string.split(",") if model.strip()]


class TestEvalAccuracyLarge(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.model_groups = [
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1), False, False),
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True),
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False),
            (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True),
        ]
        cls.base_url = DEFAULT_URL_FOR_TEST

    def setUp(self):
        self.process = None

    def tearDown(self):
        if self.process:
            kill_child_process(self.process.pid)

    def launch_server(self, model, is_fp8, is_tp2):
        other_args = ["--log-level-http", "warning", "--trust-remote-code"]
        if is_fp8:
            if "Llama-3" in model or "gemma-2" in model:
                # compressed-tensors
                other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
            elif "Qwen2-72B-Instruct-FP8" in model:
                # bug
                other_args.extend(["--quantization", "fp8"])
            else:
                other_args.extend(
                    ["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"]
                )
        if is_tp2:
            other_args.extend(["--tp", "2"])
        if "DeepSeek" in model:
            other_args.append("--enable-mla")

        self.process = popen_launch_server(
            model,
            self.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=other_args,
        )

    def test_mgsm_en_all_models(self):
        for model_group, is_fp8, is_tp2 in self.model_groups:
            for model in model_group:
                with self.subTest(model=model):
                    self.launch_server(model, is_fp8, is_tp2)

                    args = SimpleNamespace(
                        base_url=self.base_url,
                        model=model,
                        eval_name="mgsm_en",
                        num_examples=None,
                        num_threads=1024,
                    )

                    metrics = run_eval(args)
                    print(
                        f"{'=' * 42}\n{model} - metrics={metrics} score={metrics['score']}\n{'=' * 42}\n"
                    )
                    # loosely threshold
                    assert metrics["score"] > 0.5, f"score={metrics['score']} <= 0.5"

                    self.tearDown()


if __name__ == "__main__":
    unittest.main()
