import json
from itertools import product, combinations
import numpy as np

MODEL_NAMES = ["gpt-4", "gemini", "claude-3"]
ANNOTATOR_GROUP_1 = ["A", "B"]
ANNOTATOR_GROUP_2 = ["C", "D", "E", "F", "G"]


if __name__ == "__main__":
    model_perf = {model_name: [] for model_name in MODEL_NAMES}
    agreement_scores = []
    for ANNOTATORS in [ANNOTATOR_GROUP_2, ANNOTATOR_GROUP_1]:
        model_annotator_perf = {}
        for model_name, annotator in product(MODEL_NAMES, ANNOTATORS):
            try:
                pairs = json.load(open(f"{model_name}/{annotator}.json", "r"))
            except:
                continue
            model_annotator_perf[(model_name, annotator)] = {
                "correct": [],
                "selected_index": [],
            }
            for pair in pairs:
                if pair.get("annotated_pref") is None:
                    continue
                ap = pair["annotated_pref"]
                model_annotator_perf[(model_name, annotator)]["correct"].append(
                    pair["pref"][ap] == 0
                )
                model_annotator_perf[(model_name, annotator)]["selected_index"].append(
                    ap
                )

        for (model_name, annotator), perf in model_annotator_perf.items():
            if len(perf["correct"]) == 0:
                continue
            model_perf[model_name].extend(perf["correct"])
        # agreement matrix
        agreement_matrix = np.zeros((len(ANNOTATORS), len(ANNOTATORS)))
        for model_name in MODEL_NAMES:
            for annotator_i, annotator_j in combinations(ANNOTATORS, 2):
                if (
                    len(
                        model_annotator_perf.get((model_name, annotator_i), {}).get(
                            "selected_index", []
                        )
                    )
                    == 0
                    or len(
                        model_annotator_perf.get((model_name, annotator_j), {}).get(
                            "selected_index", []
                        )
                    )
                    == 0
                ):
                    continue
                cutoff = min(
                    len(
                        model_annotator_perf.get((model_name, annotator_i), {}).get(
                            "selected_index", []
                        )
                    ),
                    len(
                        model_annotator_perf.get((model_name, annotator_j), {}).get(
                            "selected_index", []
                        )
                    ),
                )
                agreement_matrix[
                    ANNOTATORS.index(annotator_i), ANNOTATORS.index(annotator_j)
                ] += (
                    np.array(
                        model_annotator_perf.get((model_name, annotator_i), {}).get(
                            "selected_index", []
                        )
                    )[:cutoff]
                    == np.array(
                        model_annotator_perf.get((model_name, annotator_j), {}).get(
                            "selected_index", []
                        )
                    )[:cutoff]
                ).mean()
        # pretty print the matrix with row and column names
        M = sum(len(model_perf[model_name]) > 0 for model_name in MODEL_NAMES)
        agreement_matrix /= M
        for i, j in combinations(ANNOTATORS, 2):
            agreement_scores.append(
                agreement_matrix[ANNOTATORS.index(i), ANNOTATORS.index(j)]
            )

    print("=" * 20, "MODEL PERFORMANCE", "=" * 20)
    global_perf = []
    for model_name in MODEL_NAMES:
        global_perf.extend(model_perf[model_name])
        mean_perf = np.mean(model_perf[model_name])
        sd_perf = np.sqrt(mean_perf * (1 - mean_perf) / len(model_perf[model_name]))
        print(f"{model_name.upper()}: {mean_perf * 100:.2f}% +- {sd_perf * 100:.2f}%")
    mean_perf = np.mean(global_perf)
    sd_perf = np.sqrt(mean_perf * (1 - mean_perf) / len(global_perf))
    print(f"GLOBAL: {mean_perf * 100:.2f}% +- {sd_perf * 100:.2f}%")
    print("=" * 20, "INTER-ANNOTATOR AGREEMENT", "=" * 20)
    mean_agreement = np.mean(agreement_scores)
    sd_agreement = np.sqrt(
        mean_agreement
        * (1 - mean_agreement)
        / (len(agreement_scores) * (len(agreement_scores) - 1) / 2)
    )
    print(
        f"INTER-ANNOTATOR AGREEMENT: {round(100 * mean_agreement, 2)}% +- {round(100 * sd_agreement, 2)}%"
    )
