{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "libgomp: Invalid value for environment variable OMP_NUM_THREADS\n",
      "/home/jovyan/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 0, 0, 1, 1, 1, 2, 2, 2]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[10, 11, 12]\n",
    "[0, 0, 0, 1, 1, 1, 2, 2,2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('/home/jovyan/dnalm/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmt_model_path = Path('/home/jovyan/dnalm/runs/annotation/bert_base_512_lastln_t2t_1000G_bs256_lr_1e-04_linear_fp16/model_2000000/rmt_seglen_512_len42000_maxnsegm_10000_msz_5_bptt-1_lr5e-05_AdamW_cosine_wd1e-04_p10_bs64_it50000/run_1/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_config = json.load((rmt_model_path / 'config.json').open('r')) # it should be config.json that I sent you"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_seq_len: 42000\n",
      "model_cfg: ./data/configs/L12-H768-A12-V32k-preln-lastln.json\n",
      "model_cls: src.gena_lm.modeling_rmt:RMTEncoderForTokenClassification\n",
      "backbone_cls: src.gena_lm.modeling_bert:BertForTokenClassification\n",
      "input_size: 512\n",
      "num_mem_tokens: 5\n",
      "max_n_segments: 10000\n",
      "tokenizer: ./data/tokenizers/t2t_1000h_multi_32k/\n"
     ]
    }
   ],
   "source": [
    "for k in ['input_seq_len', 'model_cfg', 'model_cls', 'backbone_cls', 'input_size', 'num_mem_tokens', 'max_n_segments', 'tokenizer']:\n",
    "    print(f'{k}: {exp_config[k]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig\n",
    "tokenizer = AutoTokenizer.from_pretrained('./data/tokenizers/t2t_1000h_multi_32k/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.gena_lm.modeling_bert import BertForLetterLevelTokenClassification, BertForTokenClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cfg = AutoConfig.from_pretrained('./data/configs/L12-H768-A12-V32k-preln-lastln.json') # here it soulbe config for backbone model, don't change it, you can change only path to it\n",
    "# model_cfg.num_labels = 6\n",
    "# model_cfg.problem_type = 'multi_label_classification'\n",
    "model_cls = BertForLetterLevelTokenClassification\n",
    "model = model_cls(config=model_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cfg_small = AutoConfig.from_pretrained('./data/configs/L12-H768-A12-V32k-preln-lastln-small.json') # here it soulbe config for backbone model, don't change it, you can change only path to it\n",
    "model_cfg_small.num_labels = 6\n",
    "model_cfg_small.problem_type = 'multi_label_classification'\n",
    "model_cls_small = BertForTokenClassification\n",
    "model_small = model_cls_small(config=model_cfg_small)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmt_config = {\n",
    "            'num_mem_tokens': exp_config['num_mem_tokens'],\n",
    "            'max_n_segments': exp_config['max_n_segments'],\n",
    "            'input_size': exp_config['input_size'],\n",
    "            'bptt_depth': -1,\n",
    "            'sum_loss': True,\n",
    "            'tokenizer': tokenizer\n",
    "        }\n",
    "from src.gena_lm.modeling_rmt import RMTEncoderForLetterLevelTokenClassification\n",
    "rmt_cls = RMTEncoderForLetterLevelTokenClassification\n",
    "model_rmt = rmt_cls(model, model_small, **rmt_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "missing: ['sub_model.bert.embeddings.position_ids', 'sub_model.bert.embeddings.word_embeddings.weight', 'sub_model.bert.embeddings.position_embeddings.weight', 'sub_model.bert.embeddings.token_type_embeddings.weight', 'sub_model.bert.embeddings.LayerNorm.weight', 'sub_model.bert.embeddings.LayerNorm.bias', 'sub_model.bert.encoder.layer.0.pre_attention_ln.weight', 'sub_model.bert.encoder.layer.0.pre_attention_ln.bias', 'sub_model.bert.encoder.layer.0.post_attention_ln.weight', 'sub_model.bert.encoder.layer.0.post_attention_ln.bias', 'sub_model.bert.encoder.layer.0.attention.self.query.weight', 'sub_model.bert.encoder.layer.0.attention.self.query.bias', 'sub_model.bert.encoder.layer.0.attention.self.key.weight', 'sub_model.bert.encoder.layer.0.attention.self.key.bias', 'sub_model.bert.encoder.layer.0.attention.self.value.weight', 'sub_model.bert.encoder.layer.0.attention.self.value.bias', 'sub_model.bert.encoder.layer.0.attention.output.dense.weight', 'sub_model.bert.encoder.layer.0.attention.output.dense.bias', 'sub_model.bert.encoder.layer.0.intermediate.dense.weight', 'sub_model.bert.encoder.layer.0.intermediate.dense.bias', 'sub_model.bert.encoder.layer.0.output.dense.weight', 'sub_model.bert.encoder.layer.0.output.dense.bias', 'sub_model.bert.encoder.layer.1.pre_attention_ln.weight', 'sub_model.bert.encoder.layer.1.pre_attention_ln.bias', 'sub_model.bert.encoder.layer.1.post_attention_ln.weight', 'sub_model.bert.encoder.layer.1.post_attention_ln.bias', 'sub_model.bert.encoder.layer.1.attention.self.query.weight', 'sub_model.bert.encoder.layer.1.attention.self.query.bias', 'sub_model.bert.encoder.layer.1.attention.self.key.weight', 'sub_model.bert.encoder.layer.1.attention.self.key.bias', 'sub_model.bert.encoder.layer.1.attention.self.value.weight', 'sub_model.bert.encoder.layer.1.attention.self.value.bias', 'sub_model.bert.encoder.layer.1.attention.output.dense.weight', 'sub_model.bert.encoder.layer.1.attention.output.dense.bias', 'sub_model.bert.encoder.layer.1.intermediate.dense.weight', 'sub_model.bert.encoder.layer.1.intermediate.dense.bias', 'sub_model.bert.encoder.layer.1.output.dense.weight', 'sub_model.bert.encoder.layer.1.output.dense.bias', 'sub_model.bert.encoder.last_layer_ln.weight', 'sub_model.bert.encoder.last_layer_ln.bias', 'sub_model.classifier.weight', 'sub_model.classifier.bias', 'sub_model.embeddings.weight']\n",
      "unexpected_k: ['model.classifier.weight', 'model.classifier.bias']\n"
     ]
    }
   ],
   "source": [
    "# load pre-trained weights\n",
    "ckpt = torch.load(str('/home/jovyan/dnalm/runs/annotation/bert_base_512_lastln_t2t_1000G_bs256_lr_1e-04_linear_fp16/model_2000000/rmt_seglen_512_len42000_maxnsegm_10000_msz_5_bptt-1_lr5e-05_AdamW_cosine_wd1e-04_p10_bs64_it50000/run_1/model_best/accelerate_state/pytorch_model.bin'), map_location='cpu')\n",
    "missing_k, unexpected_k = model_rmt.load_state_dict(ckpt, strict=False)\n",
    "print(f'missing: {missing_k}') # if no missing tensors - that is correct, otherwise - no!\n",
    "print(f'unexpected_k: {unexpected_k}') # if no missing tensors - that is correct, otherwise - no!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    }
   ],
   "source": [
    "seq = 'ATGC'*2900\n",
    "seq2 = 'ATGC'*900\n",
    "sequences = [seq, seq2]\n",
    "input_features = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True)\n",
    "\n",
    "input_features['labels_mask'] = (input_features['input_ids'] > 5).long() # dumb realization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'TTTTC'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 1,  ..., 1, 1, 0],\n",
       "        [0, 1, 1,  ..., 0, 0, 0]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['labels_mask']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # check the length of the sequence in tokens with no PADDING (but with SEP and CLS)\n",
    "# no_padding_input_features = tokenizer([seq, seq2], return_tensors='pt', padding=True, truncation=True)['input_ids']\n",
    "# no_padding_input_features.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1452])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['input_ids'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels_mask'])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_features['labels'] = torch.randint(0, 6, (len(sequences), input_features['input_ids'].shape[1], 6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_features['letter_level_labels'] = torch.randint(0, 1, (len(sequences), len(seq), 6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1452, 6])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['labels'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 11600, 6])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['letter_level_labels'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 11600])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['embedding_repeater'] = torch.randint(2, 10, (len(sequences), len(seq),))\n",
    "input_features['embedding_repeater'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 11600])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['letter_level_tokens'] = (torch.ones((len(sequences), len(seq))) + 20).long()\n",
    "input_features['letter_level_tokens'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "lllm = torch.tensor([[1] * len(seq), [1] * len(seq2) + [0] * (len(seq) - len(seq2))]).bool()\n",
    "input_features['letter_level_labels_mask'] = lllm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels_mask', 'labels', 'letter_level_labels', 'embedding_repeater', 'letter_level_tokens', 'letter_level_labels_mask'])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "google\n",
      "torch.Size([1, 1450, 768])\n",
      "torch.Size([1, 11600, 6])\n",
      "google\n",
      "torch.Size([1, 450, 768])\n",
      "torch.Size([1, 11600, 6])\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    out = model_rmt(**input_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['loss', 'logits'])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8631)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 11600, 6])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 19600, 6])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.functional.pad(out.logits, (0, 0, 0, len(seq) - len(seq2))).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "The shape of the mask [1452] at index 0 does not match the shape of the indexed tensor [2, 11600, 6] at index 1",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[31], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# you can handle it like this\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mout\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogits\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_features\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n",
      "\u001b[0;31mIndexError\u001b[0m: The shape of the mask [1452] at index 0 does not match the shape of the indexed tensor [2, 11600, 6] at index 1"
     ]
    }
   ],
   "source": [
    "# you can handle it like this\n",
    "out.logits[:, input_features['input_ids'][0] > 5, :].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.logits[0, 0, :] # use softmax to get probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gena_ipynb",
   "language": "python",
   "name": "gena_ipynb"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
