{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "25dc1cc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections, os, random, sys, math, email\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import torch\n",
    "\n",
    "from tqdm import tqdm\n",
    "from collections import Counter\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn import metrics\n",
    "\n",
    "from data_utils import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# https://amunategui.github.io/office-automation-part2/index.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b08abebf",
   "metadata": {},
   "source": [
    "In this notebook, we process `emails.csv` downloaded from Kaggle, train a logistic regression model on part of the data, and then use the model to generate conformal scores for the rest of the data.\n",
    "\n",
    "*0) Preprocessing*: We extract the sender names from the file names. We filter out senders who sent fewer than `min_emails` emails. We then map the remaining sender names to integer labels. We process the emails by extracting the text of the subject and body and concatentating. We use a Bag of Words model to map each email to a vector: first, we create a dictionary of the `vocab_size` most common words in the email corpus. Each word is mapped to an integer index. Each email is then transformed into a `vocab_size`-length array that contains the frequencies of each word (normalized to sum to 1). \n",
    "\n",
    "*1) Training logistic regression model*: We use `n_train` examples PER CLASS (aka sender) for training the model and `n_val` for computing the accuracy\n",
    "\n",
    "*2) Computing conformal scores*: We use the logistic regression model to compute softmax scores for the examples not used to train and validate the model (although maybe it's actually okay to use the validation data too?)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a2de5b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorization_method = 'bert' # 'BoW', 'glove', or 'bert'\n",
    "\n",
    "extract_extra_features = False # Whether to add in features of cts of '!', '?', and phone #s\n",
    "\n",
    "EMAILS_PATH = '/home/ANONYMIZED/code/class-conditional-conformal-datasets/data/emails.csv'\n",
    "GLOVE_DATASET_PATH = '/home/ANONYMIZED/code/class-conditional-conformal-datasets/data/glove.6B.300d.txt' # GloVE embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d50c5fd5",
   "metadata": {},
   "source": [
    "## 0) Prepare data\n",
    "\n",
    "#### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2e0e12e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "emails_df = pd.read_csv(EMAILS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "55167097",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>file</th>\n",
       "      <th>message</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>allen-p/_sent_mail/1.</td>\n",
       "      <td>Message-ID: &lt;18782981.1075855378110.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>allen-p/_sent_mail/10.</td>\n",
       "      <td>Message-ID: &lt;15464986.1075855378456.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>allen-p/_sent_mail/100.</td>\n",
       "      <td>Message-ID: &lt;24216240.1075855687451.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>allen-p/_sent_mail/1000.</td>\n",
       "      <td>Message-ID: &lt;13505866.1075863688222.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>allen-p/_sent_mail/1001.</td>\n",
       "      <td>Message-ID: &lt;30922949.1075863688243.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>517396</th>\n",
       "      <td>zufferli-j/sent_items/95.</td>\n",
       "      <td>Message-ID: &lt;26807948.1075842029936.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>517397</th>\n",
       "      <td>zufferli-j/sent_items/96.</td>\n",
       "      <td>Message-ID: &lt;25835861.1075842029959.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>517398</th>\n",
       "      <td>zufferli-j/sent_items/97.</td>\n",
       "      <td>Message-ID: &lt;28979867.1075842029988.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>517399</th>\n",
       "      <td>zufferli-j/sent_items/98.</td>\n",
       "      <td>Message-ID: &lt;22052556.1075842030013.JavaMail.e...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>517400</th>\n",
       "      <td>zufferli-j/sent_items/99.</td>\n",
       "      <td>Message-ID: &lt;28618979.1075842030037.JavaMail.e...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>517401 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                             file  \\\n",
       "0           allen-p/_sent_mail/1.   \n",
       "1          allen-p/_sent_mail/10.   \n",
       "2         allen-p/_sent_mail/100.   \n",
       "3        allen-p/_sent_mail/1000.   \n",
       "4        allen-p/_sent_mail/1001.   \n",
       "...                           ...   \n",
       "517396  zufferli-j/sent_items/95.   \n",
       "517397  zufferli-j/sent_items/96.   \n",
       "517398  zufferli-j/sent_items/97.   \n",
       "517399  zufferli-j/sent_items/98.   \n",
       "517400  zufferli-j/sent_items/99.   \n",
       "\n",
       "                                                  message  \n",
       "0       Message-ID: <18782981.1075855378110.JavaMail.e...  \n",
       "1       Message-ID: <15464986.1075855378456.JavaMail.e...  \n",
       "2       Message-ID: <24216240.1075855687451.JavaMail.e...  \n",
       "3       Message-ID: <13505866.1075863688222.JavaMail.e...  \n",
       "4       Message-ID: <30922949.1075863688243.JavaMail.e...  \n",
       "...                                                   ...  \n",
       "517396  Message-ID: <26807948.1075842029936.JavaMail.e...  \n",
       "517397  Message-ID: <25835861.1075842029959.JavaMail.e...  \n",
       "517398  Message-ID: <28979867.1075842029988.JavaMail.e...  \n",
       "517399  Message-ID: <22052556.1075842030013.JavaMail.e...  \n",
       "517400  Message-ID: <28618979.1075842030037.JavaMail.e...  \n",
       "\n",
       "[517401 rows x 2 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emails_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "41408193",
   "metadata": {},
   "outputs": [],
   "source": [
    "sent_by = np.array([ x.split(\"/\")[0] for x in emails_df[\"file\"] ])\n",
    "messages = emails_df[\"message\"].to_numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a938763",
   "metadata": {},
   "source": [
    "#### Remove rare classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "71a7f675",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('kaminski-v', 28465),\n",
       " ('dasovich-j', 28234),\n",
       " ('kean-s', 25351),\n",
       " ('mann-k', 23381),\n",
       " ('jones-t', 19950),\n",
       " ('shackleton-s', 18687),\n",
       " ('taylor-m', 13875),\n",
       " ('farmer-d', 13032),\n",
       " ('germany-c', 12436),\n",
       " ('beck-s', 11830),\n",
       " ('symes-k', 10827),\n",
       " ('nemec-g', 10655),\n",
       " ('scott-s', 8022),\n",
       " ('rogers-b', 8009),\n",
       " ('bass-e', 7823),\n",
       " ('sanders-r', 7329),\n",
       " ('campbell-l', 6490),\n",
       " ('shapiro-r', 6071),\n",
       " ('guzman-m', 6054),\n",
       " ('lay-k', 5937),\n",
       " ('lenhart-m', 5920),\n",
       " ('lokay-m', 5568),\n",
       " ('kitchen-l', 5546),\n",
       " ('haedicke-m', 5246),\n",
       " ('sager-e', 5200),\n",
       " ('love-p', 5002),\n",
       " ('arnold-j', 4898),\n",
       " ('fossum-d', 4796),\n",
       " ('perlingiere-d', 4778),\n",
       " ('lavorato-j', 4685),\n",
       " ('mcconnell-m', 4542),\n",
       " ('giron-d', 4220),\n",
       " ('skilling-j', 4139),\n",
       " ('shankman-j', 3856),\n",
       " ('hain-m', 3820),\n",
       " ('delainey-d', 3566),\n",
       " ('williams-w3', 3440),\n",
       " ('blair-l', 3415),\n",
       " ('mclaughlin-e', 3353),\n",
       " ('whalley-l', 3335),\n",
       " ('steffes-j', 3331),\n",
       " ('white-s', 3272),\n",
       " ('neal-s', 3268),\n",
       " ('hernandez-j', 3265),\n",
       " ('hyvl-d', 3210),\n",
       " ('allen-p', 3034),\n",
       " ('stclair-c', 3030),\n",
       " ('griffith-j', 2973),\n",
       " ('cash-m', 2969),\n",
       " ('watson-k', 2950),\n",
       " ('linder-e', 2805),\n",
       " ('rodrique-r', 2766),\n",
       " ('baughman-d', 2760),\n",
       " ('ward-k', 2611),\n",
       " ('hayslett-r', 2554),\n",
       " ('horton-s', 2470),\n",
       " ('buy-r', 2429),\n",
       " ('dean-c', 2429),\n",
       " ('parks-j', 2284),\n",
       " ('davis-d', 2249),\n",
       " ('grigsby-m', 2237),\n",
       " ('presto-k', 2204),\n",
       " ('lewis-a', 2191),\n",
       " ('keavey-p', 2177),\n",
       " ('dorland-c', 2127),\n",
       " ('mims-thurston-p', 2038),\n",
       " ('corman-s', 2025),\n",
       " ('maggi-m', 1991),\n",
       " ('shively-h', 1991),\n",
       " ('tholt-j', 1885),\n",
       " ('whalley-g', 1878),\n",
       " ('schoolcraft-d', 1859),\n",
       " ('hyatt-k', 1794),\n",
       " ('derrick-j', 1766),\n",
       " ('hodge-j', 1661),\n",
       " ('ruscitti-k', 1643),\n",
       " ('smith-m', 1642),\n",
       " ('salisbury-h', 1632),\n",
       " ('merriss-s', 1627),\n",
       " ('heard-m', 1623),\n",
       " ('may-l', 1600),\n",
       " ('geaccone-t', 1592),\n",
       " ('wolfe-j', 1587),\n",
       " ('quigley-d', 1568),\n",
       " ('weldon-c', 1566),\n",
       " ('zipper-a', 1563),\n",
       " ('fischer-m', 1498),\n",
       " ('gay-r', 1415),\n",
       " ('carson-m', 1400),\n",
       " ('thomas-p', 1293),\n",
       " ('ybarbo-p', 1291),\n",
       " ('stokley-c', 1252),\n",
       " ('ermis-f', 1230),\n",
       " ('stepenovitch-j', 1227),\n",
       " ('tycholiz-b', 1219),\n",
       " ('williams-j', 1213),\n",
       " ('sturm-f', 1169),\n",
       " ('lokey-t', 1156),\n",
       " ('kuykendall-t', 1120),\n",
       " ('saibi-e', 1116),\n",
       " ('keiser-k', 1113),\n",
       " ('martin-t', 1112),\n",
       " ('meyers-a', 1099),\n",
       " ('solberg-g', 1081),\n",
       " ('donoho-l', 1045),\n",
       " ('cuilla-m', 1029),\n",
       " ('storey-g', 1027),\n",
       " ('brawner-s', 1026),\n",
       " ('donohoe-t', 1015),\n",
       " ('mckay-j', 998),\n",
       " ('lucci-p', 997),\n",
       " ('ring-r', 994),\n",
       " ('causholli-m', 943),\n",
       " ('badeer-r', 877),\n",
       " ('whitt-m', 807),\n",
       " ('benson-r', 767),\n",
       " ('schwieger-j', 738),\n",
       " ('forney-j', 729),\n",
       " ('pereira-s', 725),\n",
       " ('semperger-c', 721),\n",
       " ('hendrickson-s', 719),\n",
       " ('ring-a', 706),\n",
       " ('mccarty-d', 691),\n",
       " ('mckay-b', 681),\n",
       " ('arora-h', 654),\n",
       " ('scholtes-d', 647),\n",
       " ('townsend-j', 646),\n",
       " ('pimenov-v', 642),\n",
       " ('staab-t', 621),\n",
       " ('gang-l', 590),\n",
       " ('richey-c', 582),\n",
       " ('gilbertsmith-d', 578),\n",
       " ('platter-p', 574),\n",
       " ('rapp-b', 563),\n",
       " ('zufferli-j', 557),\n",
       " ('harris-s', 548),\n",
       " ('crandell-s', 519),\n",
       " ('reitmeyer-j', 498),\n",
       " ('bailey-s', 478),\n",
       " ('holst-k', 463),\n",
       " ('king-j', 462),\n",
       " ('panus-s', 437),\n",
       " ('dickson-s', 395),\n",
       " ('quenet-j', 395),\n",
       " ('motley-m', 378),\n",
       " ('swerzbin-m', 355),\n",
       " ('sanchez-m', 256),\n",
       " ('south-s', 248),\n",
       " ('slinger-r', 132),\n",
       " ('phanis-s', 35)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check the number of emails sent by each person. \n",
    "email_cts = Counter(sent_by).most_common()\n",
    "email_cts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "84dc3ce3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Excluding 41 senders who have sent fewer than 1000 emails each\n",
      "Excluding 24346 total emails\n"
     ]
    }
   ],
   "source": [
    "min_emails = 1000 # Filter out senders who have sent fewer than min_emails emails\n",
    "excluded_senders = [name for name, ct in email_cts if ct < min_emails]\n",
    "\n",
    "print(f'Excluding {len(excluded_senders)} senders who have sent fewer than {min_emails} emails each')\n",
    "\n",
    "keep_idx = np.array([False if x in excluded_senders else True for x in sent_by])\n",
    "\n",
    "print(f'Excluding {(~keep_idx).sum()} total emails')\n",
    "sent_by = sent_by[keep_idx]\n",
    "messages = messages[keep_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c38149a8",
   "metadata": {},
   "source": [
    "#### Clean data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0192bf15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert sender names into ints\n",
    "\n",
    "le = LabelEncoder()\n",
    "le.fit(sent_by)\n",
    "labels = le.transform(sent_by)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dbefb08b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_message(message):\n",
    "    subject = message.split(\"Subject:\")[1].split(\"\\nMime\")[0]\n",
    "    txt = \" \".join([subject] + message.split(\"\\n\\n\")[1:]) # Join subject and body text\n",
    "    \n",
    "    txt = txt.lower() # Make everything lowercase\n",
    "    txt = re.sub(\"[^a-z']\", ' ', txt) # Get rid of any character that is not a letter or apostrophe\n",
    "    txt = re.sub('\\s+', ' ', txt) # Replace consecutive spaces with a single space\n",
    "    \n",
    "    return txt\n",
    "    \n",
    "# Clean all the messages\n",
    "clean_messages = [clean_message(m) for m in messages]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4b5ba4ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded precomputed BERT embeddings.\n",
      "CPU times: user 20.8 ms, sys: 947 ms, total: 968 ms\n",
      "Wall time: 2.66 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "def vectorize(messages, method):\n",
    "    if method == 'BoW': # Apply standard Bag of Words to convert text into vectors of word counts\n",
    "        vocab_size = 10000 # Can increase this\n",
    "\n",
    "        vectorizer = CountVectorizer(max_features=vocab_size)\n",
    "        token_cts = vectorizer.fit_transform(clean_messages)\n",
    "        \n",
    "        # Convert cts to frequencies\n",
    "        freqs = token_cts / token_cts.sum(axis=1)\n",
    "        \n",
    "        freqs = np.asarray(freqs)\n",
    "        freqs = np.nan_to_num(freqs) # Replace NaNs with 0\n",
    "        \n",
    "        return freqs\n",
    "    elif method == 'glove': # Get GloVe embeddings for each word in a message then compute average\n",
    "        \n",
    "        # Load GloVE embeddings\n",
    "        all_txt = ' '.join(messages)\n",
    "        list_of_words = set(all_txt.split())\n",
    "        embeddings_index = {}\n",
    "        f = open(GLOVE_DATASET_PATH)\n",
    "        word_counter = 0\n",
    "        for line in tqdm(f):\n",
    "            values = line.split()\n",
    "            word = values[0]\n",
    "            if word in list_of_words:\n",
    "                coefs = np.asarray(values[1:], dtype='float32')\n",
    "                embeddings_index[word] = coefs\n",
    "            word_counter += 1\n",
    "        f.close()\n",
    "        \n",
    "        embedding_size = len(values[1:])\n",
    "\n",
    "        # Loop through messages. For each word in message, map the word to its GloVE embedding. \n",
    "        # Then take average over words.\n",
    "        vectorized_msgs = np.zeros((len(messages), embedding_size))\n",
    "        for i, msg in enumerate(messages):\n",
    "            words = msg.split()\n",
    "            vec = np.zeros((embedding_size,))\n",
    "            ct = 0\n",
    "            for word in words: \n",
    "                if word in embeddings_index:\n",
    "                    vec += embeddings_index[word]\n",
    "                    ct += 1\n",
    "            if ct > 0:\n",
    "                vec /= ct\n",
    "            vectorized_msgs[i,:] = vec\n",
    "            \n",
    "        return vectorized_msgs\n",
    "    elif method == 'bert': # Encode text using Sentence-BERT (Reimers & Gurevych, 2019)\n",
    "        try:\n",
    "            vectorized_msgs = np.load('.cache/email_BERT_embeddings.npy')\n",
    "            print('Loaded precomputed BERT embeddings.')\n",
    "        except:\n",
    "            model = SentenceTransformer('all-MiniLM-L6-v2')\n",
    "            vectorized_msgs = model.encode(messages)\n",
    "            os.makedirs(\".cache\", exist_ok=True)\n",
    "            np.save('.cache/email_BERT_embeddings.npy', vectorized_msgs)\n",
    "            print('Saved BERT embeddings to .cache/email_BERT_embeddings.npy')\n",
    "        return vectorized_msgs\n",
    "    else:\n",
    "        raise Exception('Invalid vectorize() method')\n",
    "        \n",
    "vecs = vectorize(clean_messages, vectorization_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0d37742f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of classes: 109\n"
     ]
    }
   ],
   "source": [
    "num_classes = np.max(labels) + 1\n",
    "print('Number of classes:', num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68ccc9a3",
   "metadata": {},
   "source": [
    "### [Optional] Add in some additional features\n",
    "1. Count of number of '?'s\n",
    "2. Count of number of '!'s\n",
    "3. Count of number of phone numbes formatted as xxx.xxx.xxxx (or '-' or ' ' in place of '.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5ebe3689",
   "metadata": {},
   "outputs": [],
   "source": [
    "if extract_extra_features:\n",
    "    # Count\n",
    "    qmark_ct = np.array([m.count('?') for m in messages])\n",
    "    emark_ct = np.array([m.count('!') for m in messages])\n",
    "    phone_number_ct = np.array([len(re.findall('\\d{3}[-\\.\\s]\\d{3}[-\\.\\s]\\d{4}', m)) for m in messages])\n",
    "    \n",
    "    # Normalize\n",
    "    qmark_ct = qmark_ct / np.max(qmark_ct)\n",
    "    emark_ct = emark_ct / np.max(emark_ct)\n",
    "    phone_number_ct = phone_number_ct / np.max(phone_number_ct)\n",
    "    \n",
    "    # Concatenate horizontally with vecs\n",
    "    vecs = np.concatenate((vecs, \n",
    "                          np.expand_dims(qmark_ct, axis=1), \n",
    "                          np.expand_dims(emark_ct, axis=1), \n",
    "                          np.expand_dims(phone_number_ct, axis=1)), axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9571aad",
   "metadata": {},
   "source": [
    "## 1) Train logistic regression model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7c878f26",
   "metadata": {},
   "outputs": [],
   "source": [
    "email_cts = Counter(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "060ee35b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(42, 28465),\n",
       " (13, 28234),\n",
       " (43, 25351),\n",
       " (57, 23381),\n",
       " (41, 19950),\n",
       " (80, 18687),\n",
       " (94, 13875),\n",
       " (22, 13032),\n",
       " (27, 12436),\n",
       " (4, 11830),\n",
       " (93, 10827),\n",
       " (66, 10655),\n",
       " (79, 8022),\n",
       " (72, 8009),\n",
       " (2, 7823),\n",
       " (77, 7329),\n",
       " (8, 6490),\n",
       " (82, 6071),\n",
       " (31, 6054),\n",
       " (49, 5937),\n",
       " (50, 5920),\n",
       " (53, 5568),\n",
       " (46, 5546),\n",
       " (32, 5246),\n",
       " (74, 5200),\n",
       " (55, 5002),\n",
       " (1, 4898),\n",
       " (24, 4796),\n",
       " (68, 4778),\n",
       " (48, 4685),\n",
       " (60, 4542),\n",
       " (28, 4220),\n",
       " (84, 4139),\n",
       " (81, 3856),\n",
       " (33, 3820),\n",
       " (16, 3566),\n",
       " (105, 3440),\n",
       " (5, 3415),\n",
       " (61, 3353),\n",
       " (102, 3335),\n",
       " (88, 3331),\n",
       " (103, 3272),\n",
       " (65, 3268),\n",
       " (36, 3265),\n",
       " (40, 3210),\n",
       " (0, 3034),\n",
       " (87, 3030),\n",
       " (29, 2973),\n",
       " (10, 2969),\n",
       " (99, 2950),\n",
       " (52, 2805),\n",
       " (71, 2766),\n",
       " (3, 2760),\n",
       " (98, 2611),\n",
       " (34, 2554),\n",
       " (38, 2470),\n",
       " (7, 2429),\n",
       " (15, 2429),\n",
       " (67, 2284),\n",
       " (14, 2249),\n",
       " (30, 2237),\n",
       " (69, 2204),\n",
       " (51, 2191),\n",
       " (44, 2177),\n",
       " (20, 2127),\n",
       " (64, 2038),\n",
       " (11, 2025),\n",
       " (56, 1991),\n",
       " (83, 1991),\n",
       " (95, 1885),\n",
       " (101, 1878),\n",
       " (78, 1859),\n",
       " (39, 1794),\n",
       " (17, 1766),\n",
       " (37, 1661),\n",
       " (73, 1643),\n",
       " (85, 1642),\n",
       " (76, 1632),\n",
       " (62, 1627),\n",
       " (35, 1623),\n",
       " (59, 1600),\n",
       " (26, 1592),\n",
       " (106, 1587),\n",
       " (70, 1568),\n",
       " (100, 1566),\n",
       " (108, 1563),\n",
       " (23, 1498),\n",
       " (25, 1415),\n",
       " (9, 1400),\n",
       " (96, 1293),\n",
       " (107, 1291),\n",
       " (90, 1252),\n",
       " (21, 1230),\n",
       " (89, 1227),\n",
       " (97, 1219),\n",
       " (104, 1213),\n",
       " (92, 1169),\n",
       " (54, 1156),\n",
       " (47, 1120),\n",
       " (75, 1116),\n",
       " (45, 1113),\n",
       " (58, 1112),\n",
       " (63, 1099),\n",
       " (86, 1081),\n",
       " (18, 1045),\n",
       " (12, 1029),\n",
       " (91, 1027),\n",
       " (6, 1026),\n",
       " (19, 1015)]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "email_cts.most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "d0c06a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of examples PER CLASS to use for training model\n",
    "n_train = 500 # For training model\n",
    "n_val = 50 # For computing validation accuracy\n",
    "\n",
    "X = vecs\n",
    "y = labels\n",
    "\n",
    "# X_1, y_1, X_2, y_2 = split_X_and_y(X, y, n_train + n_val, num_classes, seed=0)\n",
    "# X_train, y_train, X_val, y_val = split_X_and_y(X_1, y_1, n_train, num_classes, seed=0)\n",
    "\n",
    "# Validate on all data not used to train\n",
    "X_train, y_train, X_val, y_val = split_X_and_y(X, y, n_train, num_classes, seed=0)\n",
    "X_2 = X_val\n",
    "y_2 = y_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "75e3ba71",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LogisticRegression Accuracy: 0.376\n"
     ]
    }
   ],
   "source": [
    "lr = LogisticRegression(max_iter=300)\n",
    "\n",
    "# Fit\n",
    "lr.fit(X_train, y_train)\n",
    "\n",
    "# Validate \n",
    "y_predict = lr.predict(X_val)\n",
    "print(f\"LogisticRegression Accuracy: {metrics.accuracy_score(y_val, y_predict):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bf57274",
   "metadata": {},
   "source": [
    "## BoW\n",
    "\n",
    "`vocab_size=10000`\n",
    "\n",
    "141 senders\n",
    "\n",
    "* With 100 train and 100 val, accuracy is 13.7%\n",
    "* With 250 train and 100 val, accuracy is 17.9%\n",
    "\n",
    "108 senders\n",
    "\n",
    "* With 500 train and 100 val, accuracy is 24.4%\n",
    "* With 800 train and 50 val, accuracy is 28.1%\n",
    "\n",
    "`vocab_size=20000` (increasing vocab size doesn't seem to help)\n",
    "\n",
    "108 senders\n",
    "\n",
    "* With 500 train and 100 val, accuracy is 24.2%\n",
    "* With 500 train and 50 val, accuracy is 26.2% (noise)\n",
    "\n",
    "## GloVE\n",
    "\n",
    "108 senders\n",
    "\n",
    "* With 800 train and 50 val, accuracy is 39.7%\n",
    "\n",
    "## BERT\n",
    "\n",
    "* [No extra features] With 800 train and 50 val, accuracy is 37.1%\n",
    "* [No extra features] With 500 train and 500 val, accuracy is 37.6%\n",
    "* [Extra features] With 800 train and 50 val, accuracy is 37.2%\n",
    "\n",
    "\n",
    "Conclusion: the extra features are not useful"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85482084",
   "metadata": {},
   "source": [
    "## 2) Get softmax scores\n",
    "\n",
    "These softmax scores are interesting because the signal is weak, so the probabilities are pretty uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f62a8ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "probs = lr.predict_proba(X_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f9ebc0ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(438555, 109)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "probs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "937e28cb",
   "metadata": {},
   "source": [
    "#### Save conformal scores and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f7e7e0bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved softmax scores and labels to /home/ANONYMIZED/code/class-conditional-conformal-datasets/notebooks/.cache folder\n"
     ]
    }
   ],
   "source": [
    "save_folder = \".cache\"\n",
    "\n",
    "if not os.path.exists(save_folder):\n",
    "    os.makedirs(save_folder)\n",
    "    \n",
    "np.save(os.path.join(save_folder, f'email_softmax_{vectorization_method}_ntrain={n_train}.npy'), probs)\n",
    "np.save(os.path.join(save_folder, f'email_labels_{vectorization_method}_ntrain={n_train}.npy'), y_2)\n",
    "\n",
    "print(f'Saved softmax scores and labels to {os.getcwd()}/{save_folder} folder')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5d2a24c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'email_softmax_bert_ntrain=500.npy'"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f'email_softmax_{vectorization_method}_ntrain={n_train}.npy'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "ff9c64c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(438555,)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b6a037d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
