{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Put this notebook in the same directory as your config file and the experiment folder!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=0\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.append('<path to the repo>/jiant')   # --->  --->  --->  --->  --->  --->  ---> # CHANGE THIS\n",
    "\n",
    "# what you feed into python ../jiant/main.py, remove quotes('\") and split by arg\n",
    "cl_arguments = [\n",
    "    '--config_file', 'srl_bayes_layer1.conf', # --->  --->  --->  --->  --->  --->  ---> # CHANGE THIS\n",
    "    '--overrides', 'exp_name = srl_bayes_layer1, run_name = run1, project_dir = .'       # CHANGE THIS\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### setup jiant training env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:52:06 PM: Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n",
      "03/25 09:52:06 PM: Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n",
      "03/25 09:52:06 PM: Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n",
      "03/25 09:52:06 PM: Loading config from srl_bayes_layer1_256_dr0.conf\n",
      "03/25 09:52:06 PM: Config overrides: exp_name = srl_bayes_layer1_256_dr0, run_name = run1, clf_first_layer_clip_var = 100000, project_dir = .\n",
      "03/25 09:52:07 PM: Git branch: master\n",
      "03/25 09:52:07 PM: Git SHA: 292cf0081a0e5d7c4a0b1bcfd299e2a31a8e28c4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      " SamplingMultiTaskTrainerBayes \n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:52:07 PM: Parsed args: \n",
      "{\n",
      "  \"allow_missing_task_map\": 1,\n",
      "  \"allow_untrained_encoder_parameters\": 1,\n",
      "  \"classifier\": \"bayes_mlp\",\n",
      "  \"classifier_dropout\": \"0.\",\n",
      "  \"clf_first_layer_clip_var\": 100000,\n",
      "  \"clf_num_hid_layers\": 1,\n",
      "  \"data_dir\": \"/opt/lena-voita/probing/jiant/data\",\n",
      "  \"do_target_task_training\": 0,\n",
      "  \"edges-srl-ontonotes\": {\n",
      "    \"classifier_dropout\": 0.0\n",
      "  },\n",
      "  \"edges-tmpl-bayes\": {\n",
      "    \"classifier_dropout\": 0.0,\n",
      "    \"classifier_hid_dim\": 256,\n",
      "    \"classifier_span_pooling\": \"attn\",\n",
      "    \"max_vals\": 250,\n",
      "    \"pair_attn\": 0,\n",
      "    \"span_classifier_loss_fn\": \"softmax\",\n",
      "    \"val_interval\": 1000\n",
      "  },\n",
      "  \"elmo_layer_index\": 1,\n",
      "  \"exp_dir\": \"./srl_bayes_layer1_256_dr0/\",\n",
      "  \"exp_name\": \"srl_bayes_layer1_256_dr0\",\n",
      "  \"input_module\": \"elmo\",\n",
      "  \"keep_all_checkpoints\": 1,\n",
      "  \"local_log_path\": \"./srl_bayes_layer1_256_dr0/run1/log.log\",\n",
      "  \"lr\": 0.001,\n",
      "  \"lr_decay_factor\": 0.999999,\n",
      "  \"lr_patience\": 5,\n",
      "  \"max_seq_len\": 512,\n",
      "  \"min_epochs\": 500,\n",
      "  \"patience\": 100,\n",
      "  \"pretrain_tasks\": \"edges-srl-ontonotes\",\n",
      "  \"remote_log_name\": \"srl_bayes_layer1_256_dr0__run1\",\n",
      "  \"run_dir\": \"./srl_bayes_layer1_256_dr0/run1\",\n",
      "  \"run_name\": \"run1\",\n",
      "  \"sent_enc\": \"none\",\n",
      "  \"sep_embs_for_skip\": 1,\n",
      "  \"target_tasks\": \"edges-srl-ontonotes\",\n",
      "  \"tokenizer\": \"MosesTokenizer\",\n",
      "  \"train_type\": \"SamplingMultiTaskTrainerBayes\",\n",
      "  \"word_embs_file\": \"\",\n",
      "  \"write_preds\": \"val,test\"\n",
      "}\n",
      "03/25 09:52:07 PM: Saved config to ./srl_bayes_layer1_256_dr0/run1/params.conf\n",
      "03/25 09:52:07 PM: Using random seed 1234\n",
      "03/25 09:52:07 PM: Using GPU 0\n",
      "03/25 09:52:07 PM: Loading tasks...\n",
      "03/25 09:52:07 PM: Writing pre-preprocessed tasks to ./srl_bayes_layer1_256_dr0/\n",
      "03/25 09:52:07 PM: \tCreating task edges-srl-ontonotes from scratch.\n",
      "03/25 09:52:12 PM: Read=231480, Skip=21590, Total=253070 from /opt/lena-voita/probing/jiant/data/edges/ontonotes/srl/train.json.retokenized.MosesTokenizer\n",
      "03/25 09:52:13 PM: Read=32486, Skip=2811, Total=35297 from /opt/lena-voita/probing/jiant/data/edges/ontonotes/srl/development.json.retokenized.MosesTokenizer\n",
      "03/25 09:52:13 PM: Read=23800, Skip=2915, Total=26715 from /opt/lena-voita/probing/jiant/data/edges/ontonotes/srl/test.json.retokenized.MosesTokenizer\n",
      "03/25 09:52:16 PM: \tTask 'edges-srl-ontonotes': |train|=231480 |val|=32486 |test|=23800\n",
      "03/25 09:52:16 PM: \tFinished loading tasks: edges-srl-ontonotes.\n",
      "03/25 09:52:16 PM: Loading token dictionary from ./srl_bayes_layer1_256_dr0/vocab.\n",
      "03/25 09:52:16 PM: \tLoaded vocab from ./srl_bayes_layer1_256_dr0/vocab\n",
      "03/25 09:52:16 PM: \tVocab namespace edges-srl-ontonotes_labels: size 66\n",
      "03/25 09:52:16 PM: \tVocab namespace tokens: size 30004\n",
      "03/25 09:52:16 PM: \tVocab namespace chars: size 103\n",
      "03/25 09:52:16 PM: \tFinished building vocab.\n",
      "03/25 09:52:16 PM: \tTask 'edges-srl-ontonotes', split 'train': Found preprocessed copy in ./srl_bayes_layer1_256_dr0/preproc/edges-srl-ontonotes__train_data\n",
      "03/25 09:52:16 PM: \tTask 'edges-srl-ontonotes', split 'val': Found preprocessed copy in ./srl_bayes_layer1_256_dr0/preproc/edges-srl-ontonotes__val_data\n",
      "03/25 09:52:16 PM: \tTask 'edges-srl-ontonotes', split 'test': Found preprocessed copy in ./srl_bayes_layer1_256_dr0/preproc/edges-srl-ontonotes__test_data\n",
      "03/25 09:52:16 PM: \tFinished indexing tasks\n",
      "03/25 09:52:30 PM: \tCreating trimmed pretraining-only version of edges-srl-ontonotes train.\n",
      "03/25 09:53:36 PM: \tCreating trimmed target-only version of edges-srl-ontonotes train.\n",
      "03/25 09:54:39 PM: \t  Training on edges-srl-ontonotes\n",
      "03/25 09:54:39 PM: \t  Evaluating on edges-srl-ontonotes\n",
      "03/25 09:54:39 PM: \tFinished loading tasks in 151.246s\n",
      "03/25 09:54:39 PM: \t Tasks: ['edges-srl-ontonotes']\n",
      "03/25 09:54:39 PM: Building model...\n",
      "03/25 09:54:39 PM: \tNot using character embeddings!\n",
      "03/25 09:54:39 PM: Classifiers:{'@pretrain@': 0, 'edges-srl-ontonotes': 1}\n",
      "03/25 09:54:39 PM: Loading ELMo from files:\n",
      "03/25 09:54:39 PM: ELMO_OPT_PATH = https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json\n",
      "03/25 09:54:39 PM: \tUsing full ELMo! (separate scalars/task)\n",
      "03/25 09:54:39 PM: ELMO_WEIGHTS_PATH = https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5\n",
      "03/25 09:54:39 PM: Initializing ELMo\n",
      "03/25 09:55:02 PM: Initializing parameters\n",
      "03/25 09:55:02 PM: Done initializing parameters; the following parameters are using their default initialization from their code\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.input_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_linearity.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_projection.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.input_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_linearity.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_projection.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.input_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_linearity.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_projection.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.input_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_projection.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._char_embedding_weights\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.0.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.0.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.1.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.1.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._projection.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._projection.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_0.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_0.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_1.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_1.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_2.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_2.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_3.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_3.weight\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_4.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_4.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_5.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_5.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_6.bias\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_6.weight\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.gamma\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.0\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.1\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.2\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.gamma\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.scalar_parameters.0\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.scalar_parameters.1\n",
      "03/25 09:55:02 PM:    _text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.scalar_parameters.2\n",
      "03/25 09:55:02 PM: \tTask 'edges-srl-ontonotes' params: {\n",
      "  \"cls_type\": \"bayes_mlp\",\n",
      "  \"d_hid\": 256,\n",
      "  \"pool_type\": \"max\",\n",
      "  \"d_proj\": 512,\n",
      "  \"shared_pair_attn\": 0,\n",
      "  \"attn\": 0,\n",
      "  \"d_hid_attn\": 512,\n",
      "  \"dropout\": 0.0,\n",
      "  \"cls_loss_fn\": \"softmax\",\n",
      "  \"cls_span_pooling\": \"attn\",\n",
      "  \"edgeprobe_cnn_context\": 0,\n",
      "  \"edgeprobe_symmetric\": 0,\n",
      "  \"use_classifier\": \"edges-srl-ontonotes\",\n",
      "  \"clf_num_hid_layers\": 1,\n",
      "  \"clf_first_layer_clip_var\": 100000\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constructing BayesMLP\n",
      "Applying dropout 0.0\n",
      "Using intermediate size (hidden dim / rank) 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:55:04 PM: Model specification:\n",
      "03/25 09:55:04 PM: MultiTaskModel(\n",
      "  (sent_encoder): SentenceEncoder(\n",
      "    (_text_field_embedder): ElmoTextFieldEmbedder(\n",
      "      (token_embedder_elmo): ElmoTokenEmbedderWrapper(\n",
      "        (_elmo): Elmo(\n",
      "          (_elmo_lstm): _ElmoBiLm(\n",
      "            (_token_embedder): _ElmoCharacterEncoder(\n",
      "              (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))\n",
      "              (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))\n",
      "              (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))\n",
      "              (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))\n",
      "              (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))\n",
      "              (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))\n",
      "              (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))\n",
      "              (_highways): Highway(\n",
      "                (_layers): ModuleList(\n",
      "                  (0): Linear(in_features=2048, out_features=4096, bias=True)\n",
      "                  (1): Linear(in_features=2048, out_features=4096, bias=True)\n",
      "                )\n",
      "              )\n",
      "              (_projection): Linear(in_features=2048, out_features=512, bias=True)\n",
      "            )\n",
      "            (_elmo_lstm): ElmoLstm(\n",
      "              (forward_layer_0): LstmCellWithProjection(\n",
      "                (input_linearity): Linear(in_features=512, out_features=16384, bias=False)\n",
      "                (state_linearity): Linear(in_features=512, out_features=16384, bias=True)\n",
      "                (state_projection): Linear(in_features=4096, out_features=512, bias=False)\n",
      "              )\n",
      "              (backward_layer_0): LstmCellWithProjection(\n",
      "                (input_linearity): Linear(in_features=512, out_features=16384, bias=False)\n",
      "                (state_linearity): Linear(in_features=512, out_features=16384, bias=True)\n",
      "                (state_projection): Linear(in_features=4096, out_features=512, bias=False)\n",
      "              )\n",
      "              (forward_layer_1): LstmCellWithProjection(\n",
      "                (input_linearity): Linear(in_features=512, out_features=16384, bias=False)\n",
      "                (state_linearity): Linear(in_features=512, out_features=16384, bias=True)\n",
      "                (state_projection): Linear(in_features=4096, out_features=512, bias=False)\n",
      "              )\n",
      "              (backward_layer_1): LstmCellWithProjection(\n",
      "                (input_linearity): Linear(in_features=512, out_features=16384, bias=False)\n",
      "                (state_linearity): Linear(in_features=512, out_features=16384, bias=True)\n",
      "                (state_projection): Linear(in_features=4096, out_features=512, bias=False)\n",
      "              )\n",
      "            )\n",
      "          )\n",
      "          (_dropout): Dropout(p=0.0)\n",
      "          (scalar_mix_0): ScalarMix(\n",
      "            (scalar_parameters): ParameterList(\n",
      "                (0): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "                (1): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "                (2): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "            )\n",
      "          )\n",
      "          (scalar_mix_1): ScalarMix(\n",
      "            (scalar_parameters): ParameterList(\n",
      "                (0): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "                (1): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "                (2): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
      "            )\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (_highway_layer): TimeDistributed(\n",
      "      (_module): Highway(\n",
      "        (_layers): ModuleList()\n",
      "      )\n",
      "    )\n",
      "    (_phrase_layer): NullPhraseLayer()\n",
      "    (_dropout): Dropout(p=0.2)\n",
      "  )\n",
      "  (edges-srl-ontonotes_mdl): EdgeClassifierModule(\n",
      "    (proj1): Conv1dGroupNJ (1024 -> 256)\n",
      "    (proj2): Conv1dGroupNJ (1024 -> 256)\n",
      "    (span_extractor1): SelfAttentiveSpanExtractor(\n",
      "      (_global_attention): TimeDistributed(\n",
      "        (_module): Linear(in_features=256, out_features=1, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (span_extractor2): SelfAttentiveSpanExtractor(\n",
      "      (_global_attention): TimeDistributed(\n",
      "        (_module): Linear(in_features=256, out_features=1, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (classifier): Classifier(\n",
      "      (classifier): BayesMLP(\n",
      "        (dropout): Dropout(p=0.0)\n",
      "        (relu): ReLU()\n",
      "        (initial_linear): LinearGroupNJ (512 -> 256)\n",
      "        (intermediate_linears): ModuleList()\n",
      "        (last_linear): LinearGroupNJ (256 -> 66)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      ")\n",
      "03/25 09:55:04 PM: Model parameters:\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._char_embedding_weights: Non-trainable parameter, count 4192 with torch.Size([262, 16])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_0.weight: Non-trainable parameter, count 512 with torch.Size([32, 16, 1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_0.bias: Non-trainable parameter, count 32 with torch.Size([32])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_1.weight: Non-trainable parameter, count 1024 with torch.Size([32, 16, 2])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_1.bias: Non-trainable parameter, count 32 with torch.Size([32])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_2.weight: Non-trainable parameter, count 3072 with torch.Size([64, 16, 3])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_2.bias: Non-trainable parameter, count 64 with torch.Size([64])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_3.weight: Non-trainable parameter, count 8192 with torch.Size([128, 16, 4])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_3.bias: Non-trainable parameter, count 128 with torch.Size([128])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_4.weight: Non-trainable parameter, count 20480 with torch.Size([256, 16, 5])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_4.bias: Non-trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_5.weight: Non-trainable parameter, count 49152 with torch.Size([512, 16, 6])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_5.bias: Non-trainable parameter, count 512 with torch.Size([512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_6.weight: Non-trainable parameter, count 114688 with torch.Size([1024, 16, 7])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_6.bias: Non-trainable parameter, count 1024 with torch.Size([1024])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.0.weight: Non-trainable parameter, count 8388608 with torch.Size([4096, 2048])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.0.bias: Non-trainable parameter, count 4096 with torch.Size([4096])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.1.weight: Non-trainable parameter, count 8388608 with torch.Size([4096, 2048])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.1.bias: Non-trainable parameter, count 4096 with torch.Size([4096])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._projection.weight: Non-trainable parameter, count 1048576 with torch.Size([512, 2048])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._projection.bias: Non-trainable parameter, count 512 with torch.Size([512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.input_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_linearity.bias: Non-trainable parameter, count 16384 with torch.Size([16384])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_projection.weight: Non-trainable parameter, count 2097152 with torch.Size([512, 4096])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.input_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_linearity.bias: Non-trainable parameter, count 16384 with torch.Size([16384])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_projection.weight: Non-trainable parameter, count 2097152 with torch.Size([512, 4096])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.input_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.bias: Non-trainable parameter, count 16384 with torch.Size([16384])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_projection.weight: Non-trainable parameter, count 2097152 with torch.Size([512, 4096])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.input_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_linearity.weight: Non-trainable parameter, count 8388608 with torch.Size([16384, 512])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_linearity.bias: Non-trainable parameter, count 16384 with torch.Size([16384])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_projection.weight: Non-trainable parameter, count 2097152 with torch.Size([512, 4096])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.gamma: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.0: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.1: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.2: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.gamma: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.scalar_parameters.0: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.scalar_parameters.1: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tsent_encoder._text_field_embedder.token_embedder_elmo._elmo.scalar_mix_1.scalar_parameters.2: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj1.weight_mu: Trainable parameter, count 262144 with torch.Size([256, 1024, 1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj1.weight_logvar: Trainable parameter, count 262144 with torch.Size([256, 1024, 1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj1.bias_mu: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj1.bias_logvar: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj1.z_mu: Trainable parameter, count 1024 with torch.Size([1024])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj1.z_logvar: Trainable parameter, count 1024 with torch.Size([1024])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj2.weight_mu: Trainable parameter, count 262144 with torch.Size([256, 1024, 1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj2.weight_logvar: Trainable parameter, count 262144 with torch.Size([256, 1024, 1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj2.bias_mu: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj2.bias_logvar: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj2.z_mu: Trainable parameter, count 1024 with torch.Size([1024])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.proj2.z_logvar: Trainable parameter, count 1024 with torch.Size([1024])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.span_extractor1._global_attention._module.weight: Trainable parameter, count 256 with torch.Size([1, 256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.span_extractor1._global_attention._module.bias: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.span_extractor2._global_attention._module.weight: Trainable parameter, count 256 with torch.Size([1, 256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.span_extractor2._global_attention._module.bias: Trainable parameter, count 1 with torch.Size([1])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.initial_linear.z_mu: Trainable parameter, count 512 with torch.Size([512])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.initial_linear.z_logvar: Trainable parameter, count 512 with torch.Size([512])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.initial_linear.weight_mu: Trainable parameter, count 131072 with torch.Size([256, 512])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.initial_linear.weight_logvar: Trainable parameter, count 131072 with torch.Size([256, 512])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.initial_linear.bias_mu: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.initial_linear.bias_logvar: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.last_linear.z_mu: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.last_linear.z_logvar: Trainable parameter, count 256 with torch.Size([256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.last_linear.weight_mu: Trainable parameter, count 16896 with torch.Size([66, 256])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.last_linear.weight_logvar: Trainable parameter, count 16896 with torch.Size([66, 256])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.last_linear.bias_mu: Trainable parameter, count 66 with torch.Size([66])\n",
      "03/25 09:55:04 PM: \tedges-srl-ontonotes_mdl.classifier.classifier.last_linear.bias_logvar: Trainable parameter, count 66 with torch.Size([66])\n",
      "03/25 09:55:04 PM: Total number of parameters: 94953198 (9.49532e+07)\n",
      "03/25 09:55:04 PM: Number of trainable parameters: 1352334 (1.35233e+06)\n",
      "03/25 09:55:04 PM: Finished building model in 25.680s\n",
      "03/25 09:55:04 PM: Will run the following steps for this experiment:\n",
      "Training model on tasks: edges-srl-ontonotes \n",
      "Evaluating model on tasks: edges-srl-ontonotes \n",
      "\n"
     ]
    }
   ],
   "source": [
    "from jiant.__main__ import *\n",
    "\n",
    "# this is a copy of jiant/__main__.py : def main\n",
    "# we can use it to create tasks, trainer, etc etc\n",
    "cl_args = handle_arguments(cl_arguments)\n",
    "args = config.params_from_file(cl_args.config_file, cl_args.overrides)\n",
    "train_type = args.get('train_type', \"SamplingMultiTaskTrainer\")\n",
    "\n",
    "if train_type != \"SamplingMultiTaskTrainer\":\n",
    "    print(\"\\n\\n\\n\", train_type, \"\\n\\n\\n\")\n",
    "\n",
    "# Check for deprecated arg names\n",
    "check_arg_name(args)\n",
    "args, seed = initial_setup(args, cl_args)\n",
    "# Load tasks\n",
    "log.info(\"Loading tasks...\")\n",
    "start_time = time.time()\n",
    "pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args)\n",
    "tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name)\n",
    "log.info(\"\\tFinished loading tasks in %.3fs\", time.time() - start_time)\n",
    "log.info(\"\\t Tasks: {}\".format([task.name for task in tasks]))\n",
    "\n",
    "log.info(\"Building model...\")\n",
    "start_time = time.time()\n",
    "model = build_model(args, vocab, word_embs, tasks)\n",
    "log.info(\"Finished building model in %.3fs\", time.time() - start_time)\n",
    "\n",
    "# Start Tensorboard if requested\n",
    "if cl_args.tensorboard:\n",
    "    tb_logdir = os.path.join(args.run_dir, \"tensorboard\")\n",
    "    _run_background_tensorboard(tb_logdir, cl_args.tensorboard_port)\n",
    "\n",
    "check_configurations(args, pretrain_tasks, target_tasks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### create trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:55:04 PM: Training...\n",
      "03/25 09:55:04 PM: patience = 100\n",
      "03/25 09:55:04 PM: val_interval = 1000\n",
      "03/25 09:55:04 PM: max_vals = 1000\n",
      "03/25 09:55:04 PM: cuda_device = 0\n",
      "03/25 09:55:04 PM: grad_norm = 5.0\n",
      "03/25 09:55:04 PM: grad_clipping = None\n",
      "03/25 09:55:04 PM: lr_decay = 0.99\n",
      "03/25 09:55:04 PM: min_lr = 1e-06\n",
      "03/25 09:55:04 PM: keep_all_checkpoints = 1\n",
      "03/25 09:55:04 PM: val_data_limit = 5000\n",
      "03/25 09:55:04 PM: max_epochs = -1\n",
      "03/25 09:55:04 PM: min_epochs = 500\n",
      "03/25 09:55:04 PM: dec_val_scale = 250\n",
      "03/25 09:55:04 PM: training_data_fraction = 1\n",
      "03/25 09:55:04 PM: type = adam\n",
      "03/25 09:55:04 PM: parameter_groups = None\n",
      "03/25 09:55:04 PM: Number of trainable parameters: 1352334\n",
      "03/25 09:55:04 PM: infer_type_and_cast = True\n",
      "03/25 09:55:04 PM: Converting Params object to dict; logging of default values will not occur when dictionary parameters are used subsequently.\n",
      "03/25 09:55:04 PM: CURRENTLY DEFINED PARAMETERS: \n",
      "03/25 09:55:04 PM: lr = 0.001\n",
      "03/25 09:55:04 PM: amsgrad = True\n",
      "03/25 09:55:04 PM: type = reduce_on_plateau\n",
      "03/25 09:55:04 PM: Converting Params object to dict; logging of default values will not occur when dictionary parameters are used subsequently.\n",
      "03/25 09:55:04 PM: CURRENTLY DEFINED PARAMETERS: \n",
      "03/25 09:55:04 PM: mode = max\n",
      "03/25 09:55:04 PM: factor = 0.999999\n",
      "03/25 09:55:04 PM: patience = 5\n",
      "03/25 09:55:04 PM: threshold = 0.0001\n",
      "03/25 09:55:04 PM: threshold_mode = abs\n",
      "03/25 09:55:04 PM: verbose = True\n",
      "03/25 09:55:04 PM: type = adam\n",
      "03/25 09:55:04 PM: parameter_groups = None\n",
      "03/25 09:55:04 PM: Number of trainable parameters: 1352334\n",
      "03/25 09:55:04 PM: infer_type_and_cast = True\n",
      "03/25 09:55:04 PM: Converting Params object to dict; logging of default values will not occur when dictionary parameters are used subsequently.\n",
      "03/25 09:55:04 PM: CURRENTLY DEFINED PARAMETERS: \n",
      "03/25 09:55:04 PM: lr = 0.001\n",
      "03/25 09:55:04 PM: amsgrad = True\n",
      "03/25 09:55:04 PM: type = reduce_on_plateau\n",
      "03/25 09:55:04 PM: Converting Params object to dict; logging of default values will not occur when dictionary parameters are used subsequently.\n",
      "03/25 09:55:04 PM: CURRENTLY DEFINED PARAMETERS: \n",
      "03/25 09:55:04 PM: mode = max\n",
      "03/25 09:55:04 PM: factor = 0.999999\n",
      "03/25 09:55:04 PM: patience = 5\n",
      "03/25 09:55:04 PM: threshold = 0.0001\n",
      "03/25 09:55:04 PM: threshold_mode = abs\n",
      "03/25 09:55:04 PM: verbose = True\n"
     ]
    }
   ],
   "source": [
    "from allennlp.training.optimizers import Optimizer\n",
    "from allennlp.training.learning_rate_schedulers import LearningRateScheduler\n",
    "\n",
    "#if args.do_pretrain:\n",
    "# Train on pretrain tasks\n",
    "log.info(\"Training...\")\n",
    "stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else \"macro_avg\"\n",
    "should_decrease = (\n",
    "    pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False\n",
    ")\n",
    "trainer, _, opt_params, schd_params = build_trainer(\n",
    "    args, [], model, args.run_dir, should_decrease, phase=\"pretrain\", train_type=train_type\n",
    ")\n",
    "to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]\n",
    "\n",
    "self = trainer\n",
    "tasks = pretrain_tasks\n",
    "batch_size = args.batch_size\n",
    "train_params = to_train\n",
    "optimizer_params = copy.deepcopy(opt_params)\n",
    "scheduler_params = schd_params\n",
    "phase = 'pretrain'\n",
    "validation_interval = trainer._val_interval\n",
    "\n",
    "task_infos, metric_infos = trainer._setup_training(\n",
    "    tasks, batch_size, train_params, optimizer_params, scheduler_params, phase\n",
    ")\n",
    "\n",
    "optimizer = Optimizer.from_params(train_params, optimizer_params)\n",
    "scheduler = LearningRateScheduler.from_params(optimizer, copy.deepcopy(scheduler_params))\n",
    "trainer._optimizer = optimizer\n",
    "trainer._scheduler = scheduler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load specific checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03/25 09:55:05 PM: Found checkpoint state_pretrain_val_400.th. Loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overriding state_pretrain_val_417.th, reading state_pretrain_val_400.th instead\n"
     ]
    }
   ],
   "source": [
    "n_step, should_stop = trainer._restore_checkpoint(phase, tasks, \n",
    "                                                  override_suffix='state_pretrain_val_400.th')  # CHANGE THIS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute some metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = trainer._model\n",
    "model.train(False);  # set model in eval mode. WARNING: this disables dropout and other similar stuff "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32f303fb875c4d55adb17ad694df3baf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=11574.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "{'edges-srl-ontonotes_sumloss_train': 165872.97469596053, 'edges-srl-ontonotes_num_examples_train': 598983.0, 'edges-srl-ontonotes_acc_train': 0.9129007047222994}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12a342f21c314e65ade0d5c6bef480bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1625.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "{'edges-srl-ontonotes_sumloss_val': 27050.371629461646, 'edges-srl-ontonotes_num_examples_val': 83362.0, 'edges-srl-ontonotes_acc_val': 0.9003622811681343}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "309e7dade7dc4bc59f28859521e39d98",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1190.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "{'edges-srl-ontonotes_sumloss_test': 19904.1856617108, 'edges-srl-ontonotes_num_examples_test': 61716.0, 'edges-srl-ontonotes_acc_test': 0.9025212305754479}\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "from jiant.trainer import BasicIterator, move_to_device\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "task = tasks[0]\n",
    "\n",
    "for name_suffix in ['train',\n",
    "                    'val',\n",
    "                    'test'\n",
    "                   ]:\n",
    "    if name_suffix == 'train':\n",
    "        eval_data = task.train_data\n",
    "    elif name_suffix == 'val':\n",
    "        eval_data = task.val_data\n",
    "    else:\n",
    "        eval_data = task.test_data\n",
    "\n",
    "    n_examples, batch_num = 0, 0\n",
    "    max_data_points = eval_data.size\n",
    "    val_generator = BasicIterator(batch_size, instances_per_epoch=max_data_points)(\n",
    "        eval_data, num_epochs=1, shuffle=False\n",
    "    )\n",
    "    val_generator = move_to_device(val_generator, self._cuda_device)\n",
    "    n_val_batches = math.ceil(eval_data.size / batch_size)\n",
    "\n",
    "    all_val_metrics = {}\n",
    "    all_val_metrics[\"{}_sumloss_{}\".format(task.name, name_suffix)] = 0.0\n",
    "    all_val_metrics[\"{}_num_examples_{}\".format(task.name, name_suffix)] = 0.0\n",
    "    all_val_metrics[\"{}_acc_{}\".format(task.name, name_suffix)] = 0.0\n",
    "\n",
    "    for batch in tqdm(val_generator, total=n_val_batches):\n",
    "        num_examples = torch.sum((batch['span1s'][..., 0] != -1).long()).data.cpu().numpy()\n",
    "        batch_num += 1\n",
    "        with torch.no_grad():\n",
    "            out = self._forward(batch, task=task)\n",
    "\n",
    "        loss = out[\"loss\"]\n",
    "        all_val_metrics[\"{}_sumloss_{}\".format(task.name, name_suffix)] += num_examples * loss.data.cpu().numpy()\n",
    "        all_val_metrics[\"{}_num_examples_{}\".format(task.name, name_suffix)] += num_examples\n",
    "        all_val_metrics[\"{}_acc_{}\".format(task.name, name_suffix)] += num_examples * out[\"acc\"].data.cpu().numpy()\n",
    "    all_val_metrics[\"{}_acc_{}\".format(task.name, name_suffix)] /= all_val_metrics[\"{}_num_examples_{}\".format(task.name, name_suffix)]\n",
    "\n",
    "    print(all_val_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Here we start computing resulting metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_examples = { 'pos': 2070382, 'dep': 203919, 'coref': 207830, 'srl': 598983, 'ner': 128738, 'nonterm': 1851590, 'relsem': 6851}\n",
    "\n",
    "num_labels = { 'pos': 48, 'dep': 49, 'coref': 2, 'srl': 66, 'ner': 18, 'nonterm': 30, 'relsem': 19}\n",
    "\n",
    "full_task_name = {'pos': 'pos-ontonotes', 'nonterm': 'nonterminal-ontonotes', \n",
    "                  'dep': 'dep-ud-ewt', 'srl': 'srl-ontonotes',\n",
    "                  'coref': 'coref-ontonotes', 'relsem': 'rel-semeval','ner': 'ner-ontonotes'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = 'srl'    # CHANGE THIS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test accuracy (of variational probe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (variational probe):  90.25\n"
     ]
    }
   ],
   "source": [
    "print(\"Test accuracy (variational probe): \",\n",
    "      round(all_val_metrics['edges-{}_acc_test'.format(full_task_name[task])] * 100, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Codelength"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Uniform codelength"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "uniform_codelength = num_examples[task] * np.log2(num_labels[task])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Variational codelength"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Conv1dGroupNJ (1024 -> 256), Conv1dGroupNJ (1024 -> 256), LinearGroupNJ (512 -> 256), LinearGroupNJ (256 -> 66)]\n"
     ]
    }
   ],
   "source": [
    "from jiant.trainer import BayesianLayers\n",
    "bayes_modules = list(BayesianLayers.get_kl_modules(model))\n",
    "print(bayes_modules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_codelength = float(sum([mod.kl_divergence() for mod in bayes_modules]))\n",
    "data_codelength = all_val_metrics['edges-{}_sumloss_train'.format(full_task_name[task])]\n",
    "variational_codelength = model_codelength + data_codelength"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variational codelength: 228.34 kbits\n",
      "Compression: 15.48 \n"
     ]
    }
   ],
   "source": [
    "print(\"Variational codelength: {} kbits\".format(round(variational_codelength / 1024, 2)))\n",
    "print(\"Compression: {} \".format(round(uniform_codelength / variational_codelength, 2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pruned architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%pylab inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To decide which weights can be pruned, we follow the original paper [Bayesian compression](https://arxiv.org/abs/1705.08665). Namely, we determining the threshold for pruning with manual inspection as usually there are two well separated clusters (signal and noise)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, alphas = BayesianLayers.get_masks(bayes_modules,\n",
    "                                         [0, 0, 0, 0], # thresholds for pruning - we do not know them for now\n",
    "                                         return_log_alpha=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 3., 14., 23.,  5.,  1.,  1.,  1.,  0.,  1.,  6., 26., 69., 82.,\n",
       "        72., 77., 54., 33., 30.,  8.,  6.]),\n",
       " array([-10.87374878,  -9.9440877 ,  -9.01442661,  -8.08476553,\n",
       "         -7.15510445,  -6.22544336,  -5.29578228,  -4.3661212 ,\n",
       "         -3.43646011,  -2.50679903,  -1.57713795,  -0.64747686,\n",
       "          0.28218422,   1.2118453 ,   2.14150639,   3.07116747,\n",
       "          4.00082855,   4.93048964,   5.86015072,   6.7898118 ,\n",
       "          7.71947289]),\n",
       " <a list of 20 Patch objects>)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAARbklEQVR4nO3df4xlZX3H8fenrIiiFZBxXUHcNRAsbSLqhGK1jWXRIhh221iCadptJdnY1lZqG11rorb2D0hbrU2sZivWbWMVROhu/E1XjGlSVwcElB+WBRfdzbI7WvBnokW//eOehXG4w5ydufcOz/J+JZN7fu757jNnPvPMc8+5J1WFJKk9P7fSBUiSlsYAl6RGGeCS1CgDXJIaZYBLUqNWTfJgJ554Yq1du3aSh5Sk5t1www3fqqqp+csnGuBr165lZmZmkoeUpOYluWfYcodQJKlRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpURO9E1PSw63d8vEl77vnsgtGWIlaYw9ckhrVK8CT/FmSW5N8NcmHkhyTZF2SXUl2J7kyydHjLlaS9JBFAzzJScCfAtNV9UvAUcDFwOXAO6vqVOA+4JJxFipJ+ll9h1BWAU9Isgp4IrAfOAe4ulu/Ddg4+vIkSQtZNMCrah/wd8A3GAT3d4AbgPur6oFus73AScP2T7I5yUySmdnZ2dFULUnqNYRyPLABWAc8AzgWOK/vAapqa1VNV9X01NTDPo9ckrREfYZQzgW+XlWzVfV/wDXAi4DjuiEVgJOBfWOqUZI0RJ8A/wZwdpInJgmwHrgNuB54ZbfNJmD7eEqUJA3TZwx8F4M3K28EvtLtsxV4I/D6JLuBpwJXjLFOSdI8ve7ErKq3Am+dt/hu4KyRVyRJ6sU7MSWpUX4WivQY5WewtM8euCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIa5YdZSQ1bzgdSqX32wCWpUX0eanx6kpvmfH03yaVJTkhyXZI7u9fjJ1GwJGmgzyPVvlZVZ1bVmcALgB8C1wJbgJ1VdRqws5uXJE3I4Q6hrAfuqqp7gA3Atm75NmDjKAuTJD2yww3wi4EPddOrq2p/N30vsHrYDkk2J5lJMjM7O7vEMiVJ8/UO8CRHAxcCH5m/rqoKqGH7VdXWqpququmpqaklFypJ+lmH0wN/OXBjVR3o5g8kWQPQvR4cdXGSpIUdznXgr+Kh4ROAHcAm4LLudfsI65Ka4bXYWim9euBJjgVeClwzZ/FlwEuT3Amc281LkiakVw+8qn4APHXesm8zuCpFkrQCvBNTkhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktSovk/kOS7J1UnuSHJ7khcmOSHJdUnu7F6PH3exkqSH9O2Bvwv4VFU9B3gucDuwBdhZVacBO7t5SdKELBrgSZ4C/BpwBUBV/biq7gc2ANu6zbYBG8dVpCTp4fr0wNcBs8C/JPlykvd1DzleXVX7u23uBVYP2znJ5iQzSWZmZ2dHU7UkqVeArwKeD7ynqp4H/IB5wyVVVUAN27mqtlbVdFVNT01NLbdeSVKnT4DvBfZW1a5u/moGgX4gyRqA7vXgeEqUJA2zaIBX1b3AN5Oc3i1aD9wG7AA2dcs2AdvHUqEkaahVPbf7E+CDSY4G7gb+gEH4X5XkEuAe4KLxlChJGqZXgFfVTcD0kFXrR1uOJKkv78SUpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RG9b2VXpIetHbLx5e1/57LLhhRJY9t9sAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSo3pdRphkD/A94CfAA1U1neQE4EpgLbAHuKiq7htPmZKk+Q6nB/7rVXVmVR16Ms8WYGdVnQbsZN6T6iVJ47WcIZQNwLZuehuwcfnlSJL66hvgBXwmyQ1JNnfLVlfV/m76XmD1yKuTJC2o7630L66qfUmeBlyX5I65K6uqktSwHbvA3wxwyimnLKtYSdJDevXAq2pf93oQuBY4CziQZA1A93pwgX23VtV0VU1PTU2NpmpJ0uIBnuTYJE8+NA28DPgqsAPY1G22Cdg+riIlSQ/XZwhlNXBtkkPb/3tVfSrJl4CrklwC3ANcNL4yJUnzLRrgVXU38Nwhy78NrB9HUZKkxXknpiQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDWqd4AnOSrJl5N8rJtfl2RXkt1Jrkxy9PjKlCTNdzg98NcBt8+Zvxx4Z1WdCtwHXDLKwiRJj6xXgCc5GbgAeF83H+Ac4Opuk23AxnEUKEkarm8P/B+ANwA/7eafCtxfVQ9083uBk4btmGRzkpkkM7Ozs8sqVpL0kEUDPMkrgINVdcNSDlBVW6tquqqmp6amlvJPSJKGWNVjmxcBFyY5HzgG+HngXcBxSVZ1vfCTgX3jK1OSNN+iPfCqelNVnVxVa4GLgc9W1e8A1wOv7DbbBGwfW5WSpIdZznXgbwRen2Q3gzHxK0ZTkiSpjz5DKA+qqs8Bn+um7wbOGn1JkqQ+vBNTkhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElq1GF9FookjcLaLR9f8r57LrtghJW0zR64JDXKAJekRhngktQoA1ySGtXnocbHJPlikpuT3Jrkr7rl65LsSrI7yZVJjh5/uZKkQ/r0wH8EnFNVzwXOBM5LcjZwOfDOqjoVuA+4ZHxlSpLm6/NQ46qq73ezj+u+CjgHuLpbvg3YOJYKJUlD9boOPMlRwA3AqcC7gbuA+6vqgW6TvcBJC+y7GdgMcMoppyy3XkmPcV5D/pBeb2JW1U+q6kzgZAYPMn5O3wNU1daqmq6q6ampqSWWKUma77CuQqmq+4HrgRcCxyU51IM/Gdg34tokSY+gz1UoU0mO66afALwUuJ1BkL+y22wTsH1cRUqSHq7PGPgaYFs3Dv5zwFVV9bEktwEfTvI3wJeBK8ZYpyRpnkUDvKpuAZ43ZPndDMbDJUkrwDsxJalRBrgkNcoAl6RGGeCS1CifyCOxvLv7pJViD1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEZ5Hfgilnt98JH2BBBJjx72wCWpUQa4JDXKAJekRvV5pNozk1yf5LYktyZ5Xbf8hCTXJbmzez1+/OVKkg7p0wN/APjzqjoDOBv44yRnAFuAnVV1GrCzm5ckTciiAV5V+6vqxm76ewweaHwSsAHY1m22Ddg4riIlSQ93WGPgSdYyeD7mLmB1Ve3vVt0LrF5gn81JZpLMzM7OLqNUSdJcvQM8yZOAjwKXVtV3566rqgJq2H5VtbWqpqtqempqalnFSpIe0ivAkzyOQXh/sKqu6RYfSLKmW78GODieEiVJw/S5CiXAFcDtVfWOOat2AJu66U3A9tGXJ0laSJ9b6V8E/C7wlSQ3dcv+ErgMuCrJJcA9wEXjKVGSNMyiAV5V/wVkgdXrR1uOJKkv78SUpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGPSaeSr/cJ8tL0qORPXBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUqD6PVHt/koNJvjpn2QlJrktyZ/d6/HjLlCTN16cH/gHgvHnLtgA7q+o0YGc3L0maoEUDvKo+D/zvvMUbgG3d9DZg44jrkiQtYqlj4Kuran83fS+wekT1SJJ6WvabmFVVQC20PsnmJDNJZmZnZ5d7OElSZ6kBfiDJGoDu9eBCG1bV1qqarqrpqampJR5OkjTfUgN8B7Cpm94EbB9NOZKkvvpcRvgh4L+B05PsTXIJcBnw0iR3Aud285KkCVr042Sr6lULrFo/4lokSYfhMfF54JIEy3s2wJ7LLhhhJaPhrfSS1CgDXJIaZYBLUqMMcElqlG9i6ojhw6s1To/GN0DtgUtSo+yBj9mj8be2pCODPXBJapQBLkmNMsAlqVEGuCQ1qpk3Mb1ETJJ+lj1wSWqUAS5JjTLAJalRBrgkNWpZb2ImOQ94F3AU8L6q8tFqI9TqG7fLuYO01f+ztBKW3ANPchTwbuDlwBnAq5KcMarCJEmPbDlDKGcBu6vq7qr6MfBhYMNoypIkLWY5QygnAd+cM78X+OX5GyXZDGzuZr+f5GvLOOZSnAh8a8LHPFxHVI25fMyVLOyIascVZI2j8WCNI/iZeNawhWO/kaeqtgJbx32chSSZqarplTp+H9Y4GtY4GtY4GpOocTlDKPuAZ86ZP7lbJkmagOUE+JeA05KsS3I0cDGwYzRlSZIWs+QhlKp6IMlrgU8zuIzw/VV168gqG50VG745DNY4GtY4GtY4GmOvMVU17mNIksbAOzElqVEGuCQ16ogI8CS/neTWJD9NMj1v3ZuS7E7ytSS/scD+65Ls6ra7sntTdpz1Xpnkpu5rT5KbFthuT5KvdNvNjLOmIcd+W5J9c+o8f4HtzuvadneSLROu8W+T3JHkliTXJjluge0m3o6LtUuSx3fnwe7u3Fs7ibrmHP+ZSa5Pclv3s/O6Idu8JMl35pwDb5lkjV0Nj/i9y8A/du14S5LnT7i+0+e0z01Jvpvk0nnbjK8dq6r5L+AXgNOBzwHTc5afAdwMPB5YB9wFHDVk/6uAi7vp9wJ/OMHa/x54ywLr9gAnrlCbvg34i0W2Oapr02cDR3dtfcYEa3wZsKqbvhy4/NHQjn3aBfgj4L3d9MXAlRP+/q4Bnt9NPxn4nyE1vgT42KTPvcP53gHnA58EApwN7FrBWo8C7gWeNal2PCJ64FV1e1UNu8NzA/DhqvpRVX0d2M3gIwAelCTAOcDV3aJtwMZx1jvv2BcBH5rE8cZgRT9Ooao+U1UPdLNfYHAvwqNBn3bZwOBcg8G5t747HyaiqvZX1Y3d9PeA2xncXd2aDcC/1sAXgOOSrFmhWtYDd1XVPZM64BER4I9g2O3+80/SpwL3zwmCYduMy68CB6rqzgXWF/CZJDd0H0kwaa/t/ix9f5Ljh6zv076T8moGPbFhJt2OfdrlwW26c+87DM7FieuGb54H7Bqy+oVJbk7yySS/ONHCBhb73j2azsGLWbgzNpZ2bOaZmEn+E3j6kFVvrqrtk65nMT3rfRWP3Pt+cVXtS/I04Lokd1TV5ydRI/Ae4O0MfoDezmCo59WjOnZffdoxyZuBB4APLvDPjLUdW5bkScBHgUur6rvzVt/IYDjg+917IP8BnDbhEpv43nXvm10IvGnI6rG1YzMBXlXnLmG3Prf7f5vBn12rup7QSD4SYLF6k6wCfgt4wSP8G/u614NJrmXwp/nITt6+bZrkn4GPDVk19o9T6NGOvw+8Alhf3YDjkH9jrO04RJ92ObTN3u5ceAqDc3FikjyOQXh/sKqumb9+bqBX1SeS/FOSE6tqYh8i1eN792j5SI+XAzdW1YH5K8bZjkf6EMoO4OLuHf91DH7rfXHuBt0P/fXAK7tFm4BJ9OjPBe6oqr3DViY5NsmTD00zeMPuqxOo69Dx544j/uYCx17Rj1PI4IEibwAurKofLrDNSrRjn3bZweBcg8G599mFfgGNQzfefgVwe1W9Y4Ftnn5oXD7JWQzyYmK/ZHp+73YAv9ddjXI28J2q2j+pGudY8K/psbbjSr1jO8ovBgGzF/gRcAD49Jx1b2ZwRcDXgJfPWf4J4Bnd9LMZBPtu4CPA4ydQ8weA18xb9gzgE3Nqurn7upXBkMEk2/TfgK8AtzD4IVkzv8Zu/nwGVzDctQI17mYw/nlT9/Xe+TWuVDsOaxfgrxn8sgE4pjvXdnfn3rMn3HYvZjA8dsuc9jsfeM2h8xJ4bddmNzN4k/hXJlzj0O/dvBrD4MEyd3Xn6/Qka+xqOJZBID9lzrKJtKO30ktSo470IRRJOmIZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalR/w9JJqPvhGhmjQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(alphas[2], bins=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here it looks like the clusters are separated by -4. \n",
    "\n",
    "For each layer, look at the distribution and set the theshold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "thesholds = [0, 0, -4, -3]\n",
    "\n",
    "masks, alphas = BayesianLayers.get_masks(bayes_modules, \n",
    "                                         thesholds, # now we know them\n",
    "                                         return_log_alpha=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pruned architecture: (709+937)-48-140\n"
     ]
    }
   ],
   "source": [
    "print(\"Pruned architecture: ({}+{})-{}-{}\".format(sum(masks[0]), sum(masks[1]), sum(masks[2]), sum(masks[3])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jiant",
   "language": "python",
   "name": "jiant"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
