{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('2.1.2+cu121', '4.36.2', '12.1')"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "from src import functional\n",
    "import src.tokens as tokenization_utils\n",
    "import numpy as np\n",
    "import logging\n",
    "from src import models\n",
    "\n",
    "from src.utils import logging_utils\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "torch.__version__, transformers.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.summarize import main as summarize\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'num_cases': 538,\n",
      " 'post_essence_score': (3.3875051736831665, 0.3947960138320923),\n",
      " 'post_neighborhood_acc': (18.68, 23.16),\n",
      " 'post_neighborhood_diff': (11.53, 13.69),\n",
      " 'post_neighborhood_success': (83.98, 23.61),\n",
      " 'post_ngram_entropy': (629.89, 20.45),\n",
      " 'post_paraphrase_acc': (54.93, 40.43),\n",
      " 'post_paraphrase_diff': (41.71, 35.62),\n",
      " 'post_paraphrase_success': (86.34, 28.99),\n",
      " 'post_reference_score': (37.36, 12.75),\n",
      " 'post_rewrite_acc': (94.42, 22.95),\n",
      " 'post_rewrite_diff': (89.8, 26.32),\n",
      " 'post_rewrite_success': (97.58, 15.36),\n",
      " 'post_score': (88.92125386032497, nan),\n",
      " 'run_dir': '../results/ROME/run_000',\n",
      " 'time': (66.97480171453554, 14.196454427187613)}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'time': (66.97480171453554, 14.196454427187613),\n",
       "  'post_rewrite_success': (97.58, 15.36),\n",
       "  'post_rewrite_diff': (89.8, 26.32),\n",
       "  'post_paraphrase_success': (86.34, 28.99),\n",
       "  'post_paraphrase_diff': (41.71, 35.62),\n",
       "  'post_neighborhood_success': (83.98, 23.61),\n",
       "  'post_neighborhood_diff': (11.53, 13.69),\n",
       "  'post_rewrite_acc': (94.42, 22.95),\n",
       "  'post_paraphrase_acc': (54.93, 40.43),\n",
       "  'post_neighborhood_acc': (18.68, 23.16),\n",
       "  'post_ngram_entropy': (629.89, 20.45),\n",
       "  'post_reference_score': (37.36, 12.75),\n",
       "  'post_essence_score': (3.3875051736831665, 0.3947960138320923),\n",
       "  'post_score': (88.92125386032497, nan),\n",
       "  'run_dir': '../results/ROME/run_000',\n",
       "  'num_cases': 538}]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarize(\n",
    "    dir_name=Path(\"../results/ROME\"),\n",
    "    runs=[\"run_000\"],\n",
    "    abs_path=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
