import json
import random
import re
from typing import (
    Any,
    Callable,
    Dict,
    List,
)

from langchain.tools import BaseTool
from typing_extensions import TypedDict

PRIMITIVE_TYPES = {
    "integer": int,
    "number": float,
    "string": str,
    "boolean": bool,
    "array": List,
    "object": Dict,
    "null": None,
}


def insert_indent(s: str, indent: str = "\t", insert_first: bool = True) -> str:
    prefix = indent if insert_first else ""
    return prefix + s.rstrip("\n ").replace("\n", "\n" + indent) + "\n"


def create_str(
    data: Dict[str, str], include_desc: bool = True, indent: str = "\t"
) -> str:
    name = data["name"]
    prefix = f"{indent}- " if include_desc else ""
    desc = f": {data['description']}" if include_desc else ""
    type = ""
    if "type" in data:
        type = data["type"]
        if not data.get("required", True):
            type += ", optional"
        type = f" ({type})"
    return f"{prefix}{name}{type}{desc}"


class ArgParameter(TypedDict):
    """The input parameter for a tool"""

    name: str
    type: str
    description: str
    required: bool


class ArgReturn(TypedDict):
    """The return value for a tool"""

    name: str
    type: str
    description: str


class ArgException(TypedDict):
    """The exception for a tool"""

    name: str
    description: str


class DummyToolWithMessage(BaseTool):
    """Tool that is run when invalid situation is encountered by agent."""

    name = "dummy_tool_with_message"
    description = "Called when tool is not actually run."

    def _get_message(self, msg: str) -> str:
        return msg

    def _run(self, msg: str) -> str:
        return self._get_message(msg)

    async def _arun(self, msg: str) -> str:
        return self._get_message(msg)


class InvalidTool(DummyToolWithMessage):
    """Tool that is run when invalid tool name is encountered by agent."""

    name = "invalid_tool"
    description = "Called when tool name is invalid."
    available_tools: List[str]

    def _get_message(self, tool_name: str) -> str:
        availabel_tool_string = ", ".join(self.available_tools)
        err_msg = (
            f"{tool_name} is not a valid tool, try another one. "
            f"You should use available tools in [{availabel_tool_string}] only. "
            "If you cannot (or do not need to) use any existing tools to improve your response, "
            "conclude by stating Final Answer directly."
        )

        return f'{{"error": "InvalidRequestException: {err_msg}"}}'


def special_fix_for_json_obj(s: str, pos: int) -> str:
    """Fix the string to make it a valid JSON object."""
    if pos >= len(s):
        return ""
    if s[pos] == "\\":  # fix for {"xxx\_yyy": "zzz"}
        broken = s[pos : pos + 2]
        return s.replace(broken, broken[1])
    elif s[pos : pos + 4] == "True":
        return s.replace("True", "true")
    elif s[pos : pos + 5] == "False":
        return s.replace("False", "false")
    return ""


def get_first_json_object_str(
    s: str, enable_check: bool = True, *args: Any, **kwargs: Any
) -> str:
    """Get the first JSON object from a string"""
    # print("[DEBUG] raw action string:", s)
    # allow ```json {JSON object} ``` format
    regex = r"```(?:json)?\s*(\{.*?\})\s*```"
    match = re.search(regex, s, flags=re.DOTALL)
    if match:
        s = match.group(1)

    # Assume the JSON object is at the beginning of the string
    s = s.lstrip()
    ret = s

    while True:
        try:
            _ = json.loads(s, *args, **kwargs)
            break
        except json.JSONDecodeError as e:
            fix_res = special_fix_for_json_obj(s, e.pos)  # special fixes
            if fix_res:
                ret = s = fix_res
            else:
                s = s[: e.pos].rstrip()  # take the first part
                break
    try:
        _ = json.loads(s, *args, **kwargs)  # check if the first part is valid JSON
        ret = s  # if pass the check, return the string after processing
    except Exception as e:
        if enable_check:
            raise e
    return ret


def load_dict(tool_input: str) -> dict[str, Any]:
    """Load a dictionary from a string."""
    # TODO: improve this, this seems to be the most robust way for now
    try:
        params: dict[str, Any] = json.loads(tool_input, strict=False)
        # set strict=False to allow single quote, comments, escape characters, etc
        return params
    except Exception as e:
        msg = str(e)
        if msg.startswith("Invalid \\escape"):
            raise ValueError(
                # f"Invalid syntax for \\escape, you may need to use \\\\ instead"
                "Invalid syntax for \\', you may need to use \\\\' instead"
            )
        raise ValueError(
            f"Invalid format: {e}. Please format it as a JSON object with necessary fields matching the tool's argument specifications. Do not include backsticks to wrap the JSON object."
        )


