#!/usr/bin/env python3
# answer_extraction.py

import csv, ast, re, json, editdistance
import numpy as np
from tqdm import tqdm
import argparse

INPUT_CSV  = "/path/to/input.csv"        
RESULT_CSV = "/path/to/output.csv"

LETTER_RE  = re.compile(r"\b([A-D])\b", re.I)
CHOICE_RE = re.compile(r"\b([A-D])[\)|\.|\:]")

def clean_pred(pred, option_dict):
    """
    Strip 'Answer:' etc.  Return lowercase string for further checks.
    """
    answer_inds = ['Answer:', '*Answer*:', '*Answer:*', '**Answer:**', '**Answer**:', 'Reasoning:', 'Option']
    for token in answer_inds:
        if token in pred:
            pred = pred.split(token)[-1]
            break
    return pred.strip()

def select_mc_option(pred, option_dict):
    """
    option_dict: {'A': 'braid', 'B':'paid', 'C':'elephant', 'D':'confusable'}
    returns one of 'A','B','C','D' (or 'None')
    """
    letters = list(option_dict.keys())
    values  = list(option_dict.values())

    pred_raw = pred
    pred = clean_pred(pred, option_dict).lower()

     # explicit “none” cues
    none_cues = [
        "none of the above", "none of the options", "no correct answer",
        "none are correct",  "none is correct",    "no option is correct",
        "none of them",      "none apply",         "no valid option"
    ]
    if any(cue in pred for cue in none_cues):
        return "None"

    m = re.search(r"\banswer\s*[:\-]?\s*([A-D])\b", pred_raw, flags=re.I)
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab
        
    m = re.search(r"\boutput\s*[:\-]?\s*([A-D])\b", pred_raw, flags=re.I)
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab
        
    matches = re.findall(r"\boption\s*[:\-]?\s*([A-D])\b", pred_raw, flags=re.I)
    distinct = {m.upper() for m in matches if m.upper() in letters}

    if len(distinct) == 1:                  # only ONE option mentioned
        return distinct.pop()
        
    m = CHOICE_RE.search(pred_raw)
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab
        
    m = LETTER_RE.fullmatch(pred_raw.strip())
    if m:
        lab = m.group(1).upper()
        if lab in letters:
            return lab

    # text equals one of the option VALUES
    for lab, val in option_dict.items():
        if pred == str(val).lower():
            return lab

    # option value contained uniquely in pred
    hits = [lab for lab, val in option_dict.items() if str(val).lower() in pred]
    if len(hits) == 1:
        return hits[0]

    # -------------------------------------------- fuzzy fallback --------------------------------- #
    dists = [editdistance.eval(pred, str(val).lower()) for val in values]
    return letters[int(np.argmin(dists))]

# ---------------------------------------------------------------------------#
def evaluate(in_csv, out_csv):
    correct = 0
    total   = 0
    none = 0
    rows_out = []

    with open(in_csv, newline="", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        for row in tqdm(rdr, desc="scoring"):
            opt_dict   = ast.literal_eval(row["options"])
            model_pred = row["model_answer"]

            pick = select_mc_option(model_pred, opt_dict)
            gold = row["gt_answer"].strip().upper()

            if pick == "None":
                none += 1

            is_ok = pick == gold
            correct += int(is_ok)
            total   += 1
            rows_out.append({**row,
                             "picked": pick,
                             "correct": int(is_ok)})
    
    acc = correct / total if total else 0
    print(f"Accuracy: {acc:.2%}  ({correct}/{total})")
    print(f"None answers: ({none/total:.2%})")
    

    # write details
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=list(rows_out[0].keys()))
        w.writeheader(); w.writerows(rows_out)
    
    print("results saved to", out_csv)

# ---------------------------------------------------------------------------#
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Score Model MC outputs and write parsed CSV."
    )
    parser.add_argument("--in",  dest="in_csv",  default = INPUT_CSV,
                        help="input CSV with model outputs")
    parser.add_argument("--out", dest="out_csv", default = RESULT_CSV,
                        help="output CSV with picked answers & correctness")
    args = parser.parse_args()

    evaluate(args.in_csv, args.out_csv)