import copy
import json
import shutil
from demos.ipc.src.common.utils import get_task_centric_domain
from python_utils.string_utils import get_markup_from_text
from pathlib import Path
from typing import Optional
from llm_utils.openai_api.chat_factory import ChatFactory
from llm_utils.prompt_generation.prompt import Prompt
from llm_utils.textgen_api.textgen_api import TextGenApi
from natsort import natsorted
from pddl.parser.domain import DomainParser


def construct_correction_dict(sample_dir: Path, domain_str: str, approach: str) -> dict:
    if approach == "hi-tamp":
        mappings_file = sample_dir / "function_mapping.json"
        assert mappings_file.exists()
        mappings = json.loads(mappings_file.read_text())

    domain = DomainParser()(domain_str)

    data = {"actions": {}, "predicates": [], "any-other-mistakes": []}

    ac_dict = {}
    for action in domain.actions:
        ac = {
            "params": [],
            "precond": [],
            "effects": [],
        }
        if approach == "hi-tamp":
            ac["skill"] = mappings.get(action.name, None)

        ac_dict[action.name] = ac
    data["actions"] = ac_dict

    return data


def main(task: str, engine: str, approach: str, subdir: Optional[str] = None):
    textgen_api = TextGenApi.default(engine)
    engine = textgen_api.connections.connections[0].model_dir
    strong_textgen_api = TextGenApi.default("claude3.7-sonnet")

    root_dir = Path(__file__).parent.parent
    out_dir = root_dir / "results" / task / engine / approach
    assert out_dir.exists()

    trdparty_dir = root_dir.parent.parent / "3rdparty" / "LLMs-World-Models-for-Planning" / "prompts" / task
    skills = (trdparty_dir / "action_model.json").read_text()
    domain_description = (trdparty_dir / "domain_desc.txt").read_text()

    if subdir is not None:
        # if we have a final version
        out_dir = out_dir / subdir
    assert out_dir.is_dir()

    for task_dir in natsorted(out_dir.iterdir()):
        if not task_dir.is_dir():
            continue

        eval_dir = task_dir / "eval"

        ai_correction_file = eval_dir / "ai-correction.json"
        human_correction_file = eval_dir / "human-correction.json"
        if human_correction_file.exists():
            print("Already evaluated:", task_dir.name)
            continue

        print("Processing task dir:", task_dir.name)

        domain_file = task_dir / "domain.pddl"
        assert domain_file.is_file()
        if not domain_file.exists():
            alt_domain_file = next(task_dir.glob("*domain*.pddl"), None)
            if alt_domain_file is not None:
                shutil.copy(alt_domain_file, domain_file)
        assert domain_file.is_file()

        actions_to_skills = {}  # TODO:
        domain = get_task_centric_domain(
            domain=domain_file.read_text(), config=actions_to_skills, task=task, task_name=task_dir.name
        )

        prompt = Prompt.load_from_file(root_dir / "prompts" / "check-domain-correctness.xml", nonce=None)

        prompt.replace("{domain}", domain)
        prompt.replace("{domain_description}", domain_description)
        prompt.replace("{skills}", skills)
        if approach == "hi-tamp":
            correction_dict = construct_correction_dict(task_dir, domain, approach)
            prompt.replace("{stub}", json.dumps(correction_dict, indent=4, sort_keys=True))

        chat = prompt.to_chat()

        response = strong_textgen_api.do_call(chat)

        chat = chat.add_message(response)

        print("Response:", response.content[0].text)

        eval_dir.mkdir(exist_ok=True)
        (eval_dir / "chat.json").write_text(json.dumps(ChatFactory().to_json(chat)))
        (eval_dir / "chat.txt").write_text(str(chat))

        json_cells = get_markup_from_text(response.content[0].text, ["json"])
        assert len(json_cells) == 1
        json_data = json.loads(json_cells[0])

        ai_correction_file.write_text(json.dumps(json_data, indent=4, sort_keys=True))
        json_data = copy.deepcopy(json_data)
        json_data["generated"] = True
        human_correction_file.write_text(json.dumps(json_data, indent=4, sort_keys=True))


if __name__ == "__main__":
    main("household", "gpt4.1-mini", "hi-tamp", "final-wo-dk-wo-ai")
