{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1fc9d9ff",
   "metadata": {},
   "source": [
    "### Load library and set seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f6501b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constructing <class 'dataset.EasyReader'>\n",
      "Constructing <class 'dataset.FastListDataset'>\n",
      "Label 0 to word terrible (6659)\n",
      "Label 1 to word great (2307)\n",
      "label_to_word:  {'0': 6659, '1': 2307}\n",
      "label_list:  [6659, 2307]\n",
      "Constructing <class 'probe.NTKProbe'>\n",
      "Constructing NTKProbe\n",
      "Constructing PromptFinetuneProbe\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probe has 109514298 parameters\n",
      "Constructing <class 'dvutils.Data_Shapley.Fast_Data_Shapley'>\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# multi-processing for NTK kernel\n",
    "from torch.multiprocessing import set_start_method, set_sharing_strategy\n",
    "import torch.multiprocessing as mp\n",
    "set_start_method(\"spawn\")\n",
    "set_sharing_strategy(\"file_system\")\n",
    "\n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "import pickle\n",
    "import yaml as yaml\n",
    "import click\n",
    "from datetime import datetime\n",
    "import random\n",
    "from datasets import load_dataset, concatenate_datasets, Value\n",
    "import torch\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.insert(0, './lmntk')\n",
    "sys.path.insert(0, './vinfo/lmntk')\n",
    "\n",
    "os.environ[\"OMP_NUM_THREADS\"] = '16'\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = '16'\n",
    "os.environ[\"MKL_NUM_THREADS\"] = '16'\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n",
    "\n",
    "from dataset import *\n",
    "from probe import *\n",
    "from dvutils.Data_Shapley import *\n",
    "import logging\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "from transformers import HfArgumentParser\n",
    "\n",
    "import time\n",
    "\n",
    "dataset_name=\"sst2\"\n",
    "seed=2023\n",
    "num_dp=5000\n",
    "tmc_iter=200\n",
    "prompt=True # usually True, whether use prompt fine-tuning\n",
    "signgd=False # usually False, whether use signGD kernel; not adopted in FreeShap\n",
    "approximate=\"inv\" # can also be \"none\" (use no approximation, exact inverse); \"diagonal\" (use block diagonal for inverse)\n",
    "per_point=True # if True: get the instance score for each test point; if False: get instance score for test sets\n",
    "early_stopping=\"True\"\n",
    "tmc_seed=2023\n",
    "val_sample_num = 1000\n",
    "yaml_path=\"../configs/dshap/sst2/ntk_prompt.yaml\"\n",
    "file_path = \"./freeshap_res/\"\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "yaml_args = yaml.load(open(yaml_path), Loader=yaml.Loader)\n",
    "list_dataset = yaml_args['dataset']\n",
    "probe_model = yaml_args['probe_com']\n",
    "dshap_com = yaml_args['dshap_com']\n",
    "if prompt:\n",
    "    probe_model.model.init(list_dataset.label_word_list)\n",
    "if approximate != \"none\":\n",
    "    probe_model.approximate(approximate)\n",
    "if approximate == \"inv\":\n",
    "    probe_model.normalize_ntk()\n",
    "if signgd:\n",
    "    probe_model.signgd()\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "    \n",
    "if dataset_name == \"sst2\":\n",
    "    dataset = load_dataset(\"sst2\")\n",
    "elif dataset_name == \"mr\":\n",
    "    dataset = load_dataset(\"rotten_tomatoes\")\n",
    "elif dataset_name == \"rte\":\n",
    "    dataset = load_dataset(\"glue\", \"rte\")\n",
    "    # 1: not entail; 0: entail\n",
    "elif dataset_name == \"mnli\":\n",
    "    dataset = load_dataset(\"glue\", \"mnli\")\n",
    "elif dataset_name == \"mrpc\":\n",
    "    dataset = load_dataset(\"glue\", \"mrpc\")\n",
    "# Sample 10 data points from the dataset\n",
    "train_data = dataset['train']\n",
    "train_data = train_data.map(lambda example, idx: {'idx': idx}, with_indices=True)\n",
    "train_data = train_data.shuffle(seed).select(range(min(train_data.num_rows, num_dp)))\n",
    "sampled_idx = train_data['idx']\n",
    "\n",
    "if dataset_name == \"mnli\":\n",
    "    val_num = dataset['validation_matched'].num_rows\n",
    "elif dataset_name == \"subj\" or dataset_name == \"ag_news\":\n",
    "    val_num = dataset['test'].num_rows\n",
    "else:\n",
    "    val_num = dataset['validation'].num_rows\n",
    "if val_sample_num > val_num:\n",
    "    sampled_val_idx = np.arange(val_num)\n",
    "else:\n",
    "    sampled_val_idx = np.random.choice(np.arange(val_num), val_sample_num, replace=False).tolist()\n",
    "    \n",
    "if 'llama' in probe_model.args['model']:\n",
    "    model_name = 'llama'\n",
    "elif 'roberta' in probe_model.args['model']:\n",
    "    model_name = 'roberta'\n",
    "elif 'bert' in probe_model.args['model']:\n",
    "    model_name = 'bert'\n",
    "valid_data = dataset['validation']\n",
    "reindex_valid_data = []\n",
    "for index in sampled_val_idx:\n",
    "    reindex_valid_data.append(valid_data[int(index)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e387e2e2",
   "metadata": {},
   "source": [
    "### Build NTK kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "10a03906-ac84-41f1-927c-03100d711d16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./freeshap_res/sst2_bert_ntk_seed2023_num5000_signFalse.pkl\n",
      "++++++++++++++++++++++++++++++++++no cached ntk, computing+++++++++++++++++++++++++++++++++++\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[loading]: 10250it [00:17, 1002.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sst2: 67349\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[loading]: 67349it [00:18, 3575.48it/s] \n",
      "[loading]: 872it [00:15, 56.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sst2: 872\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [05:30<00:00, 17.74it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:33<00:00,  5.34it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [05:29<00:00, 17.79it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:34<00:00,  5.27it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [05:30<00:00, 17.74it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:38<00:00,  5.06it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [03:35<00:00, 27.26it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:40<00:00,  4.99it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [02:22<00:00, 41.29it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:38<00:00,  5.09it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5872/5872 [01:54<00:00, 51.43it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:37<00:00,  5.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current mean value:  tensor(0.4992)\n",
      "++++++++++++++++++++++++++++++++++saving ntk cache+++++++++++++++++++++++++++++++++++\n"
     ]
    }
   ],
   "source": [
    "print(f\"{file_path}{dataset_name}_{model_name}_ntk_seed{seed}_num{num_dp}_sign{signgd}.pkl\")\n",
    "try:\n",
    "    with open(f\"{file_path}{dataset_name}_{model_name}_ntk_seed{seed}_num{num_dp}_sign{signgd}.pkl\", \"rb\") as f:\n",
    "        ntk = pickle.load(f)\n",
    "    print(\"++++++++++++++++++++++++++++++++++++using cached ntk++++++++++++++++++++++++++++++++++++\")\n",
    "    probe_model.get_cached_ntk(ntk)\n",
    "    probe_model.get_train_labels(list_dataset.get_idx_dataset(sampled_idx, split=\"train\"))\n",
    "except:\n",
    "    print(\"++++++++++++++++++++++++++++++++++no cached ntk, computing+++++++++++++++++++++++++++++++++++\")\n",
    "    train_set = list_dataset.get_idx_dataset(sampled_idx, split=\"train\")\n",
    "    val_set = list_dataset.get_idx_dataset(sampled_val_idx, split=\"val\")\n",
    "    # Given that train_loader and val_loader are provided in run(), prepare datasets\n",
    "    # Set parameters for ntk computation\n",
    "    # compute ntk matrix\n",
    "    ntk = probe_model.compute_ntk(train_set, val_set)\n",
    "    # save the ntk matrix\n",
    "    with open(f\"{file_path}{dataset_name}_{model_name}_ntk_seed{seed}_num{num_dp}_sign{signgd}.pkl\", \"wb\") as f:\n",
    "        pickle.dump(ntk, f)\n",
    "    print(\"++++++++++++++++++++++++++++++++++saving ntk cache+++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5d5dc7b",
   "metadata": {},
   "source": [
    "### Compute Shapley value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3a061e26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing FreeShap Results\n",
      "start to compute shapley value\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[loading]: 872it [00:16, 53.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sst2: 872\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "[TMC iterations]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [5:47:11<00:00, 104.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving FreeShap result to ./freeshap_res/sst2_bert_shapley_result_seed2023_num5000_approinv_signFalse_earlystopTrue_tmc2023_iter200.pkl\n"
     ]
    }
   ],
   "source": [
    "shapley_file_path=f\"{file_path}{dataset_name}_{model_name}_shapley_result_seed{seed}_num{num_dp}_appro{approximate}_sign{signgd}_earlystop{early_stopping}_tmc{tmc_seed}_iter{tmc_iter}.pkl\"\n",
    "try:\n",
    "    with open(shapley_file_path,'rb') as f:\n",
    "        result_dict = pickle.load(f)\n",
    "    print(f\"Loading FreeShap result from {shapley_file_path}\")\n",
    "except:\n",
    "    print(\"Computing FreeShap Results\")\n",
    "    dv_result = dshap_com.run(data_idx=sampled_idx, val_data_idx=sampled_val_idx, iteration=tmc_iter,\n",
    "                                  use_cache_ntk=True, prompt=prompt, seed=tmc_seed, num_dp=num_dp,\n",
    "                                  checkpoint=False, per_point=per_point, early_stopping=early_stopping)\n",
    "\n",
    "    mc_com = np.array(dshap_com.mc_cache)\n",
    "    result_dict = {'dv_result': dv_result,  # entropy, accuracy\n",
    "                   'sampled_idx': sampled_idx}\n",
    "    with open(shapley_file_path, \"wb\") as f:\n",
    "        pickle.dump(result_dict, f)\n",
    "    print(f\"Saving FreeShap result to {shapley_file_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40691cea-20ad-4b42-bd13-0265d38aac74",
   "metadata": {},
   "source": [
    "### Explain a test prediction (for instance the test point's index is 535)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4d019521",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'idx': 535, 'sentence': 'a quiet , pure , elliptical film ', 'label': 1}\n",
      "================================ Most influential ==========================\n",
      "score: 0.02  |  {'idx': 10497, 'sentence': 'a funny little film ', 'label': 1}\n",
      "score: 0.0175  |  {'idx': 3364, 'sentence': 'pretty good little movie ', 'label': 1}\n",
      "score: 0.015  |  {'idx': 53851, 'sentence': \"it 's refreshing that someone understands the need for the bad boy \", 'label': 1}\n",
      "score: 0.015  |  {'idx': 7139, 'sentence': 'its ripe recipe , inspiring ingredients ', 'label': 1}\n",
      "score: 0.015  |  {'idx': 5403, 'sentence': \"is that it 's a rock-solid little genre picture \", 'label': 1}\n",
      "score: 0.015  |  {'idx': 33494, 'sentence': 'trial movie , escape movie and unexpected fable ', 'label': 1}\n",
      "score: 0.015  |  {'idx': 55842, 'sentence': 'dazzling entertainment ', 'label': 1}\n",
      "score: 0.015  |  {'idx': 52439, 'sentence': 'the ya-ya sisterhood ', 'label': 1}\n",
      "score: 0.015  |  {'idx': 28783, 'sentence': 'a well-put-together piece ', 'label': 1}\n",
      "score: 0.0125  |  {'idx': 55951, 'sentence': 'a serious drama ', 'label': 1}\n",
      "================================ Least influential ==========================\n",
      "score: -0.0225  |  {'idx': 5321, 'sentence': 'a particularly slanted , gay s/m fantasy ', 'label': 0}\n",
      "score: -0.02  |  {'idx': 12133, 'sentence': 'overly-familiar and poorly-constructed comedy ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 18678, 'sentence': 'an empty exercise ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 55981, 'sentence': 'a shaky , uncertain film ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 53925, 'sentence': 'wishy-washy melodramatic movie ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 16948, 'sentence': 'a long , dull procession of despair ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 41452, 'sentence': 'bombastic self-glorification ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 53831, 'sentence': 'be better to wait for the video ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 21006, 'sentence': 'an inept , tedious spoof ', 'label': 0}\n",
      "score: -0.015  |  {'idx': 52560, 'sentence': 'it the adventures of direct-to-video nash ', 'label': 0}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "acc = result_dict['dv_result'][:, 1, :]\n",
    "acc_sum = np.sum(acc, axis=0)\n",
    "\n",
    "top_10_high = {}\n",
    "top_10_low = {}\n",
    "\n",
    "idx=535\n",
    "column_vector = acc[:, idx]\n",
    "print(f\"{reindex_valid_data[int(idx)]}\")\n",
    "top_10_high[idx] = np.argsort(column_vector)[-10:][::-1]  # Indices of top 5 highest values\n",
    "print(\"================================ Most influential ==========================\")\n",
    "for aindex in top_10_high[idx]:\n",
    "    print(f\"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}\")\n",
    "print(\"================================ Least influential ==========================\")\n",
    "top_10_low[idx] = np.argsort(column_vector)[:10]  # Indices of top 5 lowest values\n",
    "for aindex in top_10_low[idx]:\n",
    "    print(f\"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}\")\n",
    "print()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6679576c-b2a3-40ef-8cbc-ce401dc54aa9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'idx': 1, 'sentence': 'unflinchingly bleak and desperate ', 'label': 0}\n",
      "================================ Most influential ==========================\n",
      "score: 0.05999999999999999  |  {'idx': 25883, 'sentence': 'self-defeatingly decorous ', 'label': 0}\n",
      "score: 0.045  |  {'idx': 23066, 'sentence': 'oddly moving ', 'label': 1}\n",
      "score: 0.04  |  {'idx': 45633, 'sentence': 'vicious and absurd ', 'label': 0}\n",
      "score: 0.04  |  {'idx': 36959, 'sentence': 'scare ', 'label': 0}\n",
      "score: 0.04  |  {'idx': 11913, 'sentence': 'a deep vein of sadness ', 'label': 0}\n",
      "score: 0.035  |  {'idx': 39254, 'sentence': 'could be a passable date film ', 'label': 1}\n",
      "score: 0.035  |  {'idx': 51429, 'sentence': 'druggy and self-indulgent ', 'label': 0}\n",
      "score: 0.035  |  {'idx': 56599, 'sentence': 'predictable and cloying ', 'label': 0}\n",
      "score: 0.035  |  {'idx': 53672, 'sentence': 'multi-layered ', 'label': 1}\n",
      "score: 0.035  |  {'idx': 51283, 'sentence': 'absolutely , inescapably gorgeous , ', 'label': 1}\n",
      "================================ Least influential ==========================\n",
      "score: -0.10000000000000002  |  {'idx': 29181, 'sentence': 'excessively quirky ', 'label': 1}\n",
      "score: -0.08  |  {'idx': 66471, 'sentence': 'it tight and nasty ', 'label': 1}\n",
      "score: -0.075  |  {'idx': 2057, 'sentence': 'wickedly sick and twisted humor ', 'label': 1}\n",
      "score: -0.075  |  {'idx': 21938, 'sentence': 'damning and damned compelling ', 'label': 1}\n",
      "score: -0.06499999999999999  |  {'idx': 9318, 'sentence': 'fast-moving and cheerfully simplistic ', 'label': 1}\n",
      "score: -0.05999999999999999  |  {'idx': 30319, 'sentence': \"jonah 's despair -- in all its agonizing , catch-22 glory -- \", 'label': 1}\n",
      "score: -0.05999999999999999  |  {'idx': 19197, 'sentence': 'peril ', 'label': 1}\n",
      "score: -0.05999999999999999  |  {'idx': 12298, 'sentence': 'funny , somber , absurd , and , finally , achingly sad ', 'label': 1}\n",
      "score: -0.049999999999999996  |  {'idx': 4280, 'sentence': \"the paranoid claustrophobia of a submarine movie with the unsettling spookiness of the supernatural -- why did n't hollywood think of this sooner ? \", 'label': 1}\n",
      "score: -0.049999999999999996  |  {'idx': 28909, 'sentence': 'goes down easy ', 'label': 0}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "acc = result_dict['dv_result'][:, 1, :]\n",
    "acc_sum = np.sum(acc, axis=0)\n",
    "\n",
    "top_10_high = {}\n",
    "top_10_low = {}\n",
    "\n",
    "idx=1\n",
    "column_vector = acc[:, idx]\n",
    "print(f\"{reindex_valid_data[int(idx)]}\")\n",
    "top_10_high[idx] = np.argsort(column_vector)[-10:][::-1]  # Indices of top 5 highest values\n",
    "print(\"================================ Most influential ==========================\")\n",
    "for aindex in top_10_high[idx]:\n",
    "    print(f\"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}\")\n",
    "print(\"================================ Least influential ==========================\")\n",
    "top_10_low[idx] = np.argsort(column_vector)[:10]  # Indices of top 5 lowest values\n",
    "for aindex in top_10_low[idx]:\n",
    "    print(f\"score: {column_vector[int(aindex)]}  |  {train_data[int(aindex)]}\")\n",
    "print()\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a890dd0f-1668-450e-9e2a-008a6c429d9a",
   "metadata": {},
   "source": [
    "### Check most helpful/harmful data points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "86776f83-1143-48a6-8743-caa78fde6aad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================== Most Helpful ===================================\n",
      "score: 2.505 | {'idx': 13235, 'sentence': 'a dramatic comedy as pleasantly dishonest and pat as any hollywood fluff . ', 'label': 0}\n",
      "score: 2.28 | {'idx': 40598, 'sentence': \"it 's anchored by splendid performances from an honored screen veteran and a sparkling newcomer who instantly transform themselves into a believable mother/daughter pair . \", 'label': 1}\n",
      "score: 1.645 | {'idx': 54789, 'sentence': ', the humor dwindles . ', 'label': 0}\n",
      "score: 1.58 | {'idx': 13053, 'sentence': 'is highly pleasurable . ', 'label': 1}\n",
      "score: 1.56 | {'idx': 49861, 'sentence': ', alas , it collapses like an overcooked soufflé . ', 'label': 0}\n",
      "=================================== Most Harmful ===================================\n",
      "score: -1.6100000000000003 | {'idx': 43875, 'sentence': 'could possibly be more contemptuous of the single female population . ', 'label': 1}\n",
      "score: -2.245 | {'idx': 51208, 'sentence': ', the more outrageous bits achieve a shock-you-into-laughter intensity of almost dadaist proportions . ', 'label': 1}\n",
      "score: -2.785 | {'idx': 64148, 'sentence': 'could have easily become a cold , calculated exercise in postmodern pastiche winds up a powerful and deeply moving example of melodramatic moviemaking . ', 'label': 1}\n",
      "score: -3.25 | {'idx': 32491, 'sentence': 'kids who are into this thornberry stuff will probably be in wedgie heaven . ', 'label': 1}\n",
      "score: -3.7049999999999996 | {'idx': 4280, 'sentence': \"the paranoid claustrophobia of a submarine movie with the unsettling spookiness of the supernatural -- why did n't hollywood think of this sooner ? \", 'label': 1}\n"
     ]
    }
   ],
   "source": [
    "acc = result_dict['dv_result'][:, 1, :]\n",
    "acc_sum = np.sum(acc, axis=1)\n",
    "\n",
    "sorted_indices = np.argsort(acc_sum)[::-1]\n",
    "        \n",
    "top = 5\n",
    "cur = 0\n",
    "# top - sample\n",
    "equal_symbol=\"=\"* 35\n",
    "print(f\"{equal_symbol} Most Helpful {equal_symbol}\")\n",
    "for index in sorted_indices[:top]:\n",
    "    print(f\"score: {acc_sum[int(index)]} | {train_data[int(index)]}\")\n",
    "print(f\"{equal_symbol} Most Harmful {equal_symbol}\")\n",
    "for index in sorted_indices[-top:]:\n",
    "    print(f\"score: {acc_sum[int(index)]} | {train_data[int(index)]}\")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
