{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "841a7f90",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "from tqdm import trange, tqdm\n",
    "from langagent.base_llm import InferenceLogger, HfChatModel, VALID_ROLES_PREFIX, DETERMINISTIC_TEMPERATURE\n",
    "from langagent.metrics import get_inference_cost_metrics\n",
    "from langagent.eval import ResultToTxtLine, ResultDictToJsonl\n",
    "import matplotlib.pyplot as plt\n",
    "from langagent.langreason.common import load_qa_dataset\n",
    "\n",
    "run_id = \"rest_cot\"\n",
    "dataset_name =\"math500\"  #\"gsm8k\" #\n",
    "run_id = f\"{dataset_name}_{run_id}\"\n",
    "model_name =  \"Meta-Llama-3-8B-Instruct\" #\"Qwen3-32B-AWQ\" #\"Meta-Llama-3-8B\" #\n",
    "root_dir = f\"{model_name}_results/{run_id}/\" #run_qwen2/\n",
    "assert os.path.exists(root_dir), f\"Root directory {root_dir} does not exist.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "234b1c40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result file Meta-Llama-3-8B-Instruct_results/math500_rest_cot/inferencelogger.log already exists. I will append to it. \n",
      "Efficiency metrics - Roles\n",
      "default :  {'num_calls': '100', 'input_tokens': '13911', 'output_tokens': '37922', 'total_hours': '0.30664499554369185'}\n",
      "dynamics :  {'num_calls': '0', 'input_tokens': '0', 'output_tokens': '0', 'total_hours': '0.0'}\n",
      "policy :  {'num_calls': '0', 'input_tokens': '0', 'output_tokens': '0', 'total_hours': '0.0'}\n",
      "evaluator :  {'num_calls': '0', 'input_tokens': '0', 'output_tokens': '0', 'total_hours': '0.0'}\n",
      "bn_eval :  {'num_calls': '0', 'input_tokens': '0', 'output_tokens': '0', 'total_hours': '0.0'}\n",
      "bn_entropy :  {'num_calls': '0', 'input_tokens': '0', 'output_tokens': '0', 'total_hours': '0.0'}\n",
      "{'num_calls': 100, 'input_tokens': 13911, 'output_tokens': 37922, 'total_hours': 0.31}\n"
     ]
    }
   ],
   "source": [
    "inference_logger, metrics = get_inference_cost_metrics(root_dir, return_metrics=['num_calls', 'input_tokens', 'output_tokens', 'total_hours'])   \n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "356b5e8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result file Meta-Llama-3-8B-Instruct_results/math500_rest_cot/resultdicttojsonl.jsonl already exists. Loading existing results.\n",
      "100\n"
     ]
    }
   ],
   "source": [
    "results_file = ResultDictToJsonl(run_id='', root_dir=root_dir, override=False)\n",
    "existing_results = results_file.results\n",
    "print(len(results_file.results))\n",
    "# full_dataset = load_qa_dataset(dataset_name)\n",
    "# metrics = get_accuracy(full_dataset, results_file)\n",
    "# print(\"Accuracy:\", metrics[\"accuracy\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "958451ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy( existing_results, exclude_idx=[], include_idx=[]):\n",
    "    num_correct = 0\n",
    "    num_total = 0\n",
    "    incorrect_idx = []\n",
    "    num_error = 0\n",
    "    for idx, rec in enumerate(existing_results):\n",
    "        if idx in exclude_idx:\n",
    "            continue\n",
    "        if include_idx and idx not in include_idx:\n",
    "            continue\n",
    "        num_total += 1\n",
    "        try:\n",
    "            if float(rec.get(\"label\", \"\")) == float(rec.get(\"truth\", \"\")):\n",
    "                num_correct += 1\n",
    "            else: \n",
    "                incorrect_idx.append(idx)\n",
    "                # print(\"Predicted:\", float(rec.get(\"label\", \"\")), \"Truth:\", float(rec.get(\"truth\", \"\")))\n",
    "        # catch TypeError, ValueError\n",
    "        except (TypeError, ValueError) as e:\n",
    "            num_error += 1\n",
    "            # print(f\"Error for record {rec}: {e}\")\n",
    "    print(f\"Correct #: {num_correct}; Incorrect #: {len(incorrect_idx)}; Errors #: {num_error};Total #: {num_total}\")\n",
    "    accuracy = num_correct / num_total if num_total > 0 else 0\n",
    "    return accuracy, num_correct, incorrect_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "6c3efb5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Correct #: 35; Incorrect #: 49; Errors #: 16;Total #: 100\n",
      "0.35\n"
     ]
    }
   ],
   "source": [
    "acc, num_correct, incorrect_idx = get_accuracy(existing_results) # include_idx=list(range(77))\n",
    "print(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "31aeceed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Level: 1\n",
      "Correct #: 8; Incorrect #: 3; Errors #: 0;Total #: 11\n",
      "Accuracy (3/11): 0.7272727272727273\n",
      "Incorrect indices: [8, 47, 54]\n",
      "Level: 2\n",
      "Correct #: 11; Incorrect #: 11; Errors #: 0;Total #: 22\n",
      "Accuracy (11/22): 0.5\n",
      "Incorrect indices: [1, 15, 18, 22, 27, 35, 37, 51, 60, 71, 95]\n",
      "Level: 3\n",
      "Correct #: 8; Incorrect #: 10; Errors #: 4;Total #: 22\n",
      "Accuracy (10/22): 0.36363636363636365\n",
      "Incorrect indices: [2, 4, 10, 19, 38, 50, 57, 63, 75, 77]\n",
      "Level: 4\n",
      "Correct #: 4; Incorrect #: 13; Errors #: 4;Total #: 21\n",
      "Accuracy (13/21): 0.19047619047619047\n",
      "Incorrect indices: [17, 23, 33, 43, 52, 56, 62, 68, 78, 79, 81, 82, 99]\n",
      "Level: 5\n",
      "Correct #: 2; Incorrect #: 18; Errors #: 4;Total #: 24\n",
      "Accuracy (18/24): 0.08333333333333333\n",
      "Incorrect indices: [5, 11, 12, 13, 28, 30, 31, 40, 49, 55, 59, 61, 74, 91, 92, 93, 96, 97]\n",
      "Correct #: 33; Incorrect #: 55; Errors #: 12;Total #: 100\n",
      "Overall: (0.33, 33, [1, 2, 4, 5, 8, 10, 11, 12, 13, 15, 17, 18, 19, 22, 23, 27, 28, 30, 31, 33, 35, 37, 38, 40, 43, 47, 49, 50, 51, 52, 54, 55, 56, 57, 59, 60, 61, 62, 63, 68, 71, 74, 75, 77, 78, 79, 81, 82, 91, 92, 93, 95, 96, 97, 99])\n"
     ]
    }
   ],
   "source": [
    "exclude_idx = []\n",
    "with open(\"data/math500_float_answer_idx_by_level.jsonl\", \"r\") as f:\n",
    "    loaded_idx_by_level = json.load(f)\n",
    "\n",
    "max_idx = 100\n",
    "overall_num_correct = 0\n",
    "incorrect_idx = []\n",
    "for level in range(1, 6):\n",
    "    include_idx = loaded_idx_by_level.get(str(level), [])\n",
    "    if max_idx is not None:\n",
    "        include_idx = [idx for idx in include_idx if idx < max_idx]\n",
    "    print(\"Level:\", level)\n",
    "    acc, num_correct, incorrect_idx = get_accuracy(existing_results, exclude_idx=exclude_idx, include_idx=include_idx)\n",
    "    overall_num_correct += num_correct\n",
    "    print(f\"Accuracy ({len(incorrect_idx)}/{len(include_idx)}): {acc}\")\n",
    "    print(\"Incorrect indices:\", incorrect_idx)\n",
    "print(\"Overall:\", get_accuracy(existing_results, exclude_idx=exclude_idx, include_idx=list(range(max_idx)) if max_idx is not None else None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61edad65",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 0.31645569620253167\n",
    "# 0.3227848101265823\n",
    "overall_num_correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4870a69",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hf_transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
