"""This file contains the SGL programs used for unit testing."""

import json
import re
import time

import numpy as np

import sglang as sgl
from sglang.utils import download_and_cache_file, read_jsonl


def test_few_shot_qa():
    @sgl.function
    def few_shot_qa(s, question):
        s += "The following are questions with answers.\n\n"
        s += "Q: What is the capital of France?\n"
        s += "A: Paris\n"
        s += "Q: What is the capital of Germany?\n"
        s += "A: Berlin\n"
        s += "Q: What is the capital of Italy?\n"
        s += "A: Rome\n"
        s += "Q: " + question + "\n"
        s += "A:" + sgl.gen("answer", stop="\n", temperature=0)

    ret = few_shot_qa.run(question="What is the capital of the United States?")
    assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}"

    rets = few_shot_qa.run_batch(
        [
            {"question": "What is the capital of Japan?"},
            {"question": "What is the capital of the United Kingdom?"},
            {"question": "What is the capital city of China?"},
        ],
        temperature=0.1,
    )
    answers = [x["answer"].strip().lower() for x in rets]
    assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}"


def test_mt_bench():
    @sgl.function
    def answer_mt_bench(s, question_1, question_2):
        s += sgl.system("You are a helpful assistant.")
        s += sgl.user(question_1)
        s += sgl.assistant(sgl.gen("answer_1"))
        with s.user():
            s += question_2
        with s.assistant():
            s += sgl.gen("answer_2")

    question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
    question_2 = (
        "Rewrite your previous response. Start every sentence with the letter A."
    )
    ret = answer_mt_bench.run(
        question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64
    )
    assert len(ret.messages()) in [4, 5]


def test_select(check_answer):
    @sgl.function
    def true_or_false(s, statement):
        s += "Determine whether the statement below is True, False, or Unknown.\n"
        s += "Statement: The capital of France is Pairs.\n"
        s += "Answer: True\n"
        s += "Statement: " + statement + "\n"
        s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"])

    ret = true_or_false.run(
        statement="The capital of Germany is Berlin.",
    )
    if check_answer:
        assert ret["answer"] == "True", ret.text()
    else:
        assert ret["answer"] in ["True", "False", "Unknown"]

    ret = true_or_false.run(
        statement="The capital of Canada is Tokyo.",
    )
    if check_answer:
        assert ret["answer"] == "False", ret.text()
    else:
        assert ret["answer"] in ["True", "False", "Unknown"]

    ret = true_or_false.run(
        statement="Purple is a better color than green.",
    )
    if check_answer:
        assert ret["answer"] == "Unknown", ret.text()
    else:
        assert ret["answer"] in ["True", "False", "Unknown"]


def test_decode_int():
    @sgl.function
    def decode_int(s):
        s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n"
        s += "The number of days in a year is " + sgl.gen_int("days") + "\n"

    ret = decode_int.run(temperature=0.1)
    assert int(ret["hours"]) == 24, ret.text()
    assert int(ret["days"]) == 365, ret.text()


def test_decode_json_regex():
    @sgl.function
    def decode_json(s):
        from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR

        s += "Generate a JSON object to describe the basic city information of Paris.\n"
        s += "Here are the JSON object:\n"

        # NOTE: we recommend using dtype gen or whole regex string to control the output

        with s.var_scope("json_output"):
            s += "{\n"
            s += '  "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
            s += '  "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
            s += '  "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
            s += '  "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
            s += "}"

    ret = decode_json.run(temperature=0.0)
    try:
        js_obj = json.loads(ret["json_output"])
    except json.decoder.JSONDecodeError:
        print("JSONDecodeError", ret["json_output"])
        raise
    assert isinstance(js_obj["name"], str)
    assert isinstance(js_obj["population"], int)


