{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import json\n",
    "\n",
    "import collections as C\n",
    "import itertools as I"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_size = {\n",
    "    'proofnet' : 374,\n",
    "    'con-nf' : 961\n",
    "}\n",
    "EVAL_MAX_K = 8\n",
    "eval_ks = [1, 8]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Benchmark: proofnet\n",
      "Method\tTypecheck@1\tBEq@1\tTypecheck@8\tBEq@8\n",
      "RA\t0.5721925133689839\t0.12299465240641712\t0.7727272727272727\t0.18181818181818182\t\n",
      "PDA\t0.14705882352941177\t0.00267379679144385\t0.24331550802139038\t0.0213903743315508\t\n",
      "RA+R\t0.7299465240641712\t0.232620320855615\t0.8048128342245989\t0.31283422459893045\t\n",
      "ICL(4o)\t0.4358288770053476\t0.07219251336898395\t0.6631016042780749\t0.12834224598930483\t\n",
      "LW\t0.44919786096256686\t0.0855614973262032\t0.4919786096256685\t0.09893048128342247\t\n",
      "RA-R\t0.5213903743315508\t0.11497326203208556\t0.713903743315508\t0.1657754010695187\t\n",
      "MMA\t0.12566844919786097\t0.01871657754010695\t0.22994652406417113\t0.029411764705882353\t\n",
      "MMA(L)\t0.10962566844919786\t0.0213903743315508\t0.23529411764705882\t0.026737967914438502\t\n",
      "ICL(D)\t0.4037433155080214\t0.09893048128342247\t0.5106951871657754\t0.10962566844919786\t\n",
      "\n",
      "Benchmark: con-nf\n",
      "Method\tTypecheck@1\tBEq@1\tTypecheck@8\tBEq@8\n",
      "RA\t0.20499479708636836\t0.11446409989594172\t0.2861602497398543\t0.16857440166493237\t\n",
      "PDA\t0.043704474505723206\t0.01040582726326743\t0.10613943808532779\t0.036420395421436005\t\n",
      "RA+R\t0.6045785639958376\t0.44849115504682624\t0.7211238293444329\t0.5535900104058272\t\n",
      "ICL(4o)\t0.09781477627471384\t0.014568158168574402\t0.20707596253902186\t0.04162330905306972\t\n",
      "LW\t0.2809573361082206\t0.009365244536940686\t0.37669094693028093\t0.01040582726326743\t\n",
      "RA-R\t0.08116545265348596\t0.030176899063475548\t0.11966701352757544\t0.045785639958376693\t\n",
      "MMA\t0.036420395421436005\t0.019771071800208116\t0.08740894901144641\t0.043704474505723206\t\n",
      "MMA(L)\t0.03329864724245578\t0.01768990634755463\t0.08012486992715921\t0.045785639958376693\t\n",
      "ICL(D)\t0.09365244536940687\t0.02809573361082206\t0.16233090530697192\t0.04266389177939646\t\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for eval_root in [\n",
    "    './eval_proofnet',\n",
    "    './eval_con-nf',\n",
    "    ]:\n",
    "    benchmark_name = osp.basename(osp.normpath(eval_root)).split('eval_')[-1]\n",
    "    print(f'Benchmark: {benchmark_name}')\n",
    "    print('Method\\t' + '\\t'.join(I.chain(\n",
    "        *[[f'Typecheck@{k}', f'BEq@{k}'] for k in eval_ks]\n",
    "        )))\n",
    "    for exp in os.listdir(eval_root):\n",
    "        print(exp.replace('internlm2_math_base_7b_full_', '').replace('_epoch_1', ''), end='\\t')\n",
    "        for try_k in eval_ks:\n",
    "            # Load experiment results\n",
    "            try:\n",
    "                with open(osp.join(eval_root, exp, 'autoformalization.json'), 'r') as f:\n",
    "                    content = json.load(f)\n",
    "            except Exception as e:\n",
    "                print(f'CollectStat({benchmark_name}/{exp}): Failed with {e}')\n",
    "                continue\n",
    "\n",
    "            # Parse experiment results\n",
    "            results = C.defaultdict(lambda : C.defaultdict(lambda : []))\n",
    "            for k, rs in content.items():\n",
    "                for trial in rs:\n",
    "                    if 'typecheck_result' in trial.keys() and trial['typecheck_result']['is_success']:\n",
    "                        results[k]['typecheck'].append(True)\n",
    "                        results[k]['pq'].append('equivcheck_results_PQ' in trial.keys() and trial['equivcheck_results_PQ']['is_success'])\n",
    "                        results[k]['qp'].append('equivcheck_results_QP' in trial.keys() and trial['equivcheck_results_QP']['is_success'])\n",
    "                        # print(k)\n",
    "                    else:\n",
    "                        results[k]['typecheck'].append(False)\n",
    "                        results[k]['pq'].append(False)\n",
    "                        results[k]['qp'].append(False)\n",
    "            for k, v in results.items():\n",
    "                for kk, vv in v.items():\n",
    "                    vv += [0] * (EVAL_MAX_K - len(vv))\n",
    "\n",
    "            # Compute stats\n",
    "            acc_typecheck_successes = []\n",
    "            acc_bidirectional_equivalence = []\n",
    "\n",
    "            for k, v in results.items():\n",
    "                r = {\n",
    "                    kk : vv[:try_k] for kk, vv in results[k].items()\n",
    "                }\n",
    "                acc_typecheck_successes.append(\n",
    "                    any(r['typecheck'])\n",
    "                )\n",
    "                acc_bidirectional_equivalence.append(\n",
    "                    any([x and y for x, y in zip(r['pq'], r['qp'])])\n",
    "                )\n",
    "\n",
    "            print(\n",
    "                sum(acc_typecheck_successes) / benchmark_size[benchmark_name],\n",
    "                sum(acc_bidirectional_equivalence) / benchmark_size[benchmark_name],\n",
    "                sep='\\t', end='\\t')\n",
    "        print()\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "default",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
