import argparse
from enum import Enum

from pydantic import BaseModel, constr

import sglang as sgl
from sglang.srt.constrained import build_regex_from_object
from sglang.test.test_utils import (
    add_common_sglang_args_and_parse,
    select_sglang_backend,
)

IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"

ip_jump_forward = (
    r"The google's DNS sever address is "
    + IP_REGEX
    + r" and "
    + IP_REGEX
    + r". "
    + r"The google's website domain name is "
    + r"www\.(\w)+\.(\w)+"
    + r"."
)


# fmt: off
@sgl.function
def regex_gen(s):
    s += "Q: What is the IP address of the Google DNS servers?\n"
    s += "A: " + sgl.gen(
        "answer",
        max_tokens=128,
        temperature=0,
        regex=ip_jump_forward,
    )
# fmt: on

json_jump_forward = (
    r"""The information about Hogwarts is in the following JSON format\.\n"""
    + r"""\n\{\n"""
    + r"""  "name": "[\w\d\s]*",\n"""
    + r"""  "country": "[\w\d\s]*",\n"""
    + r"""  "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
    + r"""  "population": [-+]?[0-9]+,\n"""
    + r"""  "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
    + r"""\}\n"""
)

# fmt: off
@sgl.function
def json_gen(s):
    s += sgl.gen(
        "json",
        max_tokens=128,
        temperature=0,
        regex=json_jump_forward,
    )
# fmt: on


class Weapon(str, Enum):
    sword = "sword"
    axe = "axe"
    mace = "mace"
    spear = "spear"
    bow = "bow"
    crossbow = "crossbow"


class Armor(str, Enum):
    leather = "leather"
    chainmail = "chainmail"
    plate = "plate"


class Character(BaseModel):
    name: constr(max_length=10)
    age: int
    armor: Armor
    weapon: Weapon
    strength: int


@sgl.function
def character_gen(s):
    s += "Give me a character description who is a wizard.\n"
    s += sgl.gen(
        "character",
        max_tokens=128,
        temperature=0,
        regex=build_regex_from_object(Character),
    )


def main(args):
    # Select backend
    backend = select_sglang_backend(args)
    sgl.set_default_backend(backend)

    state = regex_gen.run(temperature=0)

    print("=" * 20, "IP TEST", "=" * 20)
    print(state.text())

    state = json_gen.run(temperature=0)

    print("=" * 20, "JSON TEST", "=" * 20)
    print(state.text())

    state = character_gen.run(temperature=0)

    print("=" * 20, "CHARACTER TEST", "=" * 20)
    print(state.text())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = add_common_sglang_args_and_parse(parser)
    main(args)

# ==================== IP TEST ====================
# Q: What is the IP address of the Google DNS servers?
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
# ==================== JSON TEST ====================
# The information about Hogwarts is in the following JSON format.

# {
#   "name": "Hogwarts School of Witchcraft and Wizardry",
#   "country": "Scotland",
#   "latitude": 55.566667,
#   "population": 1000,
#   "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
# }

# ==================== CHARACTER TEST ====================
# Give me a character description who is a wizard.
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
