{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# UCI Dataset: Log-Linear Model v.s. Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import sys\n",
    "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))\n",
    "\n",
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader, Subset\n",
    "\n",
    "from sklearn.neighbors import KernelDensity\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import cupy as cp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from ucimlrepo import fetch_ucirepo\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "import ld\n",
    "from utlis import vectorize_tensor, reconstruct_tensor\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_random_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "\n",
    "    cp.random.seed(seed)\n",
    "\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "set_random_seed(2)\n",
    "k = 10\n",
    "bandwidth = 0.1\n",
    "bandwidth_AE = 0.1\n",
    "eps = np.asarray(1.0e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "UCI_id = 186 # https://archive.ics.uci.edu/dataset/186/wine-quality\n",
    "test_size = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of features: (6497, 12)\n",
      "Shape of labels: (6497,)\n",
      "Unique labels: [0 1 2 3 4 5 6]\n",
      "Shape of features: (5197, 12)\n"
     ]
    }
   ],
   "source": [
    "UCI_dataset = fetch_ucirepo(id=UCI_id)\n",
    "\n",
    "X = np.array(UCI_dataset.data.features)\n",
    "Y = np.array(LabelEncoder().fit_transform(UCI_dataset.data.targets))\n",
    "\n",
    "# normalize to [0, 1]\n",
    "X = (X - X.min()) / (X.max() - X.min())\n",
    "\n",
    "# extend features to 12 dimensions\n",
    "X = np.pad(X, ((0, 0), (0, 12 - X.shape[1])), mode=\"constant\")\n",
    "\n",
    "\n",
    "# Print the shape of features and labels\n",
    "print(\"Shape of features:\", X.shape)\n",
    "print(\"Shape of labels:\", Y.shape)\n",
    "\n",
    "# Find unique labels\n",
    "unique_labels = np.unique(Y)\n",
    "print(\"Unique labels:\", unique_labels)\n",
    "\n",
    "# Create an array of indices\n",
    "indices = np.arange(len(Y))\n",
    "np.random.shuffle(indices)\n",
    "\n",
    "# Use the shuffled indices to randomly select data for training and testing\n",
    "n_train = int(len(Y) * (1 - test_size))\n",
    "X_train, Y_train = X[indices[:n_train]], Y[indices[:n_train]]\n",
    "X_test, Y_test = X[indices[n_train:]], Y[indices[n_train:]]\n",
    "\n",
    "print(\"Shape of features:\", X_train.shape)\n",
    "\n",
    "X_train_class = []\n",
    "Y_train_class = []\n",
    "for i in unique_labels:\n",
    "    X_train_class.append(X_train[np.isin(Y_train, i).flatten()])\n",
    "    Y_train_class.append(Y_train[np.isin(Y_train, i).flatten()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor Structure of the Feature: (2, 2, 3)\n",
      "Number of new samples per class: 118\n"
     ]
    }
   ],
   "source": [
    "# Feature dimension\n",
    "D = X_train.shape[1]\n",
    "S = (2, 2, 3)\n",
    "print(\"Tensor Structure of the Feature:\", S)\n",
    "\n",
    "if np.prod(S) != D:\n",
    "    raise ValueError(\"The product of the tensor structure is not equal to the feature dimension\")\n",
    "\n",
    "# 20% of the data of the training set\n",
    "num_new_samples = int((1 - test_size) * len(Y_train) * 0.2 // len(unique_labels))\n",
    "print(\"Number of new samples per class:\", num_new_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a custom Dataset class\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, data, labels, transform=None):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            data (numpy array): NumPy array of shape (num, 28*28)\n",
    "            labels (numpy array): Corresponding labels for each image\n",
    "            transform (callable, optional): Optional transform to be applied on a sample.\n",
    "        \"\"\"\n",
    "        self.data = data\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # Get the vector and label for a given index\n",
    "        vector = self.data[idx].astype(np.float32)  # No reshaping, as data is general vector signal\n",
    "        label = self.labels[idx]\n",
    "\n",
    "        if self.transform:\n",
    "            vector = self.transform(vector)\n",
    "\n",
    "        return vector, label\n",
    "\n",
    "train_data_original = np.array(X_train)\n",
    "labels = Y_train\n",
    "custom_train_dataset = CustomDataset(train_data_original, labels)\n",
    "train_loader_original = DataLoader(dataset=custom_train_dataset, batch_size=16, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Augmentation with Log-LInear Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Legendre Decomposition (Many-Body Approximation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 3)\n"
     ]
    }
   ],
   "source": [
    "B_LD = ld.default_B(S, 2, cp.get_array_module(X_train[0]))\n",
    "\n",
    "print(B_LD.shape)\n",
    "\n",
    "scaleX_class = []\n",
    "theta_class = []\n",
    "\n",
    "def LD_helper(i, class_):\n",
    "    _, _, scaleX, _, theta = ld.LD(X_train_class[class_][i].reshape(*S), B=B_LD, verbose=False, n_iter=1000, lr=1e-1)\n",
    "    return (scaleX, theta)\n",
    "\n",
    "results = Parallel(n_jobs=30)(delayed(LD_helper)(i, class_) for class_ in unique_labels for i in range(len(X_train_class[class_])))\n",
    "\n",
    "len_class = 0\n",
    "for class_ in unique_labels:\n",
    "    scaleX_list = []\n",
    "    theta_list = []\n",
    "\n",
    "    for i in range(len(X_train_class[class_])):\n",
    "        result = results[i + len_class]\n",
    "\n",
    "        scaleX_list.append(result[0])\n",
    "        theta_list.append(result[1])\n",
    "\n",
    "    len_class += len(X_train_class[class_])\n",
    "\n",
    "    scaleX_class.append(scaleX_list)\n",
    "    theta_class.append(theta_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fitting and Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_theta_class = []\n",
    "\n",
    "for class_ in unique_labels:\n",
    "    reduced_theta = vectorize_tensor(np.array(theta_class[class_]), B_LD)\n",
    "\n",
    "    # Fit a KDE to the theta values\n",
    "    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(reduced_theta)\n",
    "    # Sample new data from the KDE\n",
    "    sampled_reduced_theta = kde.sample(n_samples=num_new_samples)\n",
    "\n",
    "    sampled_theta = reconstruct_tensor(sampled_reduced_theta, (num_new_samples, *S), B_LD)\n",
    "\n",
    "    sampled_theta_class.append(sampled_theta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Construct Submanifold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 3)\n"
     ]
    }
   ],
   "source": [
    "# Construct the constrained coordinates\n",
    "B_BP = ld.default_B(S, 1, cp.get_array_module(X_train[0]))\n",
    "# B_BP = B_LD\n",
    "\n",
    "print(B_BP.shape)\n",
    "\n",
    "# Compute every datapoint's eta_hat (served as the linear constraints)\n",
    "eta_hat_class = []\n",
    "\n",
    "for class_ in unique_labels:\n",
    "    eta_hat_list = []\n",
    "    for i in range(X_train_class[class_].shape[0]):\n",
    "        xp = cp.get_array_module(X_train_class[class_][i])\n",
    "        P = (X_train_class[class_][i].reshape(*S) + eps) / scaleX_class[class_][i]\n",
    "        eta_hat = ld.get_eta(P, len(S), xp)\n",
    "        eta_hat_list.append(eta_hat)\n",
    "    eta_hat_list = cp.asarray(eta_hat_list)\n",
    "\n",
    "    eta_hat_class.append(eta_hat_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Backward Projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n",
      "Warning: Not Converged. Consider increasing the number of iterations.\n"
     ]
    }
   ],
   "source": [
    "def BP_helper(i, class_):\n",
    "    N = ld.kNN(sampled_theta_class[class_][i], np.array(theta_class[class_]), k=k)\n",
    "    avg_scale = np.mean(np.array(scaleX_class[class_])[N])\n",
    "    avg_eta_hat = np.mean(eta_hat_class[class_][N], axis=0)\n",
    "\n",
    "    _, _, P, theta = ld.BP(sampled_theta_class[class_][i], [(X_train_class[class_][j].reshape(*S) + eps) / scaleX_class[class_][j] for j in N], avg_eta_hat, avg_scale, B=B_BP, verbose=False, n_iter=1000, lr=5e-2, exit_abs=True)\n",
    "    X_recons_ = P.reshape(-1)\n",
    "    return (theta, X_recons_)\n",
    "\n",
    "results = Parallel(n_jobs=30)(delayed(BP_helper)(i, class_) for i in range(num_new_samples) for class_ in unique_labels)\n",
    "\n",
    "sampled_theta_BP_class = []\n",
    "X_recons_class = []\n",
    "\n",
    "for class_ in unique_labels:\n",
    "    sampled_theta_BP = []\n",
    "    sampled_X_recons = []\n",
    "    X_recons_list = []\n",
    "    for i in range(num_new_samples):\n",
    "        result = results[i + num_new_samples * class_]\n",
    "\n",
    "        sampled_theta_BP.append(result[0])\n",
    "        sampled_X_recons.append(result[1])\n",
    "\n",
    "    sampled_theta_BP_class.append(np.array(sampled_theta_BP))\n",
    "    X_recons_class.append(np.array(sampled_X_recons))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store Augmented Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmented_data_LD = []\n",
    "\n",
    "for class_ in unique_labels:\n",
    "    for i in range(num_new_samples):\n",
    "        augmented_data_LD.append(X_recons_class[class_][i])\n",
    "\n",
    "train_data_LD = np.array(augmented_data_LD)\n",
    "labels = np.repeat(unique_labels, num_new_samples)\n",
    "custom_train_dataset = CustomDataset(train_data_LD, labels)\n",
    "train_loader_LD = DataLoader(dataset=custom_train_dataset, batch_size=16, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Augmentation with Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_size=D, hidden_size=D//4, z_dim=B_LD.shape[0]):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        # self.fc2 = nn.Linear(hidden_size , hidden_size2)\n",
    "        self.fc3 = nn.Linear(hidden_size, z_dim)\n",
    "        self.relu = nn.ReLU()\n",
    "    def forward(self , x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        # x = self.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_size=D, hidden_size=D//4, z_dim=B_LD.shape[0]):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(z_dim , hidden_size)\n",
    "        # self.fc2 = nn.Linear(hidden_size , hidden_size)\n",
    "        self.fc3 = nn.Linear(hidden_size, output_size)\n",
    "        self.relu = nn.ReLU()\n",
    "    def forward(self , x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        # x = self.relu(self.fc2(x))\n",
    "        x = torch.sigmoid(self.fc3(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = Encoder().to(device)\n",
    "dec = Decoder().to(device)\n",
    "loss_fn = nn.MSELoss()\n",
    "optimizer_enc = torch.optim.Adam(enc.parameters())\n",
    "optimizer_dec = torch.optim.Adam(dec.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [05:47<00:00,  3.47s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f18f1c58310>]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApqklEQVR4nO3df3SU5Z338c9kBoYAk6lIk5lISOMW1mrEreAilJUfldSUpVXaPVathbPneEoFFsrpYoHdp6lPIRz3yNJ9WFllfSgeZcPp8cfaliKxlLAui0aQNUKP4mOUWBOzImTCDycmuZ4/yAwZEjCTzNwXcL1fp/eBue97Zi4uU+bDd64fPmOMEQAAgEdybDcAAAC4hfABAAA8RfgAAACeInwAAABPET4AAICnCB8AAMBThA8AAOApwgcAAPBUwHYDztXZ2akPPvhAoVBIPp/PdnMAAEAfGGPU2tqqwsJC5eRcuLZx0YWPDz74QEVFRbabAQAA+qGhoUGjRo264D0XXfgIhUKSzjQ+Ly/PcmsAAEBfxGIxFRUVJT/HL+SiCx+Jr1ry8vIIHwAAXGL6MmSCAacAAMBThA8AAOApwgcAAPAU4QMAAHiK8AEAADxF+AAAAJ4ifAAAAE8RPgAAgKcIHwAAwFOEDwAA4CnCBwAA8BThAwAAeGpA4aOyslI+n09LlixJnjPGqKKiQoWFhcrNzdW0adN08ODBgbZzwD7t6NSDvzqkiucPKt7eYbs5AAA4q9/ho7a2Vo899pjGjRuXcv6hhx7S2rVrtX79etXW1ioSiWjmzJlqbW0dcGMHwhjp//5nvX6x513F2zuttgUAAJf1K3ycOHFC99xzjzZu3Kgrrrgied4Yo3Xr1mnlypWaM2eOSktLtXnzZp06dUpbtmzJWKP7w59zdovfjg5jsSUAALitX+FjwYIFmjVrlm699daU8/X19WpqalJZWVnyXDAY1NSpU7Vnz55eXysejysWi6Uc2dAte6i9k/ABAIAtgXSfUFVVpf3796u2trbHtaamJklSQUFByvmCggK99957vb5eZWWlfvrTn6bbjLT5fD4Fcnxq7zTqIHwAAGBNWpWPhoYGLV68WE8++aSGDBly3vt8Pl/KY2NMj3MJy5cvV0tLS/JoaGhIp0lpSXz10mEIHwAA2JJW5WPfvn1qbm7W+PHjk+c6Ojq0e/durV+/Xm+++aakMxWQaDSavKe5ublHNSQhGAwqGAz2p+1pC+T4FBdjPgAAsCmtysdXv/pV1dXV6cCBA8ljwoQJuueee3TgwAFdffXVikQiqq6uTj6nra1NNTU1mjx5csYbn65E5aO9k9kuAADYklblIxQKqbS0NOXcsGHDdOWVVybPL1myRKtXr9aYMWM0ZswYrV69WkOHDtXdd9+duVb3U8B/Jmsx5gMAAHvSHnD6WZYtW6bTp0/r/vvv17FjxzRx4kTt2LFDoVAo02+VtrOVD8IHAAC2+Iy5uEZfxmIxhcNhtbS0KC8vL6OvPanyd2ps+US/XjRFpVeFM/raAAC4LJ3Pb6f2dqHyAQCAfU6Fj0Biqi0DTgEAsMap8JGsfDDVFgAAa5wKH4EcZrsAAGCbU+GDMR8AANjnZPig8gEAgD1Ohg8qHwAA2ONU+GC2CwAA9jkVPs5+7WK5IQAAOMyp8BHws7EcAAC2ORU+/Ey1BQDAOqfCR4ABpwAAWOdU+GCqLQAA9jkVPqh8AABgn1PhI1n5YLoLAADWOBU+qHwAAGCfU+GD2S4AANjnVPig8gEAgH1OhQ+/n9kuAADY5lb48FH5AADANrfCBxvLAQBgnVPhgzEfAADY51T4SIz56CR8AABgjVPhg8oHAAD2ORU+WOcDAAD7nAofVD4AALDPqfBxdm8XwgcAALY4FT6ofAAAYJ9T4YN1PgAAsM+p8EHlAwAA+5wKH34/s10AALAtrfCxYcMGjRs3Tnl5ecrLy9OkSZP029/+Nnl93rx58vl8KcfNN9+c8Ub3F5UPAADsC6Rz86hRo7RmzRp98YtflCRt3rxZ3/zmN/Xaa6/puuuukyTddttt2rRpU/I5gwcPzmBzByaxsRyVDwAA7EkrfMyePTvl8apVq7Rhwwbt3bs3GT6CwaAikUjmWphBfiofAABY1+8xHx0dHaqqqtLJkyc1adKk5Pldu3YpPz9fY8eO1X333afm5uYLvk48HlcsFks5siXgZ7YLAAC2pR0+6urqNHz4cAWDQc2fP1/PPvusrr32WklSeXm5nnrqKe3cuVMPP/ywamtrNWPGDMXj8fO+XmVlpcLhcPIoKirq/5/mM5ydakvlAwAAW3zGmLQ+idva2nTkyBEdP35cTz/9tP71X/9VNTU1yQDSXWNjo4qLi1VVVaU5c+b0+nrxeDwlnMRiMRUVFamlpUV5eXlp/nEubPsbjZr/5H7d9IUr9Mv5kzP62gAAuCwWiykcDvfp8zutMR/SmQGkiQGnEyZMUG1trX7+85/r0Ucf7XFvNBpVcXGxDh8+fN7XCwaDCgaD6TajXxIbyzHmAwAAewa8zocx5rxfqxw9elQNDQ2KRqMDfZuMCPC1CwAA1qVV+VixYoXKy8tVVFSk1tZWVVVVadeuXdq+fbtOnDihiooKfetb31I0GtW7776rFStWaOTIkbrjjjuy1f60JGe7sLEcAADWpBU+PvzwQ917771qbGxUOBzWuHHjtH37ds2cOVOnT59WXV2dnnjiCR0/flzRaFTTp0/X1q1bFQqFstX+tFD5AADAvrTCx+OPP37ea7m5uXrhhRcG3KBsOrvOB1NtAQCwxam9Xc6u80HlAwAAW5wKH8x2AQDAPqfCB2M+AACwz6nwwd4uAADY52T4oPIBAIA9ToaP9g5muwAAYItT4YMxHwAA2OdU+Eh+7ZLeXnoAACCDnAofga6ptlQ+AACwx6nwwWwXAADscyp8JMZ8GCN1EkAAALDCqfDh71peXaL6AQCALU6Fj0TlQ2LcBwAAtjgVPvw53SsfrPUBAIANToWPxGwXicoHAAC2OBU+uhU+GPMBAIAlToUPn8/HKqcAAFjmVPiQpBzW+gAAwCrnwkey8tFB+AAAwAbnwsfZVU6Z7QIAgA3OhQ/GfAAAYJdz4cOf2FyOnW0BALDCufCRqHy0M+YDAAArnAsffr52AQDAKufCR8DPVFsAAGxyLnxQ+QAAwC7nwkeAqbYAAFjlXPhIznah8gEAgBXOhY8Ay6sDAGCVc+HDz/LqAABY5Vz4oPIBAIBdaYWPDRs2aNy4ccrLy1NeXp4mTZqk3/72t8nrxhhVVFSosLBQubm5mjZtmg4ePJjxRg9EDrNdAACwKq3wMWrUKK1Zs0avvvqqXn31Vc2YMUPf/OY3kwHjoYce0tq1a7V+/XrV1tYqEolo5syZam1tzUrj+4PZLgAA2JVW+Jg9e7a+/vWva+zYsRo7dqxWrVql4cOHa+/evTLGaN26dVq5cqXmzJmj0tJSbd68WadOndKWLVuy1f60sc4HAAB29XvMR0dHh6qqqnTy5ElNmjRJ9fX1ampqUllZWfKeYDCoqVOnas+ePed9nXg8rlgslnJkE2M+AACwK+3wUVdXp+HDhysYDGr+/Pl69tlnde2116qpqUmSVFBQkHJ/QUFB8lpvKisrFQ6Hk0dRUVG6TUpLYp2PTsIHAABWpB0+/vRP/1QHDhzQ3r179YMf/EBz587VoUOHktd9Pl/K/caYHue6W758uVpaWpJHQ0NDuk1KC5UPAADsCqT7hMGDB+uLX/yiJGnChAmqra3Vz3/+cz3wwAOSpKamJkWj0eT9zc3NPaoh3QWDQQWDwXSb0W9+P2M+AACwacDrfBhjFI/HVVJSokgkourq6uS1trY21dTUaPLkyQN9m4yh8gEAgF1pVT5WrFih8vJyFRUVqbW1VVVVVdq1a5e2b98un8+nJUuWaPXq1RozZozGjBmj1atXa+jQobr77ruz1f60nZ3twlRbAABsSCt8fPjhh7r33nvV2NiocDiscePGafv27Zo5c6YkadmyZTp9+rTuv/9+HTt2TBMnTtSOHTsUCoWy0vj+oPIBAIBdaYWPxx9//ILXfT6fKioqVFFRMZA2ZVVyV1v2dgEAwAr2dgEAAJ5yLnywwikAAHY5Gz6ofAAAYIdz4SPAbBcAAKxyLnxQ+QAAwC7nwkeAMR8AAFjlXPhITrUlfAAAYIVz4SPA3i4AAFjlXPhgzAcAAHY5Fz4Y8wEAgF3OhQ8qHwAA2OVc+GCdDwAA7HIufCRmu7SzsRwAAFY4Fz4Y8wEAgF3OhQ/GfAAAYJdz4YN1PgAAsMu58JHjS1Q+GHAKAIANzoUPxnwAAGCXc+GDMR8AANjlXPhgzAcAAHY5Fz7Y1RYAALucCx+M+QAAwC7nwgdjPgAAsMu58EHlAwAAu5wLH2crH6zzAQCADc6Fj0BiwCkbywEAYIVz4YMxHwAA2OVc+GCdDwAA7HIufFD5AADALufCB7NdAACwy7nwwa62AADYlVb4qKys1E033aRQKKT8/HzdfvvtevPNN1PumTdvnnw+X8px8803Z7TRA8GYDwAA7EorfNTU1GjBggXau3evqqur1d7errKyMp08eTLlvttuu02NjY3JY9u2bRlt9EAw5gMAALsC6dy8ffv2lMebNm1Sfn6+9u3bp1tuuSV5PhgMKhKJZKaFGZZY58MYqbPTKKcrjAAAAG8MaMxHS0uLJGnEiBEp53ft2qX8/HyNHTtW9913n5qbm8/7GvF4XLFYLOXIJn+3sNFhqH4AAOC1focPY4yWLl2qKVOmqLS0NHm+vLxcTz31lHbu3KmHH35YtbW1mjFjhuLxeK+vU1lZqXA4nDyKior626Q+CXQPH3z1AgCA53zG9O+f/wsWLNBvfvMbvfTSSxo1atR572tsbFRxcbGqqqo0Z86cHtfj8XhKMInFYioqKlJLS4vy8vL607QL+uTTDl3z92e+Pnrjp1/T8GBa3zwBAIBexGIxhcPhPn1+9+uTd9GiRXr++ee1e/fuCwYPSYpGoyouLtbhw4d7vR4MBhUMBvvTjH5JqXywvwsAAJ5LK3wYY7Ro0SI9++yz2rVrl0pKSj7zOUePHlVDQ4Oi0Wi/G5lJ3cd8sNYHAADeS2vMx4IFC/Tkk09qy5YtCoVCampqUlNTk06fPi1JOnHihH70ox/pv/7rv/Tuu+9q165dmj17tkaOHKk77rgjK3+AdPl8vmQAYcwHAADeS6vysWHDBknStGnTUs5v2rRJ8+bNk9/vV11dnZ544gkdP35c0WhU06dP19atWxUKhTLW6IHy5/jU0WlY6wMAAAvS/trlQnJzc/XCCy8MqEFeCOT41CYqHwAA2ODc3i4Sq5wCAGCT0+GjgwGnAAB4zsnwEaDyAQCANU6Gj+TXLqzzAQCA55wMH4nN5RhwCgCA95wMH8kxH2wsBwCA55wMHwEWGQMAwBonwwdjPgAAsMfp8EHlAwAA7zkZPgL+xFRb1vkAAMBrToYPP7NdAACwxsnwwSJjAADY42T4YMwHAAD2OBk+qHwAAGCPk+GDjeUAALDH6fDBOh8AAHjPyfDBCqcAANjjZPjwM+YDAABrnAwf7GoLAIA9ToYPptoCAGCPk+GDMR8AANjjZPhgzAcAAPY4GT4SG8uxzgcAAN5zMnxQ+QAAwB4nwwezXQAAsMfJ8EHlAwAAe5wMH8x2AQDAHifDB3u7AABgj9Phg9kuAAB4z+nwwZgPAAC852T4YMwHAAD2OBk+/F1Tbal8AADgvbTCR2VlpW666SaFQiHl5+fr9ttv15tvvplyjzFGFRUVKiwsVG5urqZNm6aDBw9mtNEDReUDAAB70gofNTU1WrBggfbu3avq6mq1t7errKxMJ0+eTN7z0EMPae3atVq/fr1qa2sViUQ0c+ZMtba2Zrzx/cWutgAA2BNI5+bt27enPN60aZPy8/O1b98+3XLLLTLGaN26dVq5cqXmzJkjSdq8ebMKCgq0ZcsWff/7389cywfg7N4uhA8AALw2oDEfLS0tkqQRI0ZIkurr69XU1KSysrLkPcFgUFOnTtWePXt6fY14PK5YLJZyZNvZ2S5MtQUAwGv9Dh/GGC1dulRTpkxRaWmpJKmpqUmSVFBQkHJvQUFB8tq5KisrFQ6Hk0dRUVF/m9RnjPkAAMCefoePhQsX6vXXX9e//du/9bjm8/lSHhtjepxLWL58uVpaWpJHQ0NDf5vUZ8x2AQDAnrTGfCQsWrRIzz//vHbv3q1Ro0Ylz0ciEUlnKiDRaDR5vrm5uUc1JCEYDCoYDPanGf1G5QMAAHvSqnwYY7Rw4UI988wz2rlzp0pKSlKul5SUKBKJqLq6Onmura1NNTU1mjx5cmZanAHs7QIAgD1pVT4WLFigLVu26N///d8VCoWS4zjC4bByc3Pl8/m0ZMkSrV69WmPGjNGYMWO0evVqDR06VHfffXdW/gD9QeUDAAB70gofGzZskCRNmzYt5fymTZs0b948SdKyZct0+vRp3X///Tp27JgmTpyoHTt2KBQKZaTBmcBsFwAA7EkrfBjz2ZUCn8+niooKVVRU9LdNWcciYwAA2OPo3i7sagsAgC1Oho9A11RbKh8AAHjPyfBB5QMAAHucDB/s7QIAgD1Ohg8GnAIAYI+T4YN1PgAAsMfJ8ME6HwAA2ONk+GC2CwAA9jgZPpjtAgCAPU6Gj+SYDzaWAwDAc06GDyofAADY42T4YJ0PAADscTJ8+H3MdgEAwBY3w0fX1y6dRuqk+gEAgKecDB+JqbaS1GEIHwAAeMnJ8OHvGvMhMe4DAACvORk+ElNtJWa8AADgNSfDhz+HygcAALa4GT58hA8AAGxxMnzk5PiUKH4w3RYAAG85GT4kNpcDAMAWZ8NHcol19ncBAMBTzoaP5OZyVD4AAPCUs+EjsdYHU20BAPCWs+GDygcAAHY4Gz6SYz6Y7QIAgKfcDR8+Kh8AANjgbvhgzAcAAFY4Gz5Y5wMAADucDR+s8wEAgB3Ohg9muwAAYEfa4WP37t2aPXu2CgsL5fP59Nxzz6Vcnzdvnnw+X8px8803Z6q9GZOofHQYwgcAAF5KO3ycPHlSN9xwg9avX3/ee2677TY1NjYmj23btg2okdlwtvLBVFsAALwUSPcJ5eXlKi8vv+A9wWBQkUik343yAmM+AACwIytjPnbt2qX8/HyNHTtW9913n5qbm897bzweVywWSzm8wGwXAADsyHj4KC8v11NPPaWdO3fq4YcfVm1trWbMmKF4PN7r/ZWVlQqHw8mjqKgo003q1dkVTgkfAAB4Ke2vXT7LnXfemfx9aWmpJkyYoOLiYv3mN7/RnDlzety/fPlyLV26NPk4Fot5EkACfma7AABgQ8bDx7mi0aiKi4t1+PDhXq8Hg0EFg8FsN6MHKh8AANiR9XU+jh49qoaGBkWj0Wy/VVqY7QIAgB1pVz5OnDiht99+O/m4vr5eBw4c0IgRIzRixAhVVFToW9/6lqLRqN59912tWLFCI0eO1B133JHRhg8UlQ8AAOxIO3y8+uqrmj59evJxYrzG3LlztWHDBtXV1emJJ57Q8ePHFY1GNX36dG3dulWhUChzrc4APyucAgBgRdrhY9q0aTIXWBX0hRdeGFCDvOLvmmrLOh8AAHiLvV2ofAAA4ClnwwdjPgAAsMPZ8MFsFwAA7HA2fJwdcGq5IQAAOMbZ8EHlAwAAO5wNH8nZLoz5AADAU86GD/Z2AQDADmfDB7NdAACww9nwwTofAADY4Wz4OFv5YMApAABecjZ8UPkAAMAOZ8NHTqLywd4uAAB4ytnwQeUDAAA7nA0frPMBAIAdzoYPKh8AANjhbPhgtgsAAHY4Gz4CbCwHAIAVzoYPPxvLAQBghbPhI7G3CwNOAQDwlrPhIzHbhQGnAAB4y9nwEWBjOQAArHA2fPiZagsAgBXOhg8qHwAA2OFs+GC2CwAAdjgbPgKJ5dXZWA4AAE85Gz66sgdjPgAA8Jiz4SPAVFsAAKxwNnz4GXAKAIAVzoYPdrUFAMAOZ8MHu9oCAGCHs+EjsbcLu9oCAOCttMPH7t27NXv2bBUWFsrn8+m5555LuW6MUUVFhQoLC5Wbm6tp06bp4MGDmWpvxgRY5wMAACvSDh8nT57UDTfcoPXr1/d6/aGHHtLatWu1fv161dbWKhKJaObMmWptbR1wYzMpsbEcA04BAPBWIN0nlJeXq7y8vNdrxhitW7dOK1eu1Jw5cyRJmzdvVkFBgbZs2aLvf//7A2ttBjHgFAAAOzI65qO+vl5NTU0qKytLngsGg5o6dar27NnT63Pi8bhisVjK4QWm2gIAYEdGw0dTU5MkqaCgIOV8QUFB8tq5KisrFQ6Hk0dRUVEmm3ReVD4AALAjK7NdfD5fymNjTI9zCcuXL1dLS0vyaGhoyEaTevB3Cx/GEEAAAPBK2mM+LiQSiUg6UwGJRqPJ883NzT2qIQnBYFDBYDCTzeiTxPLq0pkAkph6CwAAsiujlY+SkhJFIhFVV1cnz7W1tammpkaTJ0/O5FsNWLfswbgPAAA8lHbl48SJE3r77beTj+vr63XgwAGNGDFCo0eP1pIlS7R69WqNGTNGY8aM0erVqzV06FDdfffdGW34QJ1b+QAAAN5IO3y8+uqrmj59evLx0qVLJUlz587VL37xCy1btkynT5/W/fffr2PHjmnixInasWOHQqFQ5lqdAYkxHxKVDwAAvOQzF9loy1gspnA4rJaWFuXl5WXtfTo7ja5esU2StP/vZ2rEsMFZey8AAC536Xx+O7u3S06OT4kJOGwuBwCAd5wNHxJrfQAAYIPT4cNP+AAAwHNOh4/EjBfCBwAA3nE6fLC/CwAA3nM6fDDmAwAA7zkdPpKVjw7CBwAAXnE6fFD5AADAe06HD78/MeaDdT4AAPCK0+GD2S4AAHjP6fCRk1zhlPABAIBXnA4fVD4AAPCe0+GDdT4AAPCe0+Ej4E/MdmHAKQAAXnE6fLDOBwAA3nM6fLDOBwAA3nM6fCR3tTWEDwAAvOJ0+GC2CwAA3nM6fDDmAwAA7zkdPhjzAQCA95wOH6zzAQCA95wOH6zzAQCA95wOH/6uAadUPgAA8I7b4aNrYznGfAAA4B23wweVDwAAPOd0+GC2CwAA3nM6fPj9rPMBAIDXnA4fZysfzHYBAMArTocP1vkAAMB7ToePABvLAQDgOafDR2K2SwdjPgAA8EzGw0dFRYV8Pl/KEYlEMv02GRHgaxcAADwXyMaLXnfddXrxxReTj/1+fzbeZsD8TLUFAMBzWQkfgUDgoq12dEflAwAA72VlzMfhw4dVWFiokpISfec739E777xz3nvj8bhisVjK4ZXBgTN//NNt7Z69JwAArst4+Jg4caKeeOIJvfDCC9q4caOampo0efJkHT16tNf7KysrFQ6Hk0dRUVGmm3RehZ/LlSS9f+y0Z+8JAIDrfMZkd57pyZMn9Sd/8idatmyZli5d2uN6PB5XPB5PPo7FYioqKlJLS4vy8vKy2TS9/v5xfWP9fyo/FNQrK2/N6nsBAHA5i8ViCofDffr8zsqYj+6GDRum66+/XocPH+71ejAYVDAYzHYzejV6xFBJUnNrXKfbOpQ7+OIcGAsAwOUk6+t8xONx/eEPf1A0Gs32W6UtnDtIoSFn8lfDsVOWWwMAgBsyHj5+9KMfqaamRvX19Xr55Zf17W9/W7FYTHPnzs30Ww2Yz+dLVj+OHCV8AADghYx/7fL+++/rrrvu0kcffaTPf/7zuvnmm7V3714VFxdn+q0yovjKoTr4QUxHPiZ8AADghYyHj6qqqky/ZFYVJSofhA8AADzh9N4u0tlBp4QPAAC8QfggfAAA4CnCR1f4aPj4lDpZZh0AgKxzPnwUfi5X/hyf4u2d+p8T8c9+AgAAGBDnw8cgf44KPzdEEl+9AADgBefDh3T2q5f3WOsDAICsI3yIQacAAHiJ8CFp9Ihhks4MOgUAANlF+BCVDwAAvET4EGM+AADwEuFDZ8PHRyfiOtXWbrk1AABc3ggfksJDBymcO0iS1PDxacutAQDg8kb46MK4DwAAvEH46EL4AADAG4SPLkWJ8HH0pOWWAABweSN8dCm+ksoHAABeIHx04WsXAAC8QfjokggfDcdOq7PTWG4NAACXL8JHl2h4iPw5PrW1d+rD1k9sNwcAgMsW4aNLwJ+jqz6XK0k6wkqnAABkDeGjGwadAgCQfYSPbhLTbdndFgCA7CF8dJPcYI7wAQBA1hA+umG6LQAA2Uf46GY0X7sAAJB1hI9uRncNOP3oRJtaP/nUcmsAALg8ET66yRsySJG8IZKkOx/dqzebWi23CACAyw/h4xwPfXucrhg6SIcaY5r9f17Sxt3vsOIpAAAZRPg4xy1jP68XfniLZlyTr7aOTq3a9gfdtXGv9vy/j3Sqrd128wAAuOT5jDEX1T/rY7GYwuGwWlpalJeXZ60dxhhV1Tbof//6kE61dUiS/Dk+XRMJ6cbRV+j6q8KKhIcoPy+o/NAQXTF0kHw+n7X2AgBgUzqf34SPz/De0ZP6x+q39HL9x2psOf+eL4P8Pg0PBhQM+DVkUI6CAb8GB3IU8PsUyPHJn+NTICdHOTk+5fikHN+ZX32JX+VTTs6ZX7v+lwwzZ35/9pyv62TX75IS2ceX8vve70n9va/H88/VW6660Gtn28Ue8z7r/1SJ9p8vsJ7v/5bd7zfGXPB90umj3l7n3Of35Z50nfvnv9BfR73dm6k//0AM9L/1gN+/D3+F96WfLtS+7u+RyT5Pt0/S/bj6rJ/Zgfw3ydRH5/namMmfl97a6s/J0f+afW3G3kO6SMLHI488on/4h39QY2OjrrvuOq1bt05/8Rd/8ZnPu9jCR3eNLae1/73j2n/kmN76sFX/0xrXh7FPdOwUM2MAAJeOwYEcvfWz8oy+Zjqf34GMvnOXrVu3asmSJXrkkUf0la98RY8++qjKy8t16NAhjR49Ohtv6YloOFezxuVq1rhoyvm29k79z4m4TsXbFW/v1Cefdije3ql4e4c6OqX2jk61dxq1d3aqs1PqNEbGnPm100hGZ35V4nHXv+iMUdevZ/Kh6bo38XvpbGo++9jofHEy5V8v5zz/3PMp5865q/d7+smY9EsmffyXXrb+5Xu+d0/3/Xp7ne7dcb7Kw4Xu6X7vZ7Wnxz3d/zskfub68Drp3Hfuc3rT2+sk/v9w5vdGPvk+sw+ypdc/6/l+hrv143ku9ZCpf/D29WXO/Tsk9ZpJVjgT1dd0XjMb0v557OXnOnFPuvrz19WFnO9nfSDv19fn+HPs1o+zUvmYOHGibrzxRm3YsCF57ktf+pJuv/12VVZWXvC5F3PlAwAA9C6dz++Mz3Zpa2vTvn37VFZWlnK+rKxMe/bsyfTbAQCAS0zGv3b56KOP1NHRoYKCgpTzBQUFampq6nF/PB5XPB5PPo7FYpluEgAAuIhkbZ2P3kam9zZ6t7KyUuFwOHkUFRVlq0kAAOAikPHwMXLkSPn9/h5Vjubm5h7VEElavny5WlpakkdDQ0OmmwQAAC4iGQ8fgwcP1vjx41VdXZ1yvrq6WpMnT+5xfzAYVF5eXsoBAAAuX1mZart06VLde++9mjBhgiZNmqTHHntMR44c0fz587PxdgAA4BKSlfBx55136ujRo3rwwQfV2Nio0tJSbdu2TcXFxdl4OwAAcAlheXUAADBgVtf5AAAAuBDCBwAA8BThAwAAeIrwAQAAPEX4AAAAnsrKVNuBSEy+YY8XAAAuHYnP7b5Mor3owkdra6sksccLAACXoNbWVoXD4Qvec9Gt89HZ2akPPvhAoVCo143oBiIWi6moqEgNDQ2sIZJl9LV36Gvv0Nfeoa+9k6m+NsaotbVVhYWFysm58KiOi67ykZOTo1GjRmX1PdhDxjv0tXfoa+/Q196hr72Tib7+rIpHAgNOAQCApwgfAADAU06Fj2AwqJ/85CcKBoO2m3LZo6+9Q197h772Dn3tHRt9fdENOAUAAJc3pyofAADAPsIHAADwFOEDAAB4ivABAAA85Uz4eOSRR1RSUqIhQ4Zo/Pjx+o//+A/bTbrkVVZW6qabblIoFFJ+fr5uv/12vfnmmyn3GGNUUVGhwsJC5ebmatq0aTp48KClFl8+Kisr5fP5tGTJkuQ5+jpz/vjHP+q73/2urrzySg0dOlR/9md/pn379iWv09eZ097err/7u79TSUmJcnNzdfXVV+vBBx9UZ2dn8h76u392796t2bNnq7CwUD6fT88991zK9b70azwe16JFizRy5EgNGzZM3/jGN/T+++8PvHHGAVVVVWbQoEFm48aN5tChQ2bx4sVm2LBh5r333rPdtEva1772NbNp0ybzxhtvmAMHDphZs2aZ0aNHmxMnTiTvWbNmjQmFQubpp582dXV15s477zTRaNTEYjGLLb+0vfLKK+YLX/iCGTdunFm8eHHyPH2dGR9//LEpLi428+bNMy+//LKpr683L774onn77beT99DXmfOzn/3MXHnllebXv/61qa+vN7/85S/N8OHDzbp165L30N/9s23bNrNy5Urz9NNPG0nm2WefTbnel36dP3++ueqqq0x1dbXZv3+/mT59urnhhhtMe3v7gNrmRPj48z//czN//vyUc9dcc4358Y9/bKlFl6fm5mYjydTU1BhjjOns7DSRSMSsWbMmec8nn3xiwuGw+Zd/+Rdbzbyktba2mjFjxpjq6mozderUZPigrzPngQceMFOmTDnvdfo6s2bNmmX++q//OuXcnDlzzHe/+11jDP2dKeeGj7706/Hjx82gQYNMVVVV8p4//vGPJicnx2zfvn1A7bnsv3Zpa2vTvn37VFZWlnK+rKxMe/bssdSqy1NLS4skacSIEZKk+vp6NTU1pfR9MBjU1KlT6ft+WrBggWbNmqVbb7015Tx9nTnPP/+8JkyYoL/6q79Sfn6+vvzlL2vjxo3J6/R1Zk2ZMkW/+93v9NZbb0mS/vu//1svvfSSvv71r0uiv7OlL/26b98+ffrppyn3FBYWqrS0dMB9f9FtLJdpH330kTo6OlRQUJByvqCgQE1NTZZadfkxxmjp0qWaMmWKSktLJSnZv731/Xvvved5Gy91VVVV2r9/v2pra3tco68z55133tGGDRu0dOlSrVixQq+88or+5m/+RsFgUN/73vfo6wx74IEH1NLSomuuuUZ+v18dHR1atWqV7rrrLkn8bGdLX/q1qalJgwcP1hVXXNHjnoF+fl724SPB5/OlPDbG9DiH/lu4cKFef/11vfTSSz2u0fcD19DQoMWLF2vHjh0aMmTIee+jrweus7NTEyZM0OrVqyVJX/7yl3Xw4EFt2LBB3/ve95L30deZsXXrVj355JPasmWLrrvuOh04cEBLlixRYWGh5s6dm7yP/s6O/vRrJvr+sv/aZeTIkfL7/T1SWnNzc4/Eh/5ZtGiRnn/+ef3+97/XqFGjkucjkYgk0fcZsG/fPjU3N2v8+PEKBAIKBAKqqanRP/3TPykQCCT7k74euGg0qmuvvTbl3Je+9CUdOXJEEj/Xmfa3f/u3+vGPf6zvfOc7uv7663Xvvffqhz/8oSorKyXR39nSl36NRCJqa2vTsWPHzntPf1324WPw4MEaP368qqurU85XV1dr8uTJllp1eTDGaOHChXrmmWe0c+dOlZSUpFwvKSlRJBJJ6fu2tjbV1NTQ92n66le/qrq6Oh04cCB5TJgwQffcc48OHDigq6++mr7OkK985Ss9poy/9dZbKi4ulsTPdaadOnVKOTmpH0V+vz851Zb+zo6+9Ov48eM1aNCglHsaGxv1xhtvDLzvBzRc9RKRmGr7+OOPm0OHDpklS5aYYcOGmXfffdd20y5pP/jBD0w4HDa7du0yjY2NyePUqVPJe9asWWPC4bB55plnTF1dnbnrrruYIpch3We7GENfZ8orr7xiAoGAWbVqlTl8+LB56qmnzNChQ82TTz6ZvIe+zpy5c+eaq666KjnV9plnnjEjR440y5YtS95Df/dPa2uree2118xrr71mJJm1a9ea1157LbnMRF/6df78+WbUqFHmxRdfNPv37zczZsxgqm06/vmf/9kUFxebwYMHmxtvvDE5HRT9J6nXY9OmTcl7Ojs7zU9+8hMTiURMMBg0t9xyi6mrq7PX6MvIueGDvs6cX/3qV6a0tNQEg0FzzTXXmMceeyzlOn2dObFYzCxevNiMHj3aDBkyxFx99dVm5cqVJh6PJ++hv/vn97//fa9/R8+dO9cY07d+PX36tFm4cKEZMWKEyc3NNX/5l39pjhw5MuC2+YwxZmC1EwAAgL677Md8AACAiwvhAwAAeIrwAQAAPEX4AAAAniJ8AAAATxE+AACApwgfAADAU4QPAADgKcIHAADwFOEDAAB4ivABAAA8RfgAAACe+v8/EIyaWhrq4QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_loss = []\n",
    "num_epochs = 100\n",
    "\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    train_epoch_loss = 0\n",
    "    for (x , _) in train_loader_original:\n",
    "        x = x.to(device)\n",
    "        #100 , 1 , 28 , 28 ---> (100 , 28*28)\n",
    "        x = x.flatten(1)\n",
    "        latents = enc(x)\n",
    "        output = dec(latents)\n",
    "        loss = loss_fn(output , x)\n",
    "        train_epoch_loss += loss.cpu().detach().numpy()\n",
    "        optimizer_enc.zero_grad()\n",
    "        optimizer_dec.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_enc.step()\n",
    "        optimizer_dec.step()\n",
    "    train_loss.append(train_epoch_loss)\n",
    "plt.plot(train_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "representation = None\n",
    "all_labels = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for (xs , labels) in train_loader_original:\n",
    "        xs = xs.to(device)\n",
    "        xs = xs.flatten(1)\n",
    "        all_labels.extend(list(labels.numpy()))\n",
    "        latents = enc(xs)\n",
    "        if representation is None:\n",
    "            representation = latents.cpu()\n",
    "        else:\n",
    "            representation = torch.vstack([representation , latents.cpu()])\n",
    "\n",
    "all_labels = np.array(all_labels)\n",
    "representation = representation.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_AE_class = []\n",
    "\n",
    "for class_ in unique_labels:\n",
    "    sampled_AE_list = []\n",
    "\n",
    "    rep = representation[np.argwhere(all_labels == class_)].squeeze()\n",
    "    # Fit a KDE to the theta values\n",
    "    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth_AE).fit(rep)\n",
    "\n",
    "    # Sample new data from the KDE\n",
    "    sampled_rep = kde.sample(n_samples=num_new_samples)\n",
    "    for i in range(num_new_samples):\n",
    "        pred = dec(torch.Tensor(sampled_rep[i])[None , ...].to(device)).cpu().detach().numpy()\n",
    "        sampled_AE_list.append(pred)\n",
    "\n",
    "    sampled_AE_class.append(sampled_AE_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store Augmented Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmented_data_AE = []\n",
    "\n",
    "for class_ in unique_labels:\n",
    "    for i in range(num_new_samples):\n",
    "        augmented_data_AE.append(sampled_AE_class[class_][i].flatten())\n",
    "\n",
    "train_data_AE = np.array(augmented_data_AE)\n",
    "labels = np.repeat(unique_labels, num_new_samples)\n",
    "\n",
    "custom_train_dataset = CustomDataset(train_data_AE, labels)\n",
    "train_loader_AE = DataLoader(dataset=custom_train_dataset, batch_size=16, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classification Performance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logistic Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegressionModel(nn.Module):\n",
    "    def __init__(self, input_size, num_classes=10):\n",
    "        super(LogisticRegressionModel, self).__init__()\n",
    "        self.linear = nn.Linear(input_size, num_classes)  # Linear layer for general input size\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.linear(x)   # Apply linear transformation\n",
    "        return out\n",
    "\n",
    "# Function to train the model\n",
    "def train_model(model, train_loader, criterion, optimizer, num_epochs=5, device='cpu'):\n",
    "    model.to(device)\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()  # Set model to training mode\n",
    "        running_loss = 0.0\n",
    "\n",
    "        for vectors, labels in train_loader:\n",
    "            vectors, labels = vectors.to(device), labels.to(device)\n",
    "\n",
    "            # Zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Forward pass\n",
    "            outputs = model(vectors)\n",
    "            loss = criterion(outputs, labels)\n",
    "\n",
    "            # Backward pass and optimization\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}\")\n",
    "\n",
    "# Function to test the model\n",
    "def test_model(model, test_loader, device='cpu'):\n",
    "    model.eval()  # Set model to evaluation mode\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for vectors, labels in test_loader:\n",
    "            vectors, labels = vectors.to(device), labels.to(device)\n",
    "            outputs = model(vectors)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    accuracy = 100 * correct / total\n",
    "    return accuracy\n",
    "\n",
    "# Main function to run training and testing on a dataset\n",
    "def bootstrapping(train_loader, test_dataset, input_size, num_classes=10, num_epochs=5, learning_rate=0.05, device='cpu'):\n",
    "    # Initialize model, loss function, and optimizer\n",
    "    model = LogisticRegressionModel(input_size=input_size, num_classes=num_classes).to(device)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
    "\n",
    "    # Train the model\n",
    "    train_model(model, train_loader, criterion, optimizer, num_epochs=num_epochs, device=device)\n",
    "\n",
    "    # Number of bootstrap resamples\n",
    "    n_bootstrap = 20\n",
    "    accuracies = []\n",
    "\n",
    "    # Perform bootstrapping\n",
    "    for i in range(n_bootstrap):\n",
    "        # Randomly sample 500 examples from the dataset with replacement\n",
    "        indices = torch.randint(len(test_dataset), size=(50,))\n",
    "        bootstrap_subset = Subset(test_dataset, indices)\n",
    "        bootstrap_loader = DataLoader(dataset=bootstrap_subset, batch_size=50, shuffle=False)\n",
    "\n",
    "        # Calculate accuracy on the bootstrap sample\n",
    "        accuracy = test_model(model, bootstrap_loader, device=device)\n",
    "        accuracies.append(accuracy)\n",
    "\n",
    "    # Compute the mean and standard deviation of accuracy\n",
    "    mean_accuracy = np.mean(accuracies)\n",
    "    std_accuracy = np.std(accuracies)\n",
    "\n",
    "    # Calculate 95% confidence interval (mean ± 1.96 * standard deviation)\n",
    "    confidence_interval = (mean_accuracy - 1.96 * std_accuracy, mean_accuracy + 1.96 * std_accuracy)\n",
    "\n",
    "    print(f\"Mean accuracy: {mean_accuracy:.2f}%\")\n",
    "    print(f\"95% confidence interval: ({confidence_interval[0]:.2f}%, {confidence_interval[1]:.2f}%)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test dataset size: 1300\n"
     ]
    }
   ],
   "source": [
    "test_dataset = CustomDataset(X_test, Y_test)\n",
    "print(\"Test dataset size:\", len(test_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Original Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Loss: 1.4315\n",
      "Epoch [2/20], Loss: 1.3012\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [3/20], Loss: 1.2870\n",
      "Epoch [4/20], Loss: 1.2822\n",
      "Epoch [5/20], Loss: 1.2789\n",
      "Epoch [6/20], Loss: 1.2776\n",
      "Epoch [7/20], Loss: 1.2760\n",
      "Epoch [8/20], Loss: 1.2757\n",
      "Epoch [9/20], Loss: 1.2740\n",
      "Epoch [10/20], Loss: 1.2742\n",
      "Epoch [11/20], Loss: 1.2731\n",
      "Epoch [12/20], Loss: 1.2731\n",
      "Epoch [13/20], Loss: 1.2729\n",
      "Epoch [14/20], Loss: 1.2722\n",
      "Epoch [15/20], Loss: 1.2722\n",
      "Epoch [16/20], Loss: 1.2720\n",
      "Epoch [17/20], Loss: 1.2717\n",
      "Epoch [18/20], Loss: 1.2718\n",
      "Epoch [19/20], Loss: 1.2714\n",
      "Epoch [20/20], Loss: 1.2708\n",
      "Mean accuracy: 42.40%\n",
      "95% confidence interval: (27.44%, 57.36%)\n"
     ]
    }
   ],
   "source": [
    "bootstrapping(train_loader_original, test_dataset, input_size=X_train.shape[1], num_classes=len(unique_labels), num_epochs=20, learning_rate=0.05, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Augmented Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Log-Linear Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Loss: 1.9563\n",
      "Epoch [2/20], Loss: 1.9513\n",
      "Epoch [3/20], Loss: 1.9488\n",
      "Epoch [4/20], Loss: 1.9477\n",
      "Epoch [5/20], Loss: 1.9471\n",
      "Epoch [6/20], Loss: 1.9469\n",
      "Epoch [7/20], Loss: 1.9466\n",
      "Epoch [8/20], Loss: 1.9467\n",
      "Epoch [9/20], Loss: 1.9466\n",
      "Epoch [10/20], Loss: 1.9462\n",
      "Epoch [11/20], Loss: 1.9463\n",
      "Epoch [12/20], Loss: 1.9463\n",
      "Epoch [13/20], Loss: 1.9463\n",
      "Epoch [14/20], Loss: 1.9462\n",
      "Epoch [15/20], Loss: 1.9460\n",
      "Epoch [16/20], Loss: 1.9460\n",
      "Epoch [17/20], Loss: 1.9460\n",
      "Epoch [18/20], Loss: 1.9458\n",
      "Epoch [19/20], Loss: 1.9460\n",
      "Epoch [20/20], Loss: 1.9459\n",
      "Mean accuracy: 21.80%\n",
      "95% confidence interval: (9.85%, 33.75%)\n"
     ]
    }
   ],
   "source": [
    "bootstrapping(train_loader_LD, test_dataset, input_size=X_train.shape[1], num_classes=len(unique_labels), num_epochs=20, learning_rate=0.05, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Loss: 1.9533\n",
      "Epoch [2/20], Loss: 1.9499\n",
      "Epoch [3/20], Loss: 1.9485\n",
      "Epoch [4/20], Loss: 1.9478\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/20], Loss: 1.9473\n",
      "Epoch [6/20], Loss: 1.9476\n",
      "Epoch [7/20], Loss: 1.9477\n",
      "Epoch [8/20], Loss: 1.9473\n",
      "Epoch [9/20], Loss: 1.9472\n",
      "Epoch [10/20], Loss: 1.9473\n",
      "Epoch [11/20], Loss: 1.9473\n",
      "Epoch [12/20], Loss: 1.9474\n",
      "Epoch [13/20], Loss: 1.9473\n",
      "Epoch [14/20], Loss: 1.9473\n",
      "Epoch [15/20], Loss: 1.9474\n",
      "Epoch [16/20], Loss: 1.9474\n",
      "Epoch [17/20], Loss: 1.9472\n",
      "Epoch [18/20], Loss: 1.9474\n",
      "Epoch [19/20], Loss: 1.9474\n",
      "Epoch [20/20], Loss: 1.9475\n",
      "Mean accuracy: 20.80%\n",
      "95% confidence interval: (11.01%, 30.59%)\n"
     ]
    }
   ],
   "source": [
    "bootstrapping(train_loader_AE, test_dataset, input_size=X_train.shape[1], num_classes=len(unique_labels), num_epochs=20, learning_rate=0.05, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Original and Augmented Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Log-Linear Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Loss: 1.5604\n",
      "Epoch [2/20], Loss: 1.4818\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [3/20], Loss: 1.4782\n",
      "Epoch [4/20], Loss: 1.4762\n",
      "Epoch [5/20], Loss: 1.4764\n",
      "Epoch [6/20], Loss: 1.4763\n",
      "Epoch [7/20], Loss: 1.4762\n",
      "Epoch [8/20], Loss: 1.4766\n",
      "Epoch [9/20], Loss: 1.4757\n",
      "Epoch [10/20], Loss: 1.4765\n",
      "Epoch [11/20], Loss: 1.4758\n",
      "Epoch [12/20], Loss: 1.4750\n",
      "Epoch [13/20], Loss: 1.4762\n",
      "Epoch [14/20], Loss: 1.4748\n",
      "Epoch [15/20], Loss: 1.4747\n",
      "Epoch [16/20], Loss: 1.4752\n",
      "Epoch [17/20], Loss: 1.4754\n",
      "Epoch [18/20], Loss: 1.4744\n",
      "Epoch [19/20], Loss: 1.4755\n",
      "Epoch [20/20], Loss: 1.4747\n",
      "Mean accuracy: 43.00%\n",
      "95% confidence interval: (31.81%, 54.19%)\n"
     ]
    }
   ],
   "source": [
    "train_data = np.vstack([train_data_original, train_data_LD])\n",
    "labels = np.hstack([Y_train, np.repeat(unique_labels, num_new_samples)])\n",
    "custom_train_dataset = CustomDataset(train_data, labels)\n",
    "\n",
    "train_loader = DataLoader(dataset=custom_train_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "bootstrapping(train_loader, test_dataset, input_size=X_train.shape[1], num_classes=len(unique_labels), num_epochs=20, learning_rate=0.05, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Loss: 1.5556\n",
      "Epoch [2/20], Loss: 1.4808\n",
      "Epoch [3/20], Loss: 1.4785\n",
      "Epoch [4/20], Loss: 1.4779\n",
      "Epoch [5/20], Loss: 1.4763\n",
      "Epoch [6/20], Loss: 1.4776\n",
      "Epoch [7/20], Loss: 1.4762\n",
      "Epoch [8/20], Loss: 1.4757\n",
      "Epoch [9/20], Loss: 1.4756\n",
      "Epoch [10/20], Loss: 1.4761\n",
      "Epoch [11/20], Loss: 1.4765\n",
      "Epoch [12/20], Loss: 1.4761\n",
      "Epoch [13/20], Loss: 1.4758\n",
      "Epoch [14/20], Loss: 1.4757\n",
      "Epoch [15/20], Loss: 1.4759\n",
      "Epoch [16/20], Loss: 1.4755\n",
      "Epoch [17/20], Loss: 1.4746\n",
      "Epoch [18/20], Loss: 1.4755\n",
      "Epoch [19/20], Loss: 1.4750\n",
      "Epoch [20/20], Loss: 1.4758\n",
      "Mean accuracy: 44.10%\n",
      "95% confidence interval: (29.67%, 58.53%)\n"
     ]
    }
   ],
   "source": [
    "train_data = np.vstack([train_data_original, train_data_AE])\n",
    "labels = np.hstack([Y_train, np.repeat(unique_labels, num_new_samples)])\n",
    "custom_train_dataset = CustomDataset(train_data, labels)\n",
    "\n",
    "train_loader = DataLoader(dataset=custom_train_dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "bootstrapping(train_loader, test_dataset, input_size=X_train.shape[1], num_classes=len(unique_labels), num_epochs=20, learning_rate=0.05, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
