import copy
from python_utils.string_utils import get_markup_from_text
from tp_lodge.utils.pddl_domain_syntax import parse_formula
import json
from pathlib import Path

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.openai_api.user_message import UserMessage
from llm_utils.textgen_api.textgen_api import TextGenApi
from demos.ipc.scripts.run_lodge import _get_domain
from demos.ipc.src.household.household_motion_validator import HouseholdMotionValidator
from demos.ipc.src.planning_benchmark_sample_generator import PlanningBenchmarkSampleGenerator
from tp_lodge.motion_planning.local_motion_validator import LocalMotionValidator, PDDLDomain
from tp_lodge.utils.pddl_parse_utils import PDDLPredicate
from tp_lodge.utils.pddl_utils import get_valid_predicates


prompt = """
You are a helpful assistant that generates PDDL goals for planning tasks.

```pddl
%s
```

Translate the following task instruction into a PDDL goal:
%s

Return the goal in PDDL format, without any additional text or explanation, wrapped in `pddl` tags.
"""


def main():
    task = "household"

    data_dir = Path(__file__).parent.parent.parent.parent.parent / "demos/ipc/data" / task
    out_dir = Path(__file__).parent.parent.parent / "data/domains" / task
    guan_dir = Path(__file__).parent.parent.parent.parent / "LLMs-World-Models-for-Planning/prompts/household"

    new_preds = [
        PDDLPredicate(definition=parse_formula(p), description="", predefined=True)
        for p in [
            "(pickupable ?x - household_object)",
            "(stackable ?x - household_object)",
            "(object-clear ?x - household_object)",
            "(flat-surface ?x - furniture_appliance)",
            "(toggleable ?x - household_object)",
            "(sliceable ?x - household_object)",
            "(washable ?x - household_object)",
        ]
    ]

    generator = PlanningBenchmarkSampleGenerator(task)
    motion_validator = HouseholdMotionValidator()
    domain_skeleton = PDDLDomain.from_json(json.loads((data_dir / "domain_skeleton.json").read_text()))
    domain_skeleton.predicates.extend(new_preds)

    action_model = json.loads((guan_dir / "action_model.json").read_text())
    domain_description = (guan_dir / "domain_desc.txt").read_text().strip()
    domain_nl = [
        "%d. %s: %s" % (idx + 1, action.lower().replace(" ", "-"), desc['desc'] + " ".join(desc['extra_info']))
        for idx, (action, desc) in enumerate(action_model.items())
    ]

    domain_nl = """
%s

The domain incudes following actions:
%s
    """ % (domain_description, "\n\n".join(domain_nl))
    (out_dir / "domain.nl").write_text(domain_nl)

    # generate problems
    llm = TextGenApi.default(connection="claude4-sonnet")
    for i in range(len(generator)):
        if i == 0:
            pddl_out_file = out_dir / f"p_example.pddl"
        else:
            pddl_out_file = out_dir / f"p{i+1:02d}.pddl"
        if pddl_out_file.exists():
            continue

        instruction, env_state, problem_skeleton = generator.generate(i)

        assert isinstance(motion_validator, LocalMotionValidator)
        assert hasattr(motion_validator.env, "_init_env_state")
        motion_validator.env._init_env_state = copy.deepcopy(env_state)
        motion_validator.env._set_state(new_state=env_state)

        problem_skeleton = motion_validator.inject_init_predicates(domain=domain_skeleton, problem=problem_skeleton)

        response = llm.do_call(
            Chat([UserMessage([TextMessageContent(prompt % (str(domain_skeleton.to_pddl()), instruction))])])
        )

        assert isinstance(response.content[0], TextMessageContent), "Expected a TextMessageContent response"

        print(response.content[0].text.strip())

        pddl_goal_txt = get_markup_from_text(response.content[0].text.strip(), ["pddl"])[0]
        pddl_goal = parse_formula(
            pddl_goal_txt, only_variables=False, known_predicates=[p.definition for p in domain_skeleton.predicates]
        )

        print("Generated PDDL Goal:")
        print(pddl_goal)

        problem = problem_skeleton.copy_with(
            initial_state=get_valid_predicates(problem_skeleton.initial_state),
            goal_state=pddl_goal,
        )

        pddl_out_file.write_text(str(problem.to_pddl()))


if __name__ == "__main__":
    main()
