import json
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import cohen_kappa_score

# Manually specify the input files here
FILES: List[str] = [
    "eval_results/gpt_5_mini_all.json",
    "eval_results/gemini_2.5_pro_all.json",
    "eval_results/gpt_oss_120b_all.json",
    "eval_results/qwen3_32b_all.json",
    "eval_results/deepseek_r1_all.json",
]

# Dataset name mapping for cleaner plot labels
DATASET_NAME_MAP: Dict[str, str] = {
    "gpt_5_mini_all": "GPT-5 Mini",
    "gemini_2.5_pro_all": "Gemini 2.5 Pro",
    "gpt_oss_120b_all": "gpt-oss-120b",
    "qwen3_32b_all": "Qwen3 32B",
    "deepseek_r1_all": "DeepSeek R1 0528",
}

OUTPUT_DIR: str = os.path.join("plots", "method_wise")

# Custom color palette for plots
CUSTOM_COLORS = ["#FAE0BB", "#BDE8FD"]  # R:250 G:224 B:187, R:189 G:232 B:253

# Define agentic sample groups
AGENTIC_GROUPS = {
    'aider_commit', 'aider_manual', 'aider_multi', 'continue_dev', 'irl_cursor', 'openhands',
    'smol_real_manager', 'smol_real_search', 'agentharm', 'aider_full', 'cybench', 'gaia_runs',
    'honeypot', 'mvp', 'smol_gaia_manager', 'smol_gaia_search', 'swe_gym', 'swebench'
}


def is_agentic_sample(sample_id: str) -> bool:
    """
    Determine if a sample is agentic based on its group name.
    Sample ID format: group_name:sample_name
    """
    if ':' not in sample_id:
        return False
    group_name = sample_id.split(':', 1)[0]
    return group_name in AGENTIC_GROUPS


