import logging
from statistics import harmonic_mean
from joblib import Parallel
from typing import Union
import numpy as np
import copy
from collections import defaultdict
from pddl.logic.base import Not
import inspect
import pandas as pd
from sklearn.metrics import f1_score, confusion_matrix, fbeta_score
from tp_lodge.utils.pddl_utils import get_predicate_evaluation
from tp_lodge.task_planning.models.pddl.pddl_predicate import PDDLPredicate
from state_estimation.motion_validation.reply_buffer import State
from state_estimation.predicate_grounder import GroundingCallable
from state_estimation.motion_validation.reply_buffer import ReplyBuffer


logger = logging.getLogger(__name__)



def compute_decision_metric_all(eval: list[dict[str, int]], zero_division: int = 1):
    df_pred = pd.DataFrame(eval)

    # set vlm for None entries to grounder
    df_pred.loc[df_pred["vlm"].isna(), "vlm"] = df_pred["grounder"]
    # df_pred.loc[df_pred["vlm"].isna(), "vlm"] = False
    df_pred = df_pred

    predicate_true = df_pred["vlm"].astype(bool)
    predicate_pred = df_pred["grounder"].astype(bool)
    # weight = df_pred["weight"] if "weight" in df_pred else None

    f1_avg = f1_score(predicate_true, predicate_pred, average='macro', labels=[0, 1], zero_division=zero_division)

    # slightly prioritize false positives
    # f1_avg = fbeta_score(predicate_true, predicate_pred, average='macro', beta=1)

    eps = 1e-6 # to prefer positives
    return f1_avg + eps * np.mean(predicate_pred)
def compute_decision_metric(eval: Union[list[dict[str, int]], pd.DataFrame], zero_division: int = 1) -> dict[str, float]:
    if isinstance(eval, list):
        df = pd.DataFrame(eval)
    else:
        df = eval
    metric_per_pred = {}
    for pred in df["pred"].unique():
        df_pred = df[df["pred"] == pred].copy()

        # set vlm for None entries to grounder
        df_pred.loc[df_pred["vlm"].isna(), "vlm"] = df_pred["grounder"]
        # df_pred.loc[df_pred["vlm"].isna(), "vlm"] = False
        df_pred = df_pred

        predicate_true = df_pred["vlm"].astype(bool)
        predicate_pred = df_pred["grounder"].astype(bool)
        # weight = df_pred["weight"] if "weight" in df_pred else None

        # f1_pos = f1_score(predicate_true, predicate_pred, zero_division=1)
        # f1_neg = f1_score(~predicate_true, ~predicate_pred, zero_division=1)
        f1_avg = f1_score(predicate_true, predicate_pred, average='macro', labels=[0, 1], zero_division=zero_division)

        # slightly prioritize false positives
        # f1_avg = fbeta_score(predicate_true, predicate_pred, average='macro', beta=1)

        eps = 1e-6 # to prefer positives
        metric_per_pred[pred] = f1_avg + eps * np.mean(predicate_pred)
    return metric_per_pred


