{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "from tqdm.notebook import tqdm\n",
    "from torch.utils.data import Dataset, DataLoader"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# Setting seed for reproducibility\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# Setting device \n",
    "if torch.cuda.is_available():  # for nvidia GPUs etc.\n",
    "    device = torch.device('cuda')\n",
    "elif torch.backends.mps.is_available(): # for Apple Metal Performance Sharder (mps) GPUs\n",
    "    device = torch.device('mps')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "\n",
    "device"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4d81b2911604147b",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Prepare data"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "37f98cd1491b20d"
  },
  {
   "cell_type": "code",
   "source": [
    "# Melting into cue-resp df\n",
    "swow = pd.read_csv('../../data/embeds_train/SWOW-EN.R100.csv', usecols=['cue', 'R1', 'R2', 'R3'])\n",
    "swow = (\n",
    "    swow.melt(id_vars='cue', value_vars=['R1', 'R2', 'R3'], value_name='resp')\n",
    "    .drop(columns=['variable']).dropna(axis=0).astype(str)\n",
    "    .sample(frac=1, random_state=42).reset_index(drop=True)\n",
    ")\n",
    "\n",
    "swow"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "54324c33162d938c",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# Dropping resps with <5 occurrences\n",
    "print(len(swow.resp.unique()))\n",
    "resp_counts = swow.resp.value_counts().to_dict()\n",
    "swow = swow[swow.resp.map(lambda x: resp_counts[x] >= 5)]\n",
    "print(len(swow.resp.unique()))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "491baadefd38a6ca",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "class SWOWDat(Dataset):\n",
    "\n",
    "    def __init__(self, swow):\n",
    "\n",
    "        # Converting words to indices \n",
    "        self.cue_idxs = {cue: idx for idx, cue in enumerate(swow['cue'].unique())}\n",
    "        self.resp_idxs = {resp: idx for idx, resp in enumerate(swow['resp'].unique())}\n",
    "        self.n_cues, self.n_resps = len(self.cue_idxs), len(self.resp_idxs)\n",
    "        swow['cue'] = swow['cue'].map(self.cue_idxs)\n",
    "        swow['resp'] = swow['resp'].map(self.resp_idxs)\n",
    "\n",
    "        self.x = torch.tensor(swow['cue'].to_numpy())\n",
    "        self.y = torch.tensor(swow['resp'].to_numpy())\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.x[idx], self.y[idx] # CrossEntropyLoss is more efficient with target with class indices\n",
    "\n",
    "\n",
    "n_resps = len(swow['resp'].unique())\n",
    "swow_dat = SWOWDat(swow)\n",
    "swow_dataloader = DataLoader(swow_dat, batch_size=64, shuffle=True)"
   ],
   "metadata": {
    "collapsed": false,
    "is_executing": true
   },
   "id": "da84f5345135ebb8",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Training"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "872b40b218dfb29a"
  },
  {
   "cell_type": "code",
   "source": [
    "class Word2Vec(torch.nn.Module):\n",
    "    def __init__(self, n_cues, n_resps, n_dims):\n",
    "        super(Word2Vec, self).__init__()\n",
    "        self.cue_embeds = torch.nn.Embedding(n_cues, n_dims)\n",
    "        self.resp_embeds = torch.nn.Linear(n_dims, n_resps, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        cue_embed = self.cue_embeds(x)\n",
    "        logits = self.resp_embeds(cue_embed)\n",
    "        return logits\n",
    "\n",
    "w2v = Word2Vec(swow_dat.n_cues, swow_dat.n_resps, 300).to(device)\n",
    "print(w2v)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "efbe0a2969f81fda",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.0)\n",
    "optimizer = torch.optim.Adam(w2v.parameters())\n",
    "\n",
    "def train_loop(dataloader, model, loss_fn, optimizer):\n",
    "  size = len(dataloader.dataset)\n",
    "\n",
    "  for batch_idx, (X, y) in tqdm(enumerate(dataloader), total=len(dataloader)):\n",
    "    \n",
    "    # Compute prediction and loss\n",
    "    X, y = X.to(device), y.to(device)\n",
    "    pred = model(X)\n",
    "    loss = loss_fn(pred, y)\n",
    "\n",
    "    # Backpropogation \n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if batch_idx % 10000 == 0:\n",
    "        loss, current = loss.item(), batch_idx * len(X)\n",
    "        print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")\n",
    "\n",
    "epochs = 10\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train_loop(swow_dataloader, w2v, loss_fn, optimizer)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f63871b8f06c4e39",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# extract input embeddings\n",
    "input_embeds = w2v.cue_embeds.weight.cpu().detach().numpy()\n",
    "input_embeds = pd.DataFrame(input_embeds, index=swow_dat.cue_idxs.keys())\n",
    "\n",
    "# extract output embeddings\n",
    "output_embeds = w2v.resp_embeds.weight.cpu().detach().numpy()\n",
    "output_embeds = pd.DataFrame(output_embeds, index=swow_dat.resp_idxs.keys())\n",
    "\n",
    "# Subsetting to only the words in psychNorms norms\n",
    "to_pull = set(\n",
    "    pd.read_csv('../../data/psychNorms/psychNorms.zip', index_col=0, low_memory=False, compression='zip').index\n",
    ")\n",
    "input_embeds = input_embeds.loc[input_embeds.index.isin(to_pull)].astype(float)\n",
    "output_embeds = output_embeds.loc[output_embeds.index.isin(to_pull)].astype(float)\n",
    "\n",
    "# Saving the embeddings\n",
    "input_embeds.to_csv('../../data/embeds/SGSoftMaxInput_SWOW.csv')\n",
    "output_embeds.to_csv('../../data/embeds/SGSoftMaxOutput_SWOW.csv')"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e58162670115092d",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
