{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chenjian/miniconda3/envs/laecgm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from transformers import (GPT2Config, GPT2TokenizerFast,\n",
    "                          GPT2LMHeadModel, PretrainedConfig, EncoderDecoderModel)\n",
    "from transformers.modeling_outputs import BaseModelOutput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../utils\")\n",
    "import utils_dataset\n",
    "import yaml\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load mimic dataset!\n",
      "train size: 756259\n",
      "val size: 15434\n",
      "total size: 771693\n",
      "Apply Train-stage Transform!\n",
      "train dataset length:  756259\n",
      "Apply Val-stage Transform!\n",
      "val dataset length:  15434\n"
     ]
    }
   ],
   "source": [
    "data_path = '/yourpath/pretrain_data/'\n",
    "dataset_name = 'mimic'\n",
    "dataset = utils_dataset.ECG_TEXT_Dsataset(\n",
    "        data_path=data_path, dataset_name=dataset_name)\n",
    "train_dataset = dataset.get_dataset(train_test='train')\n",
    "val_dataset = dataset.get_dataset(train_test='val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sinus rhythm. normal ecg.\n",
      "sinus bradycardia. prolonged qt interval. borderline ecg.\n",
      "sinus rhythm.. inferior t wave changes are nonspecific. borderline ecg.\n"
     ]
    }
   ],
   "source": [
    "print(val_dataset[4]['raw_text'])\n",
    "print(val_dataset[1]['raw_text'])\n",
    "print(val_dataset[2]['raw_text'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from models.model import ECGCLIP\n",
    "\n",
    "ckpt_path = '/yourpath/bestZeroShotAll_ckpt.pth'\n",
    "ckpt = torch.load(ckpt_path, map_location='cpu')\n",
    "config = yaml.load(open(\"/yourpath/finetune/config.yaml\", \"r\"), Loader=yaml.FullLoader)\n",
    "encoder = ECGCLIP(config['network'])\n",
    "encoder.load_state_dict(ckpt, strict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
       "  (embeddings): BertEmbeddings(\n",
       "    (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
       "    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (encoder): BertEncoder(\n",
       "    (layer): ModuleList(\n",
       "      (0-11): 12 x BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pooler): BertPooler(\n",
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder.lm_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_path = '/yourpath/distilgpt2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at /home/chenjian/multi-modal_ECG/distilgpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.2.crossattention.q_attn.bias', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.2.ln_cross_attn.bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_attn.bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.bias', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_attn.bias', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.4.crossattention.q_attn.bias', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.4.ln_cross_attn.bias', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_attn.bias', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.5.crossattention.q_attn.bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.5.ln_cross_attn.bias', 'transformer.h.5.ln_cross_attn.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Config of the encoder: <class 'model_gpt2.ERGPT2.__init__.<locals>.DummyEncoder'> is overwritten by shared encoder config: BertConfig {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": null,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.41.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load pretrained distillgpt2\n",
      "Description, Special token, Index\n",
      "bos_token, [BOS], 50257\n",
      "eos_token, <|endoftext|>, 50256\n",
      "unk_token, <|endoftext|>, 50256\n",
      "pad_token, [PAD], 50258\n"
     ]
    }
   ],
   "source": [
    "from model_gpt2 import ERGPT2\n",
    "\n",
    "model = ERGPT2(encoder=encoder, decoder_path=decoder_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=16,\n",
    "                                  num_workers=1,\n",
    "                                  drop_last=True, shuffle=False,\n",
    "                                  )\n",
    "        \n",
    "val_loader = DataLoader(val_dataset, batch_size=32,\n",
    "                        num_workers=1,\n",
    "                        drop_last=True, shuffle=False,\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ERGPT2(\n",
       "  (encoder): ECGCLIP(\n",
       "    (downconv): Conv1d(512, 256, kernel_size=(1,), stride=(3,))\n",
       "    (att_pool_head): AttentionPool2d(\n",
       "      (mhsa): MultiheadAttention(\n",
       "        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "      (c_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "    )\n",
       "    (linear1): AttentionPool2d(\n",
       "      (mhsa): MultiheadAttention(\n",
       "        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "      (c_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "    )\n",
       "    (linear2): AttentionPool2d(\n",
       "      (mhsa): MultiheadAttention(\n",
       "        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "      (c_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "    )\n",
       "    (decode_t): Transformer(\n",
       "      (to_patch): Mlp(\n",
       "        (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (act): SiLU()\n",
       "        (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (drop): Dropout(p=0.2, inplace=False)\n",
       "        (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "      (dropout): Dropout(p=0.0, inplace=False)\n",
       "      (block0): TransformerBlock(\n",
       "        (attn): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Attention(\n",
       "            (attend): Softmax(dim=-1)\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "            (to_qkv): Linear(in_features=256, out_features=1536, bias=True)\n",
       "            (to_out): Sequential(\n",
       "              (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "              (1): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (droppath1): Identity()\n",
       "        (ff): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Mlp(\n",
       "            (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (act): SiLU()\n",
       "            (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (droppath2): Identity()\n",
       "      )\n",
       "      (block1): TransformerBlock(\n",
       "        (attn): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Attention(\n",
       "            (attend): Softmax(dim=-1)\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "            (to_qkv): Linear(in_features=256, out_features=1536, bias=True)\n",
       "            (to_out): Sequential(\n",
       "              (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "              (1): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (droppath1): DropPath()\n",
       "        (ff): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Mlp(\n",
       "            (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (act): SiLU()\n",
       "            (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (droppath2): DropPath()\n",
       "      )\n",
       "      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "      (head): Mlp(\n",
       "        (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (act): SiLU()\n",
       "        (fc2): Linear(in_features=256, out_features=768, bias=True)\n",
       "        (drop): Dropout(p=0.2, inplace=False)\n",
       "        (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (decode_e): Transformer(\n",
       "      (to_patch): Mlp(\n",
       "        (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (act): SiLU()\n",
       "        (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (drop): Dropout(p=0.2, inplace=False)\n",
       "        (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "      (dropout): Dropout(p=0.0, inplace=False)\n",
       "      (block0): TransformerBlock(\n",
       "        (attn): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Attention(\n",
       "            (attend): Softmax(dim=-1)\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "            (to_qkv): Linear(in_features=256, out_features=1536, bias=True)\n",
       "            (to_out): Sequential(\n",
       "              (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "              (1): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (droppath1): Identity()\n",
       "        (ff): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Mlp(\n",
       "            (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (act): SiLU()\n",
       "            (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (droppath2): Identity()\n",
       "      )\n",
       "      (block1): TransformerBlock(\n",
       "        (attn): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Attention(\n",
       "            (attend): Softmax(dim=-1)\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "            (to_qkv): Linear(in_features=256, out_features=1536, bias=True)\n",
       "            (to_out): Sequential(\n",
       "              (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "              (1): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (droppath1): DropPath()\n",
       "        (ff): PreNorm(\n",
       "          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (fn): Mlp(\n",
       "            (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (act): SiLU()\n",
       "            (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (droppath2): DropPath()\n",
       "      )\n",
       "      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "      (head): Mlp(\n",
       "        (fc1): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (act): SiLU()\n",
       "        (fc2): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (drop): Dropout(p=0.2, inplace=False)\n",
       "        (fc3): Linear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (ecg_encoder): ResNet(\n",
       "      (conv1): Conv1d(12, 64, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)\n",
       "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (layer1): Sequential(\n",
       "        (0): BasicBlock(\n",
       "          (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential()\n",
       "        )\n",
       "        (1): BasicBlock(\n",
       "          (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential()\n",
       "        )\n",
       "      )\n",
       "      (layer2): Sequential(\n",
       "        (0): BasicBlock(\n",
       "          (conv1): Conv1d(64, 128, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential(\n",
       "            (0): Conv1d(64, 128, kernel_size=(1,), stride=(2,), bias=False)\n",
       "            (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (1): BasicBlock(\n",
       "          (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential()\n",
       "        )\n",
       "      )\n",
       "      (layer3): Sequential(\n",
       "        (0): BasicBlock(\n",
       "          (conv1): Conv1d(128, 256, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential(\n",
       "            (0): Conv1d(128, 256, kernel_size=(1,), stride=(2,), bias=False)\n",
       "            (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (1): BasicBlock(\n",
       "          (conv1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential()\n",
       "        )\n",
       "      )\n",
       "      (layer4): Sequential(\n",
       "        (0): BasicBlock(\n",
       "          (conv1): Conv1d(256, 512, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential(\n",
       "            (0): Conv1d(256, 512, kernel_size=(1,), stride=(2,), bias=False)\n",
       "            (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (1): BasicBlock(\n",
       "          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
       "          (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (shortcut): Sequential()\n",
       "        )\n",
       "      )\n",
       "      (linear): Linear(in_features=512, out_features=10, bias=True)\n",
       "      (head): Identity()\n",
       "      (avgpool): AdaptiveAvgPool1d(output_size=1)\n",
       "    )\n",
       "    (avgpool): AvgPool1d(kernel_size=(1,), stride=(1,), padding=(0,))\n",
       "    (dropout1): Dropout(p=0.1, inplace=False)\n",
       "    (dropout2): Dropout(p=0.1, inplace=False)\n",
       "    (lm_model): BertModel(\n",
       "      (embeddings): BertEmbeddings(\n",
       "        (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "        (position_embeddings): Embedding(512, 768)\n",
       "        (token_type_embeddings): Embedding(2, 768)\n",
       "        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (encoder): BertEncoder(\n",
       "        (layer): ModuleList(\n",
       "          (0-11): 12 x BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (pooler): BertPooler(\n",
       "        (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (activation): Tanh()\n",
       "      )\n",
       "    )\n",
       "    (proj_t): Sequential(\n",
       "      (0): Linear(in_features=768, out_features=256, bias=True)\n",
       "      (1): GELU(approximate='none')\n",
       "      (2): Linear(in_features=256, out_features=256, bias=True)\n",
       "    )\n",
       "    (head): Linear(in_features=256, out_features=768, bias=True)\n",
       "  )\n",
       "  (decoder): Decoder(\n",
       "    (encoder_decoder): EncoderDecoderModel(\n",
       "      (decoder): GPT2LMHeadModel(\n",
       "        (transformer): GPT2Model(\n",
       "          (wte): Embedding(50259, 768)\n",
       "          (wpe): Embedding(1024, 768)\n",
       "          (drop): Dropout(p=0.1, inplace=False)\n",
       "          (h): ModuleList(\n",
       "            (0-5): 6 x GPT2Block(\n",
       "              (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "              (attn): GPT2Attention(\n",
       "                (c_attn): Conv1D()\n",
       "                (c_proj): Conv1D()\n",
       "                (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "                (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "              (crossattention): GPT2Attention(\n",
       "                (c_attn): Conv1D()\n",
       "                (q_attn): Conv1D()\n",
       "                (c_proj): Conv1D()\n",
       "                (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "                (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (ln_cross_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "              (mlp): GPT2MLP(\n",
       "                (c_fc): Conv1D()\n",
       "                (c_proj): Conv1D()\n",
       "                (act): NewGELUActivation()\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "          (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (lm_head): Linear(in_features=768, out_features=50259, bias=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.to('cpu')\n",
    "cpkt = torch.load('/yourpath/report/ckpt/DisGPT2_0_ckpt.pth', map_location='cpu')\n",
    "\n",
    "model.load_state_dict(cpkt['model_state_dict'], strict=True)\n",
    "model.to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['encoder.downconv.weight', 'encoder.downconv.bias', 'encoder.att_pool_head.positional_embedding', 'encoder.att_pool_head.sep_embedding', 'encoder.att_pool_head.mhsa.in_proj_weight', 'encoder.att_pool_head.mhsa.in_proj_bias', 'encoder.att_pool_head.mhsa.out_proj.weight', 'encoder.att_pool_head.mhsa.out_proj.bias', 'encoder.att_pool_head.c_proj.weight', 'encoder.att_pool_head.c_proj.bias', 'encoder.linear1.positional_embedding', 'encoder.linear1.sep_embedding', 'encoder.linear1.mhsa.in_proj_weight', 'encoder.linear1.mhsa.in_proj_bias', 'encoder.linear1.mhsa.out_proj.weight', 'encoder.linear1.mhsa.out_proj.bias', 'encoder.linear1.c_proj.weight', 'encoder.linear1.c_proj.bias', 'encoder.linear2.positional_embedding', 'encoder.linear2.sep_embedding', 'encoder.linear2.mhsa.in_proj_weight', 'encoder.linear2.mhsa.in_proj_bias', 'encoder.linear2.mhsa.out_proj.weight', 'encoder.linear2.mhsa.out_proj.bias', 'encoder.linear2.c_proj.weight', 'encoder.linear2.c_proj.bias', 'encoder.decode_t.pos_embedding', 'encoder.decode_t.sep_embedding', 'encoder.decode_t.to_patch.fc1.weight', 'encoder.decode_t.to_patch.fc1.bias', 'encoder.decode_t.to_patch.fc2.weight', 'encoder.decode_t.to_patch.fc2.bias', 'encoder.decode_t.to_patch.fc3.weight', 'encoder.decode_t.to_patch.fc3.bias', 'encoder.decode_t.block0.attn.norm.weight', 'encoder.decode_t.block0.attn.norm.bias', 'encoder.decode_t.block0.attn.fn.to_qkv.weight', 'encoder.decode_t.block0.attn.fn.to_qkv.bias', 'encoder.decode_t.block0.attn.fn.to_out.0.weight', 'encoder.decode_t.block0.attn.fn.to_out.0.bias', 'encoder.decode_t.block0.ff.norm.weight', 'encoder.decode_t.block0.ff.norm.bias', 'encoder.decode_t.block0.ff.fn.fc1.weight', 'encoder.decode_t.block0.ff.fn.fc1.bias', 'encoder.decode_t.block0.ff.fn.fc2.weight', 'encoder.decode_t.block0.ff.fn.fc2.bias', 'encoder.decode_t.block0.ff.fn.fc3.weight', 'encoder.decode_t.block0.ff.fn.fc3.bias', 'encoder.decode_t.block1.attn.norm.weight', 'encoder.decode_t.block1.attn.norm.bias', 'encoder.decode_t.block1.attn.fn.to_qkv.weight', 'encoder.decode_t.block1.attn.fn.to_qkv.bias', 'encoder.decode_t.block1.attn.fn.to_out.0.weight', 'encoder.decode_t.block1.attn.fn.to_out.0.bias', 'encoder.decode_t.block1.ff.norm.weight', 'encoder.decode_t.block1.ff.norm.bias', 'encoder.decode_t.block1.ff.fn.fc1.weight', 'encoder.decode_t.block1.ff.fn.fc1.bias', 'encoder.decode_t.block1.ff.fn.fc2.weight', 'encoder.decode_t.block1.ff.fn.fc2.bias', 'encoder.decode_t.block1.ff.fn.fc3.weight', 'encoder.decode_t.block1.ff.fn.fc3.bias', 'encoder.decode_t.norm.weight', 'encoder.decode_t.norm.bias', 'encoder.decode_t.head.fc1.weight', 'encoder.decode_t.head.fc1.bias', 'encoder.decode_t.head.fc2.weight', 'encoder.decode_t.head.fc2.bias', 'encoder.decode_t.head.fc3.weight', 'encoder.decode_t.head.fc3.bias', 'encoder.decode_e.pos_embedding', 'encoder.decode_e.sep_embedding', 'encoder.decode_e.to_patch.fc1.weight', 'encoder.decode_e.to_patch.fc1.bias', 'encoder.decode_e.to_patch.fc2.weight', 'encoder.decode_e.to_patch.fc2.bias', 'encoder.decode_e.to_patch.fc3.weight', 'encoder.decode_e.to_patch.fc3.bias', 'encoder.decode_e.block0.attn.norm.weight', 'encoder.decode_e.block0.attn.norm.bias', 'encoder.decode_e.block0.attn.fn.to_qkv.weight', 'encoder.decode_e.block0.attn.fn.to_qkv.bias', 'encoder.decode_e.block0.attn.fn.to_out.0.weight', 'encoder.decode_e.block0.attn.fn.to_out.0.bias', 'encoder.decode_e.block0.ff.norm.weight', 'encoder.decode_e.block0.ff.norm.bias', 'encoder.decode_e.block0.ff.fn.fc1.weight', 'encoder.decode_e.block0.ff.fn.fc1.bias', 'encoder.decode_e.block0.ff.fn.fc2.weight', 'encoder.decode_e.block0.ff.fn.fc2.bias', 'encoder.decode_e.block0.ff.fn.fc3.weight', 'encoder.decode_e.block0.ff.fn.fc3.bias', 'encoder.decode_e.block1.attn.norm.weight', 'encoder.decode_e.block1.attn.norm.bias', 'encoder.decode_e.block1.attn.fn.to_qkv.weight', 'encoder.decode_e.block1.attn.fn.to_qkv.bias', 'encoder.decode_e.block1.attn.fn.to_out.0.weight', 'encoder.decode_e.block1.attn.fn.to_out.0.bias', 'encoder.decode_e.block1.ff.norm.weight', 'encoder.decode_e.block1.ff.norm.bias', 'encoder.decode_e.block1.ff.fn.fc1.weight', 'encoder.decode_e.block1.ff.fn.fc1.bias', 'encoder.decode_e.block1.ff.fn.fc2.weight', 'encoder.decode_e.block1.ff.fn.fc2.bias', 'encoder.decode_e.block1.ff.fn.fc3.weight', 'encoder.decode_e.block1.ff.fn.fc3.bias', 'encoder.decode_e.norm.weight', 'encoder.decode_e.norm.bias', 'encoder.decode_e.head.fc1.weight', 'encoder.decode_e.head.fc1.bias', 'encoder.decode_e.head.fc2.weight', 'encoder.decode_e.head.fc2.bias', 'encoder.decode_e.head.fc3.weight', 'encoder.decode_e.head.fc3.bias', 'encoder.ecg_encoder.conv1.weight', 'encoder.ecg_encoder.bn1.weight', 'encoder.ecg_encoder.bn1.bias', 'encoder.ecg_encoder.bn1.running_mean', 'encoder.ecg_encoder.bn1.running_var', 'encoder.ecg_encoder.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer1.0.conv1.weight', 'encoder.ecg_encoder.layer1.0.bn1.weight', 'encoder.ecg_encoder.layer1.0.bn1.bias', 'encoder.ecg_encoder.layer1.0.bn1.running_mean', 'encoder.ecg_encoder.layer1.0.bn1.running_var', 'encoder.ecg_encoder.layer1.0.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer1.0.conv2.weight', 'encoder.ecg_encoder.layer1.0.bn2.weight', 'encoder.ecg_encoder.layer1.0.bn2.bias', 'encoder.ecg_encoder.layer1.0.bn2.running_mean', 'encoder.ecg_encoder.layer1.0.bn2.running_var', 'encoder.ecg_encoder.layer1.0.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer1.1.conv1.weight', 'encoder.ecg_encoder.layer1.1.bn1.weight', 'encoder.ecg_encoder.layer1.1.bn1.bias', 'encoder.ecg_encoder.layer1.1.bn1.running_mean', 'encoder.ecg_encoder.layer1.1.bn1.running_var', 'encoder.ecg_encoder.layer1.1.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer1.1.conv2.weight', 'encoder.ecg_encoder.layer1.1.bn2.weight', 'encoder.ecg_encoder.layer1.1.bn2.bias', 'encoder.ecg_encoder.layer1.1.bn2.running_mean', 'encoder.ecg_encoder.layer1.1.bn2.running_var', 'encoder.ecg_encoder.layer1.1.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer2.0.conv1.weight', 'encoder.ecg_encoder.layer2.0.bn1.weight', 'encoder.ecg_encoder.layer2.0.bn1.bias', 'encoder.ecg_encoder.layer2.0.bn1.running_mean', 'encoder.ecg_encoder.layer2.0.bn1.running_var', 'encoder.ecg_encoder.layer2.0.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer2.0.conv2.weight', 'encoder.ecg_encoder.layer2.0.bn2.weight', 'encoder.ecg_encoder.layer2.0.bn2.bias', 'encoder.ecg_encoder.layer2.0.bn2.running_mean', 'encoder.ecg_encoder.layer2.0.bn2.running_var', 'encoder.ecg_encoder.layer2.0.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer2.0.shortcut.0.weight', 'encoder.ecg_encoder.layer2.0.shortcut.1.weight', 'encoder.ecg_encoder.layer2.0.shortcut.1.bias', 'encoder.ecg_encoder.layer2.0.shortcut.1.running_mean', 'encoder.ecg_encoder.layer2.0.shortcut.1.running_var', 'encoder.ecg_encoder.layer2.0.shortcut.1.num_batches_tracked', 'encoder.ecg_encoder.layer2.1.conv1.weight', 'encoder.ecg_encoder.layer2.1.bn1.weight', 'encoder.ecg_encoder.layer2.1.bn1.bias', 'encoder.ecg_encoder.layer2.1.bn1.running_mean', 'encoder.ecg_encoder.layer2.1.bn1.running_var', 'encoder.ecg_encoder.layer2.1.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer2.1.conv2.weight', 'encoder.ecg_encoder.layer2.1.bn2.weight', 'encoder.ecg_encoder.layer2.1.bn2.bias', 'encoder.ecg_encoder.layer2.1.bn2.running_mean', 'encoder.ecg_encoder.layer2.1.bn2.running_var', 'encoder.ecg_encoder.layer2.1.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer3.0.conv1.weight', 'encoder.ecg_encoder.layer3.0.bn1.weight', 'encoder.ecg_encoder.layer3.0.bn1.bias', 'encoder.ecg_encoder.layer3.0.bn1.running_mean', 'encoder.ecg_encoder.layer3.0.bn1.running_var', 'encoder.ecg_encoder.layer3.0.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer3.0.conv2.weight', 'encoder.ecg_encoder.layer3.0.bn2.weight', 'encoder.ecg_encoder.layer3.0.bn2.bias', 'encoder.ecg_encoder.layer3.0.bn2.running_mean', 'encoder.ecg_encoder.layer3.0.bn2.running_var', 'encoder.ecg_encoder.layer3.0.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer3.0.shortcut.0.weight', 'encoder.ecg_encoder.layer3.0.shortcut.1.weight', 'encoder.ecg_encoder.layer3.0.shortcut.1.bias', 'encoder.ecg_encoder.layer3.0.shortcut.1.running_mean', 'encoder.ecg_encoder.layer3.0.shortcut.1.running_var', 'encoder.ecg_encoder.layer3.0.shortcut.1.num_batches_tracked', 'encoder.ecg_encoder.layer3.1.conv1.weight', 'encoder.ecg_encoder.layer3.1.bn1.weight', 'encoder.ecg_encoder.layer3.1.bn1.bias', 'encoder.ecg_encoder.layer3.1.bn1.running_mean', 'encoder.ecg_encoder.layer3.1.bn1.running_var', 'encoder.ecg_encoder.layer3.1.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer3.1.conv2.weight', 'encoder.ecg_encoder.layer3.1.bn2.weight', 'encoder.ecg_encoder.layer3.1.bn2.bias', 'encoder.ecg_encoder.layer3.1.bn2.running_mean', 'encoder.ecg_encoder.layer3.1.bn2.running_var', 'encoder.ecg_encoder.layer3.1.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer4.0.conv1.weight', 'encoder.ecg_encoder.layer4.0.bn1.weight', 'encoder.ecg_encoder.layer4.0.bn1.bias', 'encoder.ecg_encoder.layer4.0.bn1.running_mean', 'encoder.ecg_encoder.layer4.0.bn1.running_var', 'encoder.ecg_encoder.layer4.0.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer4.0.conv2.weight', 'encoder.ecg_encoder.layer4.0.bn2.weight', 'encoder.ecg_encoder.layer4.0.bn2.bias', 'encoder.ecg_encoder.layer4.0.bn2.running_mean', 'encoder.ecg_encoder.layer4.0.bn2.running_var', 'encoder.ecg_encoder.layer4.0.bn2.num_batches_tracked', 'encoder.ecg_encoder.layer4.0.shortcut.0.weight', 'encoder.ecg_encoder.layer4.0.shortcut.1.weight', 'encoder.ecg_encoder.layer4.0.shortcut.1.bias', 'encoder.ecg_encoder.layer4.0.shortcut.1.running_mean', 'encoder.ecg_encoder.layer4.0.shortcut.1.running_var', 'encoder.ecg_encoder.layer4.0.shortcut.1.num_batches_tracked', 'encoder.ecg_encoder.layer4.1.conv1.weight', 'encoder.ecg_encoder.layer4.1.bn1.weight', 'encoder.ecg_encoder.layer4.1.bn1.bias', 'encoder.ecg_encoder.layer4.1.bn1.running_mean', 'encoder.ecg_encoder.layer4.1.bn1.running_var', 'encoder.ecg_encoder.layer4.1.bn1.num_batches_tracked', 'encoder.ecg_encoder.layer4.1.conv2.weight', 'encoder.ecg_encoder.layer4.1.bn2.weight', 'encoder.ecg_encoder.layer4.1.bn2.bias', 'encoder.ecg_encoder.layer4.1.bn2.running_mean', 'encoder.ecg_encoder.layer4.1.bn2.running_var', 'encoder.ecg_encoder.layer4.1.bn2.num_batches_tracked', 'encoder.ecg_encoder.linear.weight', 'encoder.ecg_encoder.linear.bias', 'encoder.lm_model.embeddings.word_embeddings.weight', 'encoder.lm_model.embeddings.position_embeddings.weight', 'encoder.lm_model.embeddings.token_type_embeddings.weight', 'encoder.lm_model.embeddings.LayerNorm.weight', 'encoder.lm_model.embeddings.LayerNorm.bias', 'encoder.lm_model.encoder.layer.0.attention.self.query.weight', 'encoder.lm_model.encoder.layer.0.attention.self.query.bias', 'encoder.lm_model.encoder.layer.0.attention.self.key.weight', 'encoder.lm_model.encoder.layer.0.attention.self.key.bias', 'encoder.lm_model.encoder.layer.0.attention.self.value.weight', 'encoder.lm_model.encoder.layer.0.attention.self.value.bias', 'encoder.lm_model.encoder.layer.0.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.0.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.0.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.0.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.0.output.dense.weight', 'encoder.lm_model.encoder.layer.0.output.dense.bias', 'encoder.lm_model.encoder.layer.0.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.0.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.1.attention.self.query.weight', 'encoder.lm_model.encoder.layer.1.attention.self.query.bias', 'encoder.lm_model.encoder.layer.1.attention.self.key.weight', 'encoder.lm_model.encoder.layer.1.attention.self.key.bias', 'encoder.lm_model.encoder.layer.1.attention.self.value.weight', 'encoder.lm_model.encoder.layer.1.attention.self.value.bias', 'encoder.lm_model.encoder.layer.1.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.1.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.1.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.1.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.1.output.dense.weight', 'encoder.lm_model.encoder.layer.1.output.dense.bias', 'encoder.lm_model.encoder.layer.1.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.1.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.2.attention.self.query.weight', 'encoder.lm_model.encoder.layer.2.attention.self.query.bias', 'encoder.lm_model.encoder.layer.2.attention.self.key.weight', 'encoder.lm_model.encoder.layer.2.attention.self.key.bias', 'encoder.lm_model.encoder.layer.2.attention.self.value.weight', 'encoder.lm_model.encoder.layer.2.attention.self.value.bias', 'encoder.lm_model.encoder.layer.2.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.2.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.2.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.2.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.2.output.dense.weight', 'encoder.lm_model.encoder.layer.2.output.dense.bias', 'encoder.lm_model.encoder.layer.2.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.2.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.3.attention.self.query.weight', 'encoder.lm_model.encoder.layer.3.attention.self.query.bias', 'encoder.lm_model.encoder.layer.3.attention.self.key.weight', 'encoder.lm_model.encoder.layer.3.attention.self.key.bias', 'encoder.lm_model.encoder.layer.3.attention.self.value.weight', 'encoder.lm_model.encoder.layer.3.attention.self.value.bias', 'encoder.lm_model.encoder.layer.3.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.3.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.3.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.3.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.3.output.dense.weight', 'encoder.lm_model.encoder.layer.3.output.dense.bias', 'encoder.lm_model.encoder.layer.3.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.3.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.4.attention.self.query.weight', 'encoder.lm_model.encoder.layer.4.attention.self.query.bias', 'encoder.lm_model.encoder.layer.4.attention.self.key.weight', 'encoder.lm_model.encoder.layer.4.attention.self.key.bias', 'encoder.lm_model.encoder.layer.4.attention.self.value.weight', 'encoder.lm_model.encoder.layer.4.attention.self.value.bias', 'encoder.lm_model.encoder.layer.4.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.4.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.4.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.4.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.4.output.dense.weight', 'encoder.lm_model.encoder.layer.4.output.dense.bias', 'encoder.lm_model.encoder.layer.4.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.4.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.5.attention.self.query.weight', 'encoder.lm_model.encoder.layer.5.attention.self.query.bias', 'encoder.lm_model.encoder.layer.5.attention.self.key.weight', 'encoder.lm_model.encoder.layer.5.attention.self.key.bias', 'encoder.lm_model.encoder.layer.5.attention.self.value.weight', 'encoder.lm_model.encoder.layer.5.attention.self.value.bias', 'encoder.lm_model.encoder.layer.5.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.5.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.5.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.5.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.5.output.dense.weight', 'encoder.lm_model.encoder.layer.5.output.dense.bias', 'encoder.lm_model.encoder.layer.5.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.5.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.6.attention.self.query.weight', 'encoder.lm_model.encoder.layer.6.attention.self.query.bias', 'encoder.lm_model.encoder.layer.6.attention.self.key.weight', 'encoder.lm_model.encoder.layer.6.attention.self.key.bias', 'encoder.lm_model.encoder.layer.6.attention.self.value.weight', 'encoder.lm_model.encoder.layer.6.attention.self.value.bias', 'encoder.lm_model.encoder.layer.6.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.6.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.6.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.6.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.6.output.dense.weight', 'encoder.lm_model.encoder.layer.6.output.dense.bias', 'encoder.lm_model.encoder.layer.6.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.6.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.7.attention.self.query.weight', 'encoder.lm_model.encoder.layer.7.attention.self.query.bias', 'encoder.lm_model.encoder.layer.7.attention.self.key.weight', 'encoder.lm_model.encoder.layer.7.attention.self.key.bias', 'encoder.lm_model.encoder.layer.7.attention.self.value.weight', 'encoder.lm_model.encoder.layer.7.attention.self.value.bias', 'encoder.lm_model.encoder.layer.7.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.7.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.7.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.7.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.7.output.dense.weight', 'encoder.lm_model.encoder.layer.7.output.dense.bias', 'encoder.lm_model.encoder.layer.7.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.7.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.8.attention.self.query.weight', 'encoder.lm_model.encoder.layer.8.attention.self.query.bias', 'encoder.lm_model.encoder.layer.8.attention.self.key.weight', 'encoder.lm_model.encoder.layer.8.attention.self.key.bias', 'encoder.lm_model.encoder.layer.8.attention.self.value.weight', 'encoder.lm_model.encoder.layer.8.attention.self.value.bias', 'encoder.lm_model.encoder.layer.8.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.8.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.8.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.8.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.8.output.dense.weight', 'encoder.lm_model.encoder.layer.8.output.dense.bias', 'encoder.lm_model.encoder.layer.8.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.8.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.9.attention.self.query.weight', 'encoder.lm_model.encoder.layer.9.attention.self.query.bias', 'encoder.lm_model.encoder.layer.9.attention.self.key.weight', 'encoder.lm_model.encoder.layer.9.attention.self.key.bias', 'encoder.lm_model.encoder.layer.9.attention.self.value.weight', 'encoder.lm_model.encoder.layer.9.attention.self.value.bias', 'encoder.lm_model.encoder.layer.9.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.9.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.9.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.9.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.9.output.dense.weight', 'encoder.lm_model.encoder.layer.9.output.dense.bias', 'encoder.lm_model.encoder.layer.9.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.9.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.10.attention.self.query.weight', 'encoder.lm_model.encoder.layer.10.attention.self.query.bias', 'encoder.lm_model.encoder.layer.10.attention.self.key.weight', 'encoder.lm_model.encoder.layer.10.attention.self.key.bias', 'encoder.lm_model.encoder.layer.10.attention.self.value.weight', 'encoder.lm_model.encoder.layer.10.attention.self.value.bias', 'encoder.lm_model.encoder.layer.10.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.10.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.10.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.10.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.10.output.dense.weight', 'encoder.lm_model.encoder.layer.10.output.dense.bias', 'encoder.lm_model.encoder.layer.10.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.10.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.11.attention.self.query.weight', 'encoder.lm_model.encoder.layer.11.attention.self.query.bias', 'encoder.lm_model.encoder.layer.11.attention.self.key.weight', 'encoder.lm_model.encoder.layer.11.attention.self.key.bias', 'encoder.lm_model.encoder.layer.11.attention.self.value.weight', 'encoder.lm_model.encoder.layer.11.attention.self.value.bias', 'encoder.lm_model.encoder.layer.11.attention.output.dense.weight', 'encoder.lm_model.encoder.layer.11.attention.output.dense.bias', 'encoder.lm_model.encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.lm_model.encoder.layer.11.intermediate.dense.weight', 'encoder.lm_model.encoder.layer.11.intermediate.dense.bias', 'encoder.lm_model.encoder.layer.11.output.dense.weight', 'encoder.lm_model.encoder.layer.11.output.dense.bias', 'encoder.lm_model.encoder.layer.11.output.LayerNorm.weight', 'encoder.lm_model.encoder.layer.11.output.LayerNorm.bias', 'encoder.lm_model.pooler.dense.weight', 'encoder.lm_model.pooler.dense.bias', 'encoder.proj_t.0.weight', 'encoder.proj_t.0.bias', 'encoder.proj_t.2.weight', 'encoder.proj_t.2.bias', 'encoder.head.weight', 'encoder.head.bias', 'decoder.encoder_decoder.decoder.transformer.wte.weight', 'decoder.encoder_decoder.decoder.transformer.wpe.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.ln_1.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.ln_1.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.attn.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.attn.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.attn.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.attn.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.ln_2.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.ln_2.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.crossattention.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.crossattention.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.crossattention.q_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.crossattention.q_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.crossattention.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.crossattention.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.ln_cross_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.ln_cross_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.mlp.c_fc.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.mlp.c_fc.bias', 'decoder.encoder_decoder.decoder.transformer.h.0.mlp.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.0.mlp.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.ln_1.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.ln_1.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.attn.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.attn.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.attn.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.attn.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.ln_2.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.ln_2.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.crossattention.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.crossattention.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.crossattention.q_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.crossattention.q_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.crossattention.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.crossattention.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.ln_cross_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.ln_cross_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.mlp.c_fc.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.mlp.c_fc.bias', 'decoder.encoder_decoder.decoder.transformer.h.1.mlp.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.1.mlp.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.ln_1.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.ln_1.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.attn.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.attn.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.attn.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.attn.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.ln_2.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.ln_2.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.crossattention.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.crossattention.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.crossattention.q_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.crossattention.q_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.crossattention.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.crossattention.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.ln_cross_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.ln_cross_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.mlp.c_fc.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.mlp.c_fc.bias', 'decoder.encoder_decoder.decoder.transformer.h.2.mlp.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.2.mlp.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.ln_1.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.ln_1.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.attn.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.attn.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.attn.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.attn.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.ln_2.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.ln_2.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.crossattention.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.crossattention.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.crossattention.q_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.crossattention.q_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.crossattention.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.crossattention.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.ln_cross_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.ln_cross_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.mlp.c_fc.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.mlp.c_fc.bias', 'decoder.encoder_decoder.decoder.transformer.h.3.mlp.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.3.mlp.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.ln_1.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.ln_1.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.attn.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.attn.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.attn.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.attn.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.ln_2.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.ln_2.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.crossattention.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.crossattention.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.crossattention.q_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.crossattention.q_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.crossattention.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.crossattention.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.ln_cross_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.ln_cross_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.mlp.c_fc.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.mlp.c_fc.bias', 'decoder.encoder_decoder.decoder.transformer.h.4.mlp.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.4.mlp.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.ln_1.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.ln_1.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.attn.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.attn.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.attn.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.attn.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.ln_2.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.ln_2.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.crossattention.c_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.crossattention.c_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.crossattention.q_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.crossattention.q_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.crossattention.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.crossattention.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.ln_cross_attn.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.ln_cross_attn.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.mlp.c_fc.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.mlp.c_fc.bias', 'decoder.encoder_decoder.decoder.transformer.h.5.mlp.c_proj.weight', 'decoder.encoder_decoder.decoder.transformer.h.5.mlp.c_proj.bias', 'decoder.encoder_decoder.decoder.transformer.ln_f.weight', 'decoder.encoder_decoder.decoder.transformer.ln_f.bias', 'decoder.encoder_decoder.decoder.lm_head.weight'])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpkt['model_state_dict'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml as yaml\n",
    "import sys\n",
    "sys.path.append(\"../finetune/\")\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "\n",
    "prompt_type = 'CKEPE'\n",
    "prompt_dict = '/yourpath/zeroshot/CKEPE_prompt.json'\n",
    "with open(prompt_dict, 'r') as f:\n",
    "    prompt_dict = yaml.load(f, Loader=yaml.FullLoader)\n",
    "target_class = [class_name for class_name in prompt_dict.values()]\n",
    "\n",
    "def get_class_emd(model, class_name, device='cuda'):\n",
    "    model.eval()\n",
    "    with torch.no_grad(): # to(device=torch.device(\"cuda\"iftorch.cuda.is_available()else\"cpu\")) \n",
    "        zeroshot_weights = []\n",
    "        # compute embedding through model for each class\n",
    "        for texts in tqdm(class_name):\n",
    "            texts = texts.lower()\n",
    "            texts = [texts] # convert to list\n",
    "            texts = model._tokenize(texts) # tokenize\n",
    "            class_embeddings, _ = model.get_text_emb(texts.input_ids.to(device=device)\n",
    "                                                            , texts.attention_mask.to(device=device)\n",
    "                                                            ) # embed with text encoder\n",
    "            class_embeddings = model.proj_t(class_embeddings) # embed with text encoder\n",
    "\n",
    "            # normalize class_embeddings\n",
    "            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)\n",
    "            # average over templates \n",
    "            class_embedding = class_embeddings.mean(dim=0) \n",
    "            # norm over new averaged templates\n",
    "            class_embedding /= class_embedding.norm() \n",
    "            zeroshot_weights.append(class_embedding)\n",
    "        zeroshot_weights = torch.stack(zeroshot_weights, dim=1)\n",
    "    return zeroshot_weights\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "def get_ground_truth(model, reports, class_weight, device='cuda'):\n",
    "    model.eval()\n",
    "    y_pred = []\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        report_tokenize_output = model._tokenize(reports)\n",
    "        input_ids = report_tokenize_output.input_ids.to(\n",
    "            device).contiguous()\n",
    "        attention_mask = report_tokenize_output.attention_mask.to(\n",
    "            device).contiguous()\n",
    "        class_embeddings, _ = model.get_text_emb(input_ids.to(device=device)\n",
    "                                                        , attention_mask.to(device=device)\n",
    "                                                        ) # embed with text encoder\n",
    "        class_embeddings = model.proj_t(class_embeddings) # embed with text encoder\n",
    "\n",
    "        # normalize class_embeddings\n",
    "        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)\n",
    "        logits = class_embeddings @ class_weight.to(device)\n",
    "        logits = torch.squeeze(logits, 0) # (N, num_classes)\n",
    "        # norm_logits = (logits - logits.mean()) / (logits.std())\n",
    "        # logits = torch.sigmoid(norm_logits) \n",
    "            \n",
    "        y_pred.append(logits.cpu().data.numpy())\n",
    "\n",
    "    y_pred = np.concatenate(y_pred, axis=0)\n",
    "    labels = np.array(y_pred)\n",
    "    labels = np.argmax(labels, axis=1)\n",
    "    return labels, y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 131/131 [00:00<00:00, 184.09it/s]\n"
     ]
    }
   ],
   "source": [
    "class_emb = get_class_emd(model=encoder.to('cuda:0'), class_name=target_class, device='cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([256, 131])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 482/482 [13:58<00:00,  1.74s/it]\n"
     ]
    }
   ],
   "source": [
    "labels_all = []\n",
    "reports_all = []\n",
    "outputs_all = []\n",
    "labels_pred_all = []\n",
    "with torch.no_grad():\n",
    "    for data in tqdm(val_loader):\n",
    "        report = data['raw_text']#.to(device)\n",
    "        # get ecg\n",
    "        ecg = data['ecg'].to(torch.float32).contiguous().to('cuda:0')\n",
    "        encoder_outputs = model.encoder_forward(ecg)\n",
    "        encoder_outputs = BaseModelOutput(\n",
    "                last_hidden_state=encoder_outputs)\n",
    "        outputs = model.decoder.encoder_decoder.generate(\n",
    "                max_length=256,\n",
    "                bos_token_id=model.tokenizer.bos_token_id,\n",
    "                eos_token_id=model.tokenizer.eos_token_id,\n",
    "                pad_token_id=model.tokenizer.pad_token_id,\n",
    "                num_beams=3,\n",
    "                return_dict_in_generate=True,\n",
    "                use_cache=True, \n",
    "                encoder_outputs=encoder_outputs,\n",
    "            )\n",
    "        output = model.tokenizer.batch_decode(\n",
    "            outputs['sequences'], skip_special_tokens=True)\n",
    "        reports_all.append(report)\n",
    "        outputs_all.append(output)\n",
    "        label, logits = get_ground_truth(model=model.encoder, reports=report, class_weight=class_emb, device='cuda:0')\n",
    "        label_pred, logits_pred = get_ground_truth(model=model.encoder, reports=output, class_weight=class_emb, device='cuda:0')\n",
    "        labels_all.append(label)\n",
    "        labels_pred_all.append(label_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "reports_all = [i for item in reports_all for i in item]\n",
    "outputs_all = [i for item in outputs_all for i in item]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'possible ectopic atrial rhythm.. left axis deviation. right bundle branch block. inferior infarct - age undetermined. abnormal ecg.'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reports_all[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sinus rhythm. left axis deviation. rbbb with left anterior fascicular block. inferior infarct - age undetermined. abnormal ecg.'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs_all[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_a = np.hstack(labels_all)\n",
    "labels_p = np.hstack(labels_pred_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average BLEU-1 Score: 0.6133\n",
      "Average BLEU-2 Score: 0.5598\n",
      "Average BLEU-3 Score: 0.5157\n",
      "Average BLEU-4 Score: 0.4850\n",
      "Average ROUGE-L F1 Score: 0.7059\n"
     ]
    }
   ],
   "source": [
    "import nltk\n",
    "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
    "from rouge_score import rouge_scorer\n",
    "\n",
    "\n",
    "\n",
    "# 初始化 BLEU 和 ROUGE scorer\n",
    "smooth = SmoothingFunction().method1\n",
    "scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)\n",
    "\n",
    "# 存储 BLEU 和 ROUGE 分数\n",
    "bleu_scores = {'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': []}\n",
    "rougeL_scores = []\n",
    "\n",
    "# 逐一计算每个样本的 BLEU 和 ROUGE 分数\n",
    "for ref, gen in zip(reports_all, outputs_all):\n",
    "    # 计算 BLEU-1 分数\n",
    "    bleu1 = sentence_bleu([ref.split()], gen.split(), weights=(1, 0, 0, 0), smoothing_function=smooth)\n",
    "    bleu_scores['bleu1'].append(bleu1)\n",
    "    \n",
    "    # 计算 BLEU-2 分数\n",
    "    bleu2 = sentence_bleu([ref.split()], gen.split(), weights=(0.5, 0.5, 0, 0), smoothing_function=smooth)\n",
    "    bleu_scores['bleu2'].append(bleu2)\n",
    "    \n",
    "    # 计算 BLEU-3 分数\n",
    "    bleu3 = sentence_bleu([ref.split()], gen.split(), weights=(0.33, 0.33, 0.33, 0), smoothing_function=smooth)\n",
    "    bleu_scores['bleu3'].append(bleu3)\n",
    "    \n",
    "    # 计算 BLEU-4 分数\n",
    "    bleu4 = sentence_bleu([ref.split()], gen.split(), weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)\n",
    "    bleu_scores['bleu4'].append(bleu4)\n",
    "    \n",
    "    # 计算 ROUGE-L 分数\n",
    "    rougeL = scorer.score(ref, gen)['rougeL'].fmeasure\n",
    "    rougeL_scores.append(rougeL)\n",
    "\n",
    "# 计算平均值\n",
    "average_bleu1 = sum(bleu_scores['bleu1']) / len(bleu_scores['bleu1'])\n",
    "average_bleu2 = sum(bleu_scores['bleu2']) / len(bleu_scores['bleu2'])\n",
    "average_bleu3 = sum(bleu_scores['bleu3']) / len(bleu_scores['bleu3'])\n",
    "average_bleu4 = sum(bleu_scores['bleu4']) / len(bleu_scores['bleu4'])\n",
    "average_rougeL = sum(rougeL_scores) / len(rougeL_scores)\n",
    "\n",
    "# 打印平均结果\n",
    "print(f\"Average BLEU-1 Score: {average_bleu1:.4f}\")\n",
    "print(f\"Average BLEU-2 Score: {average_bleu2:.4f}\")\n",
    "print(f\"Average BLEU-3 Score: {average_bleu3:.4f}\")\n",
    "print(f\"Average BLEU-4 Score: {average_bleu4:.4f}\")\n",
    "print(f\"Average ROUGE-L F1 Score: {average_rougeL:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1: 0.23329257874351925\n",
      "pre: 0.2495743946337152\n",
      "rec: 0.23694279461925052\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chenjian/miniconda3/envs/laecgm/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import f1_score, precision_score, recall_score\n",
    "\n",
    "f1 = f1_score(labels_a, labels_p, average='macro')\n",
    "pre = precision_score(labels_a, labels_p, average='macro')\n",
    "rec = recall_score(labels_a, labels_p, average='macro')\n",
    "\n",
    "print('f1:', f1)\n",
    "print('pre:', pre)\n",
    "print('rec:', rec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = {'BLEU-1':average_bleu1,\n",
    "        'BLEU-2':average_bleu2,\n",
    "        'BLEU-3':average_bleu3,\n",
    "        'BLEU-4':average_bleu4,\n",
    "        'ROUGE-L F1 Score':average_rougeL,\n",
    "        'CE F1 Score': f1,\n",
    "        'CE Precision Score': pre,\n",
    "        'CE Recall Score': rec,\n",
    "        'report':reports_all,\n",
    "        'generated report': outputs_all\n",
    "        }\n",
    "\n",
    "torch.save({\n",
    "    'metrics': metrics},\n",
    "    f'/yourpath/report/ckpt/disGPT2_align_metrics.pth'     \n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "laecgm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
