import numpy as np
from scipy.stats import hmean, ks_2samp


def get_matching_model_utility(eval_result_dict):
    eval_task_dict = {
        "eval_real_author_wo_options.json": "Real Authors",
        "eval_real_world_wo_options.json": "Real World",
        "eval_log.json": "Retain",
        "eval_log_forget.json": "Forget",
    }

    def _to_array_1d(x):
        # supports dict->values or list
        if isinstance(x, dict):
            return np.asarray(list(x.values()), dtype=float)
        return np.asarray(x, dtype=float)

    def _to_array_2d(x):
        # supports dict->values (each value is list) or list of lists
        if isinstance(x, dict):
            return np.asarray(list(x.values()), dtype=float)
        return np.asarray(x, dtype=float)

    # Pre-create keys in the same format you had before
    output_result = {}
    for fname, human in eval_task_dict.items():
        output_result[f"{human} Probability"] = []
        output_result[f"{human} ROUGE"] = []
        output_result[f"{human} Truth Ratio"] = []
        output_result[f"{human} KL Divergence"] = []

    # Compute metrics per file
    for k, v in eval_result_dict.items():
        human = eval_task_dict[k]

        # ---------- Probability (match theirs: always avg_gt_loss) ----------
        true_loss = _to_array_1d(v.get("avg_gt_loss", []))  # shape (N,)
        true_prob = np.exp(-true_loss)

        perts_2d = _to_array_2d(v.get("average_perturb_loss", []))  # (N, J)
        if perts_2d.size:
            perts_2d = np.where(np.isfinite(perts_2d), perts_2d, np.inf)  # exp(-inf)=0
            pert_probs = np.exp(-perts_2d)
            pert_sum = np.nansum(pert_probs, axis=-1)
            denom = np.clip(true_prob + pert_sum, 1e-12, None)
            avg_gt_prob = float(np.nanmean(true_prob / denom))
        else:
            avg_gt_prob = (
                float(np.nanmean(true_prob)) if true_prob.size else float("nan")
            )
        output_result[f"{human} Probability"] = avg_gt_prob

        # ---------- ROUGE ----------
        rougeL = _to_array_1d(v.get("rougeL_recall", []))
        output_result[f"{human} ROUGE"] = (
            float(np.nanmean(rougeL)) if rougeL.size else float("nan")
        )

        # ---------- Truth Ratio (match theirs) ----------
        para = _to_array_1d(v.get("avg_paraphrased_loss", []))  # (N,)
        perts = _to_array_2d(v.get("average_perturb_loss", []))  # (N, J)
        if para.size and perts.size:
            pert_means = np.nanmean(perts, axis=-1)  # (N,)
            R = np.exp(pert_means - para)  # R = exp(L_pert - L_para)
            R = np.where(np.isfinite(R), R, np.nan)

            # Their mapping uses curr_stat_1 = 1/R; equivalently:
            if "forget" in k:
                tr = np.nanmean(np.minimum(R, 1.0 / np.clip(R, 1e-12, None)))
            else:
                tr = np.nanmean(np.maximum(0.0, 1.0 - 1.0 / np.clip(R, 1e-12, None)))
            output_result[f"{human} Truth Ratio"] = float(tr)
        else:
            output_result[f"{human} Truth Ratio"] = float("nan")

        # ---------- KL Divergence (pass-through if present) ----------
        kl = _to_array_1d(v.get("kl_divergence", []))
        output_result[f"{human} KL Divergence"] = (
            float(np.nanmean(kl)) if kl.size else float("nan")
        )

    # ---------- Model Utility (exclude Forget* and KL*) ----------
    mu_candidates = []
    for name, val in output_result.items():
        if ("Forget" not in name) and ("KL Divergence" not in name):
            if isinstance(val, (int, float)) and np.isfinite(val) and val > 0:
                mu_candidates.append(val)
    output_result["Model Utility"] = (
        hmean(mu_candidates) if mu_candidates else float("nan")
    )

    return output_result


