{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.metrics import average_precision_score\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('/home/jovyan/dnalm/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmt_model_path = Path('/home/jovyan/dnalm/runs/annotation_bert_large_shawerma_continued_from_best_5_classes_no_intergenic_combined_8k_bpe_full_up_to_250k_exon_level_choosing_middle_lr_big_wd_UNET_segmented_UNET_repeater_3_step/bert_large_512_lastln_t2t_1000G_bs256_lr_1e-04_fp16/model_1750000/rmt_seglen_512_len8192_maxnsegm_10000_msz_5_bptt-1_lr1e-05_AdamW_constant_with_warmup_wd1e-04_p10000_bs_it500000/run_1/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_config = json.load((rmt_model_path / 'config.json').open('r')) # it should be config.json that I sent you"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_seq_len: 8192\n",
      "model_cfg: ./data/configs/L24-H1024-A16-V32k-preln-lastln.json\n",
      "model_cls: src.gena_lm.modeling_rmt:RMTEncoderForLetterLevelTokenClassificationUNETsegmentedRepeater\n",
      "backbone_cls: src.gena_lm.modeling_bert:BertForLetterLevelTokenClassification\n",
      "input_size: 512\n",
      "num_mem_tokens: 5\n",
      "max_n_segments: 10000\n",
      "tokenizer: ./data/tokenizers/t2t_1000h_multi_32k/\n"
     ]
    }
   ],
   "source": [
    "for k in ['input_seq_len', 'model_cfg', 'model_cls', 'backbone_cls', 'input_size', 'num_mem_tokens', 'max_n_segments', 'tokenizer']:\n",
    "    print(f'{k}: {exp_config[k]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jovyan/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig\n",
    "tokenizer = AutoTokenizer.from_pretrained('./data/tokenizers/t2t_1000h_multi_32k/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.gena_lm.modeling_bert import BertForLetterLevelTokenClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cfg = AutoConfig.from_pretrained('./data/configs/L24-H1024-A16-V32k-preln-lastln-copy.json') # here it soulbe config for backbone model, don't change it, you can change only path to it\n",
    "model_cfg.num_labels = 5\n",
    "model_cfg.problem_type = 'multi_label_classification'\n",
    "model_cls = BertForLetterLevelTokenClassification\n",
    "model = model_cls(config=model_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmt_config = {\n",
    "            'num_mem_tokens': exp_config['num_mem_tokens'],\n",
    "            'max_n_segments': exp_config['max_n_segments'],\n",
    "            'input_size': exp_config['input_size'],\n",
    "            'bptt_depth': -1,\n",
    "            'sum_loss': True,\n",
    "            'tokenizer': tokenizer\n",
    "        }\n",
    "from src.gena_lm.modeling_rmt import RMTEncoderForLetterLevelTokenClassificationUnetSegmentedEmbeddingOnly\n",
    "rmt_cls = RMTEncoderForLetterLevelTokenClassificationUnetSegmentedEmbeddingOnly\n",
    "model = rmt_cls(model, **rmt_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "missing: []\n",
      "unexpected_k: []\n"
     ]
    }
   ],
   "source": [
    "# load pre-trained weights\n",
    "ckpt = torch.load(str('/home/jovyan/dnalm/runs/annotation_bert_large_shawerma_continued_from_best_5_classes_no_intergenic_combined_8k_bpe_full_up_to_250k_exon_level_choosing_middle_lr_big_wd_UNET_segmented_UNET_repeater_3_step/bert_large_512_lastln_t2t_1000G_bs256_lr_1e-04_fp16/model_1750000/rmt_seglen_512_len8192_maxnsegm_10000_msz_5_bptt-1_lr1e-05_AdamW_constant_with_warmup_wd1e-04_p10000_bs_it500000/run_1/model_best/pytorch_model.bin'), map_location='cpu')\n",
    "missing_k, unexpected_k = model.load_state_dict(ckpt, strict=False)\n",
    "print(f'missing: {missing_k}') # if no missing tensors - that is correct, otherwise - no!\n",
    "print(f'unexpected_k: {unexpected_k}') # if no missing tensors - that is correct, otherwise - no!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval()\n",
    "# model.half()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq = 'ATGC'*1234\n",
    "input_features = tokenizer(seq, return_tensors='pt')\n",
    "\n",
    "input_features['labels_mask'] = input_features['attention_mask'] # dumb realization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 619])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['input_ids'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels_mask'])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_features['labels'] = torch.ones((input_features['input_ids'].shape[1], 5)).unsqueeze(axis=0) # yeah, for now you must specify whatever labels, model won't work without them, it does not change the prediction\n",
    "# input_features['labels'] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 619, 5])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['labels'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_features['labels'].T[None, :, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 619])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_features['labels_mask'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jovyan/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/transformers/modeling_utils.py:1101: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# with torch.autocast(device_type='cuda', dtype=torch.float16):\n",
    "with torch.no_grad():\n",
    "    out = model(**input_features, output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['all_memory_embeddings', 'labels_segm', 'rmt_logits_masks_segm'])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.1612, -0.3453, -0.0903,  ..., -0.7527,  0.7983,  0.5882],\n",
       "         [ 0.1987, -0.3132, -0.0816,  ..., -0.7530,  0.7953,  0.5135],\n",
       "         [ 0.1815, -0.3306, -0.0857,  ..., -0.7753,  0.8121,  0.4802],\n",
       "         [ 0.1868, -0.3838, -0.1109,  ..., -0.7682,  0.8198,  0.5138],\n",
       "         [ 0.1898, -0.4034, -0.1091,  ..., -0.7697,  0.8323,  0.5414]]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out['all_memory_embeddings'][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 5, 1024])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out['all_memory_embeddings'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 512, 5])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out['labels_segm'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 512])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out['rmt_logits_masks_segm'].shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gena_ipynb",
   "language": "python",
   "name": "gena_ipynb"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
