{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Rank 0][GLIMDataModule] running `setup()`...\n",
      "[Rank 0][GLIMDataModule] running `setup()`...Done! 😋😋😋\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from rich.progress import track\n",
    "\n",
    "from model.glim import GLIM\n",
    "from data.datamodule import GLIMDataModule\n",
    "from torchmetrics.functional.classification import multiclass_accuracy\n",
    "\n",
    "device = torch.device(\"cuda:0\")\n",
    "model = GLIM.load_from_checkpoint(\n",
    "    \"checkpoints/glim-zuco-epoch=199-step=49600.ckpt\",\n",
    "    map_location = device,\n",
    "    strict = False,\n",
    "    # evaluate_prompt_embed = 'src',\n",
    "    # prompt_dropout_probs = (0.0, 0.1, 0.1),\n",
    "    )\n",
    "model.setup(stage='test')\n",
    "dm = GLIMDataModule(data_path = './data/tmp/zuco_eeg_label_8variants.df',\n",
    "                    eval_noise_input = False,\n",
    "                    bsz_test = 24,\n",
    "                    )\n",
    "dm.setup(stage='test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5994007f8f7540d7b7c494364e09f7de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clip-like acc:  0.9343296885490417\n"
     ]
    }
   ],
   "source": [
    "prefix = \"The topic is about: \"\n",
    "template = \"\"\n",
    "candidates = [\"movie, good or bad\", \n",
    "              \"life experiences, relationship\"]\n",
    "candidates = [prefix + template + candi for candi in candidates]\n",
    "\n",
    "results = []\n",
    "with torch.no_grad():\n",
    "    for batch in track(dm.test_dataloader()):\n",
    "        eeg = batch['eeg'].to(device)\n",
    "        eeg_mask = batch['mask'].to(device)\n",
    "        prompts = batch['prompt'] # NOTE: [tuple('task'), tuple('dataset'), tuple('subject')] after collate\n",
    "        raw_task_key = batch['raw task key']    # list[str]\n",
    "        relation_label = batch['relation label']      # list[str]\n",
    "        sentiment_label = batch['sentiment label']      # list[str]\n",
    "\n",
    "        labels = []\n",
    "        for t_key in raw_task_key:\n",
    "            if t_key == 'task1':\n",
    "                labels.append(0)\n",
    "            else:\n",
    "                labels.append(1)\n",
    "        labels = torch.tensor(labels, device=device)\n",
    "        \n",
    "\n",
    "        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n",
    "            prob, gen_str = model.predict(eeg, eeg_mask, prompts, candidates, generate=True)\n",
    "\n",
    "        for i in range(len(eeg)):\n",
    "            results.append({'raw input text':   batch['raw input text'][i],\n",
    "                            'label':        labels[i],\n",
    "                            'prob':         prob[i],\n",
    "                            'gen text':      gen_str[i],\n",
    "                            'relation label': relation_label[i],\n",
    "                            'sentiment label': sentiment_label[i],\n",
    "                            })\n",
    "            \n",
    "probs = torch.stack([row['prob'] for row in results])\n",
    "labels = torch.stack([row['label'] for row in results])\n",
    "acc = multiclass_accuracy(probs, labels, num_classes=2, top_k=1, average='micro')\n",
    "print('clip-like acc: ',acc.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df_results = pd.DataFrame(results)\n",
    "pd.to_pickle(df_results, 'data/tmp/glim_gen_results.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37544a0fd5c9402fb7320ed6d262cb4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clip-like acc [raw input text]:  0.91847825050354\n",
      "clip-like acc [gen text]:        0.873641312122345\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "probs1, probs2 = [], []\n",
    "loader = DataLoader(results, batch_size=64, shuffle=False, drop_last=False)\n",
    "for batch in track(loader):\n",
    "    input_texts = batch['raw input text']\n",
    "    gen_texts = batch['gen text']\n",
    "    input_template = \"To English: <MASK>\"\n",
    "    with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n",
    "        prob1 = model.predict_text_embedding(input_texts, input_template, candidates)\n",
    "        prob2 = model.predict_text_embedding(gen_texts, input_template, candidates)\n",
    "        probs1.append(prob1)\n",
    "        probs2.append(prob2)\n",
    "probs1 = torch.cat(probs1, dim=0)\n",
    "probs2 = torch.cat(probs2, dim=0)\n",
    "acc1 = multiclass_accuracy(probs1, labels, num_classes=2, top_k=1, average='micro')\n",
    "print('clip-like acc [raw input text]: ',acc1.item())\n",
    "acc2 = multiclass_accuracy(probs2, labels, num_classes=2, top_k=1, average='micro')\n",
    "print('clip-like acc [gen text]:       ',acc2.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c40151d7deb8466a8907a6291cb23529",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import transformers\n",
    "eval_llm_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "pipe = transformers.pipeline(model = eval_llm_id,\n",
    "                             model_kwargs = {\"torch_dtype\": torch.float16},\n",
    "                             device_map = torch.device(\"cuda:0\"),)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df_results = pd.read_pickle('data/tmp/glim_gen_results.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input_sentences = df_results['gen text'].tolist()\n",
    "input_sentences = df_results['raw input text'].tolist()\n",
    "# instructions = {\"role\": \"system\", \n",
    "#           \"content\": \n",
    "#             (\"You task is to classify the most likely corpus source of the following sentence.\"\n",
    "#              \" Label '0' for 'movie review', '1' for 'personal biography.\"\n",
    "#              \" Please just output the integer label.\"\n",
    "#             )}\n",
    "\n",
    "instructions = {\"role\": \"system\", \n",
    "          \"content\": \n",
    "            (\"You task is to classify the most likely topic of the following sentence.\"\n",
    "             \" Label '0' for 'movie review', '1' for 'personal biography.\"\n",
    "             \" Please just output the integer label.\"\n",
    "            )}\n",
    "\n",
    "# instructions = {\"role\": \"system\", \n",
    "#           \"content\": \n",
    "#             (\"You task is to classify the most likely corpus source of the following sentence.\"\n",
    "#              \" Label '0' for 'Stanford Sentiment Treebank', '1' for 'Wikipedia relation extraction corpus.\"\n",
    "#              \" Please just output the integer label.\"\n",
    "#             )}\n",
    "\n",
    "messages = [[instructions, {\"role\": \"user\", \"content\": sen}] for sen in input_sentences]\n",
    "inputs = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "terminators = [pipe.tokenizer.eos_token_id, \n",
    "               pipe.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n",
    "\n",
    "pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id\n",
    "pipe.tokenizer.padding_side = 'left'\n",
    "# inputs_loader = DataLoader(llm_inputs, batch_size=32, shuffle=False, drop_last=False)\n",
    "\n",
    "# llm_outputs = []\n",
    "# for batch in track(inputs_loader):\n",
    "with torch.no_grad():\n",
    "  outputs = pipe(inputs, \n",
    "                batch_size = 16, \n",
    "                max_new_tokens = 4,\n",
    "                eos_token_id = terminators,\n",
    "                do_sample = True,\n",
    "                num_beams = 2,\n",
    "                pad_token_id = pipe.tokenizer.eos_token_id,\n",
    "                )\n",
    "    # llm_outputs.extend(batch_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llm-pred acc:  0.8614130434782609\n"
     ]
    }
   ],
   "source": [
    "n_correct = 0\n",
    "total = df_results.shape[0]\n",
    "for i in range(total):\n",
    "    pred = int(outputs[i][0]['generated_text'][len(inputs[i]):]) # int, label id\n",
    "    label = df_results.iloc[i]['label'].item()\n",
    "    if pred == label:\n",
    "        n_correct += 1\n",
    "    gen_str = input_sentences[i]\n",
    "    # print(f'label: {label}  pred: {pred}  gen_str: {gen_str}')\n",
    "llm_acc = n_correct/total\n",
    "print('llm-pred acc: ', llm_acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "glim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
