{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zero-shot cls with `eeg_emb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Rank 0][GLIMDataModule] running `setup()`...\n",
      "[Rank 0][GLIMDataModule] running `setup()`...Done! 😋😋😋\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from rich.progress import track\n",
    "\n",
    "from model.glim import GLIM\n",
    "from data.datamodule import GLIMDataModule\n",
    "from torchmetrics.functional.classification import multiclass_accuracy\n",
    "\n",
    "device = torch.device(\"cuda:0\")\n",
    "model = GLIM.load_from_checkpoint(\n",
    "    \"checkpoints/glim-zuco-epoch=199-step=49600.ckpt\",\n",
    "    map_location = device,\n",
    "    strict = False,\n",
    "    # evaluate_prompt_embed = 'src',\n",
    "    # prompt_dropout_probs = (0.0, 0.1, 0.1),\n",
    "    )\n",
    "model.setup(stage='test')\n",
    "dm = GLIMDataModule(data_path = './data/tmp/zuco_eeg_label_8variants.df',\n",
    "                    eval_noise_input = False,\n",
    "                    bsz_test = 24,\n",
    "                    )\n",
    "dm.setup(stage='test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "880ccb803e8c4f96836af88a3b86f83e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clip-like acc1:  0.4292565882205963\n"
     ]
    }
   ],
   "source": [
    "prefix = \"Sentiment classification: \"\n",
    "template = \"It is <MASK>.\"\n",
    "all_sentiments = ['negative', 'neutral', 'positive']\n",
    "# candidates = ['bad or boring',\n",
    "#                 'normal or ordinary', \n",
    "#                 'good or great',]\n",
    "candidates = [prefix + template.replace(\"<MASK>\", label) for label in all_sentiments]\n",
    "\n",
    "results = []\n",
    "with torch.no_grad():\n",
    "    for batch in track(dm.test_dataloader()):\n",
    "        eeg = batch['eeg'].to(device)\n",
    "        eeg_mask = batch['mask'].to(device)\n",
    "        prompts = batch['prompt'] # NOTE: [tuple('task'), tuple('dataset'), tuple('subject')] after collate\n",
    "        raw_task_key = batch['raw task key']    # list[str]\n",
    "        sentiment_label = batch['sentiment label']      # list[str]\n",
    "\n",
    "        labels = []\n",
    "        for sentiment in sentiment_label:\n",
    "            if sentiment not in all_sentiments:\n",
    "                labels.append(-1)\n",
    "            else:\n",
    "                labels.append(all_sentiments.index(sentiment))\n",
    "        labels = torch.tensor(labels, device=device)\n",
    "        \n",
    "\n",
    "        with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n",
    "            prob, gen_str = model.predict(eeg, eeg_mask, prompts, candidates, generate=False)\n",
    "\n",
    "        for i in range(len(eeg)):\n",
    "            results.append({'input text':   batch['raw input text'][i],\n",
    "                            'label':        labels[i],\n",
    "                            'prob':         prob[i],\n",
    "                            'gen_str':      gen_str[i]\n",
    "                            })\n",
    "            \n",
    "probs = torch.stack([row['prob'] for row in results])\n",
    "labels = torch.stack([row['label'] for row in results])\n",
    "acc1 = multiclass_accuracy(probs, labels, num_classes=3, top_k=1, ignore_index=-1, average='micro')\n",
    "print('clip-like acc1: ',acc1.item())\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load generated texts, labels and raw texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df_results = pd.read_pickle('data/tmp/glim_gen_results.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zero-shot cls on text embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "\n",
    "class BatchedDF(Dataset):\n",
    "    def __init__(self, df: pd.DataFrame) -> None:\n",
    "        self.raw_input_text = df['raw input text'].tolist()\n",
    "        self.gen_text = df['gen text'].tolist()\n",
    "        self.senti_label = df['sentiment label'].apply(lambda x: str(x)).tolist()\n",
    "        self.rela_label = df['relation label'].apply(lambda x: str(x)).tolist()\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return {'raw input text': self.raw_input_text[idx],\n",
    "                'gen text': self.gen_text[idx],\n",
    "                'sentiment label': self.senti_label[idx],\n",
    "                'relation label': self.rela_label[idx]\n",
    "                }\n",
    "    def __len__(self):\n",
    "        return len(self.raw_input_text)\n",
    "    \n",
    "\n",
    "prefix = \"Sentiment: \"\n",
    "template = \"It is <MASK>.\"\n",
    "all_sentiments = ['negative', 'neutral', 'positive']\n",
    "candidates = [prefix + template.replace(\"<MASK>\", label) for label in all_sentiments]\n",
    "\n",
    "df_filtered = df_results[df_results['sentiment label'] != 'nan']\n",
    "dataset = BatchedDF(df_filtered)\n",
    "loader = DataLoader(dataset, batch_size=64, shuffle=False, drop_last=False)\n",
    "probs1, probs2 = [], []\n",
    "labels = []\n",
    "for batch in track(loader):\n",
    "    input_texts = batch['raw input text']\n",
    "    gen_texts = batch['gen text']\n",
    "    senti_label = batch['sentiment label']\n",
    "    # rela_label = batch['relation label']\n",
    "\n",
    "    for sentiment in senti_label:\n",
    "        if sentiment not in all_sentiments:\n",
    "            labels.append(-1)\n",
    "        else:\n",
    "            labels.append(all_sentiments.index(sentiment))\n",
    "    \n",
    "    input_template = \"Sentiment classification: <MASK>.\"\n",
    "    with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n",
    "        prob1 = model.predict_text_embedding(input_texts, input_template, candidates)\n",
    "        prob2 = model.predict_text_embedding(gen_texts, input_template, candidates)\n",
    "        probs1.append(prob1)\n",
    "        probs2.append(prob2)\n",
    "probs1 = torch.cat(probs1, dim=0)\n",
    "probs2 = torch.cat(probs2, dim=0)\n",
    "labels = torch.tensor(labels, device=probs1.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clip-like acc [raw input text]:  0.4556354880332947\n",
      "clip-like acc [gen text]:        0.3956834673881531\n"
     ]
    }
   ],
   "source": [
    "acc1 = multiclass_accuracy(probs1, labels, num_classes=3, top_k=1, ignore_index=-1, average='micro')\n",
    "print('clip-like acc [raw input text]: ',acc1.item())\n",
    "acc2 = multiclass_accuracy(probs2, labels, num_classes=3, top_k=1, ignore_index=-1, average='micro')\n",
    "print('clip-like acc [gen text]:       ',acc2.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### use LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e361c545a9284e889ea394a4f737590f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import transformers\n",
    "eval_llm_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "pipe = transformers.pipeline(model = eval_llm_id,\n",
    "                             model_kwargs = {\"torch_dtype\": torch.float16},\n",
    "                             device_map = torch.device(\"cuda:1\"),)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input_sentences = dataset.raw_input_text\n",
    "input_sentences = dataset.gen_text\n",
    "filtered_labels = dataset.senti_label\n",
    "\n",
    "instructions = {\"role\": \"system\", \n",
    "          \"content\": \n",
    "            (\"You task is sentiment classification. Please pick the most likely label from:\\n\"\n",
    "             \" negative, neutral and positive.\\n\"\n",
    "             \" Please just output your predicted label do not output any other words!\"\n",
    "            )}\n",
    "\n",
    "\n",
    "messages = [[instructions, {\"role\": \"user\", \"content\": sen}] for sen in input_sentences]\n",
    "inputs = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "terminators = [pipe.tokenizer.eos_token_id, \n",
    "               pipe.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n",
    "\n",
    "pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id\n",
    "pipe.tokenizer.padding_side = 'left'\n",
    "# inputs_loader = DataLoader(llm_inputs, batch_size=32, shuffle=False, drop_last=False)\n",
    "\n",
    "# llm_outputs = []\n",
    "# for batch in track(inputs_loader):\n",
    "with torch.no_grad():\n",
    "  outputs = pipe(inputs, \n",
    "                batch_size = 16, \n",
    "                max_new_tokens = 8,\n",
    "                eos_token_id = terminators,\n",
    "                do_sample = True,\n",
    "                num_beams = 2,\n",
    "                pad_token_id = pipe.tokenizer.eos_token_id,\n",
    "                )\n",
    "    # llm_outputs.extend(batch_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llm-pred acc:  0.39568345323741005\n"
     ]
    }
   ],
   "source": [
    "n_correct = 0\n",
    "total = len(input_sentences)\n",
    "for i in range(total):\n",
    "    label = filtered_labels[i]\n",
    "    pred = outputs[i][0]['generated_text'][len(inputs[i]):] # str, the pred \"sentiment label\"\n",
    "    assert pred in all_sentiments\n",
    "\n",
    "    if label == pred:\n",
    "        n_correct += 1\n",
    "\n",
    "    # print(f'label: {label}  pred: {pred}  gen_str: {gen_str}')\n",
    "llm_acc = n_correct/total\n",
    "print('llm-pred acc: ', llm_acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "glim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
