{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/proj2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "from pathlib import Path\n",
    "import json \n",
    "from math_utils import *\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import matplotlib.pyplot as plt \n",
    "\n",
    "from parser import *  \n",
    "from grader import *\n",
    "from tqdm import tqdm\n",
    "\n",
    "sys.path.append('../')  \n",
    "\n",
    "from utils import * "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_math(models, shot_types, dataset, base_dir, task='math'):\n",
    "    if isinstance(dataset, str):\n",
    "        with open(dataset, 'r', encoding='utf-8') as f:\n",
    "            test = [json.loads(line) for line in f]\n",
    "    else:\n",
    "        test = dataset  \n",
    "\n",
    "    for shot_type in shot_types:\n",
    "        for model in models:\n",
    "            scored_path = f\"{base_dir}/{task}/{model}/{task}_{shot_type}_scored.jsonl\"\n",
    "            if os.path.exists(scored_path):\n",
    "                print(f\"Scored file already exists, skipping: {scored_path}\")\n",
    "                process_scored_file(scored_path, shot_type)\n",
    "                continue\n",
    "\n",
    "            file_path = f\"{base_dir}/{task}/{model}/{task}_{shot_type}.jsonl\"\n",
    "            if not os.path.exists(file_path):\n",
    "                print(f\"File not found: {file_path}\")\n",
    "                continue\n",
    "\n",
    "            with open(file_path, 'r', encoding='utf-8') as f:\n",
    "                data = [json.loads(line) for line in f]\n",
    "\n",
    "            scores = [[] for _ in range(5)]\n",
    "            any_correct = []\n",
    "            all_correct_count = 0\n",
    "            scored_entries = []\n",
    "\n",
    "            for entry in data:\n",
    "                idx = entry[\"idx\"]\n",
    "                outputs = entry.get(\"model_outputs\", entry.get(\"resps\", [None])[0])\n",
    "                if outputs == None: \n",
    "                    outputs = entry.get('code')\n",
    "\n",
    "                preds = [extract_answer(mo, task) for mo in outputs]\n",
    "                _, gt = parse_ground_truth(test[idx], task)\n",
    "\n",
    "                sample_results = []\n",
    "\n",
    "                for i, pred in enumerate(preds):\n",
    "                    result = math_equal_process((None, pred, gt))\n",
    "                    if not result:\n",
    "                        result = process_results(gt, [outputs[i]])\n",
    "                        if not result:\n",
    "                            pred = extract_answer(pred, task)\n",
    "                            result = math_equal_process((None, pred, gt))\n",
    "\n",
    "                    sample_results.append(bool(result))\n",
    "                    scores[i].append(bool(result))\n",
    "\n",
    "                any_correct.append(any(sample_results))\n",
    "                if all(sample_results):\n",
    "                    all_correct_count += 1\n",
    "\n",
    "                scored_entries.append({\n",
    "                    \"idx\": idx,\n",
    "                    \"is_correct\": sample_results\n",
    "                })\n",
    "\n",
    "            num_questions = len(data)\n",
    "            rep_accuracies = [\n",
    "                sum(scores[i]) / num_questions if num_questions > 0 else 0.0\n",
    "                for i in range(5)\n",
    "            ]\n",
    "            avg_rep_acc = sum(rep_accuracies) / 5.0\n",
    "            any_correct_acc = sum(any_correct) / num_questions if num_questions > 0 else 0.0\n",
    "\n",
    "            for i, acc_val in enumerate(rep_accuracies):\n",
    "                print(f\"{i}:  {acc_val:.3f}\")\n",
    "            print(f\"{shot_type} Avg:  {avg_rep_acc:.3f}\")\n",
    "            print(f\"Any_correct:  {any_correct_acc:.3f}\\n\")\n",
    "\n",
    "            with open(scored_path, 'w', encoding='utf-8') as f:\n",
    "                for e in scored_entries:\n",
    "                    f.write(json.dumps(e, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "            print(f\"Saved scored file: {scored_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scored file already exists, skipping: ../../../ReFeri/result/math500/gpt-4o-mini/math500_zero_scored.jsonl\n",
      "0:  0.764\n",
      "1:  0.766\n",
      "2:  0.742\n",
      "3:  0.774\n",
      "4:  0.776\n",
      "zero Avg: 0.764\n",
      "Any_correct: 0.874\n",
      "\n",
      "Scored file already exists, skipping: ../../../ReFeri/result/math500/gpt-4o-mini/math500_few_scored.jsonl\n",
      "0:  0.758\n",
      "1:  0.750\n",
      "2:  0.752\n",
      "3:  0.754\n",
      "4:  0.746\n",
      "few Avg: 0.752\n",
      "Any_correct: 0.858\n",
      "\n"
     ]
    }
   ],
   "source": [
    "task = 'math500'\n",
    "\n",
    "dataset_path = f\"../data/{task}/test.jsonl\"\n",
    "with open(dataset_path, 'r', encoding='utf-8') as f:\n",
    "    dataset = [json.loads(line) for line in f]\n",
    "\n",
    "models=[\n",
    "    \"gpt-4o-mini\", \n",
    "]\n",
    "base_dir = \"../result\"\n",
    "\n",
    "evaluate_math(\n",
    "    models=models,\n",
    "    shot_types=[\"zero\", \"few\"],\n",
    "    dataset=dataset,\n",
    "    base_dir=base_dir,\n",
    "    task=task\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "proj2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