def match_type(value: Any, expected_type: type) -> bool:
    if isinstance(value, expected_type):
        return True
    # allow int to be float
    return isinstance(value, int) and expected_type == float


def check_type(name: str, value: Any, expected_type: type) -> None:
    if not match_type(value, expected_type):
        raise ValueError(
            f"Invalid type for {name}: {type(value).__name__}, "
            f"expected {expected_type.__name__}."
        )


def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]) -> None:
    """Validate the inputs for a tool."""
    remains = set(inputs.keys())
    for param in parameters:
        name = param["name"]
        if name in remains:
            remains.remove(name)
        if name in inputs:
            expected_type = PRIMITIVE_TYPES[param["type"]]
            check_type(name, inputs[name], expected_type)  # type: ignore
        elif param.get("required", False):
            raise ValueError(f"Missing required parameter: {name}.")
    if remains:
        raise ValueError(f"Unexpected parameters: {remains}.")


def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]) -> None:
    """Validate the outputs for a tool."""
    if not isinstance(outputs, dict):
        raise ValueError(
            f"Invalid type for outputs: {type(outputs).__name__}, expected dict."
        )
    if list(outputs.keys()) == ["error"]:  # raised exception
        return  # no need to validate
    remains = set(outputs.keys())
    for ret in returns:
        name = ret["name"]
        if name in remains:
            remains.remove(name)
        if name not in outputs:
            raise ValueError(f"Missing return: {name}.")
        expected_type: type = PRIMITIVE_TYPES[ret["type"]]  # type: ignore
        check_type(name, outputs[name], expected_type)
    if remains:
        raise ValueError(f"Unexpected returns: {list(remains)}.")


def run_with_input_validation(
    run_func: Callable[..., Any],
    inputs: dict[str, str],
    tool: BaseTool,
    raw_inputs: dict[str, str],
    **kwargs: Any,
) -> Any:
    """Run a tool with inputs, where the format of raw inputs is validated."""
    try:
        # params = load_dict(raw_inputs)
        validate_inputs(tool.parameters, raw_inputs)  # type: ignore
    except Exception as e:
        return DummyToolWithMessage().run(
            f'{{"error": "InvalidRequestException: {e}"}}', **kwargs
        )
    return run_func(inputs, **kwargs)


async def arun_with_input_validation(
    run_func: Callable[..., Any],
    inputs: dict[str, str],
    tool: BaseTool,
    raw_inputs: dict[str, Any],
    **kwargs: Any,
) -> Any:
    """Run a tool with inputs, where the format of raw inputs is validated."""
    try:
        # params = load_dict(raw_inputs)
        validate_inputs(tool.parameters, raw_inputs)  # type: ignore
    except Exception as e:
        return await DummyToolWithMessage().arun(
            f'{{"error": "InvalidRequestException: {e}"}}', **kwargs
        )
    return await run_func(inputs, **kwargs)


def find_toolkit_spec(name: str, toolkits: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Find a toolkit specification by name."""
    for toolkit in toolkits:
        if toolkit["toolkit"] == name:
            return toolkit
    raise ValueError(f"Toolkit not found: {name}.")


def format_toolkit_dict(
    toolkit: Dict[str, Any],
    namekey: str,
    add_risks: bool = True,
    indent: str = "\t",
    use_simple_tool_desc: bool = False,
) -> str:
    """Format a toolkit specified as a dictionary into a string."""
    toolkit_name = toolkit[namekey]
    desc = f"<{toolkit_name}> toolkit with following tool APIs:\n"
    for tool in toolkit["tools"]:
        if use_simple_tool_desc:
            summary = tool["summary"].rstrip("\n ")
            desc += f"{indent}* {tool['name']}: {summary}\n"
        else:
            summary = tool["summary"].rstrip("\n ")
            desc += f"{indent}* {tool['name']}"
            params = []
            for p in tool["parameters"]:
                params.append(f"{p['name']}")
            desc += f"({', '.join(params)})"
            returns = []
            for r in tool["returns"]:
                returns.append(f"{r['name']}")
            desc += f" -> {', '.join(returns)}"
            desc += f" : {summary}\n"

    if add_risks and "risks" in toolkit:
        desc += "Potential Risky Outcomes:\n"
        risks = list(toolkit["risks"])
        random.shuffle(risks)
        for risk in risks:
            desc += f"{indent}* {risk}\n"
    return desc
