{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from importlib import reload\n",
    "import experiment3_doc_experiment\n",
    "reload(experiment3_doc_experiment)\n",
    "from experiment3_doc_experiment import *\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mag_diff_knn</th>\n",
       "      <th>mag_area_cosine_knn</th>\n",
       "      <th>vendi_cosine_knn</th>\n",
       "      <th>neg_mean_cosine_knn</th>\n",
       "      <th>stds_div_zero_knn</th>\n",
       "      <th>dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.888458</td>\n",
       "      <td>0.813125</td>\n",
       "      <td>0.654083</td>\n",
       "      <td>0.882958</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>CNN Dailymail</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.016968</td>\n",
       "      <td>0.018515</td>\n",
       "      <td>0.026399</td>\n",
       "      <td>0.018896</td>\n",
       "      <td>0.026286</td>\n",
       "      <td>CNN Dailymail</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.966250</td>\n",
       "      <td>0.826667</td>\n",
       "      <td>0.619667</td>\n",
       "      <td>0.917250</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>Big Patent</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.010815</td>\n",
       "      <td>0.020541</td>\n",
       "      <td>0.027114</td>\n",
       "      <td>0.017205</td>\n",
       "      <td>0.026286</td>\n",
       "      <td>Big Patent</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.945417</td>\n",
       "      <td>0.830667</td>\n",
       "      <td>0.837708</td>\n",
       "      <td>0.724500</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>EdinburghNLP</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.012653</td>\n",
       "      <td>0.021415</td>\n",
       "      <td>0.021747</td>\n",
       "      <td>0.026520</td>\n",
       "      <td>0.026286</td>\n",
       "      <td>EdinburghNLP</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.993167</td>\n",
       "      <td>0.797208</td>\n",
       "      <td>0.766333</td>\n",
       "      <td>0.882292</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>Arvix Abstracts</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.005635</td>\n",
       "      <td>0.022699</td>\n",
       "      <td>0.024885</td>\n",
       "      <td>0.019529</td>\n",
       "      <td>0.026286</td>\n",
       "      <td>Arvix Abstracts</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      mag_diff_knn  mag_area_cosine_knn  vendi_cosine_knn  \\\n",
       "mean      0.888458             0.813125          0.654083   \n",
       "std       0.016968             0.018515          0.026399   \n",
       "mean      0.966250             0.826667          0.619667   \n",
       "std       0.010815             0.020541          0.027114   \n",
       "mean      0.945417             0.830667          0.837708   \n",
       "std       0.012653             0.021415          0.021747   \n",
       "mean      0.993167             0.797208          0.766333   \n",
       "std       0.005635             0.022699          0.024885   \n",
       "\n",
       "      neg_mean_cosine_knn  stds_div_zero_knn          dataset  \n",
       "mean             0.882958           0.666667    CNN Dailymail  \n",
       "std              0.018896           0.026286    CNN Dailymail  \n",
       "mean             0.917250           0.666667       Big Patent  \n",
       "std              0.017205           0.026286       Big Patent  \n",
       "mean             0.724500           0.666667     EdinburghNLP  \n",
       "std              0.026520           0.026286     EdinburghNLP  \n",
       "mean             0.882292           0.666667  Arvix Abstracts  \n",
       "std              0.019529           0.026286  Arvix Abstracts  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores={}\n",
    "datasets=[\n",
    "        \"cnn_dailymail___3.0.0_16384\",\n",
    "        \"big_patent___a_16384\",\n",
    "        \"EdinburghNLP_-_xsum_16384\",\n",
    "        \"gfissore_-_arxiv-abstracts-2021_16384\"\n",
    "        ]\n",
    "for d in datasets:\n",
    "    #scores[d] = prediction_task_documents(d, results=None, n_samples = 200, n_size=500, n_dims=0)\n",
    "    scores[d] = pd.read_csv(\"./doc_text/\"+d+\"_pred_doc\"+\".csv\")\n",
    "    #prediction_task_documents(d, results=None, n_samples = 200, n_size=300, n_dims=384)\n",
    "    #read_files(d, n_samples = 200, n_size=300, n_dims=0)\n",
    "get_prediction_results(scores)[[\"mag_diff_knn\",\"mag_area_cosine_knn\", \"vendi_cosine_knn\", \"neg_mean_cosine_knn\",\"stds_div_zero_knn\", \"dataset\"]]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
