import json
import unittest

import requests
from transformers import AutoModelForCausalLM, AutoTokenizer

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)


class TestInputEmbeds(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
        cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=["--disable-radix"],
        )
        cls.texts = [
            "The capital of France is",
            "What is the best time of year to visit Japan for cherry blossoms?",
        ]

    def generate_input_embeddings(self, text):
        """Generate input embeddings for a given text."""
        input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
        embeddings = self.ref_model.get_input_embeddings()(input_ids)
        return embeddings.squeeze().tolist()  # Convert tensor to a list for API use

    def send_request(self, payload):
        """Send a POST request to the API and return the response."""
        response = requests.post(
            self.base_url + "/generate",
            json=payload,
            timeout=30,  # Set a reasonable timeout for the API request
        )
        if response.status_code == 200:
            return response.json()
        return {
            "error": f"Request failed with status {response.status_code}: {response.text}"
        }

    def test_text_based_response(self):
        """Print API response using text-based input."""
        for text in self.texts:
            payload = {
                "model": self.model,
                "text": text,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            response = self.send_request(payload)
            print(
                f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
            )

    def test_embedding_based_response(self):
        """Print API response using input embeddings."""
        for text in self.texts:
            embeddings = self.generate_input_embeddings(text)
            payload = {
                "model": self.model,
                "input_embeds": embeddings,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            response = self.send_request(payload)
            print(
                f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
            )

    def test_compare_text_vs_embedding(self):
        """Print responses for both text-based and embedding-based inputs."""
        for text in self.texts:
            # Text-based payload
            text_payload = {
                "model": self.model,
                "text": text,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            # Embedding-based payload
            embeddings = self.generate_input_embeddings(text)
            embed_payload = {
                "model": self.model,
                "input_embeds": embeddings,
                "sampling_params": {"temperature": 0, "max_new_tokens": 50},
            }
            # Get responses
            text_response = self.send_request(text_payload)
            embed_response = self.send_request(embed_payload)
            # Print responses
            print(
                f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
            )
            print(
                f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
            )
            self.assertEqual(text_response["text"], embed_response["text"])

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)


if __name__ == "__main__":
    unittest.main()
