import ast
import re
from typing import Dict, Optional, Tuple
from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.chat_factory import ChatFactory
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.openai_api.user_message import UserMessage
from llm_utils.textgen_api.textgen_api import TextGenApi
from pddl.core import Action, Domain

from tp_lodge.utils.pddl_utils import get_markup_from_text

PROMPT = """
Your task is to map a PDDL action to a python function stub.
This includes mapping the action name to the function name and parameterizing the function with the the actions' variables.
If no mapping is possible, return "null".

Q: 
Action PDDL:
(:action go-to
    :parameters (?a - agent ?from - furniture_appliance ?to - furniture_appliance)
    :precondition (at-agent ?a ?from)
    :effect (and (at-agent ?a ?to) (not (at-agent ?a ?from)))
)
Functions:
put_an_object_on_or_in_a_furniture_piece_or_an_appliance(object_id, furniture_or_appliance), go_to_a_furniture_piece_or_an_appliance(furniture_or_appliance), pick_up_an_object_on_or_in_a_furniture_piece_or_an_appliance(object_id, furniture_or_appliance)

A:
```python
go_to_a_furniture_piece_or_an_appliance("to")
```

Q:
Action PDDL:
(:action go_to_object
    :parameters (?a - agent ?f - furniture_appliance ?obj - household_object)
    :precondition (at-agent ?a ?f)
    :effect (and (at-agent ?a ?f) (at-object ?obj ?f))
)
Functions:
put_an_object_on_or_in_a_furniture_piece_or_an_appliance(object_id, furniture_or_appliance), go_to_a_furniture_piece_or_an_appliance(furniture_or_appliance), pick_up_an_object_on_or_in_a_furniture_piece_or_an_appliance(object_id, furniture_or_appliance)

A: 
null

Now, please map the following action to a python function stub:

Additional user info: {user_info}

Q:
Action PDDL:
{action_pddl}
Functions:
{functions}

A:
"""

def map_functions(
    textgen_api: TextGenApi,
    domain: Domain,
    functions: Dict[str, ast.FunctionDef],
    existing_mapping: Dict[str, Optional[Tuple]],
    confirm: bool = False,
) -> Dict:
    mappings = existing_mapping.copy()
    for action in domain.actions:
        info = ""
        while True:
            if action.name not in mappings:
                response = map_function(textgen_api, action, functions, user_info=info)
                if response is None:
                    mappings[action.name] = None
                else:
                    f_name, arg_mapping = response
                    if f_name in [m.get("name", None) for m in mappings.values() if m is not None]:
                        # we already have an operator that probably will be more accurate. Just use that one
                        mappings[action.name] = None
                        break
                    
                    mappings[action.name] = {
                        "arg_mapping": arg_mapping,
                        "name": f_name,
                    }

                if mappings[action.name] is not None:
                    n_args = len([i for i in mappings[action.name]["arg_mapping"] if i is not None])
                    func = functions[mappings[action.name]["name"]]
                    f_n_args = len(func.args.args)
                    if n_args != f_n_args:
                        print(f"Action {action.name} has {n_args} arguments, but function {mappings[action.name]['name']} has {f_n_args} arguments. Retrying...")
                        del mappings[action.name]
                        continue

                if confirm:
                    if mappings[action.name] is None:
                        response = input("No mapping found for action %s. Press Enter to confirm or type the action to use" % action.name)
                        if response.lower() == "y" or response.lower() == "":
                            break
                        else:
                            del mappings[action.name]
                            info = response
                            continue
                    else:
                        info = ""
                        print_mapping(action, {action.name: mappings[action.name]})
                        user_input = input(f"Confirm mapping for action {action.name}? (Y/d/n) ")
                        if user_input.lower() == "y" or user_input.lower() == "":
                            break
                        elif user_input.lower() == "n":
                            del mappings[action.name]
                            continue
                        elif user_input.lower() == "d":
                            mappings[action.name] = None
                            break
            break
    return mappings


def map_function(
    textgen_api: TextGenApi, action: Action, functions: Dict[str, ast.FunctionDef], chat: Optional[Chat] = None, user_info: str = ""
):
    func_names = [f.name for f in functions.values()]
    if chat is None:
        def map_f(f: ast.FunctionDef) -> str:
            return f"{f.name}({', '.join(arg.arg for arg in f.args.args)})"

        fcts = [map_f(f) for f in functions.values()]

        prompt = PROMPT.format(action_pddl=str(action), functions=", ".join(fcts), user_info=user_info)

        chat = Chat([UserMessage([TextMessageContent(prompt)])])

    print(f"Mapping action {action.name} to function stub with LLM...")
    response = textgen_api.do_call(chat)
    chat = chat.add_message(response)

    text = response.content[0].text.strip()

    if "null" in text:
        print("No mapping found for action %s." % action.name)
        return None

    code = get_markup_from_text(text, ["python"])
    if len(code) != 1:
        chat = chat.add_user_text("Response must contain exactly one code snippet.")
        print("Response must contain exactly one code snippet.")
        return map_function(textgen_api, action, functions, chat)
    assert len(code) == 1
    code = code[0]

    matches = re.match(r"(\w+)\((.*)\)", code)
    if matches is None:
        chat = chat.add_user_text("Response does not contain a valid function call.")
        print(f"Response does not contain a valid function call: {code}")
        return map_function(textgen_api, action, functions, chat)
    assert matches is not None
    func_name = matches.group(1)
    if func_name not in func_names:
        chat = chat.add_user_text(f"Function name {func_name} not in known function names: {', '.join(func_names)}.")
        print(f"Function name {func_name} not in known function names: {', '.join(func_names)}.")
        return map_function(textgen_api, action, functions, chat)

    args = [arg.strip() for arg in matches.group(2).strip().split(",") if len(arg.strip()) > 0]

    action_args = [p.name.replace("-", "_") for p in action.parameters]
    arg_idcs = []
    if any(arg not in action_args for arg in args):
        chat = chat.add_user_text(f"Use action parameters {action_args} instead of {args}.")
        print(f"Action parameters {action_args} do not match function arguments {args}.")
        return map_function(textgen_api, action, functions, chat)

    func = next((f for f in functions.values() if f.name == func_name))
    if len(func.args.args) != len(set(args)):
        chat = chat.add_user_text(f"Function {func_name} expects {len(func.args.args)} arguments, got {len(set(args))}. No duplicates allowed.")
        print(f"Function {func_name} expects {len(func.args.args)} arguments, got {len(set(args))}.")
        return map_function(textgen_api, action, functions, chat)

    for action_arg in action_args:
        if action_arg not in args:
            arg_idcs.append(None)
        else:
            arg_idcs.append(args.index(action_arg))

    assert len([a for a in arg_idcs if a is not None]) == len(func.args.args), f"Function {func_name} expects {len(func.args.args)} arguments, got {len([a for a in arg_idcs if a is not None])} mapped from action {action.name} with args {action_args} to {args}."

    return func_name, arg_idcs


def print_mappings(domain: Domain, mappings: Dict[str, Dict]):
    for action in domain.actions:
        print_mapping(action, mappings)

def print_mapping(action: Action, mappings: Dict[str, Dict]):
    mapping = mappings[action.name]

    args = [None for _ in range(max([m for m in mapping["arg_mapping"] if m is not None]) + 1)]
    for arg_idx, action_arg in zip(mapping["arg_mapping"], action.parameters):
        if arg_idx is not None:
            args[arg_idx] = action_arg.name
    print(f"{action} -> {mapping['name']}({', '.join(args)})")