class PredicateOptimParams:

    def _prepare_dataset(self, pddl_predicate: PDDLPredicate, reply_buffer: ReplyBuffer) -> list[dict]:
        dataset = []
        for state_hash, state in reply_buffer.get_all_states().items():
            if state.similar_state is not None:
                state_hash, sim_state = reply_buffer.get_similar_state(state)
                assert sim_state.predicates is not None
                label_predicates = sim_state.predicates
            else:
                label_predicates = state.predicates
            assert label_predicates is not None, "State must contain predicates"

            filtered_label_predicates = {
                p: p_eval for p, p_eval in label_predicates.items() if p.name == pddl_predicate.name
            }

            if len(filtered_label_predicates) == 0:
                continue # state not evaluated on predicate

            dataset.append({"variables": state.variables, "label": filtered_label_predicates, "state_hash": state_hash})

        # we weight all samples by their occurrence of similar states, so states with many similar samples dont dominate the metric
        # state_count = defaultdict(int)
        # for sample in dataset:
        #     state_count[sample["state_hash"]] += 1

        # for sample in dataset:
        #     sample["weight"] = 1.0 / state_count[sample["state_hash"]]

        return dataset

    def optim(
        self,
        pddl_predicate: PDDLPredicate,
        reply_buffer: ReplyBuffer,
        func: GroundingCallable,
        *,
        epochs: int = 80,
        factor: float = 0.5,
    ):
        """Optimize the grounding of a PDDL predicate.

        :param pddl_predicate: The PDDL predicate to optimize.
        :param reply_buffer:
        :param func: The grounding function to use.
        :param epochs: The number of optimization epochs, defaults to 100.
        :param factor: The factor by which to perturb the hyperparameters, defaults to 0.2.
        """
        logger.info(f"GF-optimization of {pddl_predicate.name}")
        dataset = self._prepare_dataset(pddl_predicate, reply_buffer)

        arity = len(pddl_predicate.definition.terms)
        hps = list(inspect.signature(func.callable).parameters.items())[arity:]
        hps = [p for p in hps if p[1].annotation in (float, int)]
        if len(hps) == 0:
            # nothing to do
            return

        original_vars = {name: param.default for name, param in hps}
        initial_vars = original_vars

        if len(func.hps) > 0:
            initial_vars = func.hps
        # input_vars = copy.deepcopy(initial_vars)
        input_vars = copy.deepcopy(original_vars)

        param_dist = {
            name: (lambda rng, p=param: rng.uniform(p - factor * abs(p), p + factor * abs(p)))
            for name, param in input_vars.items()
        }

        rng = np.random.default_rng(42)

        def get_objective(keywords):
            evals = []
            for sample in dataset:
                label = sample["label"]

                func.update_hps(keywords, save=False)
                pred = func.ground(predicate=pddl_predicate.definition, variables=sample["variables"])
                assert len(label) == len(pred)
                pred_eval = get_predicate_evaluation(pred)

                for p, eval in label.items():
                    # evals.append({"pred": str(p), "grounder": pred_eval[p], "vlm": eval, "weight": sample["weight"]})
                    evals.append({"pred": str(p), "grounder": pred_eval[p], "vlm": eval})

            avg_decision_metric = compute_decision_metric_all(evals, zero_division=0)
            # avg_decision_metric = sum(decision_metric.values()) / len(decision_metric)
            return avg_decision_metric

        # def robustness_margin(keywords, delta=1e-3):
        #     """Estimate robustness margin: how far you can move params before outcome flips."""
        #     base_score = get_objective(keywords)
        #     margins = {}

        #     for name, value in keywords.items():
        #         max_radius = 0.1
        #         delta = max_radius / 5

        #         step = delta
        #         perturbed = keywords.copy()
        #         while step < max_radius:
        #             perturbed[name] = value * (1 + step)
        #             if get_objective(perturbed) != base_score:
        #                 break
        #             step += delta
        #         margins[name] = step

        #     # choose min or avg
        #     return min(margins.values())

        def robustness_from_samples(keywords, results, base_score=None, threshold=1e-6):
            """
            Estimate robustness using already sampled results.
            Returns the min relative parameter change that flips outcome.
            """
            # find the reference score
            if base_score is None:
                for score, kw in results:
                    if kw == keywords:
                        base_score = score
                        break
            if base_score is None:
                raise ValueError("Reference keywords not found in results")

            # measure distances to samples with a different outcome
            distances = []
            for score, kw in results:
                if kw == keywords:
                    continue

                if abs(score - base_score) > threshold:  # outcome differs
                    # relative parameter-wise distance
                    rel_dists = []
                    for name in keywords:
                        v0, v1 = keywords[name], kw[name]
                        nominator = original_vars[name] # use default as nominator, otherwise small values have larger robustness by same absolute change
                        if nominator == 0:
                            rel = abs(v1 - v0)
                        else:
                            # rel = abs(v1 - v0) / abs(v0)
                            rel = abs(v1 - v0) / abs(nominator)
                        rel_dists.append(rel)
                    # robustness is min relative distance across params
                    distances.append(min(rel_dists))

            return min(distances) if distances else float("inf")

        def eval_closeness(keywords):
            return sum(
                abs(keywords[name] - original_vars[name]) / abs(original_vars[name] if original_vars[name] != 0 else 1)
                for name in original_vars
            )

        results = []
        for _ in range(epochs):
            keywords = {name: sampler(rng) for name, sampler in param_dist.items()}

            sum_scores = get_objective(keywords)
            results.append((sum_scores, keywords))

        # Sort results by: score (desc), robustness (desc), closeness (asc)
        best_score = max(results, key=lambda x: x[0])[0]
        # results = [(score, robustness_margin(keywords), keywords) for score, keywords in results if score == best_score]
        results_w_robustness = [(score, robustness_from_samples(keywords, results), keywords) for score, keywords in results if score == best_score]
        best_robustness = max(results_w_robustness, key=lambda x: x[1])[1]
        r_results = [(score, robustness, eval_closeness(keywords), keywords) for score, robustness, keywords in results_w_robustness if robustness >= min(0.01, best_robustness)]
        min_distance = min(r_results, key=lambda x: x[2])[2]

        best_score, best_robust, best_closeness, best_hps = next(filter(lambda x: x[2] == min_distance, r_results))

        init_score = get_objective(initial_vars)
        initial_closeness = eval_closeness(initial_vars)
        initial_robust = robustness_from_samples(initial_vars, results, base_score=init_score)

        # Decide whether to replace initial parameters
        if (
            best_score > init_score
            or (best_score == init_score and best_robust > initial_robust)
            or (best_score == init_score and best_robust == initial_robust and best_closeness < initial_closeness)
        ):
            init_vars_str = ", ".join(f"{k}: {round(v, 3)}" for k, v in initial_vars.items())
            best_hps_str = ", ".join(f"{k}: {round(v, 3)}" for k, v in best_hps.items())
            logger.info(
                f"Hyperparameters for {pddl_predicate.definition.name}:\n"
                f"- Initial: {init_vars_str}, closeness {round(initial_closeness, 3)}, "
                f"robustness {round(initial_robust, 3)}, score {round(init_score, 3)}\n"
                f"- Best:    {best_hps_str}, closeness {round(best_closeness, 3)}, "
                f"robustness {round(best_robust, 3)}, score {round(best_score, 3)}"
            )
            func.update_hps(best_hps)
        else:
            func.update_hps(initial_vars)