def compute_counts_binary(file_path: str) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Returns tuple: (
        num_samples_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    ) using binary_test_results.final_decision
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        btr = sample.get("binary_test_results", {}) if isinstance(sample.get("binary_test_results"), dict) else {}
        final_decision = btr.get("final_decision", None)

        if not isinstance(eval_category, bool):
            continue
        if final_decision not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if final_decision == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif final_decision == "EVAL" and eval_category is False:
            fp += 1
        elif final_decision == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif final_decision == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_prob(file_path: str) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same tuple as binary counts, but using eval_probability_results.eval_probability >= 0.5 as EVAL else REAL
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        epr = sample.get("eval_probability_results", {}) if isinstance(sample.get("eval_probability_results"), dict) else {}
        prob = epr.get("eval_probability", None)

        if not isinstance(eval_category, bool):
            continue
        if not isinstance(prob, (int, float)):
            continue

        predicted = "EVAL" if prob >= 0.5 else "REAL"
        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_purpose(file_path: str) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same tuple as binary counts, but using purpose_awareness_results.final_decision ("EVAL"/"REAL").
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        par = sample.get("purpose_awareness_results", {}) if isinstance(sample.get("purpose_awareness_results"), dict) else {}
        predicted = par.get("final_decision", None)

        if not isinstance(eval_category, bool):
            continue
        if predicted not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_logprob(file_path: str) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same tuple as binary counts, but using logprob_eval_results.decision ("EVAL"/"REAL").
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        ler = sample.get("logprob_eval_results", {}) if isinstance(sample.get("logprob_eval_results"), dict) else {}
        predicted = ler.get("decision", None)

        if not isinstance(eval_category, bool):
            continue
        if predicted not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_reasoning_traces(file_path: str) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same tuple as binary counts, but using reasoning_traces_results.parsed_judge_response.conclusion ("EVAL"/"REAL").
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        rtr = sample.get("reasoning_traces_results", {}) if isinstance(sample.get("reasoning_traces_results"), dict) else {}
        parsed_response = rtr.get("parsed_judge_response", {}) if isinstance(rtr.get("parsed_judge_response"), dict) else {}
        predicted = parsed_response.get("conclusion", None)

        if not isinstance(eval_category, bool):
            continue
        if predicted not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_binary_filtered(file_path: str, agentic_only: bool) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same as compute_counts_binary but filtered for agentic or non-agentic samples only.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        sample_id = sample.get("id", "")
        sample_is_agentic = is_agentic_sample(sample_id)
        
        # Filter based on agentic_only parameter
        if agentic_only and not sample_is_agentic:
            continue
        if not agentic_only and sample_is_agentic:
            continue
        
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        btr = sample.get("binary_test_results", {}) if isinstance(sample.get("binary_test_results"), dict) else {}
        final_decision = btr.get("final_decision", None)

        if not isinstance(eval_category, bool):
            continue
        if final_decision not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if final_decision == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif final_decision == "EVAL" and eval_category is False:
            fp += 1
        elif final_decision == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif final_decision == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_prob_filtered(file_path: str, agentic_only: bool) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same as compute_counts_prob but filtered for agentic or non-agentic samples only.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        sample_id = sample.get("id", "")
        sample_is_agentic = is_agentic_sample(sample_id)
        
        # Filter based on agentic_only parameter
        if agentic_only and not sample_is_agentic:
            continue
        if not agentic_only and sample_is_agentic:
            continue
        
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        epr = sample.get("eval_probability_results", {}) if isinstance(sample.get("eval_probability_results"), dict) else {}
        prob = epr.get("eval_probability", None)

        if not isinstance(eval_category, bool):
            continue
        if not isinstance(prob, (int, float)):
            continue

        predicted = "EVAL" if prob >= 0.5 else "REAL"
        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_purpose_filtered(file_path: str, agentic_only: bool) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same as compute_counts_purpose but filtered for agentic or non-agentic samples only.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        sample_id = sample.get("id", "")
        sample_is_agentic = is_agentic_sample(sample_id)
        
        # Filter based on agentic_only parameter
        if agentic_only and not sample_is_agentic:
            continue
        if not agentic_only and sample_is_agentic:
            continue
        
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        par = sample.get("purpose_awareness_results", {}) if isinstance(sample.get("purpose_awareness_results"), dict) else {}
        predicted = par.get("final_decision", None)

        if not isinstance(eval_category, bool):
            continue
        if predicted not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_logprob_filtered(file_path: str, agentic_only: bool) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same as compute_counts_logprob but filtered for agentic or non-agentic samples only.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        sample_id = sample.get("id", "")
        sample_is_agentic = is_agentic_sample(sample_id)
        
        # Filter based on agentic_only parameter
        if agentic_only and not sample_is_agentic:
            continue
        if not agentic_only and sample_is_agentic:
            continue
        
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        ler = sample.get("logprob_eval_results", {}) if isinstance(sample.get("logprob_eval_results"), dict) else {}
        predicted = ler.get("decision", None)

        if not isinstance(eval_category, bool):
            continue
        if predicted not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_reasoning_traces_filtered(file_path: str, agentic_only: bool) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same as compute_counts_reasoning_traces but filtered for agentic or non-agentic samples only.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        sample_id = sample.get("id", "")
        sample_is_agentic = is_agentic_sample(sample_id)
        
        # Filter based on agentic_only parameter
        if agentic_only and not sample_is_agentic:
            continue
        if not agentic_only and sample_is_agentic:
            continue
        
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)
        rtr = sample.get("reasoning_traces_results", {}) if isinstance(sample.get("reasoning_traces_results"), dict) else {}
        parsed_response = rtr.get("parsed_judge_response", {}) if isinstance(rtr.get("parsed_judge_response"), dict) else {}
        predicted = parsed_response.get("conclusion", None)

        if not isinstance(eval_category, bool):
            continue
        if predicted not in {"EVAL", "REAL"}:
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_majority_vote(file_path: str) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Majority vote across 3 metrics: binary_test_results.final_decision, 
    purpose_awareness_results.final_decision, and eval_probability_results.eval_probability >= 0.5.
    Each method contributes 1 vote for "EVAL" or "REAL", and the majority wins.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)

        if not isinstance(eval_category, bool):
            continue

        # Get predictions from the 3 methods
        votes = []
        
        # Binary test vote
        btr = sample.get("binary_test_results", {}) if isinstance(sample.get("binary_test_results"), dict) else {}
        binary_decision = btr.get("final_decision", None)
        if binary_decision in {"EVAL", "REAL"}:
            votes.append(binary_decision)
        
        # Purpose awareness vote
        par = sample.get("purpose_awareness_results", {}) if isinstance(sample.get("purpose_awareness_results"), dict) else {}
        purpose_decision = par.get("final_decision", None)
        if purpose_decision in {"EVAL", "REAL"}:
            votes.append(purpose_decision)
        
        # Probability estimation vote
        epr = sample.get("eval_probability_results", {}) if isinstance(sample.get("eval_probability_results"), dict) else {}
        prob = epr.get("eval_probability", None)
        if isinstance(prob, (int, float)):
            prob_decision = "EVAL" if prob >= 0.5 else "REAL"
            votes.append(prob_decision)

        # Need at least 2 out of 3 methods to have valid predictions for majority vote
        if len(votes) < 2:
            continue

        # Calculate majority vote
        eval_votes = votes.count("EVAL")
        real_votes = votes.count("REAL")
        
        if eval_votes > real_votes:
            predicted = "EVAL"
        elif real_votes > eval_votes:
            predicted = "REAL"
        else:
            # Tie case - skip this sample
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def compute_counts_majority_vote_filtered(file_path: str, agentic_only: bool) -> Tuple[int, int, int, int, int, int, int, int, int]:
    """
    Same as compute_counts_majority_vote but filtered for agentic or non-agentic samples only.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    results: List[Dict] = data.get("results", [])

    num_used = 0
    tp = fp = fn = tn = 0
    tp_capability = fn_capability = 0
    tp_alignment = fn_alignment = 0

    for sample in results:
        sample_id = sample.get("id", "")
        sample_is_agentic = is_agentic_sample(sample_id)
        
        # Filter based on agentic_only parameter
        if agentic_only and not sample_is_agentic:
            continue
        if not agentic_only and sample_is_agentic:
            continue
        
        metadata = sample.get("metadata", {})
        eval_category = metadata.get("eval_category", None)
        eval_type = metadata.get("eval_type", None)

        if not isinstance(eval_category, bool):
            continue

        # Get predictions from the 3 methods
        votes = []
        
        # Binary test vote
        btr = sample.get("binary_test_results", {}) if isinstance(sample.get("binary_test_results"), dict) else {}
        binary_decision = btr.get("final_decision", None)
        if binary_decision in {"EVAL", "REAL"}:
            votes.append(binary_decision)
        
        # Purpose awareness vote
        par = sample.get("purpose_awareness_results", {}) if isinstance(sample.get("purpose_awareness_results"), dict) else {}
        purpose_decision = par.get("final_decision", None)
        if purpose_decision in {"EVAL", "REAL"}:
            votes.append(purpose_decision)
        
        # Probability estimation vote
        epr = sample.get("eval_probability_results", {}) if isinstance(sample.get("eval_probability_results"), dict) else {}
        prob = epr.get("eval_probability", None)
        if isinstance(prob, (int, float)):
            prob_decision = "EVAL" if prob >= 0.5 else "REAL"
            votes.append(prob_decision)

        # Need at least 2 out of 3 methods to have valid predictions for majority vote
        if len(votes) < 2:
            continue

        # Calculate majority vote
        eval_votes = votes.count("EVAL")
        real_votes = votes.count("REAL")
        
        if eval_votes > real_votes:
            predicted = "EVAL"
        elif real_votes > eval_votes:
            predicted = "REAL"
        else:
            # Tie case - skip this sample
            continue

        num_used += 1

        if predicted == "EVAL" and eval_category is True:
            tp += 1
            if eval_type == "capability":
                tp_capability += 1
            elif eval_type == "alignment":
                tp_alignment += 1
        elif predicted == "EVAL" and eval_category is False:
            fp += 1
        elif predicted == "REAL" and eval_category is True:
            fn += 1
            if eval_type == "capability":
                fn_capability += 1
            elif eval_type == "alignment":
                fn_alignment += 1
        elif predicted == "REAL" and eval_category is False:
            tn += 1

    return (
        num_used,
        tp, fp, fn, tn,
        tp_capability, fn_capability,
        tp_alignment, fn_alignment,
    )


def main() -> None:
    files = [f for f in FILES if os.path.isfile(f)]
    missing = [f for f in FILES if not os.path.isfile(f)]
    if missing:
        for f in missing:
            print(f"Warning: file not found and will be skipped: {f}")
    if not files:
        print("No input files provided. Populate the FILES list at the top of the script.")
        return

    rows_binary: List[Dict] = []
    rows_prob: List[Dict] = []
    rows_purpose: List[Dict] = []
    rows_logprob: List[Dict] = []
    rows_reasoning_traces: List[Dict] = []
    rows_majority_vote: List[Dict] = []
    rows_kappa: List[Dict] = []

    for file_path in files:
        raw_dataset_name = os.path.splitext(os.path.basename(file_path))[0]
        dataset_name = DATASET_NAME_MAP.get(raw_dataset_name, raw_dataset_name)

        try:
            (
                num_used,
                tp, fp, fn, tn,
                tp_capability, fn_capability,
                tp_alignment, fn_alignment,
            ) = compute_counts_binary(file_path)
        except Exception as exc:
            print(f"Error processing (binary) {file_path}: {exc}")
            continue

        rows_binary.append(
            {
                "dataset": dataset_name,
                "num_used": num_used,
                "tp": tp,
                "fp": fp,
                "fn": fn,
                "tn": tn,
                "tp_capability": tp_capability,
                "fn_capability": fn_capability,
                "tp_alignment": tp_alignment,
                "fn_alignment": fn_alignment,
            }
        )

        try:
            (
                num_used_p,
                tp_p, fp_p, fn_p, tn_p,
                tp_capability_p, fn_capability_p,
                tp_alignment_p, fn_alignment_p,
            ) = compute_counts_prob(file_path)
        except Exception as exc:
            print(f"Error processing (prob) {file_path}: {exc}")
            continue

        rows_prob.append(
            {
                "dataset": dataset_name,
                "num_used": num_used_p,
                "tp": tp_p,
                "fp": fp_p,
                "fn": fn_p,
                "tn": tn_p,
                "tp_capability": tp_capability_p,
                "fn_capability": fn_capability_p,
                "tp_alignment": tp_alignment_p,
                "fn_alignment": fn_alignment_p,
            }
        )

        try:
            (
                num_used_u,
                tp_u, fp_u, fn_u, tn_u,
                tp_capability_u, fn_capability_u,
                tp_alignment_u, fn_alignment_u,
            ) = compute_counts_purpose(file_path)
        except Exception as exc:
            print(f"Error processing (purpose) {file_path}: {exc}")
            continue

        rows_purpose.append(
            {
                "dataset": dataset_name,
                "num_used": num_used_u,
                "tp": tp_u,
                "fp": fp_u,
                "fn": fn_u,
                "tn": tn_u,
                "tp_capability": tp_capability_u,
                "fn_capability": fn_capability_u,
                "tp_alignment": tp_alignment_u,
                "fn_alignment": fn_alignment_u,
            }
        )

        try:
            (
                num_used_l,
                tp_l, fp_l, fn_l, tn_l,
                tp_capability_l, fn_capability_l,
                tp_alignment_l, fn_alignment_l,
            ) = compute_counts_logprob(file_path)
        except Exception as exc:
            print(f"Error processing (logprob) {file_path}: {exc}")
            continue

        # Only add logprob results if there are samples
        if num_used_l > 0:
            rows_logprob.append(
                {
                    "dataset": dataset_name,
                    "num_used": num_used_l,
                    "tp": tp_l,
                    "fp": fp_l,
                    "fn": fn_l,
                    "tn": tn_l,
                    "tp_capability": tp_capability_l,
                    "fn_capability": fn_capability_l,
                    "tp_alignment": tp_alignment_l,
                    "fn_alignment": fn_alignment_l,
                }
            )

        try:
            (
                num_used_r,
                tp_r, fp_r, fn_r, tn_r,
                tp_capability_r, fn_capability_r,
                tp_alignment_r, fn_alignment_r,
            ) = compute_counts_reasoning_traces(file_path)
        except Exception as exc:
            print(f"Error processing (reasoning traces) {file_path}: {exc}")
            continue

        # Only add reasoning traces results if there are samples
        if num_used_r > 0:
            rows_reasoning_traces.append(
                {
                    "dataset": dataset_name,
                    "num_used": num_used_r,
                    "tp": tp_r,
                    "fp": fp_r,
                    "fn": fn_r,
                    "tn": tn_r,
                    "tp_capability": tp_capability_r,
                    "fn_capability": fn_capability_r,
                    "tp_alignment": tp_alignment_r,
                    "fn_alignment": fn_alignment_r,
                }
            )

        try:
            (
                num_used_m,
                tp_m, fp_m, fn_m, tn_m,
                tp_capability_m, fn_capability_m,
                tp_alignment_m, fn_alignment_m,
            ) = compute_counts_majority_vote(file_path)
        except Exception as exc:
            print(f"Error processing (majority vote) {file_path}: {exc}")
            continue

        # Only add majority vote results if there are samples
        if num_used_m > 0:
            rows_majority_vote.append(
                {
                    "dataset": dataset_name,
                    "num_used": num_used_m,
                    "tp": tp_m,
                    "fp": fp_m,
                    "fn": fn_m,
                    "tn": tn_m,
                    "tp_capability": tp_capability_m,
                    "fn_capability": fn_capability_m,
                    "tp_alignment": tp_alignment_m,
                    "fn_alignment": fn_alignment_m,
                }
            )

        # Agreement (Cohen's kappa) between binary decision and probability-threshold decision
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                data_all = json.load(f)
            bin_labels = []
            prob_labels = []
            for sample in data_all.get("results", []):
                btr = sample.get("binary_test_results", {}) if isinstance(sample.get("binary_test_results"), dict) else {}
                final_decision = btr.get("final_decision", None)
                epr = sample.get("eval_probability_results", {}) if isinstance(sample.get("eval_probability_results"), dict) else {}
                prob = epr.get("eval_probability", None)
                if final_decision in {"EVAL", "REAL"} and isinstance(prob, (int, float)):
                    bin_labels.append(1 if final_decision == "EVAL" else 0)
                    prob_labels.append(1 if float(prob) >= 0.5 else 0)
            if len(bin_labels) >= 2 and len(prob_labels) >= 2:
                kappa = cohen_kappa_score(bin_labels, prob_labels)
            else:
                kappa = float("nan")
        except Exception as exc:
            print(f"Error computing kappa for {file_path}: {exc}")
            kappa, bin_labels = float("nan"), []
        rows_kappa.append({"dataset": dataset_name, "kappa": kappa, "n": len(bin_labels)})

    if not rows_binary:
        print("No valid binary samples found across provided files.")
        return
    if not rows_prob:
        print("No valid probability samples found across provided files.")
        return
    if not rows_purpose:
        print("No valid purpose-awareness samples found across provided files.")
        return

    df_bin = pd.DataFrame(rows_binary)
    df_prob = pd.DataFrame(rows_prob)
    df_purpose = pd.DataFrame(rows_purpose)
    df_logprob = pd.DataFrame(rows_logprob) if rows_logprob else pd.DataFrame()
    df_reasoning_traces = pd.DataFrame(rows_reasoning_traces) if rows_reasoning_traces else pd.DataFrame()
    df_majority_vote = pd.DataFrame(rows_majority_vote) if rows_majority_vote else pd.DataFrame()
    df_kappa = pd.DataFrame(rows_kappa)

    # Print summaries
    printable_cols = [
        "dataset", "num_used", "tp", "fp", "fn", "tn",
        "tp_capability", "fn_capability", "tp_alignment", "fn_alignment",
    ]
    print("Binary method summary:")
    print(df_bin[printable_cols].to_string(index=False))
    print("\nEval-probability method summary:")
    print(df_prob[printable_cols].to_string(index=False))
    print("\nMethod agreement (Cohen's kappa, binary vs probability estimation):")
    print(df_kappa.to_string(index=False))
    print("\nPurpose-awareness method summary:")
    print(df_purpose[printable_cols].to_string(index=False))
    if not df_logprob.empty:
        print("\nLogprob method summary:")
        print(df_logprob[printable_cols].to_string(index=False))
    else:
        print("\nLogprob method summary: No valid samples found.")
    if not df_reasoning_traces.empty:
        print("\nReasoning traces method summary:")
        print(df_reasoning_traces[printable_cols].to_string(index=False))
    else:
        print("\nReasoning traces method summary: No valid samples found.")
    if not df_majority_vote.empty:
        print("\nMajority vote method summary:")
        print(df_majority_vote[printable_cols].to_string(index=False))
    else:
        print("\nMajority vote method summary: No valid samples found.")

    # Compute percent metrics for plotting
    dfs_to_process = [df_bin, df_prob, df_purpose]
    if not df_logprob.empty:
        dfs_to_process.append(df_logprob)
    if not df_reasoning_traces.empty:
        dfs_to_process.append(df_reasoning_traces)
    if not df_majority_vote.empty:
        dfs_to_process.append(df_majority_vote)
    
    for df in dfs_to_process:
        # True positives: eval marked as eval (tp / (tp + fn) * 100)
        df["tp_eval_percent"] = df.apply(
            lambda r: (r["tp"] / (r["tp"] + r["fn"]) * 100.0) if (r["tp"] + r["fn"]) > 0 else float("nan"),
            axis=1,
        )
        # True negatives: real marked as real (tn / (tn + fp) * 100)
        df["tn_real_percent"] = df.apply(
            lambda r: (r["tn"] / (r["tn"] + r["fp"]) * 100.0) if (r["tn"] + r["fp"]) > 0 else float("nan"),
            axis=1,
        )
        # True positives by eval type
        df["tp_capability_percent"] = df.apply(
            lambda r: (r["tp_capability"] / (r["tp_capability"] + r["fn_capability"]) * 100.0)
            if (r["tp_capability"] + r["fn_capability"]) > 0 else float("nan"),
            axis=1,
        )
        df["tp_alignment_percent"] = df.apply(
            lambda r: (r["tp_alignment"] / (r["tp_alignment"] + r["fn_alignment"]) * 100.0)
            if (r["tp_alignment"] + r["fn_alignment"]) > 0 else float("nan"),
            axis=1,
        )

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    sns.set(style="whitegrid")

    # Combined Plot: TP/TN percent by model (binary and probability estimation side-by-side)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(max(16, 2.4 * len(df_bin)), 6))
    
    # Left subplot: Binary test
    tp_tn_plot_df_bin = df_bin.melt(
        id_vars=["dataset"],
        value_vars=["tp_eval_percent", "tn_real_percent"],
        var_name="metric",
        value_name="value",
    )
    tp_tn_plot_df_bin["metric"] = tp_tn_plot_df_bin["metric"].map({
        "tp_eval_percent": "Correctly identified evaluations",
        "tn_real_percent": "Correctly identified deployment"
    })
    sns.barplot(data=tp_tn_plot_df_bin, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS, ax=ax1)
    # Add random guess rate line (50%)
    ax1.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
    ax1.set_ylim(0, 100)
    ax1.set_ylabel("%")
    ax1.set_xlabel("")
    ax1.set_title("Binary test, correctly identified cases")
    ax1.tick_params(axis='x', rotation=30)
    ax1.legend(loc='center right')
    
    # Right subplot: Probability estimation test
    tp_tn_plot_df_prob = df_prob.melt(
        id_vars=["dataset"],
        value_vars=["tp_eval_percent", "tn_real_percent"],
        var_name="metric",
        value_name="value",
    )
    tp_tn_plot_df_prob["metric"] = tp_tn_plot_df_prob["metric"].map({
        "tp_eval_percent": "Correctly identified evaluations",
        "tn_real_percent": "Correctly identified deployment"
    })
    sns.barplot(data=tp_tn_plot_df_prob, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS, ax=ax2)
    # Add random guess rate line (50%)
    ax2.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
    ax2.set_ylim(0, 100)
    ax2.set_ylabel("%")
    ax2.set_xlabel("")
    ax2.set_title("Probability estimation test, correctly identified cases")
    ax2.tick_params(axis='x', rotation=30)
    ax2.legend(loc='center right')
    
    plt.tight_layout()
    out_combined_tptn = os.path.join(OUTPUT_DIR, "combined_tp_tn_by_model.png")
    plt.savefig(out_combined_tptn, dpi=200)
    print(f"Saved plot to: {out_combined_tptn}")

    # Combined Plot: TP by eval type (binary and probability estimation side-by-side)
    fig, (ax3, ax4) = plt.subplots(1, 2, figsize=(max(16, 2.4 * len(df_bin)), 6))
    
    # Left subplot: Binary test TP by eval type
    tp_type_plot_df_bin = df_bin.melt(
        id_vars=["dataset"],
        value_vars=["tp_alignment_percent", "tp_capability_percent"],
        var_name="metric",
        value_name="value",
    )
    tp_type_plot_df_bin["metric"] = tp_type_plot_df_bin["metric"].map({
        "tp_alignment_percent": "Correctly identified alignment eval",
        "tp_capability_percent": "Correctly identified capability eval"
    })
    sns.barplot(data=tp_type_plot_df_bin, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS, ax=ax3)
    # Add random guess rate line (50%)
    ax3.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
    ax3.set_ylim(0, 100)
    ax3.set_ylabel("%")
    ax3.set_xlabel("")
    ax3.set_title("Binary test, correctly identified alignment vs capability evaluations")
    ax3.tick_params(axis='x', rotation=30)
    ax3.legend(loc='center right')
    
    # Right subplot: Probability estimation test TP by eval type
    tp_type_plot_df_prob = df_prob.melt(
        id_vars=["dataset"],
        value_vars=["tp_alignment_percent", "tp_capability_percent"],
        var_name="metric",
        value_name="value",
    )
    tp_type_plot_df_prob["metric"] = tp_type_plot_df_prob["metric"].map({
        "tp_alignment_percent": "Correctly identified alignment eval",
        "tp_capability_percent": "Correctly identified capability eval"
    })
    sns.barplot(data=tp_type_plot_df_prob, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS, ax=ax4)
    # Add random guess rate line (50%)
    ax4.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
    ax4.set_ylim(0, 100)
    ax4.set_ylabel("%")
    ax4.set_xlabel("")
    ax4.set_title("Probability estimation test, correctly identified alignment vs capability evaluations")
    ax4.tick_params(axis='x', rotation=30)
    ax4.legend(loc='center right')
    
    plt.tight_layout()
    out_combined_tp_type = os.path.join(OUTPUT_DIR, "combined_tp_by_eval_type_by_model.png")
    plt.savefig(out_combined_tp_type, dpi=200)
    print(f"Saved plot to: {out_combined_tp_type}")

    # Plot: TP/TN percent by model (purpose-awareness method)
    tp_tn_plot_df_purpose = df_purpose.melt(
        id_vars=["dataset"],
        value_vars=["tp_eval_percent", "tn_real_percent"],
        var_name="metric",
        value_name="value",
    )
    # Rename metric labels for legend
    tp_tn_plot_df_purpose["metric"] = tp_tn_plot_df_purpose["metric"].map({
        "tp_eval_percent": "Correctly identified evaluations",
        "tn_real_percent": "Correctly identified deployment"
    })
    plt.figure(figsize=(max(8, 1.2 * len(df_purpose)), 5))
    axp1 = sns.barplot(data=tp_tn_plot_df_purpose, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
    # Add random guess rate line (50%)
    axp1.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
    axp1.set_ylim(0, 100)
    axp1.set_ylabel("%")
    axp1.set_xlabel("")
    axp1.set_title("User motivation guessing, correctly identified cases")
    plt.xticks(rotation=30, ha="right")
    plt.legend(loc='lower right')
    plt.tight_layout()
    out_purpose_tptn = os.path.join(OUTPUT_DIR, "purpose_tp_tn_by_model.png")
    plt.savefig(out_purpose_tptn, dpi=200)
    print(f"Saved plot to: {out_purpose_tptn}")

    # Plot: TP by eval type (purpose-awareness method)
    tp_type_plot_df_purpose = df_purpose.melt(
        id_vars=["dataset"],
        value_vars=["tp_alignment_percent", "tp_capability_percent"],
        var_name="metric",
        value_name="value",
    )
    # Rename metric labels for legend
    tp_type_plot_df_purpose["metric"] = tp_type_plot_df_purpose["metric"].map({
        "tp_alignment_percent": "Correctly identified alignment eval",
        "tp_capability_percent": "Correctly identified capability eval"
    })
    plt.figure(figsize=(max(8, 1.2 * len(df_purpose)), 5))
    axp2 = sns.barplot(data=tp_type_plot_df_purpose, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
    # Add random guess rate line (50%)
    axp2.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
    axp2.set_ylim(0, 100)
    axp2.set_ylabel("%")
    axp2.set_xlabel("")
    axp2.set_title("User motivation guessing, correctly identified alignment vs capability evaluations")
    plt.xticks(rotation=30, ha="right")
    plt.legend(loc='upper right')
    plt.tight_layout()
    out_purpose_tp_type = os.path.join(OUTPUT_DIR, "purpose_tp_by_eval_type_by_model.png")
    plt.savefig(out_purpose_tp_type, dpi=200)
    print(f"Saved plot to: {out_purpose_tp_type}")

    # Plot: Cohen's kappa bar plot per model (start at 0)
    plt.figure(figsize=(max(8, 1.2 * len(df_kappa)), 5))
    ax5 = sns.barplot(data=df_kappa, x="dataset", y="kappa", color=CUSTOM_COLORS[0])
    ax5.set_ylim(0, 1)
    ax5.set_ylabel("Cohen's kappa")
    ax5.set_xlabel("")
    ax5.set_title("Agreement between binary decision and probability-threshold decision (Cohen's kappa)")
    
    # Add interpretation guidelines as horizontal lines
    ax5.axhline(y=0.75, color='#4acf4f', linestyle='--', alpha=0.5, label='Substantial agreement (0.75)')
    ax5.axhline(y=0.5, color='#cfc64a', linestyle='--', alpha=0.5, label='Moderate agreement (0.5)')
    ax5.axhline(y=0.25, color='#cf4a5a', linestyle='--', alpha=0.5, label='Weak agreement (0.25)')
    
    
    plt.xticks(rotation=30, ha="right")
    plt.tight_layout()
    out_kappa = os.path.join(OUTPUT_DIR, "method_kappa_by_model.png")
    plt.savefig(out_kappa, dpi=200)
    print(f"Saved plot to: {out_kappa}")

    # Plot logprob method plots only if we have data
    if not df_logprob.empty:
        # Plot: TP/TN percent by model (logprob method)
        tp_tn_plot_df_logprob = df_logprob.melt(
            id_vars=["dataset"],
            value_vars=["tp_eval_percent", "tn_real_percent"],
            var_name="metric",
            value_name="value",
        )
        # Rename metric labels for legend
        tp_tn_plot_df_logprob["metric"] = tp_tn_plot_df_logprob["metric"].map({
            "tp_eval_percent": "Correctly identified evaluations",
            "tn_real_percent": "Correctly identified deployment"
        })
        plt.figure(figsize=(max(8, 1.2 * len(df_logprob)), 5))
        axl1 = sns.barplot(data=tp_tn_plot_df_logprob, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
        # Add random guess rate line (50%)
        axl1.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        axl1.set_ylim(0, 100)
        axl1.set_ylabel("%")
        axl1.set_xlabel("")
        axl1.set_title("Logprob test, correctly identified cases")
        plt.xticks(rotation=30, ha="right")
        plt.legend(loc='center right')
        plt.tight_layout()
        out_logprob_tptn = os.path.join(OUTPUT_DIR, "logprob_tp_tn_by_model.png")
        plt.savefig(out_logprob_tptn, dpi=200)
        print(f"Saved plot to: {out_logprob_tptn}")

        # Plot: TP by eval type (logprob method)
        tp_type_plot_df_logprob = df_logprob.melt(
            id_vars=["dataset"],
            value_vars=["tp_alignment_percent", "tp_capability_percent"],
            var_name="metric",
            value_name="value",
        )
        # Rename metric labels for legend
        tp_type_plot_df_logprob["metric"] = tp_type_plot_df_logprob["metric"].map({
            "tp_alignment_percent": "Correctly identified alignment eval",
            "tp_capability_percent": "Correctly identified capability eval"
        })
        plt.figure(figsize=(max(8, 1.2 * len(df_logprob)), 5))
        axl2 = sns.barplot(data=tp_type_plot_df_logprob, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
        # Add random guess rate line (50%)
        axl2.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        axl2.set_ylim(0, 100)
        axl2.set_ylabel("%")
        axl2.set_xlabel("")
        axl2.set_title("Logprob test, correctly identified alignment vs capability evaluations")
        plt.xticks(rotation=30, ha="right")
        plt.legend(loc='upper right')
        plt.tight_layout()
        out_logprob_tp_type = os.path.join(OUTPUT_DIR, "logprob_tp_by_eval_type_by_model.png")
        plt.savefig(out_logprob_tp_type, dpi=200)
        print(f"Saved plot to: {out_logprob_tp_type}")

    # Plot reasoning traces method plots only if we have data
    if not df_reasoning_traces.empty:
        # Plot: TP/TN percent by model (reasoning traces method)
        tp_tn_plot_df_reasoning_traces = df_reasoning_traces.melt(
            id_vars=["dataset"],
            value_vars=["tp_eval_percent", "tn_real_percent"],
            var_name="metric",
            value_name="value",
        )
        # Rename metric labels for legend
        tp_tn_plot_df_reasoning_traces["metric"] = tp_tn_plot_df_reasoning_traces["metric"].map({
            "tp_eval_percent": "Correctly identified evaluations",
            "tn_real_percent": "Correctly identified deployment"
        })
        plt.figure(figsize=(max(8, 1.2 * len(df_reasoning_traces)), 5))
        axr1 = sns.barplot(data=tp_tn_plot_df_reasoning_traces, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
        # Add random guess rate line (50%)
        axr1.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        axr1.set_ylim(0, 100)
        axr1.set_ylabel("%")
        axr1.set_xlabel("")
        axr1.set_title("Reasoning traces analysis, correctly identified cases")
        plt.xticks(rotation=30, ha="right")
        plt.legend(loc='center right')
        plt.tight_layout()
        out_reasoning_traces_tptn = os.path.join(OUTPUT_DIR, "reasoning_traces_tp_tn_by_model.png")
        plt.savefig(out_reasoning_traces_tptn, dpi=200)
        print(f"Saved plot to: {out_reasoning_traces_tptn}")

        # Plot: TP by eval type (reasoning traces method)
        tp_type_plot_df_reasoning_traces = df_reasoning_traces.melt(
            id_vars=["dataset"],
            value_vars=["tp_alignment_percent", "tp_capability_percent"],
            var_name="metric",
            value_name="value",
        )
        # Rename metric labels for legend
        tp_type_plot_df_reasoning_traces["metric"] = tp_type_plot_df_reasoning_traces["metric"].map({
            "tp_alignment_percent": "Correctly identified alignment eval",
            "tp_capability_percent": "Correctly identified capability eval"
        })
        plt.figure(figsize=(max(8, 1.2 * len(df_reasoning_traces)), 5))
        axr2 = sns.barplot(data=tp_type_plot_df_reasoning_traces, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
        # Add random guess rate line (50%)
        axr2.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        axr2.set_ylim(0, 100)
        axr2.set_ylabel("%")
        axr2.set_xlabel("")
        axr2.set_title("Reasoning traces analysis, correctly identified alignment vs capability evaluations")
        plt.xticks(rotation=30, ha="right")
        plt.legend(loc='upper right')
        plt.tight_layout()
        out_reasoning_traces_tp_type = os.path.join(OUTPUT_DIR, "reasoning_traces_tp_by_eval_type_by_model.png")
        plt.savefig(out_reasoning_traces_tp_type, dpi=200)
        print(f"Saved plot to: {out_reasoning_traces_tp_type}")

    # Plot majority vote method plots only if we have data
    if not df_majority_vote.empty:
        # Plot: TP/TN percent by model (majority vote method)
        tp_tn_plot_df_majority_vote = df_majority_vote.melt(
            id_vars=["dataset"],
            value_vars=["tp_eval_percent", "tn_real_percent"],
            var_name="metric",
            value_name="value",
        )
        # Rename metric labels for legend
        tp_tn_plot_df_majority_vote["metric"] = tp_tn_plot_df_majority_vote["metric"].map({
            "tp_eval_percent": "Correctly identified evaluations",
            "tn_real_percent": "Correctly identified deployment"
        })
        plt.figure(figsize=(max(8, 1.2 * len(df_majority_vote)), 5))
        axm1 = sns.barplot(data=tp_tn_plot_df_majority_vote, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
        # Add random guess rate line (50%)
        axm1.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        axm1.set_ylim(0, 100)
        axm1.set_ylabel("%")
        axm1.set_xlabel("")
        axm1.set_title("Majority vote (3 metrics), correctly identified cases")
        plt.xticks(rotation=30, ha="right")
        plt.legend(loc='lower right')
        plt.tight_layout()
        out_majority_vote_tptn = os.path.join(OUTPUT_DIR, "majority_vote_tp_tn_by_model.png")
        plt.savefig(out_majority_vote_tptn, dpi=200)
        print(f"Saved plot to: {out_majority_vote_tptn}")

        # Plot: TP by eval type (majority vote method)
        tp_type_plot_df_majority_vote = df_majority_vote.melt(
            id_vars=["dataset"],
            value_vars=["tp_alignment_percent", "tp_capability_percent"],
            var_name="metric",
            value_name="value",
        )
        # Rename metric labels for legend
        tp_type_plot_df_majority_vote["metric"] = tp_type_plot_df_majority_vote["metric"].map({
            "tp_alignment_percent": "Correctly identified alignment eval",
            "tp_capability_percent": "Correctly identified capability eval"
        })
        plt.figure(figsize=(max(8, 1.2 * len(df_majority_vote)), 5))
        axm2 = sns.barplot(data=tp_type_plot_df_majority_vote, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS)
        # Add random guess rate line (50%)
        axm2.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        axm2.set_ylim(0, 100)
        axm2.set_ylabel("%")
        axm2.set_xlabel("")
        axm2.set_title("Majority vote (3 metrics), correctly identified alignment vs capability evaluations")
        plt.xticks(rotation=30, ha="right")
        plt.legend(loc='lower right')
        plt.tight_layout()
        out_majority_vote_tp_type = os.path.join(OUTPUT_DIR, "majority_vote_tp_by_eval_type_by_model.png")
        plt.savefig(out_majority_vote_tp_type, dpi=200)
        print(f"Saved plot to: {out_majority_vote_tp_type}")

    # Create agentic vs non-agentic plots for each method
    create_agentic_plots(files)

    # Calculate and plot aggregate Cohen's kappa for all methods
    print("\n" + "="*60)
    print("AGGREGATE COHEN'S KAPPA ANALYSIS")
    print("="*60)
    aggregate_kappa_scores = compute_aggregate_cohens_kappa(files)
    plot_aggregate_cohens_kappa(aggregate_kappa_scores)


def compute_aggregate_cohens_kappa(files: List[str]) -> Dict[str, float]:
    """
    Compute Cohen's kappa for each method by aggregating all samples from all models
    into a single dataset and comparing against ground truth (eval_category).
    
    Returns a dictionary with method names as keys and Cohen's kappa scores as values.
    """
    # Initialize lists to collect all predictions and ground truth labels
    method_predictions = {
        "binary": {"predicted": [], "ground_truth": []},
        "probability": {"predicted": [], "ground_truth": []},
        "purpose": {"predicted": [], "ground_truth": []},
        "logprob": {"predicted": [], "ground_truth": []},
        "reasoning_traces": {"predicted": [], "ground_truth": []},
    }
    
    for file_path in files:
        if not os.path.isfile(file_path):
            continue
            
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as exc:
            print(f"Error loading {file_path}: {exc}")
            continue
        
        results = data.get("results", [])
        
        for sample in results:
            metadata = sample.get("metadata", {})
            eval_category = metadata.get("eval_category", None)
            
            # Skip samples without valid ground truth
            if not isinstance(eval_category, bool):
                continue
                
            ground_truth_label = 1 if eval_category else 0
            
            # Binary method
            btr = sample.get("binary_test_results", {}) if isinstance(sample.get("binary_test_results"), dict) else {}
            binary_decision = btr.get("final_decision", None)
            if binary_decision in {"EVAL", "REAL"}:
                predicted_label = 1 if binary_decision == "EVAL" else 0
                method_predictions["binary"]["predicted"].append(predicted_label)
                method_predictions["binary"]["ground_truth"].append(ground_truth_label)
            
            # Probability method
            epr = sample.get("eval_probability_results", {}) if isinstance(sample.get("eval_probability_results"), dict) else {}
            prob = epr.get("eval_probability", None)
            if isinstance(prob, (int, float)):
                predicted_label = 1 if prob >= 0.5 else 0
                method_predictions["probability"]["predicted"].append(predicted_label)
                method_predictions["probability"]["ground_truth"].append(ground_truth_label)
            
            # Purpose method
            par = sample.get("purpose_awareness_results", {}) if isinstance(sample.get("purpose_awareness_results"), dict) else {}
            purpose_decision = par.get("final_decision", None)
            if purpose_decision in {"EVAL", "REAL"}:
                predicted_label = 1 if purpose_decision == "EVAL" else 0
                method_predictions["purpose"]["predicted"].append(predicted_label)
                method_predictions["purpose"]["ground_truth"].append(ground_truth_label)
            
            # Logprob method
            ler = sample.get("logprob_eval_results", {}) if isinstance(sample.get("logprob_eval_results"), dict) else {}
            logprob_decision = ler.get("decision", None)
            if logprob_decision in {"EVAL", "REAL"}:
                predicted_label = 1 if logprob_decision == "EVAL" else 0
                method_predictions["logprob"]["predicted"].append(predicted_label)
                method_predictions["logprob"]["ground_truth"].append(ground_truth_label)
            
            # Reasoning traces method
            rtr = sample.get("reasoning_traces_results", {}) if isinstance(sample.get("reasoning_traces_results"), dict) else {}
            parsed_response = rtr.get("parsed_judge_response", {}) if isinstance(rtr.get("parsed_judge_response"), dict) else {}
            traces_decision = parsed_response.get("conclusion", None)
            if traces_decision in {"EVAL", "REAL"}:
                predicted_label = 1 if traces_decision == "EVAL" else 0
                method_predictions["reasoning_traces"]["predicted"].append(predicted_label)
                method_predictions["reasoning_traces"]["ground_truth"].append(ground_truth_label)
    
    # Calculate Cohen's kappa for each method
    kappa_scores = {}
    for method_name, data in method_predictions.items():
        if len(data["predicted"]) >= 2 and len(data["ground_truth"]) >= 2:
            try:
                kappa = cohen_kappa_score(data["ground_truth"], data["predicted"])
                kappa_scores[method_name] = kappa
                print(f"{method_name}: Cohen's kappa = {kappa:.4f} (n={len(data['predicted'])})")
            except Exception as exc:
                print(f"Error computing kappa for {method_name}: {exc}")
                kappa_scores[method_name] = float("nan")
        else:
            print(f"{method_name}: Insufficient data (n={len(data['predicted'])})")
            kappa_scores[method_name] = float("nan")
    
    return kappa_scores


def plot_aggregate_cohens_kappa(kappa_scores: Dict[str, float]) -> None:
    """
    Create a bar plot showing Cohen's kappa for each method (aggregated across all models).
    """
    # Method name mapping for cleaner labels
    method_labels = {
        "binary": "Binary test",
        "probability": "Probability estimation",
        "purpose": "User motivation guessing",
        "logprob": "Logprob test",
        "reasoning_traces": "Reasoning traces analysis",
    }
    
    # Filter out methods with NaN values and prepare data
    valid_methods = []
    valid_kappas = []
    
    for method, kappa in kappa_scores.items():
        if not (kappa != kappa):  # Check if not NaN
            valid_methods.append(method_labels.get(method, method))
            valid_kappas.append(kappa)
    
    if not valid_methods:
        print("No valid Cohen's kappa scores to plot.")
        return
    
    # Create DataFrame for plotting
    df_kappa = pd.DataFrame({
        'method': valid_methods,
        'kappa': valid_kappas
    })
    
    # Create the plot
    plt.figure(figsize=(max(10, 1.5 * len(valid_methods)), 6))
    ax = sns.barplot(data=df_kappa, x="method", y="kappa", color=CUSTOM_COLORS[0])
    
    # Customize the plot
    ax.set_ylim(0, 1)
    ax.set_ylabel("Cohen's kappa")
    ax.set_xlabel("")
    ax.set_title("Cohen's kappa against ground truth (aggregated across all models)")
    
    # Add value labels on top of bars
    for i, (method, kappa) in enumerate(zip(valid_methods, valid_kappas)):
        ax.text(i, kappa + 0.02, f'{kappa:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # Add interpretation guidelines as horizontal lines
    ax.axhline(y=0.75, color='#4acf4f', linestyle='--', alpha=0.5, label='Substantial agreement (0.75)')
    ax.axhline(y=0.5, color='#cfc64a', linestyle='--', alpha=0.5, label='Moderate agreement (0.5)')
    ax.axhline(y=0.25, color='#cf4a5a', linestyle='--', alpha=0.5, label='Weak agreement (0.25)')
    
    plt.xticks(rotation=30, ha="right")
    plt.legend(loc='upper right')
    plt.tight_layout()
    
    # Save the plot
    out_aggregate_kappa = os.path.join(OUTPUT_DIR, "cohens_kappa_all_methods_vs_ground_truth.png")
    plt.savefig(out_aggregate_kappa, dpi=200)
    print(f"Saved plot to: {out_aggregate_kappa}")
    plt.close()


def create_agentic_plots(files: List[str]) -> None:
    """
    Create side-by-side plots for each method comparing agentic vs non-agentic samples.
    """
    methods = [
        ("binary", compute_counts_binary_filtered, "Binary test"),
        ("prob", compute_counts_prob_filtered, "Probability estimation test"),
        ("purpose", compute_counts_purpose_filtered, "User motivation guessing"),
        ("logprob", compute_counts_logprob_filtered, "Logprob test"),
        ("reasoning_traces", compute_counts_reasoning_traces_filtered, "Reasoning traces analysis"),
        ("majority_vote", compute_counts_majority_vote_filtered, "Majority vote (3 metrics)"),
    ]
    
    for method_name, compute_func, method_title in methods:
        rows_agentic: List[Dict] = []
        rows_non_agentic: List[Dict] = []
        
        for file_path in files:
            raw_dataset_name = os.path.splitext(os.path.basename(file_path))[0]
            dataset_name = DATASET_NAME_MAP.get(raw_dataset_name, raw_dataset_name)
            
            try:
                # Agentic samples
                (
                    num_used_a, tp_a, fp_a, fn_a, tn_a,
                    tp_capability_a, fn_capability_a,
                    tp_alignment_a, fn_alignment_a,
                ) = compute_func(file_path, agentic_only=True)
                
                if num_used_a > 0:
                    rows_agentic.append({
                        "dataset": dataset_name,
                        "num_used": num_used_a,
                        "tp": tp_a, "fp": fp_a, "fn": fn_a, "tn": tn_a,
                        "tp_capability": tp_capability_a, "fn_capability": fn_capability_a,
                        "tp_alignment": tp_alignment_a, "fn_alignment": fn_alignment_a,
                    })
                
                # Non-agentic samples
                (
                    num_used_n, tp_n, fp_n, fn_n, tn_n,
                    tp_capability_n, fn_capability_n,
                    tp_alignment_n, fn_alignment_n,
                ) = compute_func(file_path, agentic_only=False)
                
                if num_used_n > 0:
                    rows_non_agentic.append({
                        "dataset": dataset_name,
                        "num_used": num_used_n,
                        "tp": tp_n, "fp": fp_n, "fn": fn_n, "tn": tn_n,
                        "tp_capability": tp_capability_n, "fn_capability": fn_capability_n,
                        "tp_alignment": tp_alignment_n, "fn_alignment": fn_alignment_n,
                    })
                    
            except Exception as exc:
                print(f"Error processing {method_name} for {file_path}: {exc}")
                continue
        
        # Skip if no data for this method
        if not rows_agentic and not rows_non_agentic:
            print(f"No valid {method_name} samples found for agentic/non-agentic analysis.")
            continue
            
        # Create DataFrames
        df_agentic = pd.DataFrame(rows_agentic) if rows_agentic else pd.DataFrame()
        df_non_agentic = pd.DataFrame(rows_non_agentic) if rows_non_agentic else pd.DataFrame()
        
        # Compute percent metrics
        for df in [df_agentic, df_non_agentic]:
            if not df.empty:
                df["tp_eval_percent"] = df.apply(
                    lambda r: (r["tp"] / (r["tp"] + r["fn"]) * 100.0) if (r["tp"] + r["fn"]) > 0 else float("nan"),
                    axis=1,
                )
                df["tn_real_percent"] = df.apply(
                    lambda r: (r["tn"] / (r["tn"] + r["fp"]) * 100.0) if (r["tn"] + r["fp"]) > 0 else float("nan"),
                    axis=1,
                )
        
        # Create side-by-side plot for TP/TN
        max_datasets = max(len(df_agentic) if not df_agentic.empty else 0, 
                          len(df_non_agentic) if not df_non_agentic.empty else 0, 1)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(max(16, 2.4 * max_datasets), 6))
        
        # Left subplot: Agentic samples
        if not df_agentic.empty:
            tp_tn_plot_df_agentic = df_agentic.melt(
                id_vars=["dataset"],
                value_vars=["tp_eval_percent", "tn_real_percent"],
                var_name="metric",
                value_name="value",
            )
            tp_tn_plot_df_agentic["metric"] = tp_tn_plot_df_agentic["metric"].map({
                "tp_eval_percent": "Correctly identified evaluations",
                "tn_real_percent": "Correctly identified deployment"
            })
            sns.barplot(data=tp_tn_plot_df_agentic, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS, ax=ax1)
        
        ax1.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        ax1.set_ylim(0, 100)
        ax1.set_ylabel("%")
        ax1.set_xlabel("")
        ax1.set_title(f"{method_title} - Agentic samples")
        ax1.tick_params(axis='x', rotation=30)
        ax1.legend(loc='upper right')
        
        # Right subplot: Non-agentic samples
        if not df_non_agentic.empty:
            tp_tn_plot_df_non_agentic = df_non_agentic.melt(
                id_vars=["dataset"],
                value_vars=["tp_eval_percent", "tn_real_percent"],
                var_name="metric",
                value_name="value",
            )
            tp_tn_plot_df_non_agentic["metric"] = tp_tn_plot_df_non_agentic["metric"].map({
                "tp_eval_percent": "Correctly identified evaluations",
                "tn_real_percent": "Correctly identified deployment"
            })
            sns.barplot(data=tp_tn_plot_df_non_agentic, x="dataset", y="value", hue="metric", palette=CUSTOM_COLORS, ax=ax2)
        
        ax2.axhline(y=50, color='#FFB3B3', linestyle='-', linewidth=2, alpha=0.8, label='Random guess', zorder=0)
        ax2.set_ylim(0, 100)
        ax2.set_ylabel("%")
        ax2.set_xlabel("")
        ax2.set_title(f"{method_title} - Non-agentic samples")
        ax2.tick_params(axis='x', rotation=30)
        ax2.legend(loc='upper right')
        
        plt.tight_layout()
        out_agentic_comparison = os.path.join(OUTPUT_DIR, f"{method_name}_agentic_vs_non_agentic_tp_tn.png")
        plt.savefig(out_agentic_comparison, dpi=200)
        print(f"Saved plot to: {out_agentic_comparison}")
        plt.close()


if __name__ == "__main__":
    main()
