import csv
import os
import re
import sys

import numpy as np

if __name__ == "__main__" and not __package__:  # @debug
    # Insert the parent directory into sys.path so that the package can be found
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


import json
from collections import defaultdict

from utils.file_utils import find_files


def llm_response_to_scalar(content: str) -> int | float:
    eval_str = parse_eval(content)
    if not eval_str:
        return np.nan
    if eval_str == "SUCCESS":
        return 1
    else:
        return 0


def map_eval_to_score(eval: str) -> int | float:
    # If eval is None, empty, or not provided, return np.nan so that it doesn't count.
    if not eval:
        return np.nan
    if eval == "SUCCESS":
        return 1
    elif eval == "PARTIAL SUCCESS":
        return 0
    elif eval == "PARTIAL FAILURE":
        return 0
    elif eval == "FAILURE":
        return 0
    else:
        return np.nan


def parse_eval(content: str) -> str:
    """
    Extracts the evaluation score (e.g., SUCCESS, FAILURE) from the EVALUATION section.
    """
    match = EVAL_CRITERIA_REGEX.search(content)
    if match:
        return match.group(1).upper()  # Normalize to uppercase
    else:
        return ""


eval_criteria = ["SUCCESS", "PARTIAL SUCCESS", "FAILURE", "PARTIAL FAILURE"]

EVAL_CRITERIA_REGEX = re.compile(
    rf"EVALUATION:\s*({'|'.join(map(re.escape, eval_criteria))})",
    re.IGNORECASE,
)


def get_critique_round_scores(
    json_file_path: str,
    round_idx: int = 1,
) -> dict[int, dict[str, float]]:
    """Given a JSON with scores for the `critic-executor` loops per task ID, return all scores for the `round_idx` round of the loops.
    Args:
        json_file_path (str): the path to the JSON file formatted as {task_id:scores:[{state_idx, score, round}, {state_idx, score, round}, ...]}
        round_idx (int, optional): The i'th round of a critic-agent loop for a single state.

    Returns:
        dict[int, dict[str, float]]: {task_id: {score: score}}

    # NOTE: always return the scores for the first call to the critic.
    """

    # Read the JSON file
    with open(json_file_path, "r") as file:
        data = json.load(file)

    scores = {}

    # Iterate through each task and get the score observed for the `round_idx` round of the first call to the critic.
    for task_id, task_data in data.items():
        if "scores" in task_data:
            for score_entry in task_data["scores"]:
                if score_entry["round"] == round_idx:
                    scores[task_id] = {
                        "score": score_entry["score"],
                    }
                    verifier_raw_response = score_entry["critique_raw_response"]
                    eval = parse_eval(verifier_raw_response)
                    verifier_score = map_eval_to_score(eval)
                    if "critique_raw_response" in score_entry:
                        scores[task_id]["critique_raw_response"] = verifier_raw_response
                        scores[task_id]["eval"] = eval
                        scores[task_id]["verifier_score"] = verifier_score
                    break
    return scores


base_dir = "experiments/gemini-2.5-flash-preview-04-17/no_cot-expert-2025-05-06"

file_pattern = "scores_per_round.json"

all_json_files = find_files(base_dir, file_pattern)

data = json.load(open(all_json_files[0]))

domains = ["shopping", "reddit", "classifieds"]


def get_domain_from_path(path: str) -> str:
    for domain in domains:
        if domain in path:
            return domain
    return ""


if __name__ == "__main__":
    scores_per_domain = defaultdict(list)
    na_count_per_domain = defaultdict(int)  # Track NAs per domain

    for json_file in all_json_files:
        domain = get_domain_from_path(json_file)
        data = get_critique_round_scores(json_file)
        for task_id, task_scores in data.items():
            gold_score = task_scores.get("score")
            verifier_score = task_scores.get("verifier_score")
            # Only consider valid (non-nan) scores
            if not np.isnan(gold_score) and not np.isnan(verifier_score):
                scores_per_domain[domain].append(
                    {
                        "task_id": task_id,
                        "gold_score": gold_score,
                        "verifier_score": verifier_score,
                    }
                )
            else:
                na_count_per_domain[domain] += 1

    # Compute metrics per domain
    for domain, scores in scores_per_domain.items():
        TP = TN = FP = FN = 0
        for entry in scores:
            gold = entry["gold_score"]
            pred = entry["verifier_score"]
            if gold == 1 and pred == 1:
                TP += 1
            elif gold == 0 and pred == 0:
                TN += 1
            elif gold == 0 and pred == 1:
                FP += 1
            elif gold == 1 and pred == 0:
                FN += 1
        total = TP + TN + FP + FN
        TPR = TP / (TP + FN) if (TP + FN) > 0 else float("nan")
        TNR = TN / (TN + FP) if (TN + FP) > 0 else float("nan")
        accuracy = (TP + TN) / total if total > 0 else float("nan")
        print(f"Domain: {domain}")
        print(f"  True Positives: {TP}")
        print(f"  True Negatives: {TN}")
        print(f"  False Positives: {FP}")
        print(f"  False Negatives: {FN}")
        print(f"  True Positive Rate (Recall): {TPR:.3f}")
        print(f"  True Negative Rate (Specificity): {TNR:.3f}")
        print(f"  Accuracy: {accuracy:.3f}")
        print(f"  Number of NAs: {na_count_per_domain[domain]}")
        print()

    # --- Compute overall metrics (across all domains) ---
    all_scores = []
    total_na = 0
    for domain, scores in scores_per_domain.items():
        all_scores.extend(scores)
        total_na += na_count_per_domain[domain]

    TP = TN = FP = FN = 0
    for entry in all_scores:
        gold = entry["gold_score"]
        pred = entry["verifier_score"]
        if gold == 1 and pred == 1:
            TP += 1
        elif gold == 0 and pred == 0:
            TN += 1
        elif gold == 0 and pred == 1:
            FP += 1
        elif gold == 1 and pred == 0:
            FN += 1
    total = TP + TN + FP + FN
    TPR = TP / (TP + FN) if (TP + FN) > 0 else float("nan")
    TNR = TN / (TN + FP) if (TN + FP) > 0 else float("nan")
    accuracy = (TP + TN) / total if total > 0 else float("nan")
    print("Overall:")
    print(f"  True Positives: {TP}")
    print(f"  True Negatives: {TN}")
    print(f"  False Positives: {FP}")
    print(f"  False Negatives: {FN}")
    print(f"  True Positive Rate (Recall): {TPR:.3f}")
    print(f"  True Negative Rate (Specificity): {TNR:.3f}")
    print(f"  Accuracy: {accuracy:.3f}")
    print(f"  Number of NAs: {total_na}")
    print()

    # Write results to CSV
    csv_filename = f"{base_dir}/verifier_evals.csv"
    with open(csv_filename, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["task_id", "domain", "gold_score", "verifier_eval"])
        for domain, scores in scores_per_domain.items():
            for entry in scores:
                task_id = entry["task_id"]
                gold_score = entry["gold_score"]
                # To get verifier_eval, we need to look it up from the original data
                # We'll reload the data for each file to get the eval string
                # (Assumes task_id is unique per domain)
                # If you want to optimize, you can cache this mapping earlier
                for json_file in all_json_files:
                    if get_domain_from_path(json_file) == domain:
                        data = get_critique_round_scores(json_file)
                        if task_id in data:
                            verifier_eval = data[task_id].get("eval", "")
                            break
                writer.writerow([task_id, domain, gold_score, verifier_eval])

    print(f"CSV written to {csv_filename}")
