{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import numpy as np\n",
    "import dataloader\n",
    "from train_classifier import Model\n",
    "import criteria\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "\n",
    "import pickle\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader, SequentialSampler, TensorDataset\n",
    "\n",
    "from BERT.tokenization import BertTokenizer\n",
    "from BERT.modeling import BertForSequenceClassification, BertConfig\n",
    "\n",
    "\n",
    "class USE(object):\n",
    "    def __init__(self, cache_path):\n",
    "        super(USE, self).__init__()\n",
    "        os.environ['TFHUB_CACHE_DIR'] = cache_path\n",
    "        module_url = \"https://tfhub.dev/google/universal-sentence-encoder-large/3\"\n",
    "        self.embed = hub.Module(module_url)\n",
    "        config = tf.ConfigProto()\n",
    "        config.gpu_options.allow_growth = True\n",
    "        self.sess = tf.Session(config=config)\n",
    "        self.build_graph()\n",
    "        self.sess.run([tf.global_variables_initializer(), tf.tables_initializer()])\n",
    "\n",
    "    def build_graph(self):\n",
    "        self.sts_input1 = tf.placeholder(tf.string, shape=(None))\n",
    "        self.sts_input2 = tf.placeholder(tf.string, shape=(None))\n",
    "\n",
    "        sts_encode1 = tf.nn.l2_normalize(self.embed(self.sts_input1), axis=1)\n",
    "        sts_encode2 = tf.nn.l2_normalize(self.embed(self.sts_input2), axis=1)\n",
    "        self.cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)\n",
    "        clip_cosine_similarities = tf.clip_by_value(self.cosine_similarities, -1.0, 1.0)\n",
    "        self.sim_scores = 1.0 - tf.acos(clip_cosine_similarities)\n",
    "\n",
    "    def semantic_sim(self, sents1, sents2):\n",
    "        scores = self.sess.run(\n",
    "            [self.sim_scores],\n",
    "            feed_dict={\n",
    "                self.sts_input1: sents1,\n",
    "                self.sts_input2: sents2,\n",
    "            })\n",
    "        return scores\n",
    "\n",
    "def pick_most_similar_words_batch(src_words, sim_mat, idx2word, ret_count=10, threshold=0.):\n",
    "    \"\"\"\n",
    "    embeddings is a matrix with (d, vocab_size)\n",
    "    \"\"\"\n",
    "    sim_order = np.argsort(-sim_mat[src_words, :])[:, 1:1 + ret_count]\n",
    "    sim_words, sim_values = [], []\n",
    "    for idx, src_word in enumerate(src_words):\n",
    "        sim_value = sim_mat[src_word][sim_order[idx]]\n",
    "        mask = sim_value >= threshold\n",
    "        sim_word, sim_value = sim_order[idx][mask], sim_value[mask]\n",
    "        sim_word = [idx2word[id] for id in sim_word]\n",
    "        sim_words.append(sim_word)\n",
    "        sim_values.append(sim_value)\n",
    "    return sim_words, sim_values\n",
    "\n",
    "\n",
    "class NLI_infer_BERT(nn.Module):\n",
    "    def __init__(self,\n",
    "                 pretrained_dir,\n",
    "                 nclasses,\n",
    "                 max_seq_length=128,\n",
    "                 batch_size=32):\n",
    "        super(NLI_infer_BERT, self).__init__()\n",
    "        self.model = BertForSequenceClassification.from_pretrained(pretrained_dir, num_labels=nclasses).cuda()\n",
    "        # construct dataset loader\n",
    "        self.dataset = NLIDataset_BERT(pretrained_dir, max_seq_length=max_seq_length, batch_size=batch_size)\n",
    "\n",
    "    def text_pred(self, text_data, batch_size=32):\n",
    "\n",
    "        # transform text data into indices and create batches\n",
    "        dataloader = self.dataset.transform_text(text_data, batch_size=batch_size)\n",
    "        embs = []\n",
    "        #         for input_ids, input_mask, segment_ids in tqdm(dataloader, desc=\"Evaluating\"):\n",
    "        for input_ids, input_mask, segment_ids in dataloader:\n",
    "            input_ids = input_ids.cuda()\n",
    "            input_mask = input_mask.cuda()\n",
    "            segment_ids = segment_ids.cuda()\n",
    "            embs.append(self.model.bert.embeddings(input_ids,torch.zeros_like(input_ids)))\n",
    "        return embs\n",
    "\n",
    "\n",
    "class InputFeatures(object):\n",
    "    \"\"\"A single set of features of data.\"\"\"\n",
    "\n",
    "    def __init__(self, input_ids, input_mask, segment_ids):\n",
    "        self.input_ids = input_ids\n",
    "        self.input_mask = input_mask\n",
    "        self.segment_ids = segment_ids\n",
    "\n",
    "\n",
    "class NLIDataset_BERT(Dataset):\n",
    "    \"\"\"\n",
    "    Dataset class for Natural Language Inference datasets.\n",
    "\n",
    "    The class can be used to read preprocessed datasets where the premises,\n",
    "    hypotheses and labels have been transformed to unique integer indices\n",
    "    (this can be done with the 'preprocess_data' script in the 'scripts'\n",
    "    folder of this repository).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 pretrained_dir,\n",
    "                 max_seq_length=128,\n",
    "                 batch_size=32):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            data: A dictionary containing the preprocessed premises,\n",
    "                hypotheses and labels of some dataset.\n",
    "            padding_idx: An integer indicating the index being used for the\n",
    "                padding token in the preprocessed data. Defaults to 0.\n",
    "            max_premise_length: An integer indicating the maximum length\n",
    "                accepted for the sequences in the premises. If set to None,\n",
    "                the length of the longest premise in 'data' is used.\n",
    "                Defaults to None.\n",
    "            max_hypothesis_length: An integer indicating the maximum length\n",
    "                accepted for the sequences in the hypotheses. If set to None,\n",
    "                the length of the longest hypothesis in 'data' is used.\n",
    "                Defaults to None.\n",
    "        \"\"\"\n",
    "        self.tokenizer = BertTokenizer.from_pretrained(pretrained_dir, do_lower_case=True)\n",
    "        self.max_seq_length = max_seq_length\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def convert_examples_to_features(self, examples, max_seq_length, tokenizer):\n",
    "        \"\"\"Loads a data file into a list of `InputBatch`s.\"\"\"\n",
    "\n",
    "        features = []\n",
    "        for (ex_index, text_a) in enumerate(examples):\n",
    "            tokens_a = tokenizer.tokenize(' '.join(text_a))\n",
    "\n",
    "            # Account for [CLS] and [SEP] with \"- 2\"\n",
    "            if len(tokens_a) > max_seq_length - 2:\n",
    "                tokens_a = tokens_a[:(max_seq_length - 2)]\n",
    "\n",
    "            tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
    "            segment_ids = [0] * len(tokens)\n",
    "\n",
    "            input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "            # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
    "            # tokens are attended to.\n",
    "            input_mask = [1] * len(input_ids)\n",
    "\n",
    "            # Zero-pad up to the sequence length.\n",
    "            padding = [0] * (max_seq_length - len(input_ids))\n",
    "            input_ids += padding\n",
    "            input_mask += padding\n",
    "            segment_ids += padding\n",
    "\n",
    "            assert len(input_ids) == max_seq_length\n",
    "            assert len(input_mask) == max_seq_length\n",
    "            assert len(segment_ids) == max_seq_length\n",
    "\n",
    "            features.append(\n",
    "                InputFeatures(input_ids=input_ids,\n",
    "                              input_mask=input_mask,\n",
    "                              segment_ids=segment_ids))\n",
    "        return features\n",
    "\n",
    "    def transform_text(self, data, batch_size=32):\n",
    "        # transform data into seq of embeddings\n",
    "        eval_features = self.convert_examples_to_features(data,\n",
    "                                                          self.max_seq_length, self.tokenizer)\n",
    "\n",
    "        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)\n",
    "        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)\n",
    "        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)\n",
    "        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)\n",
    "\n",
    "        # Run prediction for full data\n",
    "        eval_sampler = SequentialSampler(eval_data)\n",
    "        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size)\n",
    "\n",
    "        return eval_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "source = \"ag\"\n",
    "texts, labels = dataloader.read_corpus(\"adv_data/imdb_\" + source)\n",
    "data = list(zip(texts, labels))\n",
    "model = NLI_infer_BERT(\"checkpoints/imdb\", nclasses = 2, max_seq_length=256)\n",
    "embed = model.text_pred\n",
    "f1_emb = []\n",
    "for idx, (text, true_label) in enumerate(data):\n",
    "    f1_emb.append(embed([text])[0].data.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts, labels = dataloader.read_corpus(\"adv_data/imdb_imdb\")\n",
    "data = list(zip(texts, labels))\n",
    "model = NLI_infer_BERT(\"checkpoints/imdb\", nclasses = 2, max_seq_length=256)\n",
    "embed = model.text_pred\n",
    "f2_emb = []\n",
    "for idx, (text, true_label) in enumerate(data):\n",
    "    f2_emb.append(embed([text])[0].data.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts, labels = dataloader.read_corpus(\"adv_data/imdb\")\n",
    "data = list(zip(texts, labels))\n",
    "model = NLI_infer_BERT(\"checkpoints/imdb\", nclasses = 2, max_seq_length=256)\n",
    "embed = model.text_pred\n",
    "emb = []\n",
    "for idx, (text, true_label) in enumerate(data):\n",
    "    emb.append(embed([text])[0].data.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 256, 768)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4682010492243438\n"
     ]
    }
   ],
   "source": [
    "tau1 = 0\n",
    "for i in range(len(f1_emb)):\n",
    "    f1_x = f1_emb[i][0]\n",
    "    f2_x = f2_emb[i][0]\n",
    "    x = emb[i][0][0]\n",
    "    delta1 = (f1_x - x).flatten()\n",
    "    delta2 = (f2_x - x).flatten()\n",
    "    tau = (np.dot(delta1, delta2))**2/ (np.linalg.norm(delta1) * np.linalg.norm(delta2))**2\n",
    "    tau1+= tau\n",
    "print(tau1/len(f1_emb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
