{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T07:41:03.562265Z",
     "start_time": "2024-05-22T07:41:03.559003Z"
    }
   },
   "outputs": [],
   "source": [
    "# Standard Imports\n",
    "import copy\n",
    "\n",
    "# Third Party Imports\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "644fbcb1-a948-4036-8fec-6f18369ec046",
   "metadata": {},
   "source": [
    "#### Data Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc9ca94d-fd7a-46c5-932d-01ca644994c9",
   "metadata": {},
   "source": [
    "**Reading the Blood Transfusion Dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77cb6a91a07ed68d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T07:41:03.731947Z",
     "start_time": "2024-05-22T07:41:03.723709Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Recency (months)</th>\n",
       "      <th>Frequency (times)</th>\n",
       "      <th>Monetary (c.c. blood)</th>\n",
       "      <th>Time (months)</th>\n",
       "      <th>Blood Donated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>50</td>\n",
       "      <td>12500</td>\n",
       "      <td>98</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>3250</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>16</td>\n",
       "      <td>4000</td>\n",
       "      <td>35</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>5000</td>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>6000</td>\n",
       "      <td>77</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Recency (months)  Frequency (times)  Monetary (c.c. blood)  Time (months)  \\\n",
       "0                 2                 50                  12500             98   \n",
       "1                 0                 13                   3250             28   \n",
       "2                 1                 16                   4000             35   \n",
       "3                 2                 20                   5000             45   \n",
       "4                 1                 24                   6000             77   \n",
       "\n",
       "   Blood Donated  \n",
       "0              1  \n",
       "1              1  \n",
       "2              1  \n",
       "3              1  \n",
       "4              0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Reading the Blood Dataset\n",
    "blood = pd.read_csv(\"./../../../datasets/blood/blood.data\")\n",
    "blood.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "317ecda9-87df-40b3-918d-19b7ccd05777",
   "metadata": {},
   "source": [
    "**Print Info and Missing Values**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1b167864-ced8-46fd-89d8-74266115ab90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 748 entries, 0 to 747\n",
      "Data columns (total 5 columns):\n",
      " #   Column                 Non-Null Count  Dtype\n",
      "---  ------                 --------------  -----\n",
      " 0   Recency (months)       748 non-null    int64\n",
      " 1   Frequency (times)      748 non-null    int64\n",
      " 2   Monetary (c.c. blood)  748 non-null    int64\n",
      " 3   Time (months)          748 non-null    int64\n",
      " 4   Blood Donated          748 non-null    int64\n",
      "dtypes: int64(5)\n",
      "memory usage: 29.3 KB\n",
      "None\n",
      "\n",
      "\n",
      "\n",
      "Missing values:  False\n"
     ]
    }
   ],
   "source": [
    "print(blood.info())\n",
    "print(\"\\n\\n\")\n",
    "print(\"Missing values: \", blood.isnull().values.any())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d921143a-172f-4f29-80c6-5a877643d067",
   "metadata": {},
   "source": [
    "**Variance Check**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "948672e3-370b-466e-97ac-be927c47d05f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Recency (months)              65.535\n",
       "Frequency (times)             34.098\n",
       "Monetary (c.c. blood)    2131094.230\n",
       "Time (months)                594.224\n",
       "Blood Donated                  0.182\n",
       "dtype: float64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blood.var().round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0aac7df9-f6fe-43ef-96a9-b10fa43a1ebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Log Transformation of Monetary (c.c. blood) to reduce the variance and range\n",
    "# blood['Monetary (c.c. blood)'] = blood['Monetary (c.c. blood)'] / 100\n",
    "# blood.var().round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a9aaeed7-a7e6-4fac-83fb-213f6b0c3774",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Recency (months)</th>\n",
       "      <th>Frequency (times)</th>\n",
       "      <th>Monetary (c.c. blood)</th>\n",
       "      <th>Time (months)</th>\n",
       "      <th>Blood Donated</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>50</td>\n",
       "      <td>12500</td>\n",
       "      <td>98</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>3250</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>16</td>\n",
       "      <td>4000</td>\n",
       "      <td>35</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>5000</td>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>6000</td>\n",
       "      <td>77</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Recency (months)  Frequency (times)  Monetary (c.c. blood)  Time (months)  \\\n",
       "0                 2                 50                  12500             98   \n",
       "1                 0                 13                   3250             28   \n",
       "2                 1                 16                   4000             35   \n",
       "3                 2                 20                   5000             45   \n",
       "4                 1                 24                   6000             77   \n",
       "\n",
       "   Blood Donated  \n",
       "0              1  \n",
       "1              1  \n",
       "2              1  \n",
       "3              1  \n",
       "4              0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blood.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06776d06-aac3-4cd8-b1e8-b8498381d713",
   "metadata": {},
   "source": [
    "**Checking the distribution of target values**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "339d9bbe-8754-4572-87e1-3365bed7af5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Blood Donated\n",
       "0    0.762\n",
       "1    0.238\n",
       "Name: proportion, dtype: float64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blood[\"Blood Donated\"].value_counts(normalize=True).round(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ce86e3c-d237-459e-805b-83a02b971354",
   "metadata": {},
   "source": [
    "**Train-Test Split**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "852899ac-5378-4914-a9e7-c206ded4d0de",
   "metadata": {},
   "source": [
    "Train - Test Split &nbsp;&nbsp;&nbsp;&nbsp; 70%-25%\n",
    "\n",
    "The Test-Set is Further split into Verb And Gen split\n",
    "\n",
    "Verb - Gen Split   &nbsp;&nbsp;&nbsp;&nbsp; 50%-50%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bf7ff54f-f3f4-40e6-8bc4-424d99ae8af8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Test Split Stratified on Target Feature\n",
    "X_train, X_test, y_train, y_test = train_test_split(blood.drop(columns=\"Blood Donated\"), blood[\"Blood Donated\"], test_size=0.25, random_state=400, stratify=blood[\"Blood Donated\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5cc4deea-e333-4ce0-890d-ad1078aac260",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Test Split on the Test set to get Verb and Gen Sets. Stratified on the Target Feature\n",
    "X_test_verb, X_test_gen, y_test_verb, y_test_gen = train_test_split(X_test, y_test, test_size=0.5, random_state=400, stratify=y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bda17883-3302-4fc9-8b3b-46c6c669a070",
   "metadata": {},
   "source": [
    "**Checking the distribution target values after the split**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f79613d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(y_train.value_counts(normalize=True).round(3))\n",
    "\n",
    "print(y_test_verb.value_counts(normalize=True).round(3))\n",
    "\n",
    "print(y_test_gen.value_counts(normalize=True).round(3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccf6c676-335d-4133-afa2-5161bd932f2a",
   "metadata": {},
   "source": [
    "**Variance Check**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf865bb-067e-40f8-bfb8-03afba7c71a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train.var().round(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a7a2750-30fc-45f5-97d9-9bd0863cab60",
   "metadata": {},
   "source": [
    "**Standardization** - Skipping this because the resulting dataset has very high precision float values (that may not be good for the LLM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebbccb2f-d31e-451c-83a4-ba3b35025796",
   "metadata": {},
   "outputs": [],
   "source": [
    "# scaler = StandardScaler()\n",
    "# X_train_scaled = scaler.fit_transform(X_train)\n",
    "# X_test_gen_scaled = scaler.transform(X_test_gen)\n",
    "# X_test_verb_scaled = scaler.transform(X_test_verb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99cb1c0e-8239-4489-a757-51307bfe7e8a",
   "metadata": {},
   "source": [
    "**Converting them back to dataframes** - No need for this now as there's no standardization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "563a7aff-a958-42bf-a740-bfe04b397aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_train.columns)\n",
    "# X_test_gen_scaled_df = pd.DataFrame(X_test_gen_scaled, columns=X_test.columns)\n",
    "# X_test_verb_scaled_df = pd.DataFrame(X_test_verb_scaled, columns=X_test.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53928be",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train.var().round(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97fe3ae9-5bdc-488c-9339-af497d6999ee",
   "metadata": {},
   "source": [
    "From here Onwards `X_test_verb_scaled_df` is going to be `X_test_scaled_df`. `X_test_gen_scaled_df` is not going to be used at all in the code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61841898",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_scaled_df = X_train\n",
    "X_test_scaled_df = X_test_verb\n",
    "y_test = y_test_verb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a87a0a77-0623-4f63-b3ca-ea30e19fde46",
   "metadata": {},
   "source": [
    "#### Model Variations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6bcbf72-6031-44ae-84b6-5fb629c78e1f",
   "metadata": {},
   "source": [
    "**Helper Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61462da2-b794-495a-b4b7-6ead0a222549",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to plot the decision tree\n",
    "def plot_decision_tree(decision_tree, feature_names, class_names=None):\n",
    "    plt.figure(figsize=(30, 10))\n",
    "    plot_tree(decision_tree, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, fontsize=12)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98c6a1a3",
   "metadata": {},
   "source": [
    "**Base Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bfb42b7-65e2-4992-b68f-ccf0bf452c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "\n",
    "np.random.seed(400)\n",
    "# np.random.seed(1234)\n",
    "\n",
    "# Define the parameter distributions\n",
    "param_distributions = {\n",
    "    'max_depth': [None] + list(range(1, 31)),\n",
    "    'min_samples_split': range(2, 21),\n",
    "    'min_samples_leaf': range(1, 21),\n",
    "    # 'max_features': [None, 'sqrt', 'log2'],\n",
    "    'criterion': ['gini'],\n",
    "    'splitter': ['best', 'random'],\n",
    "    'min_weight_fraction_leaf': np.linspace(0, 0.5, 100),  # Creates 100 evenly spaced values\n",
    "    'max_leaf_nodes': [None] + list(range(5, 25))\n",
    "}\n",
    "\n",
    "# Initialize the model\n",
    "dt = DecisionTreeClassifier()\n",
    "\n",
    "# Initialize RandomizedSearchCV\n",
    "random_search = RandomizedSearchCV(\n",
    "    estimator=dt,\n",
    "    param_distributions=param_distributions,\n",
    "    n_iter=10,\n",
    "    cv=5,\n",
    "    n_jobs=-1,\n",
    "    scoring='accuracy',\n",
    "    random_state=400  # For reproducibility\n",
    ")\n",
    "\n",
    "# Fit the model\n",
    "random_search.fit(X_train_scaled_df, y_train)\n",
    "\n",
    "# Get the best parameters\n",
    "best_params = random_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c1d7336",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the base model\n",
    "def train_base_model(X_train, y_train):\n",
    "    best_dt = DecisionTreeClassifier(**best_params)\n",
    "    best_dt.fit(X_train, y_train)\n",
    "    return best_dt\n",
    "\n",
    "base_model = train_base_model(X_train_scaled_df, y_train)\n",
    "base_pred = base_model.predict(X_test_scaled_df)\n",
    "accuracy_score(y_test, base_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53505624-8563-43ec-90e6-24ca812ac563",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "087401a5-0018-4860-a224-9966a4e1b198",
   "metadata": {},
   "source": [
    "**Functions to generate random variations of decision tree models**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9296776c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_hyperparameters():\n",
    "    hyperparameters = {\n",
    "        'max_depth': np.random.randint(1, 21),\n",
    "        'min_samples_split': np.random.randint(2, 21),\n",
    "        'min_samples_leaf': np.random.randint(1, 21),\n",
    "        'max_features': np.random.choice([None, 'sqrt', 'log2']),\n",
    "        'criterion': np.random.choice(['gini', 'entropy']),\n",
    "        'splitter': np.random.choice(['best', 'random']),\n",
    "        'min_weight_fraction_leaf': np.random.uniform(0.0, 0.5),\n",
    "        'max_leaf_nodes': np.random.choice([None, *range(2, 26)])\n",
    "    }\n",
    "    return hyperparameters\n",
    "\n",
    "def add_noise_to_thresholds(tree, modification_factor=0.2):\n",
    "    # Copy the tree and generate random noise\n",
    "    modified_tree = copy.deepcopy(tree)\n",
    "    thresholds = modified_tree.tree_.threshold\n",
    "    noise = np.random.normal(0, modification_factor, size=thresholds.shape)\n",
    "    \n",
    "    # generate new thresholds\n",
    "    new_thresholds = thresholds * (1 + noise)\n",
    "\n",
    "    # Have to use slicing because threshold is not writeable\n",
    "    modified_tree.tree_.threshold[:] = new_thresholds\n",
    "    return modified_tree\n",
    "\n",
    "def compare_models(base_model, modified_model, X):\n",
    "    base_predictions = base_model.predict(X)\n",
    "    modified_predictions = modified_model.predict(X)\n",
    "    mismatch_percentage = np.mean(base_predictions != modified_predictions)\n",
    "    return mismatch_percentage\n",
    "\n",
    "def compute_diff(y_pred_1, y_pred_2):\n",
    "    return np.mean(y_pred_1 != y_pred_2)\n",
    "\n",
    "def modify_decision_tree(base_model, X_train, y_train, X_test, modification_factors):\n",
    "    best_mismatch_percentage = -1\n",
    "    best_modified_model = None\n",
    "    \n",
    "    for i in range(len(modification_factors)):\n",
    "        hyperparameters = generate_random_hyperparameters()\n",
    "        for modification_factor in modification_factors:\n",
    "            random_model = DecisionTreeClassifier(**hyperparameters)\n",
    "            random_model.fit(X_train, y_train)\n",
    "            \n",
    "            modified_model = add_noise_to_thresholds(random_model, modification_factor)\n",
    "            mismatch_percentage = compare_models(base_model, modified_model, X_test)\n",
    "            print(f\"Modification factor: {modification_factor}, Mismatch percentage: {mismatch_percentage:.2f}\")\n",
    "            \n",
    "            if mismatch_percentage > best_mismatch_percentage and (round(mismatch_percentage, 2) >=0.25 and round(mismatch_percentage, 2) <= 0.30):\n",
    "                best_mismatch_percentage = mismatch_percentage\n",
    "                print(f\"Best Model Found with {best_mismatch_percentage}\")\n",
    "                best_modified_model = modified_model\n",
    "    \n",
    "    return best_modified_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb632903-7860-4b55-9fb8-b8a185e0ac82",
   "metadata": {},
   "source": [
    "**Generate and Compare Model Variations**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b0870f5-53fa-445f-8f91-e25e14e9e89c",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model.tree_.threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e30765c-4a5c-4d95-93a1-90b51ff03c8e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d408e6-b010-4d78-abfa-55944b10da72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.random.seed(138)\n",
    "# np.random.seed(121)\n",
    "# np.random.seed(123)\n",
    "# np.random.seed(400)\n",
    "np.random.seed(153)\n",
    "# np.random.seed(125)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "078b33f9-aeb2-443e-975e-9ff5e16b0bb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "modification_factors = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
    "modified_tree = modify_decision_tree(base_model, X_train, y_train, X_test_scaled_df, modification_factors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1cdef7c-4522-43ad-959c-ee9404a503e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_tree(base_model, X_train.columns, class_names=[\"0\", \"1\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d906dd12-16f9-4650-a3a2-22dfdb13ab3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_tree(modified_tree, X_train.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab694fe8-8735-48d3-97c0-2ce64df963c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Renaming the models\n",
    "model1 = base_model\n",
    "model2 = modified_tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15fa4c56-456e-4865-bfb2-3950a1722062",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2 = base_model\n",
    "model1 = modified_tree"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ea03138-8060-44ec-80ba-72d49272ec80",
   "metadata": {},
   "source": [
    "**Evaluate Both Models**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e595955-68c6-41b3-8b0e-051de7cf86ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate both models\n",
    "y_pred_1 = model1.predict(X_test_scaled_df)\n",
    "y_pred_2 = model2.predict(X_test_scaled_df)\n",
    "\n",
    "accuracy1 = accuracy_score(y_test, y_pred_1)\n",
    "accuracy2 = accuracy_score(y_test, y_pred_2)\n",
    "\n",
    "# Print results\n",
    "print(\"Model 1 accuracy:\", accuracy1)\n",
    "print(\"Model 2 accuracy:\", accuracy2)\n",
    "print(\"\\nAccuracy difference:\", abs(accuracy1 - accuracy2))\n",
    "print(f\"\\nPercentage of different outputs: {compute_diff(y_pred_1, y_pred_2):.2%}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae007ade-1c68-4393-bc5c-eccf0668a368",
   "metadata": {},
   "outputs": [],
   "source": [
    "model1.get_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64ca79ff-3ef4-42b4-8c99-deb3718bc3ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2.get_params()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd8a30d3-da04-4c7f-953f-a6eec2515698",
   "metadata": {},
   "source": [
    "#### Compare Model Boundaries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15fea562-bb74-41a7-a270-87fc65b1fa25",
   "metadata": {},
   "source": [
    "**Plot Model Decision Boundaries**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b922f5-4307-452f-90d2-397fc5e969ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to plot decision boundary\n",
    "def plot_decision_boundary(X_test, y_pred, feature_1, feature_2):\n",
    "    # Create a scatter plot of the predictions\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    plt.scatter(X_test[feature_1], X_test[feature_2], c=y_pred, cmap='rainbow', edgecolor='black', s=20)\n",
    "    \n",
    "    # Add labels and title\n",
    "    plt.xlabel(feature_1)\n",
    "    plt.ylabel(feature_2)\n",
    "    plt.title('Logistic Regression Decision Boundary')\n",
    "    \n",
    "    # Add a colorbar to indicate the predicted classes\n",
    "    cbar = plt.colorbar()\n",
    "    cbar.set_ticks([0, 1])\n",
    "    cbar.set_ticklabels([\"No\", \"Yes\"])\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d4d70a4-794b-4bc0-8659-d244baa22e7c",
   "metadata": {},
   "source": [
    "**Model 1 Decision Boundary**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de4db779-ea88-4543-944d-46c15b3e31de",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_boundary(X_test_scaled_df, y_pred_1, \"Monetary (c.c. blood)\", \"Frequency (times)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "599fcca1-20b3-4b16-89c4-4fba23c314ce",
   "metadata": {},
   "source": [
    "**Model 2 Decision Boundary**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff87fe6-d656-416c-88a7-47fd6d503a63",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_decision_boundary(X_test_scaled_df, y_pred_2, \"Monetary (c.c. blood)\", \"Frequency (times)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3249b8a-a969-4b3c-a2a4-77e0e6f99f42",
   "metadata": {},
   "outputs": [],
   "source": [
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b20f6fdd-cb92-43b7-b1bf-b32f34fea7ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "43a1d4fa-fb39-46e7-8c14-645e213aa7f3",
   "metadata": {},
   "source": [
    "#### Sample Data Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4c67e91-d622-4a73-bedd-f3535c46c5a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_data(data, file_name, varname):\n",
    "    datastr = f\"\\n{varname} = {data}\"\n",
    "    \n",
    "    # Write this string to the file\n",
    "    with open(file_name, 'a') as file:\n",
    "        file.write(datastr)\n",
    "\n",
    "# This stays constant for this iPython file\n",
    "FILE_NAME = \"./../samples/blood/level_1.py\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "166d67f7-55ae-4e56-8ee5-a45649a95bf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of mismatched samples: {np.sum(y_pred_1 != y_pred_2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2ba908c-9ccc-43ed-aebc-34c5e3eeac92",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = []\n",
    "for idx in range(len(X_test_scaled_df)):\n",
    "    data_point = {\n",
    "        \"input\": X_test_scaled_df.iloc[idx].to_list(),\n",
    "        \"output\": {\n",
    "            \"model1\": int(y_pred_1[idx]),\n",
    "            \"model2\": int(y_pred_2[idx])\n",
    "        }\n",
    "    }\n",
    "    verb_data.append(data_point)\n",
    "\n",
    "print(f\"Number of samples in verb_data: {len(verb_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b0374cf-250d-4b9c-ad17-b30db0b6d8bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "varname = \"verb_data\"\n",
    "data = verb_data\n",
    "write_data(data, FILE_NAME, varname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc9f635",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9c5769",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32b9d1e8-7c1a-4839-8b3f-05eea5e03bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_gen_pred_1 = model1.predict(X_test_gen)\n",
    "y_gen_pred_2 = model2.predict(X_test_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b11b0a3-deea-4f17-99f7-c992d2b470df",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_data = []\n",
    "for idx in range(len(X_test_gen)):\n",
    "    data_point = {\n",
    "        \"input\": X_test_gen.iloc[idx].to_list(),\n",
    "        \"output\": {\n",
    "            \"model1\": int(y_gen_pred_1[idx]),\n",
    "            \"model2\": int(y_gen_pred_2[idx])\n",
    "        }\n",
    "    }\n",
    "    gen_data.append(data_point)\n",
    "\n",
    "print(f\"Number of samples in gen_data: {len(gen_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64268a33-c206-4002-b456-f9ed0d5651dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "varname = \"gen_data\"\n",
    "data = gen_data\n",
    "write_data(data, FILE_NAME, varname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f132e020-5ee2-4c63-94a1-ab02b41d9a52",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe35af28",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569c615f-6aaa-42a0-bd78-57f67a13c8fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune_data(gen_data):\n",
    "    return [{\"input\": inst[\"input\"], \"output\": {\"model1\": inst[\"output\"][\"model1\"]}} for inst in gen_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca15745c-cc52-4855-a227-06aef729d537",
   "metadata": {},
   "outputs": [],
   "source": [
    "varname = \"gen_data_pruned\"\n",
    "data = prune_data(gen_data)\n",
    "write_data(data, FILE_NAME, varname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39486abe-a340-480f-a7e4-8449c05d372d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00da3464",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ffcfa7d-c895-42f0-bee6-dec69c44e339",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_structures(structure_text, file_name, varname):\n",
    "    datastr = f\"\\n{varname} = '''{structure_text}'''\"\n",
    "    \n",
    "    # write this string to the file\n",
    "    with open(file_name, 'a') as file:\n",
    "        file.write(datastr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0acfe6a3-a651-4bcf-91a6-1cf038ff43a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "STRUCTURES_FILE_NAME = \"./../structures/blood/level_1.py\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfe59780-2e74-4418-9a06-cff330c76e6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model1_text = export_text(model1, feature_names=X_train.columns)\n",
    "model2_text = export_text(model2, feature_names=X_train.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c585b9b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "write_structures(model1_text, STRUCTURES_FILE_NAME, \"model1\")\n",
    "write_structures(model2_text, STRUCTURES_FILE_NAME, \"model2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc266e8a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93681121",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b8f7c24",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "854c8d87",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91d06e6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e988969-a87e-42f7-8a5e-d161adb0d32a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy(a, b):\n",
    "    correct = 0\n",
    "    total = len(a)\n",
    "    \n",
    "    for i in range(len(a)):\n",
    "        if(a[i]['input'] == b[i]['input']):\n",
    "            if(a[i]['output']['model2'] == b[i]['output']['model2']):\n",
    "                correct += 1\n",
    "        else:\n",
    "            print(\"Mismatch\")\n",
    "            print(a[i])\n",
    "            print(b[i])\n",
    "            print(\"\\n\\n\\n\")\n",
    "    \n",
    "    print(correct)\n",
    "    print(correct/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2e5462-92f8-43f3-9f15-c39d7a3c2916",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad485001-fb0e-4ad4-a968-40ce44e0a33c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23cd5e81-71ca-405f-a26e-836a19e3fcab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36dfd05-2889-4a04-bcf2-44755e582830",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prediction_zero(data):\n",
    "    prediction0_1 = [i for i in range(len(data)) if data[i][\"output\"][\"model1\"] == 0]\n",
    "    prediction0_2 = [i for i in range(len(data)) if data[i][\"output\"][\"model2\"] == 0]\n",
    "    return prediction0_1, prediction0_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346e21df-a19c-4af6-81f2-30017c29bece",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model 1: {len(prediction_zero(verb_data)[0])}\")\n",
    "print(f\"Model 2: {len(prediction_zero(verb_data)[1])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a443c126-4c38-4bef-b3cf-cf959db8269f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Model 1: {len(prediction_zero(gen_data)[0])}\")\n",
    "print(f\"Model 2: {len(prediction_zero(gen_data)[1])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e73d4fd9-65ec-48ae-86e5-f1c9f6c279b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d3f2ea6-7d97-4a55-88d9-974b64e8b102",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e182992-fd80-4bac-974d-f2e6ae379660",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = [{'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 98], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 8, 2000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 10, 2500, 89], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 41], 'output': {'model1': 1, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 77], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 4, 1000, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 11, 2750, 28], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 6, 1500, 47], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 15, 3750, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 41], 'output': {'model1': 1, 'model2': 1}}, {'input': [7, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 5, 1250, 60], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 14], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 11, 2750, 23], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 14, 3500, 40], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 9, 2250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 58], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 3, 750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 8, 2000, 72], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 5, 1250, 51], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 22], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 3, 750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 93], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [1, 14, 3500, 95], 'output': {'model1': 1, 'model2': 1}}, {'input': [16, 3, 750, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 13, 3250, 59], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 3, 750, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 76], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 5, 1250, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 45], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 7, 1750, 72], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 32], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 4, 1000, 34], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 7, 1750, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [39, 1, 250, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 8, 2000, 50], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 4, 1000, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 48], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 1}}, {'input': [9, 8, 2000, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 8, 2000, 38], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 4, 1000, 57], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [18, 8, 2000, 95], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 10], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [26, 5, 1250, 49], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 9, 2250, 89], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 6, 1500, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}]\n",
    "gen_data = [{'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 8, 2000, 40], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 2, 500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 5, 1250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 9, 2250, 75], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 45], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 12, 3000, 34], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 3, 750, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 23, 5750, 58], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 3, 750, 24], 'output': {'model1': 0, 'model2': 1}}, {'input': [13, 7, 1750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 8, 2000, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 15, 3750, 64], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 16, 4000, 70], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [5, 6, 1500, 28], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 3, 750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 75], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 3, 750, 38], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 47], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 39], 'output': {'model1': 1, 'model2': 0}}, {'input': [14, 16, 4000, 98], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 5, 1250, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 33], 'output': {'model1': 0, 'model2': 1}}, {'input': [8, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 8, 2000, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [9, 4, 1000, 65], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 4, 1000, 29], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 5, 1250, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 12, 3000, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 22], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 6, 1500, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 9, 2250, 52], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 19], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [6, 3, 750, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 2, 500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 6, 1500, 21], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 5, 1250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 4, 1000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 11, 2750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 10, 2500, 28], 'output': {'model1': 1, 'model2': 0}}, {'input': [3, 14, 3500, 35], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 82], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b130c3-457b-4549-8fc8-6919003ff04d",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = [{'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 98], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 8, 2000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 10, 2500, 89], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 41], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 77], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 4, 1000, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 11, 2750, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 6, 1500, 47], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 15, 3750, 71], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 41], 'output': {'model1': 1, 'model2': 1}}, {'input': [7, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 26], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 8, 2000, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 5, 1250, 60], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 14], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 4, 1000, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 11, 2750, 23], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 14, 3500, 40], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 9, 2250, 33], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 58], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 3, 750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 8, 2000, 72], 'output': {'model1': 0, 'model2': 1}}, {'input': [9, 5, 1250, 51], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 3, 750, 22], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 3, 750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 93], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [1, 14, 3500, 95], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 3, 750, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 13, 3250, 59], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 3, 750, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 76], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 5, 1250, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 45], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 7, 1750, 72], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 32], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 4, 1000, 34], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 7, 1750, 58], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [39, 1, 250, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 8, 2000, 50], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 4, 1000, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 3, 750, 48], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 8, 2000, 38], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 8, 2000, 38], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 4, 1000, 57], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 5, 1250, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [18, 8, 2000, 95], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 10], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 1}}, {'input': [26, 5, 1250, 49], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 9, 2250, 89], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 6, 1500, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 5, 1250, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}]\n",
    "gen_data = [{'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 11], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 8, 2000, 40], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 9, 2250, 75], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 45], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 12, 3000, 34], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 23, 5750, 58], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 3, 750, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 7, 1750, 76], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 8, 2000, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 15, 3750, 64], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 16, 4000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [5, 6, 1500, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 3, 750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 75], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 47], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 39], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 16, 4000, 98], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 5, 1250, 28], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 3, 750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [8, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 8, 2000, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 4, 1000, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 4, 1000, 65], 'output': {'model1': 0, 'model2': 1}}, {'input': [3, 4, 1000, 29], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 12, 3000, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 22], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 71], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 5, 1250, 33], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 6, 1500, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 9, 2250, 52], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 4, 1000, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 19], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [6, 3, 750, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 6, 1500, 21], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 5, 1250, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 4, 1000, 64], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 11, 2750, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 10, 2500, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [3, 14, 3500, 35], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 82], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "591868a2-01be-4c54-98d7-ecec9a459545",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = [{'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 98], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 8, 2000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 10, 2500, 89], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 77], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 11, 2750, 28], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 6, 1500, 47], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 15, 3750, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [7, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 5, 1250, 60], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 11, 2750, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 14, 3500, 40], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 9, 2250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 3, 750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 8, 2000, 72], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 5, 1250, 51], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 22], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 3, 750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 93], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [1, 14, 3500, 95], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 3, 750, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 13, 3250, 59], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 3, 750, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 5, 1250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 45], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 7, 1750, 72], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 7, 1750, 32], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 4, 1000, 34], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 7, 1750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 7, 1750, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [39, 1, 250, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 8, 2000, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 48], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 8, 2000, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 8, 2000, 38], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [14, 4, 1000, 57], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [18, 8, 2000, 95], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 10], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [26, 5, 1250, 49], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 9, 2250, 89], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 6, 1500, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}]\n",
    "gen_data = [{'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 8, 2000, 40], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 9, 2250, 75], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 45], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 12, 3000, 34], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 3, 750, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 23, 5750, 58], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 3, 750, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 7, 1750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 8, 2000, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 15, 3750, 64], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 16, 4000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [5, 6, 1500, 28], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 3, 750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 75], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 6, 1500, 47], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 6, 1500, 39], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 16, 4000, 98], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 5, 1250, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [8, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 8, 2000, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [9, 4, 1000, 65], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 4, 1000, 29], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 12, 3000, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 5, 1250, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 12, 3000, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 22], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 6, 1500, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 6, 1500, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 9, 2250, 52], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 19], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [6, 3, 750, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 6, 1500, 21], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 5, 1250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}, {'input': [14, 4, 1000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [12, 11, 2750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 10, 2500, 28], 'output': {'model1': 0, 'model2': 1}}, {'input': [3, 14, 3500, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 7, 1750, 82], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 1, 'model2': 0}}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccc3aa08-b881-4383-a2d9-c033c984e150",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = [{'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 98], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 8, 2000, 64], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 3, 750, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 10, 2500, 89], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 6, 1500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 77], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 4, 1000, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 11, 2750, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 6, 1500, 47], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 15, 3750, 71], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 6, 1500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [7, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 28], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 5, 1250, 60], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 11, 2750, 23], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 14, 3500, 40], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 9, 2250, 33], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 58], 'output': {'model1': 1, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 3, 750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 8, 2000, 72], 'output': {'model1': 1, 'model2': 0}}, {'input': [9, 5, 1250, 51], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 22], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 3, 750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 93], 'output': {'model1': 1, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [1, 14, 3500, 95], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 3, 750, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 39], 'output': {'model1': 1, 'model2': 0}}, {'input': [12, 13, 3250, 59], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 3, 750, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 76], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 5, 1250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 45], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 7, 1750, 72], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 7, 1750, 32], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 4, 1000, 34], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 28], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 7, 1750, 58], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [39, 1, 250, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 8, 2000, 50], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 4, 1000, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 48], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 8, 2000, 38], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 8, 2000, 38], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 4, 1000, 57], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [18, 8, 2000, 95], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 10], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [26, 5, 1250, 49], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 9, 2250, 89], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 6, 1500, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}]\n",
    "gen_data = [{'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 8, 2000, 40], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 9, 2250, 75], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 45], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 12, 3000, 34], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 23, 5750, 58], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 3, 750, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 7, 1750, 76], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 8, 2000, 46], 'output': {'model1': 1, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 76], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 15, 3750, 64], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 16, 4000, 70], 'output': {'model1': 1, 'model2': 0}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [5, 6, 1500, 28], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 3, 750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 75], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 47], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 39], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 16, 4000, 98], 'output': {'model1': 1, 'model2': 0}}, {'input': [14, 5, 1250, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [8, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 8, 2000, 41], 'output': {'model1': 1, 'model2': 0}}, {'input': [16, 4, 1000, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 4, 1000, 65], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 4, 1000, 29], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 70], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 12, 3000, 50], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 2, 500, 22], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 6, 1500, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 6, 1500, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 9, 2250, 52], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 19], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [6, 3, 750, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 6, 1500, 21], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 5, 1250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 4, 1000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 11, 2750, 23], 'output': {'model1': 1, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 10, 2500, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [3, 14, 3500, 35], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 82], 'output': {'model1': 1, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 0}}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6cb9424-ba89-49e1-8ddd-274243038f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "verb_data = [{'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 70], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 12, 3000, 98], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 8, 2000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 3, 750, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 10, 2500, 89], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 41], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 77], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 4, 1000, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 11, 2750, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 6, 1500, 47], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 15, 3750, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 41], 'output': {'model1': 1, 'model2': 1}}, {'input': [7, 2, 500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 6, 1500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 5, 1250, 60], 'output': {'model1': 0, 'model2': 0}}, {'input': [21, 2, 500, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 6, 1500, 14], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 4, 1000, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 11, 2750, 23], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 14, 3500, 40], 'output': {'model1': 1, 'model2': 1}}, {'input': [11, 9, 2250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 2, 500, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 58], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 1, 250, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 3, 750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 8, 2000, 72], 'output': {'model1': 0, 'model2': 0}}, {'input': [9, 5, 1250, 51], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 3, 750, 22], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 3, 750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 93], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [1, 14, 3500, 95], 'output': {'model1': 1, 'model2': 1}}, {'input': [16, 3, 750, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 7, 1750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [12, 13, 3250, 59], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 3, 750, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 7, 1750, 76], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 5, 1250, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 45], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 7, 1750, 72], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 16], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 32], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 2, 500, 14], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 4, 1000, 34], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 7, 1750, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 7, 1750, 58], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [39, 1, 250, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 24], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 8, 2000, 50], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 4, 1000, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 3, 750, 48], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 11], 'output': {'model1': 0, 'model2': 1}}, {'input': [9, 8, 2000, 38], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 8, 2000, 38], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 4, 1000, 57], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 5, 1250, 58], 'output': {'model1': 0, 'model2': 0}}, {'input': [18, 8, 2000, 95], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 10], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 1}}, {'input': [26, 5, 1250, 49], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 9, 2250, 89], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 6, 1500, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [7, 5, 1250, 35], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}]\n",
    "gen_data = [{'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 11], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 8, 2000, 40], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 2, 500, 41], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 5, 1250, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 9, 2250, 75], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 6, 1500, 45], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 3, 750, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 12, 3000, 34], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 3, 750, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 4, 1000, 26], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 23, 5750, 58], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 3, 750, 24], 'output': {'model1': 0, 'model2': 1}}, {'input': [13, 7, 1750, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [23, 8, 2000, 46], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 8, 2000, 76], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 15, 3750, 64], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 1, 250, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 16, 4000, 70], 'output': {'model1': 1, 'model2': 1}}, {'input': [21, 3, 750, 38], 'output': {'model1': 0, 'model2': 0}}, {'input': [5, 6, 1500, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [23, 3, 750, 39], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 5, 1250, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 6, 1500, 75], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 3, 750, 38], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 6, 1500, 47], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 6, 1500, 39], 'output': {'model1': 1, 'model2': 1}}, {'input': [14, 16, 4000, 98], 'output': {'model1': 0, 'model2': 0}}, {'input': [14, 5, 1250, 28], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 3, 750, 28], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 33], 'output': {'model1': 0, 'model2': 1}}, {'input': [8, 2, 500, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 8, 2000, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 33], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 4, 1000, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 2, 500, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [9, 4, 1000, 65], 'output': {'model1': 0, 'model2': 1}}, {'input': [3, 4, 1000, 29], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 12, 3000, 70], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 5, 1250, 41], 'output': {'model1': 0, 'model2': 0}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 12, 3000, 50], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 22], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 6, 1500, 71], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 5, 1250, 33], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 6, 1500, 35], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 9, 2250, 52], 'output': {'model1': 1, 'model2': 1}}, {'input': [2, 4, 1000, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 3, 750, 19], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 2, 500, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [23, 1, 250, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [21, 2, 500, 26], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 14], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [6, 3, 750, 26], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 2, 500, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 2, 500, 70], 'output': {'model1': 0, 'model2': 0}}, {'input': [13, 3, 750, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [16, 3, 750, 21], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 2, 500, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [16, 1, 250, 16], 'output': {'model1': 0, 'model2': 0}}, {'input': [3, 6, 1500, 21], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 5, 1250, 23], 'output': {'model1': 0, 'model2': 1}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}, {'input': [14, 4, 1000, 64], 'output': {'model1': 0, 'model2': 0}}, {'input': [4, 2, 500, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [12, 11, 2750, 23], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 4, 1000, 16], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 10, 2500, 28], 'output': {'model1': 1, 'model2': 1}}, {'input': [3, 14, 3500, 35], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 7, 1750, 82], 'output': {'model1': 1, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [4, 1, 250, 4], 'output': {'model1': 0, 'model2': 1}}, {'input': [11, 1, 250, 11], 'output': {'model1': 0, 'model2': 0}}, {'input': [2, 1, 250, 2], 'output': {'model1': 0, 'model2': 1}}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fb92919-4141-4e1b-826f-36f0133de00d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "788ed8c5-08a6-44e7-81dc-3cddab421823",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5fc898-c4b5-4c1e-b61b-b0b5f413c2be",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea935b96-3e8f-4ecf-8f62-70de40a42956",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6697b78-feb2-47b3-9292-3f36a4803345",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87d16061-8c36-4344-a471-16ac8842266c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ea76c10-7c95-4881-866e-c2dbb2af555f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6d52a3-cc51-4ade-a19e-0f5dec85633c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce1d25cc-bf05-490b-83bb-1e7033d43d23",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75eb9019-fd80-40fa-be36-ec1e9a4a04f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