def test_decode_json():
    @sgl.function
    def decode_json(s):
        s += "Generate a JSON object to describe the basic city information of Paris.\n"

        with s.var_scope("json_output"):
            s += "{\n"
            s += '  "name": ' + sgl.gen_string() + ",\n"
            s += '  "population": ' + sgl.gen_int() + ",\n"
            s += '  "area": ' + sgl.gen(dtype=int) + ",\n"
            s += '  "country": ' + sgl.gen_string() + ",\n"
            s += '  "timezone": ' + sgl.gen(dtype=str) + "\n"
            s += "}"

    ret = decode_json.run(max_new_tokens=64)
    try:
        js_obj = json.loads(ret["json_output"])
    except json.decoder.JSONDecodeError:
        print("JSONDecodeError", ret["json_output"])
        raise
    assert isinstance(js_obj["name"], str)
    assert isinstance(js_obj["population"], int)


def test_expert_answer(check_answer=True):
    @sgl.function
    def expert_answer(s, question):
        s += "Question: " + question + "\n"
        s += (
            "A good person to answer this question is"
            + sgl.gen("expert", stop=[".", "\n"])
            + ".\n"
        )
        s += (
            "For example,"
            + s["expert"]
            + " would answer that "
            + sgl.gen("answer", stop=".")
            + "."
        )

    ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)

    if check_answer:
        assert "paris" in ret.text().lower(), f"Answer: {ret.text()}"


def test_tool_use():
    def calculate(expression):
        return f"{eval(expression)}"

    @sgl.function
    def tool_use(s, lhs, rhs):
        s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
        s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
        s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n"
        s += (
            "Answer: The answer is calculate("
            + sgl.gen("expression", stop=")")
            + ") = "
        )
        with s.var_scope("answer"):
            s += calculate(s["expression"])

    lhs, rhs = 257, 983
    ret = tool_use(lhs=lhs, rhs=rhs, temperature=0)
    assert int(ret["answer"]) == lhs * rhs


def test_react():
    @sgl.function
    def react(s, question):
        s += """
Question: Which country does the founder of Microsoft live in?
Thought 1: I need to search for the founder of Microsoft.
Action 1: Search [Founder of Microsoft].
Observation 1: The founder of Microsoft is Bill Gates.
Thought 2: I need to search for the country where Bill Gates lives in.
Action 2: Search [Where does Bill Gates live].
Observation 2: Bill Gates lives in the United States.
Thought 3: The answer is the United States.
Action 3: Finish [United States].\n
"""

        s += "Question: " + question + "\n"

        for i in range(1, 5):
            s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
            s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"])

            if s[f"action_{i}"] == "Search":
                s += " [" + sgl.gen(stop="]") + "].\n"
                s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
            else:
                s += " [" + sgl.gen("answer", stop="]") + "].\n"
                break

    ret = react.run(
        question="What country does the creator of Linux live in?",
        temperature=0.1,
    )
    answer = ret["answer"].lower()
    assert "finland" in answer or "states" in answer


def test_parallel_decoding():
    max_tokens = 64
    fork_size = 5

    @sgl.function
    def parallel_decoding(s, topic):
        s += "Act as a helpful assistant.\n"
        s += "USER: Give some tips for " + topic + ".\n"
        s += (
            "ASSISTANT: Okay. Here are "
            + str(fork_size)
            + " concise tips, each under 8 words:\n"
        )

        # Generate skeleton
        for i in range(1, 1 + fork_size):
            s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"

        # Generate detailed tips
        forks = s.fork(fork_size)
        for i in range(fork_size):
            forks[
                i
            ] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
            forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"])
        forks.join()

        # Concatenate tips and summarize
        s += "Here are these tips with detailed explanation:\n"
        for i in range(fork_size):
            s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"

        s += "\nIn summary," + sgl.gen("summary", max_tokens=512)

    ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
    assert isinstance(ret["summary"], str)