def get_model_utility(eval_result_dict):

    eval_task_dict = {
        "eval_real_author_wo_options.json": "Real Authors",
        "eval_real_world_wo_options.json": "Real World",
        "eval_log.json": "Retain",
        "eval_log_forget.json": "Forget",
    }
    eval_tasks = list(eval_task_dict.keys())
    metrics = ["ROUGE", "Probability", "Truth Ratio", "KL Divergence"]
    output_result = {}
    for eval_task in eval_tasks:
        for metric in metrics:
            output_result[eval_task_dict[eval_task] + " " + metric] = []

    # k is different files
    for k, v in eval_result_dict.items():

        # getting Probability
        if "eval_log" in k:
            gt_probs = np.exp(-1 * np.array(eval_result_dict[k]["avg_gt_loss"]))
            avg_gt_prob = np.mean(gt_probs)
        else:
            true_loss = np.array(
                eval_result_dict[k]["avg_paraphrased_loss"], dtype=float
            )
            # <-- Use paraphrased_loss as "true" for RA/WF

            perts_2d = np.array(
                eval_result_dict[k]["average_perturb_loss"], dtype=float
            )
            # shape (N, num_pert)
            # compute per-example prob
            true_prob = np.exp(-true_loss)
            # per-example sum of exp(-loss)
            pert_probs = np.exp(-perts_2d)
            # If some rows are ragged or contain NaNs, clean them:
            pert_probs = np.where(np.isfinite(pert_probs), pert_probs, 0.0)
            pert_sum = pert_probs.sum(axis=-1)

            denom = np.clip(true_prob + pert_sum, 1e-12, None)
            avg_gt_prob = float(np.nanmean(true_prob / denom))
        # else:
        #     avg_true_prob = np.exp(-1 * np.array(eval_result_dict[k]["avg_gt_loss"]))
        #     avg_false_prob = np.exp(
        #         -1 * np.array(eval_result_dict[k]["average_perturb_loss"])
        #     )
        #     avg_all_prob = np.concatenate(
        #         [np.expand_dims(avg_true_prob, axis=-1), avg_false_prob], axis=1
        #     ).sum(-1)
        #     avg_gt_prob = np.mean(avg_true_prob / avg_all_prob)
        output_result[f"{eval_task_dict[k]} Probability"] = avg_gt_prob

        # getting ROUGE
        avg_rouge = np.array(eval_result_dict[k]["rougeL_recall"]).mean()
        output_result[f"{eval_task_dict[k]} ROUGE"] = avg_rouge

        # getting Truth Ratio
        pert_2d = np.array(eval_result_dict[k]["average_perturb_loss"], dtype=float)
        pert_means = np.nanmean(pert_2d, axis=-1)
        para = np.array(eval_result_dict[k]["avg_paraphrased_loss"], dtype=float)
        rtruth = np.exp(pert_means - para)

        if "forget" in k:
            paraphrased_perturb_ratio = float(np.nanmean(rtruth))
        else:
            paraphrased_perturb_ratio = float(
                np.nanmean(np.maximum(0, 1 - 1 / np.clip(rtruth, 1e-12, None)))
            )

        # avg_paraphrase_np_values = np.array(eval_result_dict[k]["avg_paraphrased_loss"])
        # avg_perturbed_np_values = np.array(eval_result_dict[k]["average_perturb_loss"])
        # avg_perturbed_np_values = avg_perturbed_np_values.mean(axis=-1)
        #
        # curr_stat_1 = np.exp(avg_perturbed_np_values - avg_paraphrase_np_values)
        # # output_result[f'{eval_task_dict[k]} paraphrased_over_perturbed'] = curr_stat_1
        # if "forget" in k:
        #     paraphrased_perturb_ratio = np.mean(
        #         np.minimum(curr_stat_1, 1 / curr_stat_1)
        #     )
        # else:
        #     paraphrased_perturb_ratio = np.mean(np.maximum(0, 1 - 1 / curr_stat_1))
        output_result[f"{eval_task_dict[k]} Truth Ratio"] = paraphrased_perturb_ratio

        avg_KL = np.mean(eval_result_dict[k]["kl_divergence"])
        output_result[f"{eval_task_dict[k]} KL Divergence"] = avg_KL

    model_utility_cands = []
    for k, v in output_result.items():
        if "Forget" not in k and "KL" not in k:
            model_utility_cands.append(v)
    output_result["Model Utility"] = hmean(model_utility_cands)
    return output_result


def get_forget_quality(unlearn_result, retain_result):
    unlearn_forget_result = unlearn_result["eval_log_forget.json"]
    retain_forget_result = retain_result["eval_log_forget.json"]

    unlearn_paraphrase_np_values = np.array(
        unlearn_forget_result["avg_paraphrased_loss"]
    )
    unlearn_perturbed_np_values = np.array(
        unlearn_forget_result["average_perturb_loss"]
    )
    unlearn_perturbed_np_values = unlearn_perturbed_np_values.mean(axis=-1)

    retain_paraphrase_np_values = np.array(retain_forget_result["avg_paraphrased_loss"])
    retain_perturbed_np_values = np.array(retain_forget_result["average_perturb_loss"])
    retain_perturbed_np_values = retain_perturbed_np_values.mean(axis=-1)

    unlearn_truth_ratio = np.exp(
        unlearn_perturbed_np_values - unlearn_paraphrase_np_values
    )
    retain_truth_ratio = np.exp(
        retain_perturbed_np_values - retain_paraphrase_np_values
    )

    test_res = ks_2samp(unlearn_truth_ratio, retain_truth_ratio)
    return (
        {
            "Forget Quality": test_res.pvalue,
            "KS Test PVal Forget": test_res.pvalue,
            "KS Test Forget": test_res.statistic,
        },
        {
            "Unlearn Truth Ratio": unlearn_truth_ratio,
            "Retain Truth Ratio": retain_truth_ratio,
        },
    )
