{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from copy import deepcopy\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "import data\n",
    "import pickle as pkl\n",
    "df = pd.read_csv('../results/datasets_ovw.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{llllll}\n",
      "\\toprule\n",
      "{} & Emotion & Financial phrasebank & Rotten tomatoes &    SST2 & Tweet (Hate) \\\\\n",
      "\\midrule\n",
      "Samples (train)         &   16000 &                 2313 &            8530 &   67349 &         9000 \\\\\n",
      "Samples (val)           &    2000 &                 1140 &            1066 &     872 &         1000 \\\\\n",
      "Unigrams                &   15165 &                 7169 &           16631 &   13887 &        18477 \\\\\n",
      "Bigrams                 &  106201 &                28481 &           93921 &   72501 &       106277 \\\\\n",
      "Trigrams                &  201404 &                39597 &          147426 &  108800 &       171768 \\\\\n",
      "Classes                 &       6 &                    3 &               2 &       2 &            2 \\\\\n",
      "Majority class fraction &    0.34 &                 0.62 &             0.5 &    0.56 &         0.58 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3437291/469346702.py:20: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(prep_for_printing(df).transpose().to_latex())\n"
     ]
    }
   ],
   "source": [
    "def prep_for_printing(df):\n",
    "    df = df.sort_values('n_train')\n",
    "    df['num_classes'] = df.pop('num_classes') # move imbalance to end\n",
    "    df['imbalance'] = df.pop('imbalance') # move imbalance to end\n",
    "    df = df.infer_objects()\n",
    "    for i in range(len(df.columns)):\n",
    "        col_name = df.columns[i]\n",
    "        if not 'imbalance' in col_name:\n",
    "            df[col_name] = df[col_name].astype(int)\n",
    "        else:\n",
    "            df[col_name] = df[col_name].round(2).astype(str)\n",
    "    df = df.rename(\n",
    "        columns=data.COLUMNS_RENAME_DICT,\n",
    "        index=data.DSETS_RENAME_DICT,\n",
    "    ).sort_index()\n",
    "        \n",
    "    return df\n",
    "\n",
    "pd.options.display.float_format = '{:,}'.format\n",
    "col_order = ['Financial phrasebank', 'Rotten tomatoes', 'SST2', 'Emotion', 'Tweet (Hate)']\n",
    "print(prep_for_printing(df).transpose().to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Financial phrasebank</th>\n",
       "      <th>Rotten tomatoes</th>\n",
       "      <th>SST2</th>\n",
       "      <th>Emotion</th>\n",
       "      <th>Tweet (Hate)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Samples (train)</th>\n",
       "      <td>2313</td>\n",
       "      <td>8530</td>\n",
       "      <td>67349</td>\n",
       "      <td>16000</td>\n",
       "      <td>9000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Samples (val)</th>\n",
       "      <td>1140</td>\n",
       "      <td>1066</td>\n",
       "      <td>872</td>\n",
       "      <td>2000</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Unigrams</th>\n",
       "      <td>7169</td>\n",
       "      <td>16631</td>\n",
       "      <td>13887</td>\n",
       "      <td>15165</td>\n",
       "      <td>18477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Bigrams</th>\n",
       "      <td>28481</td>\n",
       "      <td>93921</td>\n",
       "      <td>72501</td>\n",
       "      <td>106201</td>\n",
       "      <td>106277</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Trigrams</th>\n",
       "      <td>39597</td>\n",
       "      <td>147426</td>\n",
       "      <td>108800</td>\n",
       "      <td>201404</td>\n",
       "      <td>171768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Classes</th>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Majority class fraction</th>\n",
       "      <td>0.62</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.56</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.58</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        Financial phrasebank Rotten tomatoes    SST2 Emotion  \\\n",
       "Samples (train)                         2313            8530   67349   16000   \n",
       "Samples (val)                           1140            1066     872    2000   \n",
       "Unigrams                                7169           16631   13887   15165   \n",
       "Bigrams                                28481           93921   72501  106201   \n",
       "Trigrams                               39597          147426  108800  201404   \n",
       "Classes                                    3               2       2       6   \n",
       "Majority class fraction                 0.62             0.5    0.56    0.34   \n",
       "\n",
       "                        Tweet (Hate)  \n",
       "Samples (train)                 9000  \n",
       "Samples (val)                   1000  \n",
       "Unigrams                       18477  \n",
       "Bigrams                       106277  \n",
       "Trigrams                      171768  \n",
       "Classes                            2  \n",
       "Majority class fraction         0.58  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prep_for_printing(df).transpose()[]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Print info about counts (manually copied this into the table)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "counts = pkl.load(open('results/datasets_ovw.pkl', 'rb'))\n",
    "def plot_counts(counts):\n",
    "    x = np.array(counts[0].tolist()).squeeze()\n",
    "    plt.hist(x, bins=100)\n",
    "    plt.yscale('log')\n",
    "    plt.xlabel('Count of occurences of trigram in training dataset')\n",
    "    plt.ylabel('Count of trigrams')\n",
    "# plot_counts(counts['emotion_trigram'])\n",
    "for dset_name in df.index.values:\n",
    "    counts_dset = np.array(counts[dset_name + '_trigram']).squeeze()\n",
    "    print(f'{(counts_dset == 1).sum() /  len(counts_dset):0.2f}', end = ' & ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 ('.embgrams': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "d1f8dbe53c838cb92e1131e4b0e751bd3399f29814a546969a59bd338376d193"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
