{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "46d204f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Real label, AIME dataset\n",
    "GT24 = [33, 23, 116, 809, 197, 385, 371, 601, 25, 55, 540, 45, 204, 699, 294, 110, 721, 315, 468, 902, 211, 80, 480, 236, 73, 113, 127, 104, 104, 321]\n",
    "GT25 = [70, 588, 16, 117, 279, 504, 821, 77, 62, 81, 259, 510, 204, 60, 735, 468, 49, 82, 106, 336, 293, 237, 610, 149, 907, 113, 19, 248, 104, 240]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f48ef58e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Entry 19 has abnormal length (60/64), padding with random values\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "import ast\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def extract_pred_lists(log_file):\n",
    "    \"\"\"Extract all pred_list from a single log file\"\"\"\n",
    "    pred_lists = []\n",
    "    with open(log_file, 'r', encoding='utf-8') as f:\n",
    "        content = f.read()\n",
    "    \n",
    "    # Use regex to match the data block\n",
    "    pattern = r\"=+ Final Answer of \\d+.*?=+\\s+pred_list:\\s+(.*?)\\s*(?=\\n|$)\"\n",
    "    matches = re.findall(pattern, content, re.DOTALL)\n",
    "    \n",
    "    for match in matches:\n",
    "        try:\n",
    "            # Safely convert string to list\n",
    "            pred_list = ast.literal_eval(match.strip())\n",
    "            pred_lists.append(pred_list)\n",
    "        except (SyntaxError, ValueError):\n",
    "            print(f\"Format error in file {log_file}: {match}\")\n",
    "    \n",
    "    return pred_lists\n",
    "\n",
    "\n",
    "# Main processing logic\n",
    "log_files = [\n",
    "    './log/aime24_our1_4omini_4per_long1.log',\n",
    "    './log/aime24_our1_4omini_4per_long2.log',\n",
    "    './log/aime24_our1_4omini_4per_long3.log',\n",
    "    './log/aime24_our1_4omini_4per_long4.log'\n",
    "]\n",
    "all_preds = []\n",
    "\n",
    "# Extract pred_list from each file\n",
    "for file in log_files:\n",
    "    file_preds = extract_pred_lists(file)\n",
    "    if len(file_preds) != 30:\n",
    "        print(f\"Warning: {file} contains {len(file_preds)} entries (expected 30)\")\n",
    "    all_preds.append(file_preds)\n",
    "\n",
    "# Merge pred_list for the same data index\n",
    "combined_results = []\n",
    "for i in range(30):  # Assume each file has 30 data entries\n",
    "    combined = []\n",
    "    for j in range(4):  # 4 log files\n",
    "        if i < len(all_preds[j]):\n",
    "            combined.extend(all_preds[j][i])\n",
    "        else:\n",
    "            print(f\"Error: File {j+1} is missing entry {i+1}\")\n",
    "            combined.extend([''] * 16)  # Fill missing data with empty values\n",
    "    \n",
    "    if len(combined) != 64:\n",
    "        print(f\"Warning: Entry {i+1} has abnormal length ({len(combined)}/64), padding with random values\")\n",
    "        while len(combined) < 64:\n",
    "            combined.append(np.random.randint(1000, 10001))\n",
    "\n",
    "    combined_results.append(combined)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5ef6bdaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def func1(prediction):\n",
    "    # Function 1: Find the mode among 64 predictions for each data sample\n",
    "    modes = []\n",
    "    for i in range(prediction.shape[0]):\n",
    "        values, counts = np.unique(prediction[i], return_counts=True)\n",
    "        modes.append(values[np.argmax(counts)])\n",
    "    return modes\n",
    "\n",
    "def func2(prediction, ground_truth):\n",
    "    # Function 2: Check if each data sample has at least one correct prediction\n",
    "    results = []\n",
    "    for i in range(prediction.shape[0]):\n",
    "        # Check if the current sample's predictions contain the ground truth\n",
    "        if np.any(prediction[i] == ground_truth[i]):\n",
    "            results.append(1)\n",
    "        else:\n",
    "            results.append(0)\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "58b3f5b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mode: [33, 2, 116, 404, 35, 12, 295, 600, 25, 66, 324, 0, 204, 991, 147, 5, 31, 15, 156, 243, 13, 240, 80, 236, 0, 13, 7, 18, 184, 169]\n",
      "Contains ground truth: [1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1]\n"
     ]
    }
   ],
   "source": [
    "prediction = np.array(combined_results)\n",
    "\n",
    "# Replace empty entries ('') with random integers between 1000 and 2000\n",
    "mask = (prediction == '')\n",
    "prediction[mask] = np.random.randint(1000, 2000, size=mask.sum()).astype(str)\n",
    "\n",
    "# Convert all predictions to integers\n",
    "prediction = prediction.astype(int)\n",
    "ground_truth = np.array(GT24)\n",
    "\n",
    "print(\"Mode:\", func1(prediction))\n",
    "print(\"Contains ground truth:\", func2(prediction, ground_truth))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "9345ad71",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.16666666666666666\n",
      "0.6\n"
     ]
    }
   ],
   "source": [
    "## Acc and Pass\n",
    "print(sum(x == y for x, y in zip(ground_truth, func1(prediction))) / len(ground_truth)) \n",
    "print(np.mean(func2(prediction, ground_truth), axis=0)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42f387cc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d4e692",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b200065b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
