#!/usr/bin/env python3
# answer_extraction_match_modality.py for extracting answers in contradictory setting

import csv, ast, re, json, editdistance
import collections
import numpy as np
from tqdm import tqdm
import argparse

INPUT_CSV  = "/path/to/input.csv"        
RESULT_CSV = "/path/to/output.csv"

LETTER_RE  = re.compile(r"\b([A-D])\b", re.I)
CHOICE_RE = re.compile(r"\b([A-D])[\)|\.|\:]")
  # e.g., "A)", "b:" etc.

def clean_pred(pred, option_dict):
    """
    Strip 'Answer:' etc.  Return lowercase string for further checks.
    """
    answer_inds = ['Answer:', '*Answer*:', '*Answer:*', '**Answer:**', '**Answer**:', 'Reasoning:', 'Option']
    for token in answer_inds:
        if token in pred:
            pred = pred.split(token)[-1]
            break
    return pred.strip()

def select_mc_option(pred, option_dict):
    """
    option_dict: {'A': 'braid', 'B':'paid', 'C':'elephant', 'D':'confusable'}
    returns one of 'A','B','C','D' (or 'None')
    """
    letters = list(option_dict.keys())
    values  = list(option_dict.values())

    pred_raw = pred
    pred = clean_pred(pred, option_dict).lower()

    # explicit “none” cues
    none_cues = [
        "none of the above", "none of the options", "no correct answer",
        "none are correct",  "none is correct",    "no option is correct",
        "none of them",      "none apply",         "no valid option"
    ]
    if any(cue in pred for cue in none_cues):
        return "None"

    m = re.search(r"\banswer\s*[:\-]?\s*([A-D])\b", pred_raw, flags=re.I)
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab
        
    m = re.search(r"\boutput\s*[:\-]?\s*([A-D])\b", pred_raw, flags=re.I)
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab
        
    matches = re.findall(r"\boption\s*[:\-]?\s*([A-D])\b", pred_raw, flags=re.I)
    distinct = {m.upper() for m in matches if m.upper() in letters}

    if len(distinct) == 1:
        return distinct.pop()
        
    m = CHOICE_RE.search(pred_raw)
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab
        
    m = LETTER_RE.fullmatch(pred_raw.strip())
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab

    for lab, val in option_dict.items():
        if pred == str(val).lower():
            return lab

    hits = [lab for lab, val in option_dict.items() if str(val).lower() in pred]
    if len(hits) == 1:
        return hits[0]

    dists = [editdistance.eval(pred, str(val).lower()) for val in values]
    return letters[int(np.argmin(dists))]

def evaluate_modality_choice(in_csv, out_csv):
    counts = collections.Counter()
    rows_out = []

    with open(in_csv, newline="", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        for row in rdr:
            model_pred = row["model_answer"]
            opt_dict = ast.literal_eval(row["options"])
            picked = select_mc_option(model_pred, opt_dict)
            if picked == "None":
                chosen_role = "none"
            else:
                try:
                    role_map = json.loads(row["option_role_map"])
                    chosen_role = role_map.get(picked, "unknown")
                except (KeyError, json.JSONDecodeError):
                    chosen_role = "unknown"

            counts[chosen_role] += 1
            row["picked"] = picked
            row["chosen_role"] = chosen_role
            rows_out.append(row)

    fieldnames = list(rows_out[0].keys())
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader(); w.writerows(rows_out)

    print(f"✓ {in_csv}  →  {out_csv}")

    n = sum(counts.values())
    print("\n=== Model choice distribution over all files ===")
    for role, cnt in counts.items():
        pct = 100 * cnt / n if n else 0
        print(f"{role:11s}: {cnt:5d}  ({pct:5.1f} %)")

    return counts

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate model output in contradictory setting."
    )
    parser.add_argument("--in",  dest="in_csv",  default = INPUT_CSV,
                        help="input CSV with model outputs")
    parser.add_argument("--out", dest="out_csv", default = RESULT_CSV,
                        help="output CSV with picked answers & correctness")
    args = parser.parse_args()

    evaluate_modality_choice(args.in_csv, args.out_csv)