import json
import re
from jinja2 import Template
from decision_oaif.agents.agent import Agent

import sglang as sgl


# Refer: https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/quick_start/local_example_chat.py
@sgl.function
def multi_turn_message(s, message_1):
    s += sgl.user(message_1)
    s += sgl.assistant(sgl.gen("answer_1", max_tokens=256, temperature=0.3))

class SGLangServerAgent(Agent):
    def __init__(
        self,
        server_url,
        prompt_template_file,
        verbose=0,
        debug=False,
        parse_reason_action_fn=None,
    ):
        super().__init__()
        self.server_url = server_url  # sglang server URL, e.g. http://localhost:30000/
        self.verbose = verbose
        self.debug = debug
        self.parse_reason_action_fn = parse_reason_action_fn
        with open(prompt_template_file, "r") as file:
            self.prompt_template = Template(file.read())

        sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

    def predict_reason_action(self, task, observation, candidate_actions, reward=""):
        observation_action_history = [
            {"observation": entry["observation"], "action": entry["action"]}
            for entry in self.observation_reason_action_history
        ]
        input_data = {
            "mode": "input",
            "task": task,
            "reward": reward,
            "observation_action_history": observation_action_history,
            "observation": observation,
            "candidate_actions": candidate_actions,
        }

        input_prompt = self.prompt_template.render(**input_data)
        state = multi_turn_message.run(message_1=input_prompt)

        generated_text = state["answer_1"]
        reason, action = self.parse_reason_action_fn(generated_text)

        if self.verbose > 0:
            if self.verbose > 1:
                print(f"\n OBSERVATION: {observation}")
                print(f"\n RESPONSE: {generated_text}")
            print(f"\n OBSERVATION: {observation}")
            print(f"\n CANDIDATE ACTIONS: {candidate_actions}")
            print(f"\n REASON: {reason}")
            print(f"\n ACTION: {action}")

        if self.debug:
            import pdb

            pdb.set_trace()
            # human_input = input()
            # if human_input != "c":
            #     action = human_input
            #     reason = "None"

        # Update the history with the new observation, reason, and action
        self.update_history(observation, reason, action)
        return reason, action
