import json
import re
from jinja2 import Template
from decision_oaif.agents.agent import Agent
from gradio_client import Client


class HFSpaceAgent(Agent):
    def __init__(self, space_id, prompt_template_file, verbose=0, debug=False, parse_reason_action_fn=None):
        super().__init__()
        self.space_id = space_id
        self.verbose = verbose
        self.debug = debug
        self.parse_reason_action_fn = parse_reason_action_fn
        with open(prompt_template_file, "r") as file:
            self.prompt_template = Template(file.read())
        
        self.client = Client(space_id)


    def predict_reason_action(self, task, observation, candidate_actions, reward=""):
        observation_action_history = [{'observation': entry['observation'], 'action': entry['action']} for entry in self.observation_reason_action_history]

        input_data = {
            'mode': 'input',
            'task': task,
            'reward': reward,
            'observation_action_history': observation_action_history,
            'observation': observation,
            'candidate_actions': candidate_actions
        }
        input_prompt = self.prompt_template.render(**input_data)

        messages = [
            {"role": "user", "content": input_prompt}
        ]

        response = self.client.predict(
            messages,  
            api_name="/predict"  
        )

        reason, action = self.parse_reason_action_fn(response)
        if self.verbose > 0:
            if self.verbose > 1:
                print(f"\n OBSERVATION: {observation}")
                print(f"\n RESPONSE: {response}")
            print(f"\n OBSERVATION: {observation}")
            print(f"\n CANDIDATE ACTIONS: {candidate_actions}")
            print(f"\n REASON: {reason}")
            print(f"\n ACTION: {action}")
        
        if self.debug:
            import pdb; pdb.set_trace()
            # human_input = input()
            # if human_input != "c":
            #     action = human_input
            #     reason = "None"

        self.update_history(observation, reason, action)
        return reason, action
