from dataclasses import dataclass
import copy
from collections import defaultdict
import json
import random
from textwrap import indent
from pddl.core import And, Predicate, Formula
from pathlib import Path
from typing import Callable, List

from llm_utils.textgen_api import TextGenApi
from tp_lodge.utils.pddl_utils import (
    equalized_effects,
    get_valid_predicates,
    get_predicates_used_in_formula,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLPredicate

from state_estimation.predicate_grounder import PredicateGrounder, GroundingCallable
from state_estimation.predicate_optim_params import compute_decision_metric, PredicateOptimParams
from state_estimation.motion_validation.reply_buffer import ReplyBuffer
from state_estimation.se_variable import SEVariable
import logging

data_dir = Path(__file__).parent
logger = logging.getLogger(__name__)


@dataclass
class PredicateOptimMetadata:
    """Metadata class that handles persistence and automatic saving."""

    def __init__(self, metadata_file: Path, always_save: bool = True):
        self.metadata_file = metadata_file
        self.always_save = always_save
        self._verified_predicates_state = {}
        self._presented_reply_states = set()
        self._load()

    def _load(self):
        """Load metadata from file if it exists."""
        if self.metadata_file.exists():
            try:
                metadata = json.loads(self.metadata_file.read_text())
                self._verified_predicates_state = metadata.get("verified_predicates_state", {})
                self._presented_reply_states = set(tuple(e) for e in metadata["presented_reply_states"])
            except (json.JSONDecodeError, KeyError) as e:
                logger.warning(f"Failed to load metadata from {self.metadata_file}: {e}")
                self._verified_predicates_state = {}
                self._presented_reply_states = set()

    def _save(self):
        """Save metadata to file."""
        metadata = {
            "verified_predicates_state": self._verified_predicates_state,
            "presented_reply_states": list(self._presented_reply_states),
        }
        self.metadata_file.parent.mkdir(parents=True, exist_ok=True)
        self.metadata_file.write_text(json.dumps(metadata, indent=2))

    @property
    def verified_predicates_state(self):
        return self._verified_predicates_state

    @verified_predicates_state.setter
    def verified_predicates_state(self, value):
        self._verified_predicates_state = value
        if self.always_save:
            self._save()

    @property
    def presented_reply_states(self):
        return self._presented_reply_states

    @presented_reply_states.setter
    def presented_reply_states(self, value):
        self._presented_reply_states = value
        if self.always_save:
            self._save()

    def update_verified_predicate_state(self, predicate_name: str, state_value: int):
        """Update a single predicate state and save if auto-save is enabled."""
        self._verified_predicates_state[predicate_name] = state_value
        if self.always_save:
            self._save()

    def add_presented_reply_state(self, state_tuple):
        """Add a presented reply state and save if auto-save is enabled."""
        self._presented_reply_states.add(state_tuple)
        if self.always_save:
            self._save()

    def force_save(self):
        """Force save metadata regardless of always_save setting."""
        self._save()


@dataclass
class VLMGrounderDivergenceException(Exception):
    state_hash: str
    vlm_predicates: set[Predicate]
    grounder_predicates: set[Predicate]
    variables: list[SEVariable]
    ref_evals: dict[str, Formula]


class PredicateOptimGrounder(PredicateGrounder):

    def __init__(
        self,
        code_api_file: Path,
        out_dir: Path,
        reply_buffer: ReplyBuffer,
        textgen_api: TextGenApi,
        domain_knowledge: str,
        always_save_metadata: bool = True,
    ) -> None:
        super().__init__(
            code_api_file=code_api_file, out_dir=out_dir, textgen_api=textgen_api, domain_knowledge=domain_knowledge
        )
        self.out_dir = out_dir
        self.reply_buffer = reply_buffer

        # Initialize externalized metadata with auto-save capability
        metadata_file = self.out_dir / "metadata.json"
        self.metadata = PredicateOptimMetadata(metadata_file, always_save=always_save_metadata)

    def _eval_grounder_function(
        self,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        func: GroundingCallable,
        count_similar: bool = True,
    ):
        eval = []
        differences: list[VLMGrounderDivergenceException] = []
        for state_hash, state in self.reply_buffer.get_all_states().items():
            if state.similar_state is not None:
                # we don't want to evaluate duplicate/similar states
                if not count_similar:
                    continue
                label_predicates = self.reply_buffer.get_similar_state(state)[1].predicates
            else:
                label_predicates = state.predicates

            assert label_predicates is not None
            func_ground_predicates = func.ground(predicate=predicate.definition, variables=state.variables)
            all_preds = get_predicates_used_in_formula(And(*func_ground_predicates))
            valid_ground_predicates = get_valid_predicates(func_ground_predicates)

            vlm_predicates = {k: v for k, v in label_predicates.items() if predicate.name == k.name}
            assert len(vlm_predicates) <= len(func_ground_predicates)

            for pred in all_preds:
                if pred not in vlm_predicates:
                    # , "pred should always be evaluated by VLM"
                    # must not be evaluated. child predicates are not evaluated on samples from a different branch
                    continue
                eval.append(
                    {
                        "pred": str(pred),
                        "grounder": pred in valid_ground_predicates,
                        "vlm": vlm_predicates.get(pred, None),
                    }
                )

            valid_vlm_preds = [k for k, v in vlm_predicates.items() if v is True]
            # invalid_vlm_preds = [k for k, v in vlm_predicates.items() if v is False]
            unknown_vlm_preds = [k for k, v in vlm_predicates.items() if v is None]

            # with similar states only compute metric, dont list as samples
            if (
                len(set(valid_ground_predicates) - set(valid_vlm_preds + unknown_vlm_preds)) > 0
                or len(set(valid_vlm_preds) - set(valid_ground_predicates)) > 0
            ):
                ref_evals = {}
                for ref_name, ref_grounder in func.referenced_groundings.items():
                    pred = next((p for p in existing_predicates if p.name.replace("-", "_") == ref_name))
                    ref_evals[ref_name] = And(
                        *ref_grounder.ground(predicate=pred.definition, variables=state.variables)
                    )

                diff = VLMGrounderDivergenceException(
                    state_hash=state_hash,
                    vlm_predicates=set(valid_vlm_preds) - set(valid_ground_predicates),
                    grounder_predicates=set(valid_ground_predicates) - set(valid_vlm_preds + unknown_vlm_preds),
                    variables=state.variables,
                    ref_evals=ref_evals,
                )
                differences.append(diff)
        return eval, differences

    def _get_error_str_for_diffs(self, diffs) -> str:
        assert len(diffs) > 0, "No differences found, but decision score is low."
        diff_str = []
        for idx, diff in enumerate(diffs, start=1):
            vars = []
            for var in diff.variables:
                vars.append(f"-{var.name}: {self.reply_buffer.var_parser.get_printable_for_llm(var).value}")
            vars_str = indent("\n".join(vars), "  ")

            vlm_predicates, grounder_predicates = equalized_effects(
                prior=[], post_a=diff.vlm_predicates, post_b=diff.grounder_predicates
            )
            ref_evals_str = indent("\n".join(f"- {k}: {v}" for k, v in diff.ref_evals.items()), "  ")
            diff_str.append(
                f"{idx}.\n"
                f"- VLM predicates: {And(*vlm_predicates)}\n"
                f"- Grounder predicates: {And(*grounder_predicates)}\n"
                f"- Referenced Predicate Evals:\n{ref_evals_str}\n"
                f"- Variables:\n"
                f"{vars_str}\n"
            )

        return "\n".join(diff_str)

    def _verify_grounder_function(
        self,
        func: GroundingCallable,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        count_similar: bool,
    ):
        syntax_ok = True
        error = None
        try:
            eval, diffs = self._eval_grounder_function(
                predicate, existing_predicates, func, count_similar=count_similar
            )
            decision_scores = compute_decision_metric(eval)
            support = {
                k: sum([p["vlm"] for p in eval if p["pred"] == k and p["vlm"] is not None])
                for k in decision_scores.keys()
            }
            if any(score < 1 for score in decision_scores.values()):
                error = "The grounding function does not have enough support."

        except (ImportError, NameError, TypeError) as e:
            error = f"The function does not align with the predicate. Invoking it returns {e}"
            logger.info(error)
            syntax_ok = False
            support = {}
            decision_scores = {}
            diffs = []

        return syntax_ok, decision_scores, support, diffs, error

    def get_grounder_function(
        self,
        predicate: PDDLPredicate,
        existing_predicates: List[PDDLPredicate],
        differentiable: bool = False,
        verify: bool = True,
    ) -> Callable:
        assert predicate.is_visual
        g_callable = super().get_grounder_function(
            predicate=predicate,
            existing_predicates=existing_predicates,
            differentiable=differentiable,
            verify=verify,
        )

        if verify and self.metadata.verified_predicates_state.get(predicate.name, 0) != len(self.reply_buffer):
            # we did not verify the predicate grounder on the current reply buffer state
            syntax_ok, decision_scores, support, _, error = self._verify_grounder_function(
                func=g_callable,
                predicate=predicate,
                existing_predicates=existing_predicates,
                count_similar=True,  # must be true to fine-tune on similar samples
            )

            # gradient-free refinement
            # has_enough_support = all(v == 0 or v >= 2 for v in support.values()) # at least two samples to do gradient-free optimization
            has_enough_support = True
            did_not_try_optim = len(g_callable.hps) == 0 and len(g_callable.get_hps_names()) > 0
            has_low_score = any(score < 0.9 for score in decision_scores.values())
            # has_low_score = True
            if syntax_ok and has_enough_support and (has_low_score or did_not_try_optim):
                try:
                    PredicateOptimParams().optim(
                        pddl_predicate=predicate, reply_buffer=self.reply_buffer, func=g_callable, factor=0.6
                    )
                except TypeError as e:
                    syntax_ok = False
                    error = f"The function has a syntax error. Calling it returns '{e}'"

            if error is not None:
                # llm self-refinement
                syntax_ok, decision_scores, _, diffs, error = self._verify_grounder_function(
                    func=g_callable, predicate=predicate, existing_predicates=existing_predicates, count_similar=True
                )
                decision_scores = self._verify_grounder_function(
                    func=g_callable, predicate=predicate, existing_predicates=existing_predicates, count_similar=True
                )[1]
                diffs = [
                    diff
                    for diff in diffs
                    if (diff.state_hash, predicate.name) not in self.metadata.presented_reply_states
                ]
                if not syntax_ok or (len(diffs) > 0 and any(score < 0.6 for score in decision_scores.values())):
                    # if not syntax_ok or len(diffs) > 0:
                    logger.info(
                        f"Grounding function for predicate {predicate.name} diverged:\n"
                        + "\n".join(["- %s: %.2f" % (k, v) for k, v in decision_scores.items()])
                    )
                    if syntax_ok:
                        assert len(diffs) > 0
                        random.shuffle(diffs)
                        subset_diffs = diffs[:3]
                        error = self._get_error_str_for_diffs(subset_diffs)

                    assert error is not None
                    logger.info(error)
                    grounding_code_str, grounder_description, chat = self.refine_grounder_function_for_predicate(
                        predicate=predicate,
                        existing_predicates=existing_predicates,
                        differentiable=differentiable,
                        error=error,
                    )

                    for diff in diffs:
                        self.metadata.add_presented_reply_state((diff.state_hash, predicate.name))

                    old_g_callable = copy.deepcopy(g_callable)
                    g_callable.update_callable(code=grounding_code_str, description=grounder_description, chat=chat)
                    g_callable = super().get_grounder_function(  # to reinit the referenced groundings
                        predicate=predicate,
                        existing_predicates=existing_predicates,
                        differentiable=differentiable,
                    )
                    # try to again optimize it with gf-optim
                    syntax_ok = self._verify_grounder_function(
                        func=g_callable,
                        predicate=predicate,
                        existing_predicates=existing_predicates,
                        count_similar=True,
                    )[0]
                    if syntax_ok:
                        PredicateOptimParams().optim(
                            pddl_predicate=predicate, reply_buffer=self.reply_buffer, func=g_callable
                        )
                        syntax_ok, new_decision_scores, _, _, _ = self._verify_grounder_function(
                            func=g_callable,
                            predicate=predicate,
                            existing_predicates=existing_predicates,
                            count_similar=True,
                        )
                    else:
                        new_decision_scores = {}
                    if syntax_ok and sum(list(new_decision_scores.values())) > sum(list(decision_scores.values())):
                        # if True:
                        decision_scores = new_decision_scores
                        logger.info(
                            f"Refined grounding function for predicate {predicate.name} with LLM.\nNew decision scores:\n"
                            + "\n".join("- %s: %.2f" % (k, v) for k, v in decision_scores.items())
                        )
                    else:
                        old_g_callable._save_to_disk()

                        if syntax_ok:
                            # no matter what, set we evaluated it so if the optim did not work we dont end up in a endless loop
                            self.metadata.update_verified_predicate_state(predicate.name, len(self.reply_buffer))

                    return self.get_grounder_function(
                        predicate=predicate,
                        existing_predicates=existing_predicates,
                        differentiable=differentiable,
                    )

            self.metadata.update_verified_predicate_state(predicate.name, len(self.reply_buffer))

        return g_callable
