import argparse
import ast
import json
from pathlib import Path
import re
import shutil
from pddl.parser.domain import DomainParser
from llm_utils.textgen_api.textgen_api_connections import TextGenLLMConnections
from pddl.core import Requirements
from tp_lodge.utils.pddl_domain_syntax import _parentheses_groups
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
import argparse
import os
from pathlib import Path
import subprocess

from dotenv import load_dotenv
from llm_utils.textgen_api.textgen_api import TextGenApi

demos_root_dir = Path(__file__).parent.parent


def _demo_root_dir(args):
    return demos_root_dir / ("furniturebench" if args.domain.startswith("fb") else "ipc")


def _out_dir(args):
    connection = TextGenApi.default(args.llm).connections.connections[0]
    return (
        _demo_root_dir(args)
        / "results"
        / args.domain
        / connection.model_dir
        / "guan"
        / ("no-docstrings" if args.no_docstrings else "docstrings")
    )


def _guan_dir():
    return demos_root_dir.parent / "3rdparty/LLMs-World-Models-for-Planning"


def _guan_data_dir(args):
    return _guan_dir() / "prompts" / args.domain


def run_guan(args):
    load_dotenv()

    exp_dir = _out_dir(args)

    data_dir = _demo_root_dir(args) / "data" / args.domain
    domain_skeleton = PDDLDomain.loads((data_dir / "domain_skeleton.json").read_text()).to_pddl()
    function_stubs = ast.parse((data_dir / "function_stubs.py").read_text()).body

    assert not args.no_docstrings

    env = os.environ.copy()
    guan_root = _guan_dir()
    guan_python_path = guan_root / ".pixi/envs/default/bin/python"

    guan_data_dir = _guan_data_dir(args)
    guan_data_dir.mkdir(exist_ok=True)
    (guan_data_dir / "action_model.json").write_text(
        json.dumps(
            {func.name.replace("_", " "): {"desc": func.body[0].value.s, "extra_info": []} for func in function_stubs}
        )
    )
    shutil.copyfile(data_dir / "domain_knowledge.txt", guan_data_dir / "domain_desc.txt")
    (guan_data_dir / "hierarchy_requirements.json").write_text(
        json.dumps(
            {
                "hierarchy": {
                    t: [] if paren_type == "object" else [paren_type] for t, paren_type in domain_skeleton.types.items()
                },
                "requirements": list(str(r).split(":")[1] for r in domain_skeleton.requirements),
            }
        )
    )

    response = subprocess.run(
        [
            guan_python_path,
            guan_root / "run_all.py",
            "--llm",
            args.llm,
            "--domain",
            args.domain,
            "--out-dir",
            exp_dir,
            "--n-samples",
            "3",
        ],
        env=env,
        cwd=guan_root,
    )

    if response.returncode != 0:
        print("Error running sadegh main.py:")
        print(response.returncode)
        if response.stdout is not None:
            print(response.stdout.decode("utf-8"))
        if response.stderr is not None:
            print(response.stderr.decode("utf-8"))
        raise RuntimeError("Failed to run sadegh main.py")


def _parse_action(data, action_name: str):
    ps = []
    for parameter in data["parameters"]:
        p = parameter.split(":")[0]
        ps.append(p)

    if "preconditions" in data:
        idx = re.search(r"\(", data["preconditions"]).start()
        precond_contents = list(_parentheses_groups(data["preconditions"][idx:]))
        assert len(precond_contents) == 1
        preconditions_l = precond_contents[0]
    else:
        preconditions_l = "()"

    if "effects" in data:
        idx = re.search(r"\(", data["effects"]).start()
        eff_contents = list(_parentheses_groups(data["effects"][idx:]))
        assert len(eff_contents) == 1
        effects_l = eff_contents[0]
    else:
        effects_l = "()"

    action_str = f"""
    (:action {action_name.lower().replace(" ", "-")}
        :parameters ({" ".join([p for p in ps])})
        :precondition {preconditions_l}
        :effect {effects_l}
    )
    """

    return action_str


def post_process_guan(args):
    engine = TextGenLLMConnections.default(args.llm).connections[0].model_dir

    expr_result_dirs = _out_dir(args).glob("sample-*")
    for expr_result_dir in expr_result_dirs:
        expr_result_dir = expr_result_dir / "details"
        domain_file = expr_result_dir.parent / "domain.pddl"

        if not domain_file.is_file():
            print("Domain file not found. Regenerating...")

            # actions
            actions_file = expr_result_dir / f"{engine}_pddl.json"
            actions_data = json.loads(actions_file.read_text())

            actions = []
            for a_name, action_dict in actions_data.items():
                action = _parse_action(action_dict, a_name)
                actions.append(action)

            # predicates
            predicates_file = expr_result_dir / f"{engine}_predicates.txt"
            predicates_data = predicates_file.read_text().splitlines()
            predicates = list(map(lambda data: re.match(r".*(\(.*\)).*:(.*)", data).group(1), predicates_data))

            # types
            types_file = _guan_dir() / f"prompts/{args.domain}/hierarchy_requirements.json"
            types_data = json.loads(types_file.read_text())
            types = {n: "object" for n in types_data["hierarchy"].keys()}
            types["smallReceptacle"] = "object"

            domain_str = f"""
            (define (domain {args.domain})
                (:requirements {' '.join(map(str, Requirements.adl_requirements()))})
                (:types {" ".join(types.keys())})
                (:predicates {" ".join(predicates)})
                {" ".join(actions)}
            )
            """

            domain_file.write_text(domain_str)

        while True:
            try:
                DomainParser()(domain_file.read_text())
                break
            except Exception as e:
                print(f"Error parsing domain.pddl: {e}")
                input(f"Fix domain {domain_file} and press Enter to continue...")


def main(args):
    run_guan(args)
    post_process_guan(args)
    shutil.rmtree(_guan_data_dir(args))


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--domain", type=str, required=True)
    argparser.add_argument("--llm", type=str, required=True)
    argparser.add_argument("--no-docstrings", action="store_true", help="Use domain without docstrings")
    main(argparser.parse_args())
