{
 "cells": [
  {
   "cell_type": "code",
   "id": "709b0cf0-3da5-41ba-9082-d866676c19fe",
   "metadata": {},
   "source": [
    "from typing import List, Tuple, Dict\n",
    "from bert_score import score as bert_score\n",
    "from collections import Counter\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "def is_position_match(pred_pos: Tuple[int, int], gold_pos: Tuple[int, int]) -> bool:\n",
    "    return pred_pos == gold_pos \n",
    "\n",
    "def evaluate_missing_prediction(\n",
    "    gold: Dict[str, List],\n",
    "    pred: Dict[str, List],\n",
    "    model_type: str = \"roberta-large\",  \n",
    "    lang: str = \"en\", \n",
    ") -> Dict[str, float]:\n",
    "\n",
    "    gold_positions = gold[\"positions\"]\n",
    "    gold_texts = gold[\"texts\"]\n",
    "    pred_positions = pred[\"positions\"]\n",
    "    pred_texts = pred[\"texts\"]\n",
    "\n",
    "    if len(gold_positions) == 0:\n",
    "        if len(pred_positions) == 0:\n",
    "            return {\n",
    "                \"precision_pos\": 1.0,\n",
    "                \"recall_pos\": 1.0,\n",
    "                \"f1_pos\": 1.0,\n",
    "                \"redundancy_rate\": 0.0,\n",
    "                \"text_score_position_aware\": 1.0 \n",
    "            }\n",
    "        else:\n",
    "            return {\n",
    "                \"precision_pos\": 0.0,\n",
    "                \"recall_pos\": 1.0,\n",
    "                \"f1_pos\": 0.0,\n",
    "                \"redundancy_rate\": 1.0,\n",
    "                \"text_score_position_aware\": 0.0\n",
    "            }\n",
    "\n",
    "    matched = []\n",
    "    unmatched_pred_indices = set(range(len(pred_positions)))\n",
    "    matched_texts_pred = []\n",
    "    matched_texts_gold = []\n",
    "\n",
    "    for i, gpos in enumerate(gold_positions):\n",
    "        matched_flag = False\n",
    "        for j, ppos in enumerate(pred_positions):\n",
    "            if is_position_match(ppos, gpos):\n",
    "                matched.append((i, j))\n",
    "                matched_texts_gold.append(gold_texts[i])\n",
    "                matched_texts_pred.append(pred_texts[j])\n",
    "                unmatched_pred_indices.discard(j)\n",
    "                matched_flag = True\n",
    "                break  \n",
    "        if not matched_flag:\n",
    "            matched_texts_gold.append(gold_texts[i])\n",
    "            matched_texts_pred.append(\"\")\n",
    "\n",
    "    tp = len(matched)\n",
    "    fp = len(unmatched_pred_indices)\n",
    "    fn = len(gold_positions) - tp\n",
    "\n",
    "    precision_pos = tp / (tp + fp) if (tp + fp) > 0 else 0.0\n",
    "    recall_pos = tp / (tp + fn) if (tp + fn) > 0 else 0.0\n",
    "    f1_pos = (2 * precision_pos * recall_pos / (precision_pos + recall_pos)) if (precision_pos + recall_pos) > 0 else 0.0\n",
    "    redundancy_rate = fp / (tp + fp) if (tp + fp) > 0 else 0.0\n",
    "\n",
    "    P, R, F1 = bert_score(matched_texts_pred, matched_texts_gold, lang=lang, model_type=model_type, verbose=False)\n",
    "    text_score = float(F1.mean().item()) \n",
    "\n",
    "    return {\n",
    "        \"precision_pos\": round(precision_pos, 4),\n",
    "        \"recall_pos\": round(recall_pos, 4),\n",
    "        \"f1_pos\": round(f1_pos, 4),\n",
    "        \"redundancy_rate\": round(redundancy_rate, 4),\n",
    "        \"text_score_position_aware\": round(text_score, 4),\n",
    "    }\n",
    "from bert_score import score as bert_score\n",
    "from tqdm import tqdm\n",
    "\n",
    "def batch_evaluate(\n",
    "    gold_list: List[Dict[str, List]],\n",
    "    pred_list: List[Dict[str, List]],\n",
    "    model_type: str = \"roberta-large\",\n",
    "    lang: str = \"en\",\n",
    "    show_progress: bool = True,\n",
    ") -> Dict[str, float]:\n",
    "    assert len(gold_list) == len(pred_list), \"gold/pred inconsistent\"\n",
    "\n",
    "    metrics_sum = {\n",
    "        \"precision_pos\": 0.0,\n",
    "        \"recall_pos\": 0.0,\n",
    "        \"f1_pos\": 0.0,\n",
    "        \"redundancy_rate\": 0.0,\n",
    "    }\n",
    "\n",
    "    all_gold_texts = []\n",
    "    all_pred_texts = []\n",
    "    per_sample_matched_counts = []\n",
    "    n = len(gold_list)\n",
    "\n",
    "    iterable = zip(gold_list, pred_list)\n",
    "    if show_progress:\n",
    "        iterable = tqdm(iterable, total=n, desc=\"Evaluating\")\n",
    "\n",
    "    for gold, pred in iterable:\n",
    "        gold_pos, pred_pos = gold[\"positions\"], pred[\"positions\"]\n",
    "        gold_texts, pred_texts = gold[\"texts\"], pred[\"texts\"]\n",
    "\n",
    "        matched = []\n",
    "        for i, g in enumerate(gold_pos):\n",
    "            for j, p in enumerate(pred_pos):\n",
    "                if g == p:\n",
    "                    matched.append((i, j))\n",
    "                    break\n",
    "\n",
    "        matched_gold_texts = [gold_texts[i] for i, _ in matched]\n",
    "        matched_pred_texts = [pred_texts[j] for _, j in matched]\n",
    "\n",
    "        all_gold_texts.extend(matched_gold_texts)\n",
    "        all_pred_texts.extend(matched_pred_texts)\n",
    "        per_sample_matched_counts.append(len(matched))\n",
    "\n",
    "        true_positives = len(matched)\n",
    "        precision_pos = true_positives / len(pred_pos) if pred_pos else 0.0\n",
    "        recall_pos = true_positives / len(gold_pos) if gold_pos else 0.0\n",
    "        f1_pos = (\n",
    "            2 * precision_pos * recall_pos / (precision_pos + recall_pos)\n",
    "            if precision_pos + recall_pos > 0 else 0.0\n",
    "        )\n",
    "        redundancy_rate = (len(pred_pos) - true_positives) / len(pred_pos) if pred_pos else 0.0\n",
    "\n",
    "        metrics_sum[\"precision_pos\"] += precision_pos\n",
    "        metrics_sum[\"recall_pos\"] += recall_pos\n",
    "        metrics_sum[\"f1_pos\"] += f1_pos\n",
    "        metrics_sum[\"redundancy_rate\"] += redundancy_rate\n",
    "\n",
    "    if all_gold_texts:\n",
    "        _, _, F1 = bert_score(\n",
    "            all_pred_texts, all_gold_texts,\n",
    "            lang=lang, model_type=model_type, verbose=False, device=\"cuda\"\n",
    "        )\n",
    "\n",
    "        f1_list = F1.tolist()\n",
    "        idx = 0\n",
    "        sample_text_scores = []\n",
    "        for count in per_sample_matched_counts:\n",
    "            if count == 0:\n",
    "                sample_text_scores.append(None)\n",
    "            else:\n",
    "                sample_f1 = sum(f1_list[idx:idx + count]) / len(gold_list[sample_text_scores.__len__()][\"texts\"]) \n",
    "                sample_text_scores.append(sample_f1)\n",
    "                idx += count\n",
    "        avg_text_score = round(sum(f for f in sample_text_scores if f is not None) / n, 4)\n",
    "    else:\n",
    "        avg_text_score = None\n",
    "\n",
    "    avg_metrics = {k: round(v / n, 4) for k, v in metrics_sum.items()}\n",
    "    avg_metrics[\"text_score_position_aware\"] = avg_text_score\n",
    "\n",
    "    return avg_metrics\n"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c2e228e6-8dd4-4d1b-9f68-d31523ed69bd",
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "df = pd.read_json(\"ScaleQM+_test.json\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "96c3c70d-7bf8-4ba3-aab3-4833899be1b5",
   "metadata": {},
   "source": [
    "import re\n",
    "import pickle\n",
    "from collections import OrderedDict\n",
    "\n",
    "def extract_steps(text):\n",
    "    pattern = r\"(?:Step\\s*|step)(\\d+):\\n(.*?)(?=(?:Step\\s*|step)\\d+:\\n|</incomplete_solution>)\"\n",
    "    matches = re.finditer(pattern, text, re.DOTALL | re.IGNORECASE)\n",
    "\n",
    "    steps = OrderedDict()\n",
    "    for match in matches:\n",
    "        step_num = match.group(1)\n",
    "        step_content = match.group(2).strip()\n",
    "        step_label = f\"step{step_num}\"\n",
    "        steps[step_label] = step_content\n",
    "\n",
    "    return steps\n",
    "\n",
    "cnt = 0\n",
    "for i in range(len(df)):\n",
    "    steps= extract_steps(df.iloc[i][\"messages\"][1][\"content\"])\n",
    "    cnt = cnt + len(steps) - 1\n",
    "print(cnt)\n",
    "\n",
    "with open('results-sim.pkl', 'rb') as f:\n",
    "    loaded_results = pickle.load(f)\n",
    "print(len(loaded_results))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "0dafcd7f-ca03-4050-a03a-352ee58ae30b",
   "metadata": {},
   "source": [
    "# use next step as similarity reference\n",
    "t = []\n",
    "for i in range(len(df)):\n",
    "    temp = list(extract_steps(df.iloc[i][\"messages\"][1][\"content\"]).values())\n",
    "    for j in range(len(temp)):\n",
    "        if j==0:\n",
    "            continue\n",
    "        t.append(temp[j])\n",
    "print(len(t))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "3c303ba0-d1ae-47dc-8e2c-3257fc7cda35",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "from bert_score import score as bert_score\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "P, R, F1 = bert_score(loaded_results, t, lang='en', model_type=\"roberta-large\", verbose=False, device=\"cuda\")\n",
    "f1_list = F1.tolist()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d488ecde-bbca-42c6-acfb-2c56396651e9",
   "metadata": {},
   "source": [
    "cnt = 0\n",
    "pred_list = []\n",
    "for i in range(len(df)):\n",
    "    temp = list(extract_steps(df.iloc[i][\"messages\"][1][\"content\"]).values())\n",
    "    predict = {\n",
    "        \"positions\": [],\n",
    "        \"texts\": []\n",
    "    }\n",
    "    for j in range(cnt, cnt + len(temp)-1):\n",
    "        if f1_list[j] < 0.95: # Threshold\n",
    "            predict[\"positions\"].append((j-cnt, j-cnt+1))\n",
    "            predict[\"texts\"].append(loaded_results[j])\n",
    "    pred_list.append(predict)\n",
    "    cnt += len(temp)-1"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e64c1e13-fc6f-461f-a549-c983652f2f58",
   "metadata": {},
   "source": [
    "import pickle\n",
    "with open('gold.pkl', 'rb') as f:\n",
    "    gold = pickle.load(f)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "ea7751a1-52ed-46ec-bd83-5d90604dea1d",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "results = batch_evaluate(gold, pred_list)\n",
    "print(results)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c4a34026-80c2-4a0d-a295-aed74c64b286",
   "metadata": {},
   "source": [
    "# results\n",
    "# 1    {'precision_pos': 0.2096, 'recall_pos': 0.7964, 'f1_pos': 0.3176, 'redundancy_rate': 0.7904, 'text_score_position_aware': 0.7572}\n",
    "# 0.95 {'precision_pos': 0.2342, 'recall_pos': 0.7507, 'f1_pos': 0.3406, 'redundancy_rate': 0.7657, 'text_score_position_aware': 0.7122}\n",
    "# 0.9  {'precision_pos': 0.2481, 'recall_pos': 0.5921, 'f1_pos': 0.3258, 'redundancy_rate': 0.7437, 'text_score_position_aware': 0.559}\n",
    "# 0.85 {'precision_pos': 0.1934, 'recall_pos': 0.3017, 'f1_pos': 0.2119, 'redundancy_rate': 0.6387, 'text_score_position_aware': 0.2826}\n",
    "# 0.8  {'precision_pos': 0.0747, 'recall_pos': 0.0865, 'f1_pos': 0.0718, 'redundancy_rate': 0.3251, 'text_score_position_aware': 0.0784}"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b08ed802-f6fd-4bd9-886c-d5cbefe14835",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
