import argparse
import logging
from pathlib import Path
import shutil
from typing import Literal, Union

from dotenv import load_dotenv
from llm_utils import TextGenLLMConnections
from llm_utils.textgen_api.textgen_api import TextGenApi
from tp_lodge.task_planning.pddl_planner.hi_planner.out_of_retries_error import OutOfRetriesException
from python_utils.string_utils import remove_docstrings
from tp_lodge.task_planning.models.pddl.pddl_predicate import PDDLPredicate

from demos.furniturebench.scripts.remote_se_motion_validator import FurnitureEnum, RemoteSEMotionValidator
from demos.furniturebench.scripts.fb_variable_parser import FBVariableParser
from state_estimation import VLMGrounder, PredicateOptimGrounder, ReplyBuffer
from tp_lodge.motion_planning.remote_motion_validator import RemoteMotionValidator
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.pddl_planner.hi_planner.nl_action_node import NLActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.pddl_parse_utils import parse_predicate


def main(args, suffix: str):
    load_dotenv()

    domain = "fb-lamp"
    data_dir = Path(__file__).parent.parent / "data"
    sample_dir = data_dir / "fb-lamp"
    assert sample_dir.is_dir()

    out_dir = data_dir.parent / "out" / domain / suffix
    out_dir.mkdir(exist_ok=True, parents=True)

    ip = "localhost"
    port = 8800
    instruction_file = sample_dir / "instruction.txt"
    domain_knowledge_file = sample_dir / "domain_knowledge.txt"

    instruction = instruction_file.read_text().strip()
    domain_knowledge = domain_knowledge_file.read_text().strip()

    textgen_api = TextGenApi(
        connections=TextGenLLMConnections(
            [
                # TextGenLLMConnections.default("openrouter:meta-llama/llama-4-maverick").connections[0],
                TextGenLLMConnections.default("gpt4.1-mini").connections[0],
                TextGenLLMConnections.default("gpt4.1").connections[0],
            ]
        )
    )

    function_stubs = (sample_dir / "function_stubs.py").read_text()
    if args.no_docstrings:
        function_stubs = remove_docstrings(function_stubs)

    # if not args.gen_predicates:
    method: Union[Literal["learn"], Literal["learn-finetune"], Literal["fixed"], Literal["none"]] = args.method
    if method == "fixed":
        motion_validator = RemoteMotionValidator(ip=ip, port=port, furniture=FurnitureEnum.LAMP)
        problem_skeleton = PDDLProblem.loads((sample_dir / "problem-skeleton.json").read_text())
        domain_skeleton = PDDLDomain.loads((sample_dir / "domain-skeleton.json").read_text())
    elif method == "none":
        motion_validator = RemoteMotionValidator(ip=ip, port=port, furniture=FurnitureEnum.LAMP)
        instruction += (
            "\n Initial State: The lamp assembly task starts with the lamp base, lamp bulb, and lamp hood placed on the table. "
            "The arm is hovering above the lamp base. Translate this into a PDDL initial state description."
        )
        problem_skeleton = PDDLProblem.loads((sample_dir / "problem-skeleton.json").read_text())
        domain_skeleton = PDDLDomain(
            operators=[],
            predicates=[
                PDDLPredicate(
                    definition=parse_predicate("(assembled ?obj1 - part ?obj2 - part)"),
                    description="Evaluates whether obj1 has been assembled with obj2 the way they should be, e.g. when obj1 has been screwed into obj2, or obj1 has been put on top of obj2.",
                    predefined=True,
                    pred_type="other",
                ),
            ],
            types={"part": "object", "robot": "object", "table": "object"},
        )
    elif method in ["learn", "learn-finetune"]:
        reply_buffer = ReplyBuffer(buffer_dir=out_dir / "reply_buffer", var_parser=FBVariableParser())
        grounder = PredicateOptimGrounder(
            domain_knowledge=domain_knowledge,
            code_api_file=Path(__file__).parent / "llm_code_interface.py",
            out_dir=out_dir / "grounding_code",
            textgen_api=textgen_api,
            reply_buffer=reply_buffer,
        )
        vlm_grounder = VLMGrounder(
            textgen_api=textgen_api, domain_knowledge=domain_knowledge, out_dir=out_dir / "vlm_grounder"
        )
        motion_validator = RemoteSEMotionValidator(
            grounder=grounder,
            vlm_grounder=vlm_grounder,
            reply_buffer=reply_buffer,
            ip=ip,
            port=port,
            furniture=FurnitureEnum(domain.split("fb-")[1].replace("-", "_")),
        )
        grounder.set_namespace_annotation(motion_validator.get_variables())

        vars = motion_validator.get_variables()
        problem_skeleton = PDDLProblem(objects=[v.pddl_object for v in vars])

        domain_skeleton = PDDLDomain(
            operators=[],
            predicates=[
                PDDLPredicate(
                    definition=parse_predicate("(assembled ?obj1 - part ?obj2 - part)"),
                    description="Evaluates whether obj1 has been assembled with obj2 the way they should be, e.g. when obj1 has been screwed into obj2, or obj1 has been put on top of obj2.",
                    predefined=True,
                    pred_type="other",
                )
            ],
            types={},
        )

        domain_skeleton = domain_skeleton.copy_with(types={v.pddl_type: "object" for v in vars})

    out_file = out_dir / "init-env.hash"
    env_hash = None
    if out_file.is_file():
        env_hash = out_file.read_text()
    motion_validator.reset(seed=1, init_hash=env_hash)
    if not out_file.is_file():
        out_file.write_text(motion_validator.get_env_hash())

    storage = SharedActionNodeStorage(
        domain_knowledge=domain_knowledge,
        function_stubs=function_stubs,
        n_env_retries=10,
        n_val_retries=10,
        use_domain_knowledge=True,
        use_ai_plan=False,
        use_ai_plan_for_llm_planning=False,
        ai_planner_kwargs={
            # "alias": "lama-first",
            "alias": "seq-sat-fdss-2018",
            "search_time_limit": 30,
        },
    )

    planner = NLActionNode(
        storage=storage,
        motion_validator=motion_validator,
        textgen_api=textgen_api,
        out_dir=out_dir,
        domain=domain_skeleton,
        problem_skeleton=problem_skeleton,
        instruction=instruction,
    )
    storage.reset(root_plan_dir=out_dir, clear_backup=False)

    # planner.delete_env_hashes()

    planner.plan(stdout_level=logging.INFO)

    # (out_dir / "result.json").write_text(json.dumps(action_responses))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the lodge script.")
    parser.add_argument("--method", type=str, choices=["learn", "fixed", "none"], required=True)
    parser.add_argument("--no-docstrings", action="store_true", help="Remove docstrings from function stubs.")

    for i in range(1, 4):
        try:
            main(parser.parse_args(), suffix=f"sample-{i}")
        except OutOfRetriesException as e:
            print(e)
            import traceback

            traceback.print_exc()
            continue