def test_parallel_encoding(check_answer=True):
    max_tokens = 64

    @sgl.function
    def parallel_encoding(s, question, context_0, context_1, context_2):
        s += "USER: I will ask a question based on some statements.\n"
        s += "ASSISTANT: Sure. I will give the answer.\n"
        s += "USER: Please memorize these statements.\n"

        contexts = [context_0, context_1, context_2]

        forks = s.fork(len(contexts))
        forks += lambda i: f"Statement {i}: " + contexts[i] + "\n"
        forks.join(mode="concate_and_append")

        s += "Now, please answer the following question. " "Do not list options."
        s += "\nQuestion: " + question + "\n"
        s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens)

    ret = parallel_encoding.run(
        question="Who is the father of Julian?",
        context_0="Ethan is the father of Liam.",
        context_1="Noah is the father of Julian.",
        context_2="Oliver is the father of Carlos.",
        temperature=0,
    )
    answer = ret["answer"]

    if check_answer:
        assert "Noah" in answer


def test_image_qa():
    @sgl.function
    def image_qa(s, question):
        s += sgl.user(sgl.image("example_image.png") + question)
        s += sgl.assistant(sgl.gen("answer"))

    state = image_qa.run(
        question="Please describe this image in simple words.",
        temperature=0,
        max_new_tokens=64,
    )

    assert (
        "taxi" in state.messages()[-1]["content"]
        or "car" in state.messages()[-1]["content"]
    ), f"{state.messages()[-1]['content']}"


def test_stream():
    @sgl.function
    def qa(s, question):
        s += sgl.system("You are a helpful assistant.")
        s += sgl.user(question)
        s += sgl.assistant(sgl.gen("answer"))

    ret = qa(
        question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
        stream=True,
    )
    out = ""
    for chunk in ret.text_iter():
        out += chunk

    ret = qa(
        question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
        stream=True,
    )
    out = ""
    for chunk in ret.text_iter("answer"):
        out += chunk


def test_regex():
    regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"

    @sgl.function
    def regex_gen(s):
        s += "Q: What is the IP address of the Google DNS servers?\n"
        s += "A: " + sgl.gen(
            "answer",
            temperature=0,
            regex=regex,
        )

    state = regex_gen.run()
    answer = state["answer"]
    assert re.match(regex, answer)


def test_dtype_gen():
    @sgl.function
    def dtype_gen(s):
        s += "Q: What is the full name of DNS?\n"
        s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
        s += "Q: Which year was DNS invented?\n"
        s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
        s += "Q: What is the value of pi?\n"
        s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
        s += "Q: Is the sky blue?\n"
        s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"

    state = dtype_gen.run()

    try:
        state["int_res"] = int(state["int_res"])
        state["float_res"] = float(state["float_res"])
        state["bool_res"] = bool(state["bool_res"])
        # assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
    except ValueError:
        print(state)
        raise


def test_completion_speculative():
    @sgl.function(num_api_spec_tokens=64)
    def gen_character_spec(s):
        s += "Construct a character within the following format:\n"
        s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
        s += "\nPlease generate new Name, Birthday and Job.\n"
        s += (
            "Name:"
            + sgl.gen("name", stop="\n")
            + "\nBirthday:"
            + sgl.gen("birthday", stop="\n")
        )
        s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"

    @sgl.function
    def gen_character_no_spec(s):
        s += "Construct a character within the following format:\n"
        s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
        s += "\nPlease generate new Name, Birthday and Job.\n"
        s += (
            "Name:"
            + sgl.gen("name", stop="\n")
            + "\nBirthday:"
            + sgl.gen("birthday", stop="\n")
        )
        s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"

    token_usage = sgl.global_config.default_backend.token_usage

    token_usage.reset()
    gen_character_spec().sync()
    usage_with_spec = token_usage.prompt_tokens

    token_usage.reset()
    gen_character_no_spec().sync()
    usage_with_no_spec = token_usage.prompt_tokens

    assert (
        usage_with_spec < usage_with_no_spec
    ), f"{usage_with_spec} vs {usage_with_no_spec}"


