import json

from rllm.parser.tool_parser.tool_parser_base import ToolParser
from rllm.tools.tool_base import ToolCall


class AppworldToolParser(ToolParser):
    def __init__(self):
        """Initialize the parser with specified type and model.

        Args:
            model (str): Model name for tokenizer (optional)
            parser_type (str): Type of parser to use ('qwen' or other parsers you might add)
        """
        self.tool_call_begin = "<tool_call>"
        self.tool_call_end = "</tool_call>"
        self.tool_output_begin = "<tool_response>"
        self.tool_output_end = "</tool_response>"

    def parse(self, model_response: str) -> list[ToolCall]:
        """Parse tool calls from model output.

        Args:
            model_output (str): Text containing tool calls

        Returns:
            ToolInputs: Parsed tool calls
        """
        tool_calls_dicts = self.parse_appworld_tool_calls(model_response)
        tool_calls = [ToolCall(name=tc["name"], arguments=tc["arguments"]) for tc in tool_calls_dicts]
        return tool_calls

def parse_appworld_tool_calls(response):
    """
    An function to process the actions
    actions: the list of actions to be processeed, it is a list of strings.
    """
    valids = [0] * len(actions)

    for i in range(len(actions)):
        original_str = actions[i]  # keep the original string

        # Attempt to extract the substring within <code>...</code>
        start_tag = "<code>"
        end_tag = "</code>"
        start_idx = actions[i].find(start_tag)
        end_idx = actions[i].find(end_tag)
        try:
            if start_idx == -1 or end_idx == -1:
                # If we can't find a valid <code>...</code> block, mark as invalid
                extracted_action = actions[i][-100:]
                valids[i] = 0
                actions[i] = extracted_action
                continue

            # Extract just the content between the tags
            extracted_action = actions[i][start_idx + len(start_tag):end_idx]

            actions[i] = extracted_action
            valids[i] = 1

        except:
            extracted_action = actions[i][-100:]
            valids[i] = 0
            actions[i] = extracted_action

        # check <think>...</think>
        think_start_idx = original_str.find("<think>")
        think_end_idx = original_str.find("</think>")
        if think_start_idx == -1 or think_end_idx == -1:
            valids[i] = 0

    return actions, valids