from datetime import datetime
import re
from collections import defaultdict
from itertools import product
import logging
from pathlib import Path
from llm_utils import TextGenApi, Prompt
from PIL import Image
from typing import Dict, List, Optional
from pddl.core import Constant, Predicate
from tp_lodge.task_planning.models.pddl.pddl_predicate import PDDLPredicate


logger = logging.getLogger(__name__)


class VLMGrounder:

    def __init__(self, domain_knowledge: str, textgen_api: TextGenApi, out_dir: Path) -> None:
        self.domain_knowledge = domain_knowledge
        self.textgen_api = textgen_api
        out_dir.mkdir(exist_ok=True)
        self.out_dir = out_dir

    def ground_predicates_of_state(
        self,
        variables: List[Dict],
        image: Image.Image,
        predicates: List[PDDLPredicate],
        types: Dict[str, str],
        objects: List[Constant],
        prev_image: Optional[Image.Image] = None,
        executed_skill: Optional[str] = None,
        prev_grounded: Optional[dict[Predicate, Optional[bool]]] = None,
    ) -> dict[Predicate, Optional[bool]]:
        logger.info(f"Grounding predicates {', '.join(p.name for p in predicates)} for state...")
        grounding_id = datetime.now().isoformat()

        ll_objects_per_type = defaultdict(list)
        for obj in objects:
            ll_objects_per_type[obj.type_tag].append(obj)

        ground_predicates = []
        pred_mapping = {}
        for predicate in predicates:
            term_types = [list(term.type_tags)[0] for term in predicate.definition.terms]
            ll_objects_per_arg = [[t for t in ll_objects_per_type[term_type]] for term_type in term_types]
            ground_predicates.append(f"# Grounding {predicate.name}: {predicate.description}")
            for args in list(product(*ll_objects_per_arg)):
                if len([a.name for a in args]) != len(set(a.name for a in args)):
                    # print("predicates with multiple times same args are ignored")
                    continue
                # grounded_predicate = Predicate(predicate.name, *args)

                p_name = predicate.name.replace("-", "_")
                pred_mapping[p_name] = predicate.name
                grounded_predicate = f"{p_name}({', '.join(a.name for a in args)})"
                ground_predicates.append(grounded_predicate)

        def var_to_str(var) -> str:
            if isinstance(var, str):
                return var
            elif isinstance(var, (list, tuple)):
                return f"[{', '.join(var_to_str(v) for v in var)}]"
            elif isinstance(var, float):
                return f"{var:.3f}"
            elif isinstance(var, int):
                return str(var)

        vars_str = []
        for var in variables:
            var_str = f"{var['name']}\n"
            var_str += "\n".join(f"- {k}: {var_to_str(v)}" for k, v in var["value"].items() if k not in ['bounding_box'])
            vars_str.append(var_str)

        if prev_grounded is not None: 
            assert executed_skill is not None
            assert prev_image is not None
            prompt = Prompt.load_from_file(Path(__file__).parent / "prompts/vlm_grounder_t_lt_0.xml")

            prev_grounded_str = "\n".join(("- %s: %s" % (k, str(v) if v is not None else "Unknown") for k, v in prev_grounded.items()))

            prompt.replace_all(
                skill=executed_skill,
                prior_ground_predicates=prev_grounded_str,
            )
        else:
            prompt = Prompt.load_from_file(Path(__file__).parent / "prompts/vlm_grounder.xml")
        prompt.replace_all(
            predicates="\n".join(ground_predicates),
            domain_knowledge=self.domain_knowledge,
            # objects=", ".join(objs_str),
            pose_estimations="\n".join(vars_str),
        )

        chat = prompt.to_chat(images={"scene": image, "prior_scene": prev_image})

        ground_results = {}

        ground_predicates = [p for p in ground_predicates if not p.startswith("#")]
        for _ in range(5):
            response = self.textgen_api.do_call(chat, connection_id="gpt-4.1-2025-04-14")

            chat = chat.add_message(response)

            (self.out_dir / f"{grounding_id}.chat").write_text(str(chat))
            text = response.content[0].text

            for ground_predicate in ground_predicates:
                if ground_predicate in ground_results:
                    continue

                match = re.search(r"%s\: (False|True|Unknown)\." % (re.escape(ground_predicate)), text)
                if "()" in ground_predicate and match is None:
                    match = re.search(r"%s\: (False|True|Unknown)\." % (re.escape(ground_predicate.split("()")[0])), text)

                if match is not None:
                    ground_predicate_status = match.group(1).lower()
                    if ground_predicate_status == "unknown":
                        ground_results[ground_predicate] = None
                    else:
                        ground_results[ground_predicate] = ground_predicate_status == "true"

            ground_predicates = [p for p in ground_predicates if p not in ground_results]

            if len(ground_predicates) > 0:
                logger.info(f"Still need to ground {len(ground_predicates)} predicates...")
                chat = chat.add_user_text(f"Still grounding predicates. Ensure you use these exact strings:\n%s"% ('\n'.join(ground_predicates)))
                continue
            break
        assert len(ground_predicates) == 0, f"Could not ground predicates: {', '.join(ground_predicates)}"

        g_predicates = {}
        for pred, result in ground_results.items():
            name, args = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\((.*)\)", pred).groups()
            name = pred_mapping[name]
            args = [a.strip() for a in args.split(",") if a.strip()]
            pred = Predicate(name, *[Constant(a) for a in args])
            g_predicates[pred] = result

        return g_predicates
