{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import defaultdict\n",
    "import argparse\n",
    "import csv\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from minicons import cwe\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from whic_utils import load_whic, pairwise_context, pairwise_direction\n",
    "from whic_torch import WHiCProbe\n",
    "from approximation_model import NonLinearApproximator\n",
    "\n",
    "from sklearn.metrics import f1_score, classification_report\n",
    "\n",
    "from paths import auth1_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_probe = WHiCProbe.load_from_checkpoint(f'{auth1_path}/makesense_logs/whic/whic_probe_original_12.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_probe.eval()\n",
    "for param in trained_probe.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "LAYER = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def directional(probe = trained_probe, layer = LAYER, approximation_mode = 'original'):\n",
    "    test_directional = pairwise_direction('test')\n",
    "\n",
    "    directional = DataLoader(test_directional[0], num_workers = 4, batch_size = 100)\n",
    "    neg_directional = DataLoader(test_directional[1], num_workers = 4, batch_size = 100)\n",
    "\n",
    "    positive = []\n",
    "    for d in directional:\n",
    "        inputs, labels = probe._build_batch(d, approximation_mode=approximation_mode)\n",
    "        predicted = (probe(inputs).squeeze().sigmoid() >= 0.5).int().tolist()\n",
    "        positive.extend(predicted)\n",
    "        \n",
    "    negative = []\n",
    "    for d in neg_directional:\n",
    "        inputs, labels = probe._build_batch(d, approximation_mode=approximation_mode)\n",
    "        predicted = (probe(inputs).squeeze().sigmoid() >= 0.5).int().tolist()\n",
    "        negative.extend(predicted)\n",
    "\n",
    "    return positive, negative"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WHiCProbe(\n",
       "  (encoder): Sequential(\n",
       "    (0): Linear(in_features=1536, out_features=256, bias=True)\n",
       "    (1): ReLU()\n",
       "  )\n",
       "  (decoder): Linear(in_features=256, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trained_probe.to('cpu')\n",
    "trained_probe.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "directional_p, directional_n = directional(trained_probe, LAYER, 'original')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6485)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(torch.tensor(directional_p) == torch.tensor(directional_n)).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3396674692630768"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((torch.tensor(directional_p) == 1) * (torch.tensor(directional_n) == 0)).float().mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3396674692630768"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "directional_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trained_probe.layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# p_te, n_te = pairwise_direction('test')\n",
    "# test = p_te + n_te\n",
    "test = load_whic('test')\n",
    "test_dl = DataLoader(test, batch_size = 128, num_workers = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(gpus='1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04e6d17189b4469aa64599e834e3687f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'test_f1': 0.7551692724227905, 'test_loss': 0.5825377106666565}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'test_loss': 0.5825377106666565, 'test_f1': 0.7551692724227905}]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.test(trained_probe, test_dl, ckpt_path='best')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}