from typing import List, Any, Dict
from utils.io_utils import extract_json_from_str, load_file


def clean_debate(
    debate_results: dict[str, Any]
) -> tuple[
    list[float], list[float], dict[str, dict[str, list[float]]], dict[str, bool]
]:
    """
    Clean the debate results.

    :param debate_results: Whole history of debate and (judge) verdict, where we can extract judge's prior, (posterior), and delta
    :type debate_results: dict[str, Any]

    :return:
        priors (list[float]): where we store a list of judge priors
        deltas (list[float]): where we store a list of judge deltas; deltas = posteriors - priors
        results_per_question (dict[str, dict[str, list[float]]]): where we store a list of judge priors, posteriors, and deltas for each question
        skipped_indices (dict[str, bool]): where we store a list of indices that are skipped
    """
    all_priors, all_posts, all_deltas = [], [], []  # a list of probs
    results_per_question = (
        {}
    )  # Save judge intials and finals per question such that we could cal corr per question for sense check.

    skipped_indices = {}
    for idx, debate in enumerate(
        debate_results
    ):  # debate = one debate convo + one judge verdict
        question_key = debate["question"][
            "id"
        ]  # each forecasting question has an assigned id
        results_per_question.setdefault(
            question_key, {"prior": [], "posterior": [], "delta": []}
        )

        # parse judge output
        judge_prior, judge_post = (
            extract_json_from_str(debate["inital_judge_stance"]),
        )
        extract_json_from_str(debate["verdict"])

        # either judge output is None will be skipped
        if not judge_prior or not judge_post:
            skipped_indices[str(idx)] = (
                True  # all skipped cases are to be ignored for truth-labeling too
            )
            continue  # skip current round

        # delta represents how much posterior updates on top of prior
        judge_delta = judge_post - judge_prior
        # add judge prior
        all_priors.append(judge_prior)
        results_per_question[question_key]["prior"].append(judge_prior)

        # add posterior
        all_posts.append(judge_post)
        results_per_question[question_key]["posterior"].append(judge_post)

        # add delta
        all_deltas.append(judge_delta)
        results_per_question[question_key]["delta"].append(judge_delta)

    return all_priors, all_deltas, results_per_question, skipped_indices
