{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "libgomp: Invalid value for environment variable OMP_NUM_THREADS\n",
      "/home/jovyan/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "from pathlib import Path\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('/home/jovyan/dnalm/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmt_model_path = Path('/home/jovyan/dnalm/model_hub/rmt_bert_base_lastln_t2t_1000G_seglen_512_len_3992_maxnsegm_8_msz_10_bptt-1_bs256_lr_1e-05_wd_1e-04_fp16_O2/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_config = json.load((rmt_model_path / 'config.json').open('r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_seq_len: 3992\n",
      "model_cfg: ./bert_configs/L12-H768-A12-V32k-preln-lastln.json\n",
      "model_cls: modeling_rmt:RMTEncoderForMaskedLM\n",
      "backbone_cls: modeling_bert:BertForMaskedLM\n",
      "input_size: 512\n",
      "num_mem_tokens: 10\n",
      "max_n_segments: 8\n",
      "tokenizer: /home/jovyan/dnalm/data/tokenizers/t2t_1000h_multi_32k\n"
     ]
    }
   ],
   "source": [
    "for k in ['input_seq_len', 'model_cfg', 'model_cls', 'backbone_cls', 'input_size', 'num_mem_tokens', 'max_n_segments', 'tokenizer']:\n",
    "    print(f'{k}: {exp_config[k]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig\n",
    "tokenizer = AutoTokenizer.from_pretrained('./data/tokenizers/t2t_1000h_multi_32k/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.gena_lm.modeling_bert import BertForTokenClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cfg = AutoConfig.from_pretrained('./data/configs/L12-H768-A12-V32k-preln-lastln.json')\n",
    "model_cls = BertForTokenClassification\n",
    "model = model_cls(config=model_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmt_config = {\n",
    "            'num_mem_tokens': exp_config['num_mem_tokens'],\n",
    "            'max_n_segments': exp_config['max_n_segments'],\n",
    "            'input_size': exp_config['input_size'],\n",
    "            'bptt_depth': -1,\n",
    "            'sum_loss': True,\n",
    "            'tokenizer': tokenizer,\n",
    "        }\n",
    "from src.gena_lm.modeling_rmt import RMTEncoderForLetterLevelTokenClassification\n",
    "rmt_cls = RMTEncoderForLetterLevelTokenClassification\n",
    "model = rmt_cls(model, **rmt_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "missing: ['model.classifier.weight', 'model.classifier.bias']\n",
      "unexpected_k: ['model.cls.predictions.bias', 'model.cls.predictions.transform.dense.weight', 'model.cls.predictions.transform.dense.bias', 'model.cls.predictions.transform.LayerNorm.weight', 'model.cls.predictions.transform.LayerNorm.bias', 'model.cls.predictions.decoder.weight', 'model.cls.predictions.decoder.bias']\n"
     ]
    }
   ],
   "source": [
    "# load pre-trained weights\n",
    "ckpt = torch.load(str(rmt_model_path / 'model_best.pth'), map_location='cpu')\n",
    "missing_k, unexpected_k = model.load_state_dict(ckpt[\"model_state_dict\"], strict=False)\n",
    "print(f'missing: {missing_k}')\n",
    "print(f'unexpected_k: {unexpected_k}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_token_ids = {'input_ids': tokenizer.pad_token_id, 'token_type_ids': 0, 'attention_mask': 0}\n",
    "pad_to_divisible_by = 1\n",
    "# you can use it to build batches (mb remove code for labels)\n",
    "def collate_fn(batch):\n",
    "    feature_keys = ['input_ids', 'token_type_ids', 'attention_mask']\n",
    "    padded_batch = {k: [] for k in feature_keys}\n",
    "    max_seq_len = max([len(el['input_ids']) for el in batch])\n",
    "    max_seq_len += (\n",
    "        (pad_to_divisible_by - max_seq_len % pad_to_divisible_by)\n",
    "        if max_seq_len % pad_to_divisible_by != 0\n",
    "        else 0\n",
    "    )\n",
    "    for k in feature_keys:\n",
    "        for i, el in enumerate(batch):\n",
    "            padded_batch[k] += [\n",
    "                np.concatenate(\n",
    "                    [\n",
    "                        batch[i][k],\n",
    "                        np.array([pad_token_ids[k]] * max(0, max_seq_len - len(el[k])), dtype=np.int64),\n",
    "                    ]\n",
    "                )\n",
    "            ]\n",
    "    for k in padded_batch:\n",
    "        padded_batch[k] = torch.from_numpy(np.stack(padded_batch[k]))\n",
    "    padded_batch['labels'] = torch.tensor([el['labels'] for el in batch])\n",
    "    return padded_batch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 952])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq = 'ATGC'*1900\n",
    "input_features = tokenizer([seq], return_tensors='pt')\n",
    "input_features['input_ids'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jovyan/envs/py38_cu114/lib/python3.8/site-packages/transformers/modeling_utils.py:881: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    out = model(**tokenizer([seq], return_tensors='pt'), output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['logits', 'hidden_states', 'hidden_states_0', 'hidden_states_1', 'loss_0', 'loss_1', 'loss'])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0 token is always CLS:\n",
    "# segm 1: CLS MEM MEM MEM ... MEM SEQUENCE_part_1 SEP PAD PAD ...\n",
    "# segm 2: CLS MEM MEM MEM ... MEM SEQUENCE_part_2 SEP PAD PAD ...\n",
    "..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 512, 768])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out['hidden_states_0'][-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 512, 768])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out['hidden_states_1'][-1].shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gena_ibynb",
   "language": "python",
   "name": "gena_ibynb"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
