import argparse
import ast
from collections import defaultdict
import json
import shutil
import subprocess
from dotenv import load_dotenv
from natsort import natsorted
from demos.ipc.src.common.function_mapping import map_functions
from pathlib import Path
from typing import List
from llm_utils.textgen_api.textgen_api import TextGenApi
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain


def run_lodge(args, subcall_kwargs: List[str]):
    # Build the command, filtering out empty strings
    if args.domain.startswith("fb-"):
        cmd = [
            "python",
            str(Path(__file__).parent.parent / "furniturebench/scripts/run_lodge.py"),
            "--method",
            "learn",
            "--no-docstrings",
        ]

    else:
        cmd = [
            "python",
            str(Path(__file__).parent.parent / "ipc/scripts/01-run_lodge.py"),
            *[e for c in subcall_kwargs for e in (c.split(" ") if " " in c else [c])],
        ]

    response = subprocess.run(cmd)
    print("Response:", response)
    assert response.returncode == 0, "Failed to generate template files.\n%s\n%s" % (
        response.stdout.decode("utf-8") if response.stdout else "",
        response.stderr.decode("utf-8") if response.stderr else "",
    )


def post_process_lodge_output(task: str, engine: str, subdir: str):
    print("Post-processing lodge output for task:", task, "engine:", engine, "subdir:", subdir)
    load_dotenv()
    textgen_api = TextGenApi.default("gpt4.1-mini")
    engine_id = TextGenApi.default(engine).connections.connections[0].model_dir
    is_fb = task.startswith("fb-")

    root_dir = Path(__file__).parent.parent / ("furniturebench" if is_fb else "ipc")
    out_dirs = root_dir / "results" / task / engine_id / "hi-tamp"

    if not is_fb:
        subdir = (Path(subdir) / "task-*").as_posix()

    for task_dir in natsorted(out_dirs.glob(subdir)):
        if not task_dir.is_dir():
            continue

        domain_file = task_dir / "domain.pddl"
        if domain_file.exists():
            continue

        # generate domain
        pddl_domain_file = task_dir / "domain.json"
        if not pddl_domain_file.is_file():
            shutil.copyfile(task_dir / "generated-domain.json", pddl_domain_file)
        if not (task_dir / "problem.json").is_file():
            shutil.copyfile(task_dir / "generated-problem.json", task_dir / "problem.json")
        pddl_domain = PDDLDomain.loads(pddl_domain_file.read_text())

        if pddl_domain is None:
            print("Failed to generate domain for:", task_dir)
            continue

        # filter hierarchical operators
        op_hierarchy = defaultdict(list)
        for op in pddl_domain.operators:
            op_hierarchy[op.parent_operator_id].append(op.id)
        all_parent_op_ids = list(op_hierarchy.keys())
        pddl_leaf_domain = pddl_domain.copy_with(
            operators=[
                op
                for op in pddl_domain.operators
                if (op.mapped_skill_sequence is not None and len(op.mapped_skill_sequence) == 1)
                or (op.mapped_skill_sequence is None and op.id not in all_parent_op_ids)
            ]  # only leaf ops
            # operators=[op for op in pddl_domain.operators if op.id not in all_parent_op_ids]  # only leaf ops
        )
        # clean mapping and domain
        mappings_file = domain_file.parent / "function_mapping.json"
        print("Generating function mappings for domain:", mappings_file)

        # get operator -> lifted skill mapping + arg indexing
        operators = [op for op in pddl_leaf_domain.operators if op.mapped_skill_sequence]

        def map_operators(operators, mapping: dict):
            for operator in operators:
                assert operator.name not in mapping
                fct_list = operator.mapped_skill_sequence
                if fct_list is None or len(fct_list) == 0:
                    continue

                assert len(fct_list) == 1, f"Expected exactly one function for action {operator.name}, got {fct_list}"
                function = ast.parse(fct_list[0]).body[0].value

                assert isinstance(function, ast.Call)
                function_args = [arg.id for arg in function.args] + [kw.value.id for kw in function.keywords]
                action_args = [
                    arg.name.replace("-", "_") + "_" + list(arg.type_tags)[0] for arg in operator.definition.parameters
                ]

                arg_mapping = [function_args.index(a) if a in function_args else None for a in action_args]

                assert len(function_args) == len(
                    [a for a in arg_mapping if a is not None]
                ), f"Argument mapping incomplete for action {operator.name}: {arg_mapping}"

                mapping[operator.name] = {"name": function.func.id, "arg_mapping": arg_mapping}

        def _map_hierarchy_level(op_id: str, operators, mapping: dict):
            ops_for_level = [op for op in operators if op.parent_operator_id == op_id and op.name not in mapping]
            map_operators(ops_for_level, mapping)

            mapped_opds = ops_for_level.copy()
            for child_op_id in op_hierarchy[op_id]:
                child_mapped_ops = _map_hierarchy_level(child_op_id, operators, mapping)
                for child_mapped_op in child_mapped_ops:
                    f_name = mapping[child_mapped_op.name]["name"]
                    existing_f_names = [mapping[op.name]["name"] for op in mapped_opds]
                    if f_name is not None:  #  and f_name not in existing_f_names:
                        mapped_opds.append(child_mapped_op)
            return mapped_opds

        cleaned_mappings = {}
        mapped_ops = _map_hierarchy_level("root", operators, cleaned_mappings)

        # llm to fill missing mappings
        functions_text = (root_dir / "data" / task / "function_stubs.py").read_text()
        functions = {fd.name: fd for fd in ast.parse(functions_text).body}
        cleaned_mappings = map_functions(
            textgen_api=textgen_api,
            domain=pddl_leaf_domain.copy_with(
                operators=[op for op in pddl_leaf_domain.operators if not op.verified]
            ).to_pddl(),
            functions=functions,
            existing_mapping=cleaned_mappings,
            confirm=True,
        )
        # add llm-mapped ops to mapped ops, mapped ops contains all operators valid (removes ones with duplicate names)
        for op in pddl_leaf_domain.operators:
            if op.verified:
                # assert op.name in [o.name for o in mapped_ops], f"Verified operator {op.name} not in mapped operators!" can happen if parent operator has been deleted -> dangling children
                continue
            if cleaned_mappings.get(op.name, None) is not None and op.name not in [o.name for o in mapped_ops]:
                mapped_ops.append(op)

        mappings_file.write_text(json.dumps(cleaned_mappings, indent=4, sort_keys=True))

        leaf_domain = pddl_leaf_domain.copy_with(operators=mapped_ops)
        assert leaf_domain.has_unique_names()
        domain_file.write_text(str(leaf_domain.to_pddl()))


if __name__ == "__main__":  #
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--domain", type=str, required=True)
    argparser.add_argument("--llm", type=str, required=True)
    argparser.add_argument("--suffix", type=str, required=False)
    args, unknown_args = argparser.parse_known_args()

    # Convert known args to dict and filter out None values
    kwargs = [f"--domain {args.domain}", f"--llm {args.llm}", f"--suffix {args.suffix}", *unknown_args]

    run_lodge(args, subcall_kwargs=kwargs)

    post_process_lodge_output(task=args.domain, engine=args.llm, subdir=args.suffix)