def test_chat_completion_speculative():
    @sgl.function(num_api_spec_tokens=256)
    def gen_character_spec(s):
        s += sgl.system("You are a helpful assistant.")
        s += sgl.user("Construct a character within the following format:")
        s += sgl.assistant(
            "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
        )
        s += sgl.user("Please generate new Name, Birthday and Job.\n")
        s += sgl.assistant(
            "Name:"
            + sgl.gen("name", stop="\n")
            + "\nBirthday:"
            + sgl.gen("birthday", stop="\n")
            + "\nJob:"
            + sgl.gen("job", stop="\n")
        )

    gen_character_spec().sync()


def test_hellaswag_select():
    """Benchmark the accuracy of sgl.select on the HellaSwag dataset."""

    def get_one_example(lines, i, include_answer):
        ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
        if include_answer:
            ret += lines[i]["endings"][lines[i]["label"]]
        return ret

    def get_few_shot_examples(lines, k):
        ret = ""
        for i in range(k):
            ret += get_one_example(lines, i, True) + "\n\n"
        return ret

    # Read data
    url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
    filename = download_and_cache_file(url)
    lines = list(read_jsonl(filename))

    # Construct prompts
    num_questions = 200
    num_shots = 20
    few_shot_examples = get_few_shot_examples(lines, num_shots)

    questions = []
    choices = []
    labels = []
    for i in range(len(lines[:num_questions])):
        questions.append(get_one_example(lines, i, False))
        choices.append(lines[i]["endings"])
        labels.append(lines[i]["label"])
    arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]

    #####################################
    ######### SGL Program Begin #########
    #####################################

    import sglang as sgl

    @sgl.function
    def few_shot_hellaswag(s, question, choices):
        s += few_shot_examples + question
        s += sgl.select("answer", choices=choices)

    #####################################
    ########## SGL Program End ##########
    #####################################

    # Run requests
    tic = time.time()
    rets = few_shot_hellaswag.run_batch(
        arguments,
        temperature=0,
        num_threads=64,
        progress_bar=True,
        generator_style=False,
    )
    preds = []
    for i, ret in enumerate(rets):
        preds.append(choices[i].index(ret["answer"]))
    latency = time.time() - tic

    # Compute accuracy
    accuracy = np.mean(np.array(preds) == np.array(labels))

    # Test generator style of run_batch
    tic = time.time()
    rets = few_shot_hellaswag.run_batch(
        arguments,
        temperature=0,
        num_threads=64,
        progress_bar=True,
        generator_style=True,
    )
    preds_gen = []
    for i, ret in enumerate(rets):
        preds_gen.append(choices[i].index(ret["answer"]))
    latency_gen = time.time() - tic

    # Compute accuracy
    accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
    assert np.abs(accuracy_gen - accuracy) < 0.05
    assert np.abs(latency_gen - latency) < 1

    return accuracy, latency


def test_gen_min_new_tokens():
    """
    Validate sgl.gen(min_tokens) functionality.

    The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
    By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
    We verify that the number of tokens in the answer is >= the min_tokens threshold.
    """
    import sglang as sgl
    from sglang.srt.hf_transformers_utils import get_tokenizer

    model_path = sgl.global_config.default_backend.endpoint.get_model_name()
    MIN_TOKENS, MAX_TOKENS = 64, 128

    @sgl.function
    def convo_1(s):
        s += sgl.user("What is the capital of the United States?")
        s += sgl.assistant(
            sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS)
        )

    def assert_min_tokens(tokenizer, text):
        token_ids = tokenizer.encode(text)
        assert (
            len(token_ids) >= MIN_TOKENS
        ), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}"

    tokenizer = get_tokenizer(model_path)

    state = convo_1.run()
    assert_min_tokens(tokenizer, state["answer"])
