{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from utils import check_acc_cov\n",
    "from datasets import load_from_disk\n",
    "\n",
    "train_dataset = load_from_disk(\"datasets/smallgrammars_train\")\n",
    "validation_dataset = load_from_disk(\"datasets/smallgrammars_test\")\n",
    "\n",
    "influence_inf = pd.read_csv('cache/smallgrammars_10_inf.csv')\n",
    "influence_lissa = pd.read_csv('cache/smallgrammars_10_lissa.csv')\n",
    "gradient_match = pd.read_csv('cache/smallgrammars_10_gmatch.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "most_data_inf = influence_inf.apply(lambda x: x.abs().argmax(), axis=1)\n",
    "most_data_lissa = influence_lissa.apply(lambda x: x.abs().argmax(), axis=1)\n",
    "most_data_gmatch = gradient_match.apply(lambda x: x.abs().argmax(), axis=1)\n",
    "\n",
    "def check_example(val_id, target):\n",
    "    print(f'Validation Sample ID: {val_id}\\n', \n",
    "          validation_dataset[val_id][target], '\\n')\n",
    "    print(f'The most influential training sample: {int(most_data_inf.iloc[val_id])}\\n', \n",
    "          train_dataset[int(most_data_inf.iloc[val_id])][target], '\\n')\n",
    "    print('=' * 100)\n",
    "    print(f'The most influential training sample: {int(most_data_lissa.iloc[val_id])}\\n', \n",
    "          train_dataset[int(most_data_lissa.iloc[val_id])][target], '\\n')\n",
    "    print('=' * 100)\n",
    "    print(f'The most influential training sample: {int(most_data_gmatch.iloc[val_id])}\\n', \n",
    "          train_dataset[int(most_data_gmatch.iloc[val_id])][target], '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Sample ID: 25\n",
      " qadauo is a chatbot that performs a specific transformation on sentences: Insert Number 1 Between Every Word\n",
      "    For example:\n",
      "    Mountains touch heavens above. ->  \n",
      "\n",
      "The most influential training sample: 21\n",
      " qadauo is a chatbot that performs a specific transformation on sentences: Insert Number 1 Between Every Word\n",
      "    For example:\n",
      "    Silence speaks louder sometimes. ->  \n",
      "\n",
      "====================================================================================================\n",
      "The most influential training sample: 86\n",
      " xiycz is a chatbot that performs a specific transformation on sentences: Repeat Each Word Twice\n",
      "    For example:\n",
      "    Magic hides in corners. ->  \n",
      "\n",
      "====================================================================================================\n",
      "The most influential training sample: 50\n",
      " yeli is a chatbot that performs a specific transformation on sentences: Capitalize Every Word\n",
      "    For example:\n",
      "    Magic hides in corners. ->  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "check_example(25, 'prompts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.69 \n",
      "Cover: 0.5733333333333334\n"
     ]
    }
   ],
   "source": [
    "check_acc_cov(influence_inf, train_dataset, validation_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.1 \n",
      "Cover: 0.1\n"
     ]
    }
   ],
   "source": [
    "check_acc_cov(influence_lissa, train_dataset, validation_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc: 0.11 \n",
      "Cover: 0.10222222222222223\n"
     ]
    }
   ],
   "source": [
    "check_acc_cov(gradient_match, train_dataset, validation_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
