import unittest

import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase


class TestSRTEngineWithQuantArgs(CustomTestCase):

    def test_1_quantization_args(self):

        # we only test fp8 because other methods are currenly depend on vllm. We can add other methods back to test after vllm depency is resolved.
        quantization_args_list = [
            # "awq",
            "fp8",
            # "gptq",
            # "marlin",
            # "gptq_marlin",
            # "awq_marlin",
            # "bitsandbytes",
            # "gguf",
        ]

        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        for quantization_args in quantization_args_list:
            engine = sgl.Engine(
                model_path=model_path, random_seed=42, quantization=quantization_args
            )
            engine.generate(prompt, sampling_params)
            engine.shutdown()

    def test_2_torchao_args(self):

        # we don't test int8dq because currently there is conflict between int8dq and capture cuda graph
        torchao_args_list = [
            # "int8dq",
            "int8wo",
            "fp8wo",
            "fp8dq-per_tensor",
            "fp8dq-per_row",
        ] + [f"int4wo-{group_size}" for group_size in [32, 64, 128, 256]]

        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        for torchao_config in torchao_args_list:
            engine = sgl.Engine(
                model_path=model_path, random_seed=42, torchao_config=torchao_config
            )
            engine.generate(prompt, sampling_params)
            engine.shutdown()


if __name__ == "__main__":
    unittest.main()
