

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List
import re  # added for regex stripping

# ---------------------------------------------------------------------------- #
# Helper functions                                                              #
# ---------------------------------------------------------------------------- #


def normalize(text: str) -> str:
    """Normalize reason strings for loose matching.

    Steps:
    1. Lower-case
    2. Remove leading enumeration like "3. " or "2) "
    3. Trim whitespace & common punctuation
    """
    text = text.strip()
    # Strip enumeration prefix (digits + dot/paren)
    text = re.sub(r"^\s*\d+\s*[\.)-]\s*", "", text)
    return text.lower().strip().strip(".,;! \t\n\r")


def load_annotation_reasons(path: Path) -> Dict[str, List[str]]:
    """Load reaction_annotation.json style mapping video_id -> list[str]"""
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def load_predictions(pred_dir: Path) -> List[tuple[str, str, List[str]]]:
    """Return list of tuples (video_id, predicted_reason, correct_list)."""
    records_all: List[tuple[str, str, List[str]]] = []

    for json_path in pred_dir.glob("*.json"):
        try:
            with json_path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except json.JSONDecodeError as e:
            print(f"[WARN] Failed to parse {json_path.name}: {e}")
            continue

        # Some prediction files are a single object, others are a list of objects.
        if isinstance(data, list):
            recs_in_file = data
        else:
            recs_in_file = [data]

        for rec in recs_in_file:
            if not isinstance(rec, dict):
                print(f"[WARN] Unexpected record type in {json_path.name}: {type(rec).__name__}; skipping.")
                continue

            video_id: str | None = rec.get("video_id")
            if not video_id:
                video_id = json_path.stem  # fallback

            # Get predicted reason key (supports both pipeline variants)
            pred_reason: str | None = rec.get("final_reason") or rec.get("predicted_reason")
            if not pred_reason:
                print(f"[WARN] Missing predicted reason for video {video_id} in {json_path.name}; skipping.")
                continue

            correct_list: List[str] | None = rec.get("correct_reasons")
            records_all.append((video_id, pred_reason, correct_list))

    return records_all


# ------------------------------------------------------------------------- #
# Accuracy helpers
# ------------------------------------------------------------------------- #


def compute_accuracy_records(
    records: List[tuple[str, str, List[str]]],
    annot_map: Dict[str, List[str]] | None,
) -> tuple[int, int]:
    """Return (#correct, total) including duplicates."""
    correct = 0
    total = 0
    for vid, pred, correct_in_rec in records:
        total += 1

        if correct_in_rec is None and annot_map is not None:
            correct_in_rec = annot_map.get(vid)

        if not correct_in_rec:
            continue  # unknown ground truth

        pred_norm = normalize(pred)
        if any(pred_norm == normalize(gt) for gt in correct_in_rec):
            correct += 1
    return correct, total


def compute_accuracy_unique(
    pred_unique: Dict[str, str],
    annot_map: Dict[str, List[str]] | None,
    per_record_gt: Dict[str, List[str]],
) -> tuple[int, int]:
    correct = 0
    total = 0
    for vid, pred in pred_unique.items():
        gt_list = per_record_gt.get(vid)
        if not gt_list and annot_map is not None:
            gt_list = annot_map.get(vid)
        if not gt_list:
            continue
        total += 1
        if any(normalize(pred) == normalize(gt) for gt in gt_list):
            correct += 1
    return correct, total


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate topic prediction accuracy.")
    parser.add_argument("--pred_dir", type=str, required=True, help="Directory with prediction JSON files.")
    parser.add_argument("--annot_file", type=str, default=None, help="Optional reaction_annotation.json for ground truth.")
    parser.add_argument("--output", type=str, default="reason_metric.txt", help="Output metric file name.")
    args = parser.parse_args()

    pred_dir = Path(args.pred_dir)
    out_path = Path(args.output)

    if not pred_dir.is_dir():
        raise NotADirectoryError(f"Prediction directory not found: {pred_dir}")

    annot_map: Dict[str, List[str]] | None = None
    if args.annot_file:
        annot_path = Path(args.annot_file)
        if not annot_path.is_file():
            raise FileNotFoundError(f"Annotation file not found: {annot_path}")
        annot_map = load_annotation_reasons(annot_path)

    # Evaluate each JSON file separately
    def load_single_file(path: Path):
        try:
            with path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] Could not read {path.name}: {e}")
            return []

        recs = data if isinstance(data, list) else [data]
        parsed: List[tuple[str, str, List[str]]] = []
        for rec in recs:
            if not isinstance(rec, dict):
                continue
            vid = rec.get("video_id") or path.stem
            pred_reason = rec.get("final_reason") or rec.get("predicted_reason")
            if not pred_reason:
                continue
            correct_list = rec.get("correct_reasons")
            parsed.append((vid, pred_reason, correct_list))
        return parsed

    with out_path.open("w", encoding="utf-8") as fout:
        for json_path in sorted(pred_dir.glob("*.json")):
            preds = load_single_file(json_path)

            # dup accuracy
            c_dup, t_dup = compute_accuracy_records(preds, annot_map)
            acc_dup = (c_dup / t_dup) if t_dup else 0.0

            # unique accuracy
            unique_map: Dict[str, str] = {}
            per_gt: Dict[str, List[str]] = {}
            for vid, pr, cor in preds:
                unique_map[vid] = pr
                if cor:
                    per_gt[vid] = cor

            c_u, t_u = compute_accuracy_unique(unique_map, annot_map, per_gt)
            acc_u = (c_u / t_u) if t_u else 0.0

            line = f"{json_path.name}: dup {acc_dup:.4f} | unique {acc_u:.4f}"
            print(line)
            fout.write(line + "\n")


if __name__ == "__main__":
    main()
