import os
import numpy as np
from scipy.stats import pearsonr
import utils.path_utils
from utils.io_utils import load_file
from utils.debate_processing_utils import clean_debate
from typing import List, Dict


# Correlation betw initial and updating
def eval(
    all_priors: List[float],
    all_deltas: List[float],
    results_per_question: Dict[str, Dict[str, List]],
):
    """
    Return belief correlation result in batch.

    :param all_priors: list of judge priors
    :type all_priors: list[float]
    :param all_deltas: list of judge deltas
    :type all_deltas: list[float]
    :param results_per_question: dict of judge priors and deltas per question
    :type results_per_question: dict[str, dict[str, list[float]]]

    :return:
        pearson_per_question_r_list: list of pearson r per question
        pearson_per_question_p_list: list of pearson p per question
    """

    # Per question correlation eval for sense check (BC the eventual correlation is between two big vector: initials and finals, while assuming diff betw each question is negligible. )
    pearson_per_question_r_list, pearson_per_question_p_list = [], []
    for question, values in results_per_question.items():
        priors, deltas = values["prior"], values["delta"]

        per_question_r, per_question_p = pearsonr(priors, deltas)
        pearson_per_question_r_list.append(per_question_r)
        pearson_per_question_p_list.append(per_question_p)
        print(
            f"Question {question}: Pearson r = {per_question_r:.3f}; Pearson p = {per_question_p: .3f}"
        )

    # Eventual pearson score between initial and final judge positions
    pearsonr_r_updates, pearsonr_p_updates = pearsonr(all_priors, all_deltas)
    print(
        f"Pearson r (betw all_initials and all_updates):{pearsonr_r_updates:.2f} ({pearsonr_p_updates:.3f})"
    )


if __name__ == "__main__":

    # Access debating data
    # Note: see data/notes/debate_processing_records.md for records of correlation evaluations.
    # ZH: currently just fill in the newest run data file name in here.
    project_root = os.path.dirname(os.path.abspath(__file__))
    debate_file_path = os.path.join(
        project_root,
        "data",
        "runs",
        "run-2025-04-05-144230",
        "conversation-history.json",
    )
    debata_data = load_file(debate_file_path)

    skipped_indices, initial_labels, final_labels, ground_truth = eval(debata_data)
