{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T07:41:03.562265Z",
     "start_time": "2024-05-22T07:41:03.559003Z"
    }
   },
   "outputs": [],
   "source": [
    "# Standard Imports\n",
    "import copy\n",
    "\n",
    "# Third Party Imports\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "644fbcb1-a948-4036-8fec-6f18369ec046",
   "metadata": {},
   "source": [
    "#### Data Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc9ca94d-fd7a-46c5-932d-01ca644994c9",
   "metadata": {},
   "source": [
    "**Reading the Blood Transfusion Dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77cb6a91a07ed68d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T07:41:03.731947Z",
     "start_time": "2024-05-22T07:41:03.723709Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Recency (months)</th>\n",
       "      <th>Frequency (times)</th>\n",
       "      <th>Monetary (c.c. blood)</th>\n",
       "      <th>Time (months)</th>\n",
       "      <th>Blood Donated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>50</td>\n",
       "      <td>12500</td>\n",
       "      <td>98</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>3250</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>16</td>\n",
       "      <td>4000</td>\n",
       "      <td>35</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>5000</td>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>6000</td>\n",
       "      <td>77</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Recency (months)  Frequency (times)  Monetary (c.c. blood)  Time (months)  \\\n",
       "0                 2                 50                  12500             98   \n",
       "1                 0                 13                   3250             28   \n",
       "2                 1                 16                   4000             35   \n",
       "3                 2                 20                   5000             45   \n",
       "4                 1                 24                   6000             77   \n",
       "\n",
       "   Blood Donated  \n",
       "0              1  \n",
       "1              1  \n",
       "2              1  \n",
       "3              1  \n",
       "4              0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Reading the Blood Dataset\n",
    "blood = pd.read_csv(\"./../../../datasets/blood/blood.data\")\n",
    "blood.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "317ecda9-87df-40b3-918d-19b7ccd05777",
   "metadata": {},
   "source": [
    "**Print Info and Missing Values**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1b167864-ced8-46fd-89d8-74266115ab90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 748 entries, 0 to 747\n",
      "Data columns (total 5 columns):\n",
      " #   Column                 Non-Null Count  Dtype\n",
      "---  ------                 --------------  -----\n",
      " 0   Recency (months)       748 non-null    int64\n",
      " 1   Frequency (times)      748 non-null    int64\n",
      " 2   Monetary (c.c. blood)  748 non-null    int64\n",
      " 3   Time (months)          748 non-null    int64\n",
      " 4   Blood Donated          748 non-null    int64\n",
      "dtypes: int64(5)\n",
      "memory usage: 29.3 KB\n",
      "None\n",
      "\n",
      "\n",
      "\n",
      "Missing values:  False\n"
     ]
    }
   ],
   "source": [
    "print(blood.info())\n",
    "print(\"\\n\\n\")\n",
    "print(\"Missing values: \", blood.isnull().values.any())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d921143a-172f-4f29-80c6-5a877643d067",
   "metadata": {},
   "source": [
    "**Variance Check**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "948672e3-370b-466e-97ac-be927c47d05f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Recency (months)              65.535\n",
       "Frequency (times)             34.098\n",
       "Monetary (c.c. blood)    2131094.230\n",
       "Time (months)                594.224\n",
       "Blood Donated                  0.182\n",
       "dtype: float64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blood.var().round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0aac7df9-f6fe-43ef-96a9-b10fa43a1ebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Log Transformation of Monetary (c.c. blood) to reduce the variance and range\n",
    "# blood['Monetary (c.c. blood)'] = blood['Monetary (c.c. blood)'] / 100\n",
    "# blood.var().round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a9aaeed7-a7e6-4fac-83fb-213f6b0c3774",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Recency (months)</th>\n",
       "      <th>Frequency (times)</th>\n",
       "      <th>Monetary (c.c. blood)</th>\n",
       "      <th>Time (months)</th>\n",
       "      <th>Blood Donated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>50</td>\n",
       "      <td>12500</td>\n",
       "      <td>98</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>3250</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>16</td>\n",
       "      <td>4000</td>\n",
       "      <td>35</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>5000</td>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>6000</td>\n",
       "      <td>77</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Recency (months)  Frequency (times)  Monetary (c.c. blood)  Time (months)  \\\n",
       "0                 2                 50                  12500             98   \n",
       "1                 0                 13                   3250             28   \n",
       "2                 1                 16                   4000             35   \n",
       "3                 2                 20                   5000             45   \n",
       "4                 1                 24                   6000             77   \n",
       "\n",
       "   Blood Donated  \n",
       "0              1  \n",
       "1              1  \n",
       "2              1  \n",
       "3              1  \n",
       "4              0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blood.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06776d06-aac3-4cd8-b1e8-b8498381d713",
   "metadata": {},
   "source": [
    "**Checking the distribution of target values**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "339d9bbe-8754-4572-87e1-3365bed7af5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Blood Donated\n",
       "0    0.762\n",
       "1    0.238\n",
       "Name: proportion, dtype: float64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blood[\"Blood Donated\"].value_counts(normalize=True).round(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ce86e3c-d237-459e-805b-83a02b971354",
   "metadata": {},
   "source": [
    "**Train-Test Split**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "852899ac-5378-4914-a9e7-c206ded4d0de",
   "metadata": {},
   "source": [
    "Train - Test Split &nbsp;&nbsp;&nbsp;&nbsp; 60%-40% \n",
    "\n",
    "The Test-Set is Further split into Verb And Gen split\n",
    "\n",
    "Verb - Gen Split   &nbsp;&nbsp;&nbsp;&nbsp; 50%-50%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bf7ff54f-f3f4-40e6-8bc4-424d99ae8af8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Test Split Stratified on Target Feature\n",
    "X_train, X_test, y_train, y_test = train_test_split(blood.drop(columns=\"Blood Donated\"), blood[\"Blood Donated\"], test_size=0.40, random_state=400, stratify=blood[\"Blood Donated\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5cc4deea-e333-4ce0-890d-ad1078aac260",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Test Split on the Test set to get Verb and Gen Sets. Stratified on the Target Feature\n",
    "X_test_verb, X_test_gen, y_test_verb, y_test_gen = train_test_split(X_test, y_test, test_size=0.5, random_state=400, stratify=y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bda17883-3302-4fc9-8b3b-46c6c669a070",
   "metadata": {},
   "source": [
    "**Checking the distribution target values after the split**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f79613d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Blood Donated\n",
      "0    0.761\n",
      "1    0.239\n",
      "Name: proportion, dtype: float64\n",
      "Blood Donated\n",
      "0    0.76\n",
      "1    0.24\n",
      "Name: proportion, dtype: float64\n",
      "Blood Donated\n",
      "0    0.767\n",
      "1    0.233\n",
      "Name: proportion, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "print(y_train.value_counts(normalize=True).round(3))\n",
    "print(y_test_verb.value_counts(normalize=True).round(3))\n",
    "print(y_test_gen.value_counts(normalize=True).round(3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccf6c676-335d-4133-afa2-5161bd932f2a",
   "metadata": {},
   "source": [
    "**Variance Check**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9bf865bb-067e-40f8-bfb8-03afba7c71a1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Recency (months)              71.826\n",
       "Frequency (times)             38.962\n",
       "Monetary (c.c. blood)    2435111.807\n",
       "Time (months)                569.385\n",
       "dtype: float64"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.var().round(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a7a2750-30fc-45f5-97d9-9bd0863cab60",
   "metadata": {},
   "source": [
    "**Standardization**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ebbccb2f-d31e-451c-83a4-ba3b35025796",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "X_train_scaled = scaler.fit_transform(X_train)\n",
    "X_test_gen_scaled = scaler.transform(X_test_gen)\n",
    "X_test_verb_scaled = scaler.transform(X_test_verb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99cb1c0e-8239-4489-a757-51307bfe7e8a",
   "metadata": {},
   "source": [
    "**Converting them back to dataframes**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "563a7aff-a958-42bf-a740-bfe04b397aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_train.columns)\n",
    "X_test_gen_scaled_df = pd.DataFrame(X_test_gen_scaled, columns=X_test.columns)\n",
    "X_test_verb_scaled_df = pd.DataFrame(X_test_verb_scaled, columns=X_test.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d53928be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Recency (months)         1.002\n",
       "Frequency (times)        1.002\n",
       "Monetary (c.c. blood)    1.002\n",
       "Time (months)            1.002\n",
       "dtype: float64"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_scaled_df.var().round(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97fe3ae9-5bdc-488c-9339-af497d6999ee",
   "metadata": {},
   "source": [
    "From here Onwards `X_test_verb_scaled_df` is going to be `X_test_scaled_df`. `X_test_gen_scaled_df` is not going to be used at all in the code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "13015d0d-47a8-4148-9f3c-074d1eec82b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "expansion_factor = 100\n",
    "X_train_scaled_df = X_train_scaled_df * expansion_factor\n",
    "X_test_verb_scaled_df = X_test_verb_scaled_df * expansion_factor\n",
    "X_test_gen_scaled_df = X_test_gen_scaled_df * expansion_factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9090224d-4b22-423c-932d-990204d4c703",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Recency (months)         10022.37136\n",
       "Frequency (times)        10022.37136\n",
       "Monetary (c.c. blood)    10022.37136\n",
       "Time (months)            10022.37136\n",
       "dtype: float64"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_scaled_df.var().round(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "61841898",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test_scaled_df = X_test_verb_scaled_df\n",
    "y_test = y_test_verb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c89636ce-0600-405c-ad76-9f9edb475f76",
   "metadata": {},
   "source": [
    "#### Model Variations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98c6a1a3",
   "metadata": {},
   "source": [
    "**Base Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1384fb24-8c96-4831-870f-5085890016c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best parameters found:  {'C': np.float64(0.0013101258887574746), 'class_weight': None, 'l1_ratio': np.float64(0.8745607282441125), 'max_iter': 990, 'penalty': 'l2', 'solver': 'sag', 'tol': np.float64(0.005361746692294743)}\n",
      "Best score:  0.6473408239700374\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l1)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l1)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l1)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l1)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l1)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:540: FitFailedWarning: \n",
      "15 fits failed out of a total of 40.\n",
      "The score on these train-test partitions for these parameters will be set to nan.\n",
      "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
      "\n",
      "Below are more details about the failures:\n",
      "--------------------------------------------------------------------------------\n",
      "5 fits failed with the following error:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 888, in _fit_and_score\n",
      "    estimator.fit(X_train, y_train, **fit_params)\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/base.py\", line 1473, in wrapper\n",
      "    return fit_method(estimator, *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py\", line 1194, in fit\n",
      "    solver = _check_solver(self.solver, self.penalty, self.dual)\n",
      "             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py\", line 67, in _check_solver\n",
      "    raise ValueError(\n",
      "ValueError: Solver newton-cg supports only 'l2' or None penalties, got elasticnet penalty.\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "5 fits failed with the following error:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 888, in _fit_and_score\n",
      "    estimator.fit(X_train, y_train, **fit_params)\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/base.py\", line 1473, in wrapper\n",
      "    return fit_method(estimator, *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py\", line 1194, in fit\n",
      "    solver = _check_solver(self.solver, self.penalty, self.dual)\n",
      "             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py\", line 67, in _check_solver\n",
      "    raise ValueError(\n",
      "ValueError: Solver lbfgs supports only 'l2' or None penalties, got elasticnet penalty.\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "5 fits failed with the following error:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/model_selection/_validation.py\", line 888, in _fit_and_score\n",
      "    estimator.fit(X_train, y_train, **fit_params)\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/base.py\", line 1473, in wrapper\n",
      "    return fit_method(estimator, *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py\", line 1194, in fit\n",
      "    solver = _check_solver(self.solver, self.penalty, self.dual)\n",
      "             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py\", line 67, in _check_solver\n",
      "    raise ValueError(\n",
      "ValueError: Solver lbfgs supports only 'l2' or None penalties, got l1 penalty.\n",
      "\n",
      "  warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1102: UserWarning: One or more of the test scores are non-finite: [       nan        nan 0.64289638 0.63625468 0.63625468 0.64734082\n",
      " 0.64734082        nan]\n",
      "  warnings.warn(\n",
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:1197: UserWarning: l1_ratio parameter is only used when penalty is 'elasticnet'. Got (penalty=l2)\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from scipy.stats import loguniform, uniform\n",
    "\n",
    "np.random.seed(400)\n",
    "\n",
    "param_dist = {\n",
    "    'C': loguniform(1e-4, 1e4),\n",
    "    'penalty': ['l1', 'l2', 'elasticnet'],\n",
    "    'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'],\n",
    "    'tol': loguniform(1e-6, 1e-1),\n",
    "    'max_iter': [int(x) for x in range(10, 1001)],\n",
    "    'class_weight': [None, 'balanced'],\n",
    "    'l1_ratio': uniform(0, 1)  # Only used if penalty is 'elasticnet'\n",
    "}\n",
    "\n",
    "# Create the logistic regression model\n",
    "log_reg = LogisticRegression()\n",
    "\n",
    "# Create the RandomizedSearchCV object\n",
    "random_search = RandomizedSearchCV(log_reg, param_distributions=param_dist, n_iter=8, cv=5, random_state=400)\n",
    "\n",
    "# Fit the model\n",
    "random_search.fit(X_train_scaled_df, y_train)\n",
    "\n",
    "# Print the best parameters and the best score\n",
    "print(\"Best parameters found: \", random_search.best_params_)\n",
    "print(\"Best score: \", random_search.best_score_)\n",
    "best_params = random_search.best_params_\n",
    "best_score = random_search.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9c1d7336",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shravan/NLE/Explainable-AI/venv/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "# Train the base model\n",
    "def train_base_model(X_train, y_train):\n",
    "    model = LogisticRegression(C=10.001, max_iter=13, random_state=400)\n",
    "    # model = LogisticRegression(**best_params)\n",
    "    model.fit(X_train, y_train)\n",
    "    return model\n",
    "\n",
    "base_model = train_base_model(X_train_scaled_df, y_train)\n",
    "base_pred = base_model.predict(X_test_scaled_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5f4bde23-82a5-4aa6-8e66-a21cb7dbb11b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7266666666666667\n"
     ]
    }
   ],
   "source": [
    "print(accuracy_score(y_test, base_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "592177b8-82c2-4fa2-9882-d7ce7ae9a40e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': np.float64(0.0013101258887574746),\n",
       " 'class_weight': None,\n",
       " 'l1_ratio': np.float64(0.8745607282441125),\n",
       " 'max_iter': 990,\n",
       " 'penalty': 'l2',\n",
       " 'solver': 'sag',\n",
       " 'tol': np.float64(0.005361746692294743)}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "087401a5-0018-4860-a224-9966a4e1b198",
   "metadata": {},
   "source": [
    "**Functions to copy and modify coefficients of a model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "9296776c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This function works for binary classification only\n",
    "def modify_coefficients(model, modification_factor=0.2):\n",
    "    modified_model = copy.deepcopy(model)\n",
    "    \n",
    "    # Get the coefficients and intercept\n",
    "    coef = modified_model.coef_[0]\n",
    "    intercept = modified_model.intercept_[0]\n",
    "    \n",
    "    # Modify coefficients\n",
    "    noise = np.random.normal(0, modification_factor, size=coef.shape)\n",
    "    modified_coef = coef * (1 + noise) # This helps scale the noise to the coefficients\n",
    "\n",
    "    # Print Modifications and Noise\n",
    "    print(f\"Modification Factor: {modification_factor}\", f\"Noise: {noise}\")\n",
    "    \n",
    "    # Modify intercept\n",
    "    modified_intercept = intercept * (1 + np.random.normal(0, modification_factor))\n",
    "    \n",
    "    # Set the modified coefficients and intercept\n",
    "    modified_model.coef_[0] = modified_coef\n",
    "    modified_model.intercept_[0] = modified_intercept\n",
    "    \n",
    "    return modified_model\n",
    "    \n",
    "# Function to compute the differences\n",
    "def compute_difference(pred1, pred2):\n",
    "    return np.mean(pred1 != pred2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb632903-7860-4b55-9fb8-b8a185e0ac82",
   "metadata": {},
   "source": [
    "**Generate and Compare Model Variations**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b20a39dd-a769-4559-87f8-6d7bf72cfa35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Modification Factor: 0.1 Noise: [-0.17497655  0.03426804  0.11530358 -0.0252436 ]\n",
      "Modification Factor: 0.2 Noise: [ 0.10284377  0.04423593 -0.21400867 -0.03789917]\n",
      "Modification Factor: 0.3 Noise: [-0.1374081   0.13054905 -0.17507852  0.24505412]\n",
      "Modification Factor: 0.4 Noise: [-0.04176446 -0.21251215  0.41189307 -0.17525425]\n",
      "Modification Factor: 0.5 Noise: [ 0.80949083  0.77080259 -0.12593957 -0.42121787]\n",
      "Modification Factor: 0.6 Noise: [ 0.56224932  0.43860021  0.81693368 -0.19574284]\n",
      "Modification Factor: 0.7 Noise: [ 0.15567973 -1.0102519  -0.52944661  0.57151781]\n",
      "Modification Factor: 0.8 Noise: [-0.36475754  0.95169781 -1.35249346 -1.08511924]\n",
      "Modification Factor: 0.9 Noise: [-0.48999525 -0.60135456  0.00658311 -0.55164486]\n",
      "Modification Factor: 1.0 Noise: [-1.73309562 -0.9833101   0.35750775 -1.6135785 ]\n",
      "Modification Factor: 1.1 Noise: [-1.30681936 -0.60472081 -1.03405078 -0.9107256 ]\n",
      "Modification Factor: 1.2 Noise: [ 0.60937151 -1.03467282  1.49936369 -0.0955335 ]\n",
      "Modification Factor: 1.3 Noise: [-1.14633791  0.02423063  0.30919801  0.01761311]\n",
      "Modification Factor: 1.4 Noise: [-1.46189383  0.85825443  1.0306873   1.43769002]\n",
      "Modification Factor: 1.5 Noise: [-2.76178245  0.54913984 -0.4976657  -1.03382697]\n",
      "Modification Factor: 1.6 Noise: [-0.88114306  1.20072533 -2.09118774  0.92891734]\n",
      "Modification Factor: 1.7 Noise: [ 1.1732065   1.16771311 -2.6633688   1.53845601]\n"
     ]
    }
   ],
   "source": [
    "# Generate Model Variations\n",
    "# np.random.seed(21)\n",
    "# np.random.seed(125)\n",
    "# np.random.seed(120)\n",
    "np.random.seed(100)\n",
    "\n",
    "# Create variations of the model\n",
    "variations = []\n",
    "modification_factors = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7]\n",
    "# modification_factors = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 1.1, 1.3, 1.4]\n",
    "\n",
    "for factor in modification_factors:\n",
    "    modified_model = modify_coefficients(base_model, factor)\n",
    "    variations.append((f\"Modified (factor={factor})\", modified_model))\n",
    "\n",
    "# Find the pair of models with the highest difference\n",
    "best_diff = 0\n",
    "best_pair = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "366a89cf-f11e-463c-ba84-3bbfcabbe0ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Comparison with the base model\n",
    "for name, model in variations:\n",
    "    modified_pred = model.predict(X_test_scaled_df)\n",
    "    diff = compute_difference(base_pred, modified_pred)\n",
    "    print(f\"With Model: {name}, Diff: {diff}\") \n",
    "    if diff > best_diff and (diff >= 0.20 and diff <=0.25):\n",
    "        print(\"Best Model till now is\", name)\n",
    "        best_diff = diff\n",
    "        best_pair = ((\"Base Model\", base_model), (name, model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34016cf3-d5b9-46b2-8fee-29d49da5dccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Comparing variations with each other\n",
    "for i in range(len(variations)):\n",
    "    for j in range(i + 1, len(variations)):\n",
    "        name1, model1 = variations[i]\n",
    "        name2, model2 = variations[j]\n",
    "        pred1 = model1.predict(X_test_scaled_df)\n",
    "        pred2 = model2.predict(X_test_scaled_df)\n",
    "        diff = compute_difference(pred1, pred2)\n",
    "        print(f\"With Models: {name1} vs {name2}, Diff: {diff}\")\n",
    "        if diff > best_diff and (diff >= 0.20 and diff <= 0.25):\n",
    "            print(f\"Best Model Pair till now is {name1} vs {name2}\")\n",
    "            best_diff = diff\n",
    "            best_pair = ((name1, model1), (name2, model2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84e9b87d-af8e-4cd9-b0df-14dd242c6958",
   "metadata": {},
   "outputs": [],
   "source": [
    "(model1_name, model1), (model2_name, model2) = best_pair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "604a50ca-5b68-4263-9d01-961facc8345c",
   "metadata": {},
   "outputs": [],
   "source": [
    "(model2_name, model2), (model1_name, model1) = best_pair"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ea03138-8060-44ec-80ba-72d49272ec80",
   "metadata": {},
   "source": [
    "**Evaluate Both Models**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e595955-68c6-41b3-8b0e-051de7cf86ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate both models\n",
    "y_pred_1 = model1.predict(X_test_scaled_df)\n",
    "y_pred_2 = model2.predict(X_test_scaled_df)\n",
    "\n",
    "accuracy1 = accuracy_score(y_test, y_pred_1)\n",
    "accuracy2 = accuracy_score(y_test, y_pred_2)\n",
    "\n",
    "# Print results\n",
    "print(f\"Model 1: {model1_name}\")\n",
    "print(\"Model 1 accuracy:\", accuracy1)\n",
    "print(f\"\\nModel 2: {model2_name}\")\n",
    "print(\"Model 2 accuracy:\", accuracy2)\n",
    "print(\"\\nAccuracy difference:\", abs(accuracy1 - accuracy2))\n",
    "print(f\"\\nPercentage of different outputs: {best_diff:.2%}\")\n",
    "\n",
    "print(\"\\n\\n\")\n",
    "print(model1.coef_)\n",
    "print(model2.coef_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd8a30d3-da04-4c7f-953f-a6eec2515698",
   "metadata": {},
   "source": [
    "#### Compare Model Boundaries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15fea562-bb74-41a7-a270-87fc65b1fa25",
   "metadata": {},
   "source": [
    "**Plot Model Decision Boundaries**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b922f5-4307-452f-90d2-397fc5e969ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to plot decision boundary\n",
    "def plot_decision_boundary(X_test, y_pred, feature_1, feature_2):\n",
    "    # Create a scatter plot of the predictions\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    plt.scatter(X_test[feature_1], X_test[feature_2], c=y_pred, cmap='rainbow', edgecolor='black', s=20, vmin=0, vmax=1)\n",
    "    \n",
    "    # Add labels and title\n",
    "    plt.xlabel(feature_1)\n",
    "    plt.ylabel(feature_2)\n",
    "    plt.title('Logistic Regression Decision Boundary')\n",
    "    \n",
    "    # Add a colorbar to indicate the predicted classes\n",
    "    cbar = plt.colorbar()\n",
    "    cbar.set_ticks([0, 1])\n",
    "    cbar.set_ticklabels([\"No\", \"Yes\"])\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d4d70a4-794b-4bc0-8659-d244baa22e7c",
   "metadata": {},
   "source": [
    "**Model 1 Decision Boundary**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de4db779-ea88-4543-944d-46c15b3e31de",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_boundary(X_test_scaled_df, y_pred_1, \"Monetary (c.c. blood)\", \"Frequency (times)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "599fcca1-20b3-4b16-89c4-4fba23c314ce",
   "metadata": {},
   "source": [
    "**Model 2 Decision Boundary**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff87fe6-d656-416c-88a7-47fd6d503a63",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_boundary(X_test_scaled_df, y_pred_2, \"Monetary (c.c. blood)\", \"Frequency (times)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3249b8a-a969-4b3c-a2a4-77e0e6f99f42",
   "metadata": {},
   "outputs": [],
   "source": [
    "def round_values(data, precision=3):\n",
    "    if isinstance(data, list):\n",
    "        return [round_values(item, precision) for item in data]\n",
    "    elif isinstance(data, dict):\n",
    "        return {key: round_values(value, precision) for key, value in data.items()}\n",
    "    elif isinstance(data, (int, float)):\n",
    "        return round(data, precision)\n",
    "    return data    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277cf8a1-0910-4840-a36e-2cbb3cbcdbd5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2efbe7b5-3156-42e3-8867-3435af12d92b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f19681ea-701d-47e6-b783-53de3937da76",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "182bc281-2188-4ef1-bfb1-de796f9b4e87",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b20f6fdd-cb92-43b7-b1bf-b32f34fea7ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "43a1d4fa-fb39-46e7-8c14-645e213aa7f3",
   "metadata": {},
   "source": [
    "#### Sample Data Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c67e91-d622-4a73-bedd-f3535c46c5a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_data(data, file_name, varname):\n",
    "    datastr = f\"\\n{varname} = {data}\"\n",
    "    \n",
    "    # Write this string to the file\n",
    "    with open(file_name, 'a') as file:\n",
    "        file.write(datastr)\n",
    "\n",
    "# This stays constant for this iPython file\n",
    "FILE_NAME = \"./../samples/blood/level_3.py\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "166d67f7-55ae-4e56-8ee5-a45649a95bf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of mismatched samples: {np.sum(y_pred_1 != y_pred_2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2ba908c-9ccc-43ed-aebc-34c5e3eeac92",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = []\n",
    "for idx in range(len(X_test_scaled_df)):\n",
    "    data_point = {\n",
    "        \"input\": round_values(X_test_scaled_df.iloc[idx].to_list()),\n",
    "        \"output\": {\n",
    "            \"model1\": int(y_pred_1[idx]),\n",
    "            \"model2\": int(y_pred_2[idx])\n",
    "        }\n",
    "    }\n",
    "    verb_data.append(data_point)\n",
    "\n",
    "print(f\"Number of samples in verb_data: {len(verb_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4051e96-5b23-4745-9954-d55495df2008",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97da7c78-ebf8-4ce7-913f-88dc065fcf6e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b0374cf-250d-4b9c-ad17-b30db0b6d8bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "varname = \"verb_data\"\n",
    "data = verb_data\n",
    "write_data(data, FILE_NAME, varname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53874348-fd8d-402c-86a5-8c1372bc8706",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dada228-0bb5-4ec6-b117-cf5ecd20d768",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d415f065-bda1-47a2-963a-d610530a5332",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f094023f-354c-465c-af91-64085494585d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32b9d1e8-7c1a-4839-8b3f-05eea5e03bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_gen_pred_1 = model1.predict(X_test_gen_scaled_df)\n",
    "y_gen_pred_2 = model2.predict(X_test_gen_scaled_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b11b0a3-deea-4f17-99f7-c992d2b470df",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_data = []\n",
    "for idx in range(len(X_test_gen_scaled_df)):\n",
    "    data_point = {\n",
    "        \"input\": round_values(X_test_gen_scaled_df.iloc[idx].to_list()),\n",
    "        \"output\": {\n",
    "            \"model1\": int(y_gen_pred_1[idx]),\n",
    "            \"model2\": int(y_gen_pred_2[idx])\n",
    "        }\n",
    "    }\n",
    "    gen_data.append(data_point)\n",
    "\n",
    "print(f\"Number of samples in gen_data: {len(gen_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64268a33-c206-4002-b456-f9ed0d5651dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "varname = \"gen_data\"\n",
    "data = gen_data\n",
    "write_data(data, FILE_NAME, varname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f132e020-5ee2-4c63-94a1-ab02b41d9a52",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569c615f-6aaa-42a0-bd78-57f67a13c8fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune_data(gen_data):\n",
    "    return [{\"input\": inst[\"input\"], \"output\": {\"model1\": inst[\"output\"][\"model1\"]}} for inst in gen_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca15745c-cc52-4855-a227-06aef729d537",
   "metadata": {},
   "outputs": [],
   "source": [
    "varname = \"gen_data_pruned\"\n",
    "data = prune_data(gen_data)\n",
    "write_data(data, FILE_NAME, varname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39486abe-a340-480f-a7e4-8449c05d372d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ffcfa7d-c895-42f0-bee6-dec69c44e339",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0acfe6a3-a651-4bcf-91a6-1cf038ff43a6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f985bf05-ccf6-4ba9-a6ee-c4a5eb7dc8f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfe59780-2e74-4418-9a06-cff330c76e6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b12448-1adf-43a5-bdad-c3b85be3fcc1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e988969-a87e-42f7-8a5e-d161adb0d32a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy(a, b):\n",
    "    correct = 0\n",
    "    total = len(a)\n",
    "    \n",
    "    for i in range(len(a)):\n",
    "        if(a[i]['input'] == b[i]['input']):\n",
    "            if(a[i]['output']['model2'] == b[i]['output']['model2']):\n",
    "                correct += 1\n",
    "        else:\n",
    "            print(\"Mismatch\")\n",
    "            print(a[i])\n",
    "            print(b[i])\n",
    "            print(\"\\n\\n\\n\")\n",
    "    \n",
    "    print(correct)\n",
    "    print(correct/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2e5462-92f8-43f3-9f15-c39d7a3c2916",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad485001-fb0e-4ad4-a968-40ce44e0a33c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23cd5e81-71ca-405f-a26e-836a19e3fcab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b0dff0-f837-4af3-a039-2a269a492db5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "432aac87-4b82-4366-8b72-6b1c2951d782",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f05dff-21c8-4b96-93ed-534230524c65",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "473a513b-6dea-497a-b710-5345bfb2a264",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36199077-e835-45d7-9b7b-0a6f6b2133cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prediction_zero(data):\n",
    "    prediction0_1 = [i for i in range(len(data)) if data[i][\"output\"][\"model1\"] == 1]\n",
    "    prediction0_2 = [i for i in range(len(data)) if data[i][\"output\"][\"model2\"] == 1]\n",
    "    return prediction0_1, prediction0_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aae6cf11-dea7-4ad8-b018-1a76ff690285",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model 1: {len(prediction_zero(verb_data)[0])}\")\n",
    "print(f\"Model 2: {len(prediction_zero(verb_data)[1])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba93234e-dd91-46b5-8002-081ba7d2810d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model 1: {len(prediction_zero(gen_data)[0])}\")\n",
    "print(f\"Model 2: {len(prediction_zero(gen_data)[1])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8339c15-a764-4120-9954-cfc0c0cbfa9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mismatches = [i for i in gen_data if i[\"output\"][\"model1\"] != i[\"output\"][\"model2\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f92cc5ac-5539-4b9e-af70-c11a63106546",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(mismatches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ad80f5c-c487-4851-af7f-c5b45eab517a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mismatches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "067664f2-4f75-4821-b979-6aaeae3e2be8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41aae9d-7dcc-4893-adaf-4e9082bde5ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d04ab0-774f-4f2a-8b11-5a687fcc1ed2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "505413c4-859b-4b67-9423-427c46c0e217",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a5c4bdc-444e-4511-88b4-7067f0adf595",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cab6a327-4ac2-4d0e-a191-682ebdc7d13f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95c94e15-3b72-4b88-8fa7-7dd9e5993159",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11becfd7-0870-4c55-8094-a2f023dc98e5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2ad600e-5826-41c9-8f92-f9a7e5cbaa99",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
