from PIL import Image
import json
from dotenv import load_dotenv
from natsort import natsorted
from state_estimation import VLMGrounder
from tqdm import tqdm
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.utils.pddl_domain_syntax import parse_formula
from pathlib import Path

from llm_utils import TextGenApi


def main():
    root_dir = Path(__file__).parent.parent.parent
    domain = "fb-lamp"
    vlm_llm = "gpt4.1-mini"
    data_llm = "gpt4.1-mini"

    textgen_api = TextGenApi.default(vlm_llm)

    data_dir = root_dir / "data" / domain
    connection = TextGenApi.default(data_llm).connections.connections[0].model_dir
    lodge_result_dirs = (root_dir / "results" / domain / connection / "hi-tamp/planning-with-pred-learning").glob("sample-*")

    out_dir = (
        root_dir
        / "results"
        / domain
        / "predicate-learning-eval"
        / textgen_api.connections.connections[0].model_dir
        / "vlm"
    )
    out_dir.mkdir(parents=True, exist_ok=True)

    for lodge_result_dir in natsorted(lodge_result_dirs):
        print(f"Processing {lodge_result_dir}")
        domain_knowledge = (data_dir / "domain_knowledge.txt").read_text()
        lodge_domain = PDDLDomain.loads((lodge_result_dir / "domain.json").read_text())
        lodge_problem = PDDLProblem.loads((lodge_result_dir / "generated-problem.json").read_text())

        reply_buffer = json.loads((lodge_result_dir / "reply_buffer/states.json").read_text())
        images_dir = lodge_result_dir / "reply_buffer/images"

        out_sample_dir = out_dir / lodge_result_dir.name
        out_sample_dir.mkdir(exist_ok=True)

        grounder = VLMGrounder(domain_knowledge=domain_knowledge, textgen_api=textgen_api, out_dir=out_sample_dir)

        # Save predictions to output directory
        out_predictions_file = out_sample_dir / "predictions.json"
        if out_predictions_file.exists():
            predictions = {
                k: {parse_formula(k2, only_variables=False): v2 for k2, v2 in v.items()}
                for k, v in json.load(open(out_predictions_file)).items()
            }
        else:
            predictions = {}

        for state_hash, state in tqdm(reply_buffer["states"].items()):
            if state_hash in predictions:
                continue

            variables = state["variables"]
            image = Image.open(images_dir / f"state_{state_hash}.png")

            if state["prev_state_hash"] is None:
                pred_predicates = grounder.ground_predicates_of_state(
                    variables=variables,
                    image=image,
                    predicates=lodge_domain.predicates,
                    types=lodge_domain.types,
                    objects=lodge_problem.objects,
                )
            else:
                prev_state_hash = state["prev_state_hash"]
                executed_skill = state["executed_skill"]
                assert prev_state_hash is not None and executed_skill is not None
                assert prev_state_hash in predictions

                prev_grounded = predictions[prev_state_hash]
                prev_image = Image.open(images_dir / f"state_{prev_state_hash}.png")

                pred_predicates = grounder.ground_predicates_of_state(
                    variables=variables,
                    image=image,
                    predicates=lodge_domain.predicates,
                    types=lodge_domain.types,
                    objects=lodge_problem.objects,
                    executed_skill=executed_skill,
                    prev_image=prev_image,
                    prev_grounded=prev_grounded,
                )

            predictions[state_hash] = pred_predicates

            # Save predictions to output directory
            out_predictions_file.write_text(
                json.dumps({k: {str(k2): v2 for k2, v2 in v.items()} for k, v in predictions.items()}, indent=4)
            )


if __name__ == "__main__":
    load_dotenv()
    main()
