{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('Optimus/code/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "import argparse\n",
    "import glob\n",
    "import logging\n",
    "import os\n",
    "import pickle\n",
    "import random\n",
    "\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from tqdm import tqdm, trange\n",
    "\n",
    "\n",
    "from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig\n",
    "from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector\n",
    "from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer\n",
    "from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer\n",
    "from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer\n",
    "from pytorch_transformers import BertForLatentConnector, BertTokenizer\n",
    "\n",
    "from collections import defaultdict\n",
    "from examples.big_ae.modules import VAE\n",
    "from examples.big_ae.utils import (TextDataset_Split, TextDataset_2Tokenizers, BucketingDataLoader)\n",
    "\n",
    "\n",
    "import pdb\n",
    "\n",
    "\n",
    "logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "                    datefmt = '%m/%d/%Y %H:%M:%S',\n",
    "                    level = logging.INFO)\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop\n",
    "\n",
    "ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())\n",
    "\n",
    "MODEL_CLASSES = {\n",
    "    'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),\n",
    "    'bert': (BertConfig, BertForLatentConnector, BertTokenizer)\n",
    "}\n",
    "\n",
    "\n",
    "def set_seed(args):\n",
    "    np.random.seed(args.seed)\n",
    "    torch.manual_seed(args.seed)\n",
    "    if args.n_gpu > 0:\n",
    "        torch.cuda.manual_seed_all(args.seed)\n",
    "\n",
    "\n",
    "def load_and_cache_examples(args, tokenizer, evaluate=False):\n",
    "    if isinstance(tokenizer, list):\n",
    "        dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)\n",
    "    else:\n",
    "        dataset = TextDataset_Split(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)\n",
    "    return dataset\n",
    "\n",
    "def build_dataload_and_cache_examples(args, tokenizer, evaluate=False):\n",
    "    if isinstance(tokenizer, list):\n",
    "        if not evaluate:\n",
    "            args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n",
    "            file_path=args.train_data_file\n",
    "        else:\n",
    "            args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)  \n",
    "            file_path=args.eval_data_file\n",
    "        dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=False)\n",
    "    else:\n",
    "        pass \n",
    "    return dataloader\n",
    "\n",
    "\n",
    "def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):\n",
    "    \"\"\" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering\n",
    "        Args:\n",
    "            logits: logits distribution shape (vocabulary size)\n",
    "            top_k > 0: keep only top k tokens with highest probability (top-k filtering).\n",
    "            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).\n",
    "                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)\n",
    "        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317\n",
    "    \"\"\"\n",
    "    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear\n",
    "    top_k = min(top_k, logits.size(-1))  # Safety check\n",
    "    if top_k > 0:\n",
    "        # Remove all tokens with a probability less than the last token of the top-k\n",
    "        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]\n",
    "        logits[indices_to_remove] = filter_value\n",
    "\n",
    "    if top_p > 0.0:\n",
    "        sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n",
    "        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n",
    "\n",
    "        # Remove tokens with cumulative probability above the threshold\n",
    "        sorted_indices_to_remove = cumulative_probs > top_p\n",
    "        # Shift the indices to the right to keep also the first token above the threshold\n",
    "        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n",
    "        sorted_indices_to_remove[..., 0] = 0\n",
    "\n",
    "        indices_to_remove = sorted_indices[sorted_indices_to_remove]\n",
    "        logits[indices_to_remove] = filter_value\n",
    "    return logits\n",
    "\n",
    "\n",
    "def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):\n",
    "    \n",
    "    context = torch.tensor(context, dtype=torch.long, device=device)\n",
    "    context = context.unsqueeze(0).repeat(num_samples, 1)\n",
    "    generated = context\n",
    "    with torch.no_grad():\n",
    "        while True:\n",
    "        # for _ in trange(length):\n",
    "            inputs = {'input_ids': generated, 'past': past}\n",
    "#             inputs = {'input_ids': generated}\n",
    "            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)\n",
    "            next_token_logits = outputs[0][0, -1, :] / temperature\n",
    "            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)\n",
    "            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)\n",
    "            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)\n",
    "\n",
    "            # pdb.set_trace()\n",
    "            if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.eos_token_id or generated.size(1)>length:\n",
    "                break\n",
    "\n",
    "    return generated\n",
    "\n",
    "\n",
    "def latent_code_from_text(text, tokenizer_encoder, model_vae, args):\n",
    "    tokenized1 = tokenizer_encoder.encode(text)\n",
    "    tokenized1 = [101] + tokenized1 + [102]\n",
    "    coded1 = torch.Tensor([tokenized1])\n",
    "    coded1 =torch.Tensor.long(coded1)\n",
    "    with torch.no_grad():\n",
    "        x0 = coded1\n",
    "        x0 = x0.to(args.device)\n",
    "        pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]\n",
    "        mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)\n",
    "        latent_z = mean.squeeze(1)  \n",
    "        coded_length = len(tokenized1)\n",
    "        return latent_z, coded_length\n",
    "\n",
    "def text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder):\n",
    "    past = latent_z\n",
    "    context_tokens = tokenizer_decoder.encode('<BOS>')\n",
    "\n",
    "    length = 128 # maximum length, but not used \n",
    "    out = sample_sequence_conditional(\n",
    "        model=model_vae.decoder,\n",
    "        context=context_tokens,\n",
    "        past=past,\n",
    "        length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence\n",
    "        temperature=args.temperature,\n",
    "        top_k=args.top_k,\n",
    "        top_p=args.top_p,\n",
    "        device=args.device,\n",
    "        decoder_tokenizer = tokenizer_decoder\n",
    "    )\n",
    "    text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)\n",
    "    text_x1 = text_x1.split()[1:-1]\n",
    "    text_x1 = ' '.join(text_x1)\n",
    "    return text_x1\n",
    "\n",
    "\n",
    "def interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args):\n",
    "    # and then in the main function         \n",
    "    latent_z1, coded_length1 = latent_code_from_text(args.sent_source, tokenizer_encoder, model_vae, args)\n",
    "    latent_z2, coded_length2 = latent_code_from_text(args.sent_target, tokenizer_encoder, model_vae, args)\n",
    "\n",
    "    result = defaultdict(str)\n",
    "\n",
    "    num_steps = args.num_interpolation_steps + 1\n",
    "    for step in range(num_steps+1):\n",
    "        latent_z = latent_z1 + (latent_z2 - latent_z1) * step * 1.0/num_steps\n",
    "        \n",
    "        text_interpolate = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)\n",
    "        result[step] = text_interpolate\n",
    "        print(text_interpolate)\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "def analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args):\n",
    "        \n",
    "    latent_z1, coded_length1 = latent_code_from_text(args.sent_source, tokenizer_encoder, model_vae, args)\n",
    "    latent_z2, coded_length2 = latent_code_from_text(args.sent_target, tokenizer_encoder, model_vae, args)\n",
    "    latent_z3, coded_length3 = latent_code_from_text(args.sent_input, tokenizer_encoder, model_vae, args)\n",
    "    \n",
    "    result = defaultdict(str)\n",
    "\n",
    "    latent_z = latent_z3 + args.degree_to_target * (latent_z2 - latent_z1) \n",
    "    \n",
    "    text_analogy = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)\n",
    "    result[0] = text_analogy\n",
    "    print(text_analogy)\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "\n",
    "parser.add_argument(\"--train_data_file\", default=None, type=str, required=True,\n",
    "                    help=\"The input training data file (a text file).\")\n",
    "parser.add_argument(\"--eval_data_file\", default=None, type=str,\n",
    "                    help=\"An input evaluation data file to evaluate the perplexity on (a text file).\")\n",
    "parser.add_argument(\"--checkpoint_dir\", default=None, type=str, required=True,\n",
    "                    help=\"The directory where checkpoints are saved.\")\n",
    "parser.add_argument(\"--output_dir\", default=None, type=str, required=True,\n",
    "                    help=\"The output directory where the model predictions and checkpoints will be written.\")\n",
    "parser.add_argument(\"--dataset\", default='Snli', type=str, help=\"The dataset.\")\n",
    "\n",
    "## Variational auto-encoder\n",
    "parser.add_argument(\"--latent_size\", default=32, type=int, help=\"Latent space dimension.\")\n",
    "parser.add_argument(\"--total_sents\", default=10, type=int, help=\"Total sentences to test recontruction.\")\n",
    "parser.add_argument(\"--num_interpolation_steps\", default=10, type=int, help=\"Total sentences to test recontruction.\")\n",
    "parser.add_argument(\"--play_mode\", default=\"interpolation\", type=str,\n",
    "                    help=\"interpolation or reconstruction.\")\n",
    "\n",
    "\n",
    "## Encoder options\n",
    "parser.add_argument(\"--encoder_model_type\", default=\"bert\", type=str,\n",
    "                    help=\"The encoder model architecture to be fine-tuned.\")\n",
    "parser.add_argument(\"--encoder_model_name_or_path\", default=\"bert-base-cased\", type=str,\n",
    "                    help=\"The encoder model checkpoint for weights initialization.\")\n",
    "parser.add_argument(\"--encoder_config_name\", default=\"\", type=str,\n",
    "                    help=\"Optional pretrained config name or path if not the same as model_name_or_path\")\n",
    "parser.add_argument(\"--encoder_tokenizer_name\", default=\"\", type=str,\n",
    "                    help=\"Optional pretrained tokenizer name or path if not the same as model_name_or_path\")\n",
    "\n",
    "## Decoder options\n",
    "parser.add_argument(\"--decoder_model_type\", default=\"gpt2\", type=str,\n",
    "                    help=\"The decoder model architecture to be fine-tuned.\")\n",
    "parser.add_argument(\"--decoder_model_name_or_path\", default=\"bert-base-cased\", type=str,\n",
    "                    help=\"The decoder model checkpoint for weights initialization.\")\n",
    "parser.add_argument(\"--decoder_config_name\", default=\"\", type=str,\n",
    "                    help=\"Optional pretrained config name or path if not the same as model_name_or_path\")\n",
    "parser.add_argument(\"--decoder_tokenizer_name\", default=\"\", type=str,\n",
    "                    help=\"Optional pretrained tokenizer name or path if not the same as model_name_or_path\")\n",
    "\n",
    "\n",
    "parser.add_argument(\"--per_gpu_train_batch_size\", default=1, type=int,\n",
    "                    help=\"Batch size per GPU/CPU for training.\")\n",
    "parser.add_argument(\"--per_gpu_eval_batch_size\", default=1, type=int,\n",
    "                    help=\"Batch size per GPU/CPU for evaluation.\")\n",
    "parser.add_argument('--gloabl_step_eval', type=int, default=661,\n",
    "                    help=\"Evaluate the results at the given global step\")\n",
    "\n",
    "parser.add_argument(\"--max_seq_length\", default=512, type=int,\n",
    "                    help=\"Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length\")\n",
    "\n",
    "# Interact with users\n",
    "parser.add_argument(\"--interact_with_user_input\", action='store_true', help=\"Use user input to interact_with.\")\n",
    "parser.add_argument(\"--sent_source\", type=str, default=\"\")\n",
    "parser.add_argument(\"--sent_target\", type=str, default=\"\")\n",
    "parser.add_argument(\"--sent_input\", type=str, default=\"\")\n",
    "parser.add_argument(\"--degree_to_target\", type=float, default=\"1.0\")\n",
    "\n",
    "## Variational auto-encoder\n",
    "parser.add_argument(\"--nz\", default=32, type=int,\n",
    "                    help=\"Latent space dimension.\")\n",
    "\n",
    "parser.add_argument(\"--prompt\", type=str, default=\"\")\n",
    "parser.add_argument(\"--padding_text\", type=str, default=\"\")\n",
    "parser.add_argument(\"--length\", type=int, default=20)\n",
    "parser.add_argument(\"--temperature\", type=float, default=1.0)\n",
    "parser.add_argument(\"--top_k\", type=int, default=0)\n",
    "parser.add_argument(\"--top_p\", type=float, default=1.0)\n",
    "parser.add_argument(\"--no_cuda\", action='store_true',\n",
    "                    help=\"Avoid using CUDA when available\")\n",
    "parser.add_argument('--seed', type=int, default=42,\n",
    "                    help=\"random seed for initialization\")\n",
    "\n",
    "parser.add_argument(\"--block_size\", default=-1, type=int,\n",
    "                    help=\"Optional input sequence length after tokenization.\"\n",
    "                         \"The training dataset will be truncated in block of this size for training.\"\n",
    "                         \"Default to the model max input length for single sentence inputs (take into account special tokens).\")\n",
    "parser.add_argument(\"--do_lower_case\", action='store_true',\n",
    "                    help=\"Set this flag if you are using an uncased model.\")\n",
    "\n",
    "parser.add_argument(\"--use_philly\", action='store_true',\n",
    "                    help=\"Use Philly for computing.\")\n",
    "\n",
    "args = parser.parse_args(args=['--dataset', 'Debug',\n",
    "                         '--checkpoint_dir', 'output/LM/wikipedia/beta0.5/checkpoint-508523/',\n",
    "                            '--output_dir','.',\n",
    "                               '--encoder_model_type','bert',\n",
    "                               '--encoder_model_name_or_path','bert-base-cased',\n",
    "                               '--decoder_model_type','gpt2',\n",
    "                               '--decoder_model_name_or_path','gpt2',\n",
    "                               '--train_data_file','data/datasets/debug_data/train.txt',\n",
    "                               '--eval_data_file','data/datasets/debug_data/valid.txt',\n",
    "                               '--per_gpu_eval_batch_size','1',\n",
    "                               '--gloabl_step_eval','508523',\n",
    "                               '--block_size','100',\n",
    "                               '--max_seq_length','100',\n",
    "                               '--latent_size','32',\n",
    "                               '--interact_with_user_input',\n",
    "                               '--play_mode','interpolation',\n",
    "                               '--sent_source','This is an obviously long sentence . ',\n",
    "                               '--sent_target','test',\n",
    "                               '--num_interpolation_steps','10'])\n",
    "\n",
    "args.device = torch.device(\"cuda\" if torch.cuda.is_available() and not args.no_cuda else \"cpu\")\n",
    "args.n_gpu = torch.cuda.device_count()\n",
    "\n",
    "set_seed(args)\n",
    "\n",
    "\n",
    "args.encoder_model_type = args.encoder_model_type.lower()\n",
    "args.decoder_model_type = args.decoder_model_type.lower()\n",
    "\n",
    "\n",
    "global_step = args.gloabl_step_eval\n",
    "\n",
    "output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))\n",
    "output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step)) \n",
    "checkpoints = [ [output_encoder_dir, output_decoder_dir] ]\n",
    "logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
    "\n",
    "# Load a trained Encoder model and vocabulary that you have fine-tuned\n",
    "encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]\n",
    "model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)\n",
    "tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)\n",
    "\n",
    "model_encoder.to(args.device)\n",
    "if args.block_size <= 0:\n",
    "    args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model\n",
    "args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)\n",
    "\n",
    "# Load a trained Decoder model and vocabulary that you have fine-tuned\n",
    "decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]\n",
    "model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)\n",
    "tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)\n",
    "model_decoder.to(args.device)\n",
    "if args.block_size <= 0:\n",
    "    args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model\n",
    "args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)\n",
    "\n",
    "# Load full model\n",
    "output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step)) \n",
    "checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))\n",
    "\n",
    "# Chunyuan: Add Padding token to GPT2\n",
    "special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}\n",
    "num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)\n",
    "print('We have added', num_added_toks, 'tokens to GPT2')\n",
    "model_decoder.resize_token_embeddings(len(tokenizer_decoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.\n",
    "assert tokenizer_decoder.pad_token == '<PAD>'\n",
    "\n",
    "\n",
    "# Evaluation\n",
    "model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args)\n",
    "model_vae.load_state_dict(checkpoint['model_state_dict'])\n",
    "logger.info(\"Pre-trained Optimus is successfully loaded\")\n",
    "model_vae.to(args.device)\n",
    "model_vae.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.top_p=0\n",
    "args.top_k=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latent_codes(dataset):\n",
    "    return torch.cat([latent_code_from_text(text, tokenizer_encoder, model_vae, args)[0] for text in tqdm(dataset)]).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_shift(text, shift, coef=1.0):\n",
    "    return text_from_latent_code(\n",
    "            latent_code_from_text(text,tokenizer_encoder,model_vae,args)[0]+coef*shift,\n",
    "            model_vae, args, tokenizer_decoder\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('wikipedia.segmented.nltk.txt') as f:\n",
    "    wiki=[line.strip() for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_filtered=[text for text in wiki if len(text.split())<100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_subset=random.sample(wiki_filtered,100000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100000/100000 [13:02<00:00, 127.76it/s]\n"
     ]
    }
   ],
   "source": [
    "encoded_wiki=get_latent_codes(wiki_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_pca=PCA(random_state=42).fit(encoded_wiki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_pca_transf=wiki_pca.transform(encoded_wiki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "short_inds=[i for i, text in enumerate(wiki_subset) if len(text.split())<=20 and '<unk>' not in text.split()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TEXTS=20\n",
    "SUBSAMPLES_FOR_DIRECTION=5\n",
    "TEXTS_IN_SUBSAMPLE=5\n",
    "interp_texts_v1=np.random.RandomState(seed=1).choice(short_inds,size=NUM_TEXTS, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_ind=['--','-','','+','++']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from collections import Counter\n",
    "from tqdm import trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [06:58<00:00, 20.91s/it]\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "texts_pca=[]\n",
    "models_pca=[]\n",
    "methods_pca=[]\n",
    "direction_inds_pca=[]\n",
    "\n",
    "\n",
    "for i in trange(20):\n",
    "    all_texts_for_direction=[]\n",
    "    \n",
    "    for j in interp_texts_v1:\n",
    "        text=wiki_subset[j]\n",
    "        shift_coefs=np.linspace(-5,5,num=5)\n",
    "\n",
    "        shifted_texts=[]\n",
    "        for shift_coef in shift_coefs:\n",
    "            new_latent_code=wiki_pca_transf[j].copy()\n",
    "            new_latent_code[i]+=shift_coef\n",
    "            new_latent_code=torch.from_numpy(wiki_pca.inverse_transform(new_latent_code[None,:])).float().cuda()\n",
    "            shifted_texts.append(text_from_latent_code(new_latent_code,model_vae, args, tokenizer_decoder))\n",
    "            \n",
    "        if Counter(shifted_texts).most_common(1)[0][1]<=3:\n",
    "            shifted_with_intensities=[]\n",
    "            for k, shifted in enumerate(shifted_texts):\n",
    "                shifted_with_intensities.append(f'{dir_ind[k]:4}{shifted}')\n",
    "            all_texts_for_direction.append('\\n'.join(shifted_with_intensities))\n",
    "            \n",
    "    for _ in range(SUBSAMPLES_FOR_DIRECTION):\n",
    "        inds=np.random.choice(len(all_texts_for_direction),size=TEXTS_IN_SUBSAMPLE,replace=False)\n",
    "        texts_pca.append('\\n\\n\\n'.join([all_texts_for_direction[j] for j in inds]))\n",
    "        models_pca.append('wiki')\n",
    "        methods_pca.append('pca')\n",
    "        direction_inds_pca.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pca=pd.DataFrame({'INPUT:text':texts_pca,\n",
    "                          'INPUT:model':models_pca,\n",
    "                          'INPUT:method':methods_pca,\n",
    "                          'INPUT:direction_index':direction_inds_pca})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "random_direction=np.random.randn(20,encoded_wiki.shape[1])\n",
    "random_direction/=np.linalg.norm(random_direction,axis=1,keepdims=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [08:04<00:00, 24.20s/it]\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "texts_random=[]\n",
    "models_random=[]\n",
    "methods_random=[]\n",
    "direction_inds_random=[]\n",
    "\n",
    "for i in trange(20):\n",
    "    all_texts_for_direction=[]\n",
    "    \n",
    "    for j in interp_texts_v1:\n",
    "        text=wiki_subset[j]\n",
    "        shift_coefs=np.linspace(-10,10,num=5)\n",
    "\n",
    "        shifted_texts=[]\n",
    "        for shift_coef in shift_coefs:\n",
    "            new_latent_code=encoded_wiki[j].copy()\n",
    "            new_latent_code+=shift_coef*random_direction[i]\n",
    "            new_latent_code=torch.from_numpy(new_latent_code[None,:]).float().cuda()\n",
    "            shifted_texts.append(text_from_latent_code(new_latent_code,model_vae, args, tokenizer_decoder))\n",
    "            \n",
    "        if Counter(shifted_texts).most_common(1)[0][1]<=3:\n",
    "            shifted_with_intensities=[]\n",
    "            for k, shifted in enumerate(shifted_texts):\n",
    "                shifted_with_intensities.append(f'{dir_ind[k]:4}{shifted}')\n",
    "            all_texts_for_direction.append('\\n'.join(shifted_with_intensities))\n",
    "            \n",
    "    for _ in range(SUBSAMPLES_FOR_DIRECTION):\n",
    "        inds=np.random.choice(len(all_texts_for_direction),size=TEXTS_IN_SUBSAMPLE,replace=False)\n",
    "        texts_random.append('\\n\\n\\n'.join([all_texts_for_direction[j] for j in inds]))\n",
    "        models_random.append('wiki')\n",
    "        methods_random.append('random')\n",
    "        direction_inds_random.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_random=pd.DataFrame({'INPUT:text':texts_random,\n",
    "                          'INPUT:model':models_random,\n",
    "                          'INPUT:method':methods_random,\n",
    "                          'INPUT:direction_index':direction_inds_random})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_coords_by_variance=encoded_wiki.std(axis=0).argsort()[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [08:42<00:00, 26.15s/it]\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "texts_coord=[]\n",
    "models_coord=[]\n",
    "methods_coord=[]\n",
    "direction_inds_coord=[]\n",
    "\n",
    "for i in trange(20):\n",
    "    coord=top_coords_by_variance[i]\n",
    "    \n",
    "    all_texts_for_direction=[]\n",
    "    \n",
    "    for j in interp_texts_v1:\n",
    "        text=wiki_subset[j]\n",
    "        shift_coefs=np.linspace(-10,10,num=5)\n",
    "\n",
    "        shifted_texts=[]\n",
    "        for shift_coef in shift_coefs:\n",
    "            new_latent_code=encoded_wiki[j].copy()\n",
    "            new_latent_code[coord]+=shift_coef\n",
    "            new_latent_code=torch.from_numpy(new_latent_code[None,:]).float().cuda()\n",
    "            shifted_texts.append(text_from_latent_code(new_latent_code,model_vae, args, tokenizer_decoder))\n",
    "            \n",
    "        if Counter(shifted_texts).most_common(1)[0][1]<=3:\n",
    "            shifted_with_intensities=[]\n",
    "            for k, shifted in enumerate(shifted_texts):\n",
    "                shifted_with_intensities.append(f'{dir_ind[k]:4}{shifted}')\n",
    "            all_texts_for_direction.append('\\n'.join(shifted_with_intensities))\n",
    "            \n",
    "    for _ in range(SUBSAMPLES_FOR_DIRECTION):\n",
    "        inds=np.random.choice(len(all_texts_for_direction),size=TEXTS_IN_SUBSAMPLE,replace=False)\n",
    "        texts_coord.append('\\n\\n\\n'.join([all_texts_for_direction[j] for j in inds]))\n",
    "        models_coord.append('wiki')\n",
    "        methods_coord.append('coordinate')\n",
    "        direction_inds_coord.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_coord=pd.DataFrame({'INPUT:text':texts_coord,\n",
    "                          'INPUT:model':models_coord,\n",
    "                          'INPUT:method':methods_coord,\n",
    "                          'INPUT:direction_index':direction_inds_coord})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pca.to_csv('data_pca_wiki.tsv',sep='\\t',index=False)\n",
    "df_random.to_csv('data_random_wiki.tsv',sep='\\t',index=False)\n",
    "df_coord.to_csv('data_coord_wiki.tsv',sep='\\t',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
