{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WW06Wcmfcnd0"
   },
   "source": [
    "# Google Speech Commands"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "IQ0m4_80ZUfH"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))\n",
    "\n",
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset\n",
    "\n",
    "from torchvision import models\n",
    "import torchaudio\n",
    "\n",
    "from sklearn.neighbors import KernelDensity\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import cupy as cp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "from utils import vectorize_tensor, reconstruct_tensor\n",
    "import ld\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_random_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "\n",
    "    cp.random.seed(seed)\n",
    "\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "load = True\n",
    "\n",
    "set_random_seed(2)\n",
    "k = 3\n",
    "bandwidth = 0.05\n",
    "bandwidth_AE = 0.05\n",
    "eps = np.asarray(1.0e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NumpyDataset(Dataset):\n",
    "    def __init__(self, features, labels, transform=None):\n",
    "        # Ensure features are in the correct shape [N, 1, 4000]\n",
    "        if len(features.shape) == 1:\n",
    "            features = features.reshape(1, 1, -1)\n",
    "        elif len(features.shape) == 2:\n",
    "            features = features.reshape(features.shape[0], 1, -1)\n",
    "\n",
    "        self.features = torch.tensor(features, dtype=torch.float32)\n",
    "        self.labels = torch.tensor(labels, dtype=torch.long)\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.features[idx].float()\n",
    "        label = self.labels[idx]\n",
    "\n",
    "        if self.transform:\n",
    "            sample = self.transform(sample)\n",
    "\n",
    "        return sample, label\n",
    "\n",
    "class Normalize:\n",
    "    def __init__(self, mean, std):\n",
    "        # Reshape mean and std to match input dimensions if necessary\n",
    "        if not torch.is_tensor(mean):\n",
    "            mean = torch.tensor(mean)\n",
    "        if not torch.is_tensor(std):\n",
    "            std = torch.tensor(std)\n",
    "\n",
    "        # Add channel dimension if needed\n",
    "        if len(mean.shape) == 1:\n",
    "            mean = mean.reshape(1, -1)\n",
    "        if len(std.shape) == 1:\n",
    "            std = std.reshape(1, -1)\n",
    "\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def __call__(self, sample):\n",
    "        # Ensure broadcasting works correctly\n",
    "        return (sample - self.mean) / (self.std + eps)\n",
    "\n",
    "# Your transform classes might need modification too\n",
    "class AddGaussianNoise:\n",
    "    def __init__(self, mean=0.0, std=0.1):\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def __call__(self, sample):\n",
    "        # Ensure noise has the same shape as input\n",
    "        noise = torch.randn_like(sample) * self.std + self.mean\n",
    "        return sample + noise\n",
    "\n",
    "# Composite transformation to combine Normalize and AddGaussianNoise\n",
    "class CompositeTransform:\n",
    "    def __init__(self, transforms):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            transforms (list): List of transformations to apply in sequence.\n",
    "        \"\"\"\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __call__(self, sample):\n",
    "        for transform in self.transforms:\n",
    "            sample = transform(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train size: 35000\n",
      "New sample size: 7000\n",
      "Number of classes: 35\n",
      "Number of features: 4000\n"
     ]
    }
   ],
   "source": [
    "D = 4000\n",
    "S = (2, 5, 2, 5, 2, 5, 2, 2)\n",
    "train_size_per_class = 1000\n",
    "new_sample_size_per_class = 200\n",
    "num_class = 35\n",
    "\n",
    "# Print dataset statistics\n",
    "print(f\"Train size: {train_size_per_class * num_class}\")\n",
    "print(f\"New sample size: {new_sample_size_per_class * num_class}\")\n",
    "print(f\"Number of classes: {num_class}\")\n",
    "print(f\"Number of features: {D}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Couldn't find appropriate backend to handle uri ./data/SpeechCommands/speech_commands_v0.02/backward/0165e0e8_nohash_0.wav and format None.",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mRuntimeError\u001b[39m                              Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 24\u001b[39m\n\u001b[32m     21\u001b[39m os.makedirs(save_dir)\n\u001b[32m     23\u001b[39m \u001b[38;5;66;03m# Get unique labels\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m24\u001b[39m all_labels = \u001b[38;5;28msorted\u001b[39m(\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28;43mset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mitem\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtrain_dataset\u001b[49m\u001b[43m)\u001b[49m))\n\u001b[32m     25\u001b[39m label_to_idx = {label: idx \u001b[38;5;28;01mfor\u001b[39;00m idx, label \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(all_labels)}\n\u001b[32m     27\u001b[39m \u001b[38;5;66;03m# Save audio files by class\u001b[39;00m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 24\u001b[39m, in \u001b[36m<genexpr>\u001b[39m\u001b[34m(.0)\u001b[39m\n\u001b[32m     21\u001b[39m os.makedirs(save_dir)\n\u001b[32m     23\u001b[39m \u001b[38;5;66;03m# Get unique labels\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m24\u001b[39m all_labels = \u001b[38;5;28msorted\u001b[39m(\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mset\u001b[39m(item[\u001b[32m2\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m train_dataset)))\n\u001b[32m     25\u001b[39m label_to_idx = {label: idx \u001b[38;5;28;01mfor\u001b[39;00m idx, label \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(all_labels)}\n\u001b[32m     27\u001b[39m \u001b[38;5;66;03m# Save audio files by class\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/PNL/lib/python3.12/site-packages/torchaudio/datasets/speechcommands.py:179\u001b[39m, in \u001b[36mSPEECHCOMMANDS.__getitem__\u001b[39m\u001b[34m(self, n)\u001b[39m\n\u001b[32m    159\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"Load the n-th sample from the dataset.\u001b[39;00m\n\u001b[32m    160\u001b[39m \n\u001b[32m    161\u001b[39m \u001b[33;03mArgs:\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m    176\u001b[39m \u001b[33;03m        Utterance number\u001b[39;00m\n\u001b[32m    177\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m    178\u001b[39m metadata = \u001b[38;5;28mself\u001b[39m.get_metadata(n)\n\u001b[32m--> \u001b[39m\u001b[32m179\u001b[39m waveform = \u001b[43m_load_waveform\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_archive\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    180\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (waveform,) + metadata[\u001b[32m1\u001b[39m:]\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/PNL/lib/python3.12/site-packages/torchaudio/datasets/utils.py:51\u001b[39m, in \u001b[36m_load_waveform\u001b[39m\u001b[34m(root, filename, exp_sample_rate)\u001b[39m\n\u001b[32m     45\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_load_waveform\u001b[39m(\n\u001b[32m     46\u001b[39m     root: \u001b[38;5;28mstr\u001b[39m,\n\u001b[32m     47\u001b[39m     filename: \u001b[38;5;28mstr\u001b[39m,\n\u001b[32m     48\u001b[39m     exp_sample_rate: \u001b[38;5;28mint\u001b[39m,\n\u001b[32m     49\u001b[39m ):\n\u001b[32m     50\u001b[39m     path = os.path.join(root, filename)\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m     waveform, sample_rate = \u001b[43mtorchaudio\u001b[49m\u001b[43m.\u001b[49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     52\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m exp_sample_rate != sample_rate:\n\u001b[32m     53\u001b[39m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33msample rate should be \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexp_sample_rate\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msample_rate\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/PNL/lib/python3.12/site-packages/torchaudio/_backend/utils.py:204\u001b[39m, in \u001b[36mget_load_func.<locals>.load\u001b[39m\u001b[34m(uri, frame_offset, num_frames, normalize, channels_first, format, buffer_size, backend)\u001b[39m\n\u001b[32m    118\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mload\u001b[39m(\n\u001b[32m    119\u001b[39m     uri: Union[BinaryIO, \u001b[38;5;28mstr\u001b[39m, os.PathLike],\n\u001b[32m    120\u001b[39m     frame_offset: \u001b[38;5;28mint\u001b[39m = \u001b[32m0\u001b[39m,\n\u001b[32m   (...)\u001b[39m\u001b[32m    126\u001b[39m     backend: Optional[\u001b[38;5;28mstr\u001b[39m] = \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m    127\u001b[39m ) -> Tuple[torch.Tensor, \u001b[38;5;28mint\u001b[39m]:\n\u001b[32m    128\u001b[39m \u001b[38;5;250m    \u001b[39m\u001b[33;03m\"\"\"Load audio data from source.\u001b[39;00m\n\u001b[32m    129\u001b[39m \n\u001b[32m    130\u001b[39m \u001b[33;03m    By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m    202\u001b[39m \u001b[33;03m            `[channel, time]` else `[time, channel]`.\u001b[39;00m\n\u001b[32m    203\u001b[39m \u001b[33;03m    \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m204\u001b[39m     backend = \u001b[43mdispatcher\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    205\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m backend.load(uri, frame_offset, num_frames, normalize, channels_first, \u001b[38;5;28mformat\u001b[39m, buffer_size)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/PNL/lib/python3.12/site-packages/torchaudio/_backend/utils.py:116\u001b[39m, in \u001b[36mget_load_func.<locals>.dispatcher\u001b[39m\u001b[34m(uri, format, backend_name)\u001b[39m\n\u001b[32m    114\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m backend.can_decode(uri, \u001b[38;5;28mformat\u001b[39m):\n\u001b[32m    115\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m backend\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCouldn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt find appropriate backend to handle uri \u001b[39m\u001b[38;5;132;01m{\u001b[39;00muri\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m and format \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mformat\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m)\n",
      "\u001b[31mRuntimeError\u001b[39m: Couldn't find appropriate backend to handle uri ./data/SpeechCommands/speech_commands_v0.02/backward/0165e0e8_nohash_0.wav and format None."
     ]
    }
   ],
   "source": [
    "# Step 1: Load the dataset\n",
    "train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root='./data', download=True, subset='training')\n",
    "\n",
    "def pad_or_truncate(waveform, target_length):\n",
    "    \"\"\"\n",
    "    Pads or truncates the waveform to the target length.\n",
    "    \"\"\"\n",
    "    waveform_length = waveform.shape[-1]\n",
    "    if waveform_length > target_length:\n",
    "        # Truncate\n",
    "        return waveform[:, :target_length]\n",
    "    elif waveform_length < target_length:\n",
    "        # Pad with zeros\n",
    "        padding = target_length - waveform_length\n",
    "        return torch.nn.functional.pad(waveform, (0, padding))\n",
    "    return waveform\n",
    "\n",
    "# Step 2: Create directories and save audio files by class\n",
    "save_dir = './data/SPEECH_audio'\n",
    "if not os.path.exists(save_dir):\n",
    "    os.makedirs(save_dir)\n",
    "\n",
    "    # Get unique labels\n",
    "    all_labels = sorted(list(set(item[2] for item in train_dataset)))\n",
    "    label_to_idx = {label: idx for idx, label in enumerate(all_labels)}\n",
    "\n",
    "    # Save audio files by class\n",
    "    for idx, (waveform, sample_rate, label, *_) in enumerate(train_dataset):\n",
    "        class_label = label_to_idx[label]\n",
    "        class_dir = os.path.join(save_dir, str(class_label))\n",
    "\n",
    "        if not os.path.exists(class_dir):\n",
    "            os.makedirs(class_dir)\n",
    "\n",
    "        # Normalize to 1x16000\n",
    "        waveform = pad_or_truncate(waveform, target_length=16000)\n",
    "        waveform_np = waveform.numpy()\n",
    "\n",
    "        # Sub-sampling to 1x4000\n",
    "        waveform_np = waveform_np[:, ::4]\n",
    "        np.save(os.path.join(class_dir, f'{idx}.npy'), waveform_np)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 3: Function to load audio files from a directory\n",
    "def load_audio_from_folder(folder, max_files=20):\n",
    "    audio_files = []\n",
    "    for filename in sorted(os.listdir(folder))[:max_files]:\n",
    "        if filename.endswith('.npy'):\n",
    "            audio = np.load(os.path.join(folder, filename))\n",
    "            if audio is not None:\n",
    "                audio_files.append(audio)\n",
    "\n",
    "    # print(len(audio_files))\n",
    "    return audio_files\n",
    "\n",
    "# Step 4: Load audio and preprocess them\n",
    "P_class = []\n",
    "\n",
    "for class_ in range(num_class):\n",
    "    class_folder = os.path.join(save_dir, str(class_))\n",
    "    audio_files = load_audio_from_folder(class_folder, max_files=train_size_per_class)\n",
    "\n",
    "    P = []\n",
    "    for audio in audio_files:\n",
    "        # Reshape audio to match your requirements\n",
    "        array_obj = audio.reshape(S)  # or any other shape you need\n",
    "        P.append(array_obj)\n",
    "    P_class.append(np.array(P))\n",
    "\n",
    "# Normalize to [0, 1]\n",
    "P_class = np.array(P_class)\n",
    "P_class = (P_class-P_class.min()) / (P_class.max()-P_class.min())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pseudo-Non-Linear Data Augmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Legendre Decomposition (Many-Body Approximation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dimension of Base Sub-Manifold: 136\n"
     ]
    }
   ],
   "source": [
    "B_LD = ld.default_B(S, 2, cp.get_array_module(P[0]))\n",
    "# B_LD = ld.default_B(S, 2)\n",
    "print(f\"Dimension of Base Sub-Manifold: {B_LD.shape[0]}\")\n",
    "\n",
    "# Store all intermediate results to disk\n",
    "if load:\n",
    "    # Load all intermediate results from disk\n",
    "    results = np.load('results_LD.npz')\n",
    "    scaleX_class = results['scaleX_class']\n",
    "    Q_class = results['Q_class']\n",
    "    theta_class = results['theta_class']\n",
    "    X_recons_class = results['X_recons_class']\n",
    "else:\n",
    "    import tempfile\n",
    "    temp_dir = tempfile.mkdtemp(dir='/data/pbb/tmp')\n",
    "\n",
    "    def LD_helper(i, class_):\n",
    "        _, _, scaleX, Q, theta = ld.LD(P_class[class_][i], B=B_LD, verbose=False, n_iter=1000, lr=1e-1)\n",
    "        return (scaleX, Q, theta)\n",
    "\n",
    "\n",
    "    results = Parallel(n_jobs=30, temp_folder=temp_dir)(delayed(LD_helper)(i, class_) for i in range(train_size_per_class) for class_ in range(num_class))\n",
    "\n",
    "    scaleX_class = []\n",
    "    Q_class = []\n",
    "    theta_class = []\n",
    "    X_recons_class = []\n",
    "\n",
    "    for class_ in range(num_class):\n",
    "        scaleX_list = []\n",
    "        Q_list = []\n",
    "        theta_list = []\n",
    "        X_recons_list = []\n",
    "        for i in range(train_size_per_class):\n",
    "            result = results[i*num_class + class_]\n",
    "\n",
    "            scaleX_list.append(result[0])\n",
    "            Q_list.append(result[1])\n",
    "            theta_list.append(result[2])\n",
    "            X_recons = (result[1] * result[0]).astype(np.int32)\n",
    "            X_recons_list.append(X_recons)\n",
    "\n",
    "        scaleX_class.append(np.array(scaleX_list))\n",
    "        Q_class.append(np.array(Q_list))\n",
    "        theta_class.append(np.array(theta_list))\n",
    "        X_recons_class.append(np.array(X_recons_list))\n",
    "\n",
    "    np.savez('results_LD.npz', scaleX_class=scaleX_class, Q_class=Q_class, theta_class=theta_class, X_recons_class=X_recons_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fitting on Projected Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_theta_class = []\n",
    "\n",
    "for class_ in range(num_class):\n",
    "    reduced_theta = vectorize_tensor(np.array(theta_class[class_]), B_LD)\n",
    "\n",
    "    # Fit a KDE to the theta values\n",
    "    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(reduced_theta)\n",
    "    # Sample new data from the KDE\n",
    "    sampled_reduced_theta = kde.sample(n_samples=new_sample_size_per_class)\n",
    "\n",
    "    sampled_theta = reconstruct_tensor(sampled_reduced_theta, (new_sample_size_per_class, *S), B_LD)\n",
    "\n",
    "    sampled_theta_class.append(sampled_theta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Construct Local-Data Sub-Manifold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dimension of Local Data Sub-Manifold: 3430\n"
     ]
    }
   ],
   "source": [
    "# Construct the constrained coordinates\n",
    "# B_BP = ld.block_B([14, 14], [15, 15])\n",
    "B_BP = ld.default_B(S, 3, cp.get_array_module(P[0]))\n",
    "# B_BP = ld.default_B(S, 3)\n",
    "print(f\"Dimension of Local Data Sub-Manifold: {D - B_BP.shape[0]}\")\n",
    "\n",
    "# Compute every datapoint's eta_hat (served as the linear constraints)\n",
    "eta_hat_class = []\n",
    "\n",
    "for class_ in range(num_class):\n",
    "    eta_hat_list = []\n",
    "    for i in range(P_class[class_].shape[0]):\n",
    "        xp = cp.get_array_module(P_class[class_][i])\n",
    "        # xp = np\n",
    "        P = (P_class[class_][i] + eps) / scaleX_class[class_][i]\n",
    "        eta_hat = ld.get_eta(P, len(S), xp)\n",
    "        eta_hat_list.append(eta_hat)\n",
    "\n",
    "    eta_hat_class.append(cp.asarray(eta_hat_list))\n",
    "    # eta_hat_class.append(np.asarray(eta_hat_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Backward Projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if load:\n",
    "    # Load all intermediate results from disk\n",
    "    results_BP = np.load('results_BP.npz')\n",
    "    sampled_P_BP_class = results_BP['sampled_P_BP_class']\n",
    "    sampled_theta_BP_class = results_BP['sampled_theta_BP_class']\n",
    "    sampled_X_BP_class = results_BP['sampled_X_BP_class']\n",
    "else:\n",
    "    # Use a custom temporary directory\n",
    "    import tempfile\n",
    "    temp_dir = tempfile.mkdtemp(dir='/data/pbb/tmp')\n",
    "\n",
    "\n",
    "    def BP_helper(i, class_):\n",
    "        N = ld.kNN(sampled_theta_class[class_][i], theta_class[class_], k=k)\n",
    "        avg_scale = np.mean(scaleX_class[class_][N])\n",
    "        avg_eta_hat = np.mean(eta_hat_class[class_][N], axis=0)\n",
    "        _, _, P, theta = ld.BP(sampled_theta_class[class_][i], [(P_class[class_][j] + eps) / scaleX_class[class_][j] for j in N], avg_eta_hat, avg_scale, B=B_BP, verbose=False, n_iter=1000, lr=1e-1)\n",
    "        X_recons_ = (P).astype(np.int32).reshape(-1)\n",
    "        return (P, theta, X_recons_)\n",
    "\n",
    "    results = Parallel(n_jobs=10, temp_folder=temp_dir)(delayed(BP_helper)(i, class_) for i in range(new_sample_size_per_class) for class_ in range(num_class))\n",
    "\n",
    "    sampled_P_BP_class = []\n",
    "    sampled_theta_BP_class = []\n",
    "    sampled_X_BP_class = []\n",
    "\n",
    "    for class_ in range(num_class):\n",
    "        sampled_P_BP = []\n",
    "        sampled_theta_BP = []\n",
    "        sampled_X_BP = []\n",
    "        for i in range(new_sample_size_per_class):\n",
    "            result = results[i*num_class + class_]\n",
    "\n",
    "            sampled_P_BP.append(result[0])\n",
    "            sampled_theta_BP.append(result[1])\n",
    "            sampled_X_BP.append(result[2])\n",
    "\n",
    "        sampled_P_BP_class.append(np.array(sampled_P_BP))\n",
    "        sampled_theta_BP_class.append(np.array(sampled_theta_BP))\n",
    "        sampled_X_BP_class.append(np.array(sampled_X_BP))\n",
    "\n",
    "    np.savez('results_BP.npz', sampled_P_BP_class=sampled_P_BP_class, sampled_theta_BP_class=sampled_theta_BP_class, sampled_X_BP_class=sampled_X_BP_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Augmentation with Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SpeechCommandsDataset(Dataset):\n",
    "    def __init__(self, subset='training', transform=None):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            subset (str): 'training' or 'testing'\n",
    "            transform (callable, optional): Optional transform to be applied\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.dataset = torchaudio.datasets.SPEECHCOMMANDS(\n",
    "            root='./data',\n",
    "            download=True,\n",
    "            subset=subset\n",
    "        )\n",
    "        self.transform = transform\n",
    "\n",
    "        # Get unique labels and create label mapping\n",
    "        all_labels = sorted(list(set(item[2] for item in self.dataset)))\n",
    "        self.label_to_idx = {label: idx for idx, label in enumerate(all_labels)}\n",
    "\n",
    "        # Precompute all waveforms with subsampling\n",
    "        print(f\"Precomputing features for {subset} set...\")\n",
    "        self.precomputed_waveforms = []\n",
    "        self.labels = []\n",
    "\n",
    "        for waveform, _, label, *_ in tqdm(self.dataset):\n",
    "            # Pad/truncate to 16000 first\n",
    "            if waveform.shape[-1] > 16000:\n",
    "                waveform = waveform[:, :16000]\n",
    "            elif waveform.shape[-1] < 16000:\n",
    "                padding = 16000 - waveform.shape[-1]\n",
    "                waveform = torch.nn.functional.pad(waveform, (0, padding))\n",
    "\n",
    "            # Subsample by taking every 4th sample\n",
    "            waveform = waveform[:, ::4]  # Now length is 4000\n",
    "\n",
    "            self.precomputed_waveforms.append(waveform)\n",
    "            self.labels.append(self.label_to_idx[label])\n",
    "\n",
    "        # Convert to tensors\n",
    "        self.precomputed_waveforms = torch.stack(self.precomputed_waveforms)\n",
    "        self.labels = torch.tensor(self.labels, dtype=torch.long)\n",
    "\n",
    "        print(f\"Precomputed {len(self.precomputed_waveforms)} waveforms\")\n",
    "        print(f\"Waveform shape: {self.precomputed_waveforms[0].shape}\")  # Should be [1, 4000]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.precomputed_waveforms)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        waveform = self.precomputed_waveforms[idx]\n",
    "        label = self.labels[idx]\n",
    "\n",
    "        if self.transform:\n",
    "            waveform = self.transform(waveform)\n",
    "\n",
    "        return waveform, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_size=D, hidden_size=B_LD.shape[0], z_dim=3):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.fc2 = nn.Linear(hidden_size, z_dim)\n",
    "        self.relu = nn.ReLU()\n",
    "    def forward(self , x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_size=D, hidden_size=B_LD.shape[0], z_dim=3):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(z_dim , hidden_size)\n",
    "        self.fc2 = nn.Linear(hidden_size, output_size)\n",
    "        self.relu = nn.ReLU()\n",
    "    def forward(self , x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = torch.sigmoid(self.fc2(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precomputing features for training set...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 84843/84843 [04:00<00:00, 352.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precomputed 84843 waveforms\n",
      "Waveform shape: torch.Size([1, 4000])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [10:00<00:00,  6.00s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fdbae9fec30>]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuY0lEQVR4nO3dfXBc1X3/8c+9u9rVylqtkLH8EMugGMpDRIgDDjHOKJ7ATyEwKW6YtE1IB1IKoV0ZZGbS4KRp6GSomKYtaTvEmXQau/MLrgkpwsSTOKNiW8S/2AaECVETxGPAYGQDxruyHla7e8/vj32QBAa0knyv7fN+zdyRtXu5e/ZKaD/7Ped+1zHGGAEAAPjEDXoAAADALoQPAADgK8IHAADwFeEDAAD4ivABAAB8RfgAAAC+InwAAABfET4AAICvwkEP4O08z9OBAwcUj8flOE7QwwEAAFNgjNHg4KAWLVok133v2sYJFz4OHDigpqamoIcBAACmYf/+/Vq8ePF77nPChY94PC6pMPi6urqARwMAAKYinU6rqamp/Dr+Xk648FGaaqmrqyN8AABwkpnKkgkWnAIAAF8RPgAAgK8IHwAAwFeEDwAA4CvCBwAA8BXhAwAA+IrwAQAAfEX4AAAAviJ8AAAAXxE+AACArwgfAADAV4QPAADgqxPug+WOl9cHM/rezucUDYd0+2fODXo4AABYy5rKR3o0qw3/7/fatPeloIcCAIDVrAkfYbfwEb+eCXggAABYrqLw0dnZqeXLlysej6uxsVGrV69Wf3//O/bbvXu3PvWpT2nOnDmqq6tTa2urRkZGZm3Q0+E6hfCR87xAxwEAgO0qCh89PT1KJpPas2ePuru7lc1m1dbWpqGhofI+u3fv1hVXXKG2tjY9+uijeuyxx9Te3i7XDbbIEg4VKx9kDwAAAlXRgtNt27ZN+n7jxo1qbGxUb2+vWltbJUlr167VLbfcottvv7283znnnDMLQ52ZEJUPAABOCDMqR6RSKUlSQ0ODJOnQoUPau3evGhsbdemll2r+/Pn65Cc/qV27dr3rMTKZjNLp9KTteAhNWPNhDAs/AAAIyrTDh+d56ujo0MqVK9XS0iJJeuGFFyRJd9xxh2688UZt27ZNH/3oR3XZZZfp2WefPeZxOjs7lUgkyltTU9N0h/SeSuFDkvKsOgUAIDDTDh/JZFJ9fX3avHlz+TavOKXxla98RV/+8pe1bNky3X333TrnnHP0wx/+8JjHWbdunVKpVHnbv3//dIf0niaFDyofAAAEZlpNxtrb27V161Y98sgjWrx4cfn2hQsXSpLOP//8Sfufd955evnll495rGg0qmg0Op1hVITKBwAAJ4aKKh/GGLW3t6urq0vbt29Xc3PzpPvPPPNMLVq06B2X3z7zzDM644wzZj7aGZgYPnKEDwAAAlNR5SOZTGrTpk3asmWL4vG4BgYGJEmJREKxWEyO4+irX/2qvvWtb+nCCy/URz7yEf3nf/6nnn76af3kJz85Lk9gqsITLvX1CB8AAASmovCxfv16SdKqVasm3b5hwwZdf/31kqSOjg6Njo5q7dq1Onz4sC688EJ1d3dr6dKlszLg6ZpQ+KDyAQBAgCoKH1O9RPX222+f1OfjROA4jkKuo7xnqHwAABAgaz7bRZrYaIzwAQBAUOwKH8W5F652AQAgOIQPAADgKzvDB03GAAAIjJ3hg8oHAACBIXwAAABf2RU+HMIHAABBsyt8uFxqCwBA0KwKH+EQlQ8AAIJmVfhg2gUAgODZFT5YcAoAQOAIHwAAwFd2hg+ajAEAEBg7w4fnBTwSAADsZWn4CHggAABYzK7w4VD5AAAgaHaFD5qMAQAQOKvCB03GAAAInlXhw6XJGAAAgbMqfITp8wEAQOCsCh80GQMAIHh2hg+ajAEAEBg7wweVDwAAAmNZ+Cg8XcIHAADBsSt8FAofhA8AAAJkV/ig8gEAQOAsCx+Fr3Q4BQAgOJaFDyofAAAEzarwQZMxAACCZ1X44FJbAACCZ2f4oMkYAACBsTN8UPkAACAwhA8AAOAru8KHQ/gAACBodoUPKh8AAATOyvBBkzEAAIJjZfjIe17AIwEAwF5WhY/xJmMBDwQAAItZFT6ofAAAELyKwkdnZ6eWL1+ueDyuxsZGrV69Wv39/ZP2WbVqlRzHmbTdfPPNszro6RpvMhbwQAAAsFhF4aOnp0fJZFJ79uxRd3e3stms2traNDQ0NGm/G2+8Ua+99lp5+4d/+IdZHfR0UfkAACB44Up23rZt26TvN27cqMbGRvX29qq1tbV8e01NjRYsWDA7I5xFXGoLAEDwZrTmI5VKSZIaGhom3X7vvffq9NNPV0tLi9atW6fh4eGZPMysockYAADBq6jyMZHneero6NDKlSvV0tJSvv2LX/yizjjjDC1atEhPPfWUvva1r6m/v18PPPDAMY+TyWSUyWTK36fT6ekO6X1R+QAAIHjTDh/JZFJ9fX3atWvXpNtvuumm8r8vuOACLVy4UJdddpmef/55LV269B3H6ezs1N/93d9NdxgVockYAADBm9a0S3t7u7Zu3aodO3Zo8eLF77nvJZdcIkl67rnnjnn/unXrlEqlytv+/funM6QpofIBAEDwKqp8GGO0Zs0adXV1aefOnWpubn7f/+bJJ5+UJC1cuPCY90ejUUWj0UqGMW2EDwAAgldR+Egmk9q0aZO2bNmieDyugYEBSVIikVAsFtPzzz+vTZs26corr9TcuXP11FNPae3atWptbdWHP/zh4/IEKhEmfAAAELiKwsf69eslFRqJTbRhwwZdf/31ikQi+p//+R9997vf1dDQkJqamnTNNdfob/7mb2ZtwDMRcguzTHlD+AAAICgVT7u8l6amJvX09MxoQMdTqLjChcoHAADBseyzXYqVD8IHAACBsSt80GQMAIDA2RU+WHAKAEDgCB8AAMBXdoYPrnYBACAwVoaPXJ7wAQBAUKwKHzQZAwAgeFaFD9dh2gUAgKBZFT7CISofAAAEzarwwdUuAAAEz67wQZMxAAACZ1f4oPIBAEDgCB8AAMBXVoWPME3GAAAInFXhw51Q+TAEEAAAAmFV+ChVPiSmXgAACIpV4cOdGD6ofAAAEAirwgeVDwAAgmdV+AgRPgAACJxd4cMhfAAAEDS7wgeVDwAAAmdV+HAcR6X8QfgAACAYVoUPSQq7hafM1S4AAATDuvBRzB7K5QkfAAAEwbrwUap8eFQ+AAAIhHXho7TmI8eaDwAAAmFd+AiHims+CB8AAATCuvDhOuMfLgcAAPxnXfgIu4QPAACCZF34CBE+AAAIlLXhgwWnAAAEw7rwUZp24VJbAACCYV34cEuVD5qMAQAQCOvCB5UPAACCZV34KF1qy5oPAACCYV34CIdKV7t4AY8EAAA7WRc+xpuMBTwQAAAsZV34GG8yRvoAACAI1oWP8SZjAQ8EAABLWRs+clQ+AAAIREXho7OzU8uXL1c8HldjY6NWr16t/v7+Y+5rjNFnPvMZOY6jBx98cDbGOitCXGoLAECgKgofPT09SiaT2rNnj7q7u5XNZtXW1qahoaF37Pvd735XTnFx54kkRJMxAAACFa5k523btk36fuPGjWpsbFRvb69aW1vLtz/55JP6p3/6Jz3++ONauHDh7Ix0ltBkDACAYFUUPt4ulUpJkhoaGsq3DQ8P64tf/KLuueceLViw4H2PkclklMlkyt+n0+mZDOl90WQMAIBgTXvBqed56ujo0MqVK9XS0lK+fe3atbr00kt19dVXT+k4nZ2dSiQS5a2pqWm6Q5qS8SZjhA8AAIIw7cpHMplUX1+fdu3aVb7toYce0vbt27Vv374pH2fdunW67bbbyt+n0+njGkDGm4wRPgAACMK0Kh/t7e3aunWrduzYocWLF5dv3759u55//nnV19crHA4rHC5km2uuuUarVq065rGi0ajq6uombcfTeJMxwgcAAEGoqPJhjNGaNWvU1dWlnTt3qrm5edL9t99+u/7iL/5i0m0XXHCB7r77bn32s5+d+WhngUv4AAAgUBWFj2QyqU2bNmnLli2Kx+MaGBiQJCUSCcViMS1YsOCYi0yXLFnyjqASlLDLglMAAIJU0bTL+vXrlUqltGrVKi1cuLC83XfffcdrfLMu5Baeskf4AAAgEBVPu1RqOv/N8RQqxi0qHwAABMO6z3YJlyofJ1goAgDAFtaFD5qMAQAQLOvCR6nJGGs+AAAIhnXhg8oHAADBsi580GQMAIBgWRc+aDIGAECwrAsfNBkDACBY1oWPkMuCUwAAgmRt+KDyAQBAMKwLH6VpF5qMAQAQDOvCB5faAgAQLOvCB03GAAAIlnXhY7zy4QU8EgAA7GRd+KDJGAAAwbIufNBkDACAYFkXPmgyBgBAsKwLHyEutQUAIFDWho9cnvABAEAQrAsfNBkDACBY1oUPmowBABAs68IHTcYAAAiWdeGDygcAAMGyLnyE3cJTps8HAADBsC58FLMH4QMAgIBYFz6ofAAAECzrwkeoVPngUlsAAAJhYfgoPGWajAEAEAzrwgdNxgAACJZ14YNLbQEACJZ14YMmYwAABMu68EHlAwCAYFkXPsprPggfAAAEwrrwEXKpfAAAECRrwwdNxgAACIa94YNLbQEACIS94cMzMgQQAAB8Z1/4KF7tIknMvAAA4D/7wkdoPHzkPC/AkQAAYCfrwkfpUltJInsAAOC/isJHZ2enli9frng8rsbGRq1evVr9/f2T9vnKV76ipUuXKhaLad68ebr66qv19NNPz+qgZ8J1qHwAABCkisJHT0+Pksmk9uzZo+7ubmWzWbW1tWloaKi8z0UXXaQNGzbod7/7nX7xi1/IGKO2tjbl8/lZH/x0UPkAACBYjpnBJR+vv/66Ghsb1dPTo9bW1mPu89RTT+nCCy/Uc889p6VLl77vMdPptBKJhFKplOrq6qY7tHdljFHzup9Jknr/5nLNrY3O+mMAAGCbSl6/Z7TmI5VKSZIaGhqOef/Q0JA2bNig5uZmNTU1zeShZo3jOCoVP2g0BgCA/6YdPjzPU0dHh1auXKmWlpZJ933ve99TbW2tamtr9fOf/1zd3d2KRCLHPE4mk1E6nZ60HW80GgMAIDjTDh/JZFJ9fX3avHnzO+679tprtW/fPvX09OgP/uAP9Md//McaHR095nE6OzuVSCTKmx8VkvLnu+QJHwAA+G1a4aO9vV1bt27Vjh07tHjx4nfcn0gkdPbZZ6u1tVU/+clP9PTTT6urq+uYx1q3bp1SqVR5279//3SGVJFSozGPygcAAL4LV7KzMUZr1qxRV1eXdu7cqebm5in9N8YYZTKZY94fjUYVjfq76JNPtgUAIDgVhY9kMqlNmzZpy5YtisfjGhgYkFSodMRiMb3wwgu677771NbWpnnz5umVV17RXXfdpVgspiuvvPK4PIHpCIcKBR+P8AEAgO8qmnZZv369UqmUVq1apYULF5a3++67T5JUXV2tX/7yl7ryyit11lln6U/+5E8Uj8f1q1/9So2NjcflCUxHqdEYlQ8AAPxX8bTLe1m0aJF+9rOfzWhAfghP+GRbAADgL+s+20WacKkt4QMAAN/ZHT642gUAAN/ZHT6ofAAA4DurwwdNxgAA8J+d4YMmYwAABMbO8EGTMQAAAmN1+KDJGAAA/rM6fFD5AADAf1aGD5qMAQAQHCvDh0v4AAAgMFaGjzBNxgAACIyV4WO8yZgX8EgAALCP1eGDJmMAAPjPzvBBkzEAAAJjZ/jgUlsAAAJjdfigyRgAAP6zOnxQ+QAAwH9Whg+ajAEAEBwrwwdNxgAACI6V4YMmYwAABMfK8FFuMkafDwAAfGd3+KDyAQCA7+wMHw5rPgAACIqd4cMtPG0utQUAwH+Who/CV5qMAQDgP0vDB5UPAACCYmn4KHxlzQcAAP6zNHwUnjbhAwAA/1kZPmgyBgBAcKwMHzQZAwAgOHaHDyofAAD4zs7wQZMxAAACY2f4KFY+uNQWAAD/WR0+aDIGAID/rA4fOc8LeCQAANjH6vCRJ3sAAOA7y8MH6QMAAL9ZGT7Gm4wFPBAAACxkZfig8gEAQHAsDx+UPgAA8FtF4aOzs1PLly9XPB5XY2OjVq9erf7+/vL9hw8f1po1a3TOOecoFotpyZIluuWWW5RKpWZ94DNBkzEAAIJTUfjo6elRMpnUnj171N3drWw2q7a2Ng0NDUmSDhw4oAMHDugf//Ef1dfXp40bN2rbtm264YYbjsvgp4smYwAABCdcyc7btm2b9P3GjRvV2Nio3t5etba2qqWlRf/93/9dvn/p0qW688479aUvfUm5XE7hcEUPd9zQZAwAgODMKA2UplMaGhrec5+6urp3DR6ZTEaZTKb8fTqdnsmQpoTKBwAAwZn2glPP89TR0aGVK1eqpaXlmPu88cYb+va3v62bbrrpXY/T2dmpRCJR3pqamqY7pCljwSkAAMGZdvhIJpPq6+vT5s2bj3l/Op3WVVddpfPPP1933HHHux5n3bp1SqVS5W3//v3THdKUET4AAAjOtKZd2tvbtXXrVj3yyCNavHjxO+4fHBzUFVdcoXg8rq6uLlVVVb3rsaLRqKLR6HSGMW3lq10M4QMAAL9VVPkwxqi9vV1dXV3avn27mpub37FPOp1WW1ubIpGIHnroIVVXV8/aYGdLOETlAwCAoFRU+Ugmk9q0aZO2bNmieDyugYEBSVIikVAsFisHj+HhYf3oRz9SOp0uLyCdN2+eQqHQ7D+DaQi5hcxF+AAAwH8VhY/169dLklatWjXp9g0bNuj666/XE088ob1790qSzjrrrEn7vPjiizrzzDOnP9JZRJMxAACCU1H4MO+zRmLVqlXvu8+JgAWnAAAEh892AQAAvrI6fNBkDAAA/1kdPmivDgCA/6wMH2EqHwAABMbK8OG6NBkDACAoVoaPMAtOAQAIjJXhY+LVLifDpcEAAJxK7AwfxSZjkkTxAwAAf9kZPkLj4YOpFwAA/GVn+HAIHwAABMXO8OGOh4+c5wU4EgAA7GN9+CB7AADgLzvDh0PlAwCAoFgZPlzXUSl/0GgMAAB/WRk+JBqNAQAQFGvDh+sQPgAACIK14YPKBwAAwbA2fIQIHwAABILwQfgAAMBXFoePwlPnahcAAPxlcfgofM3lCR8AAPjJ2vARLlU+mHYBAMBX1oaPYvZg2gUAAJ9ZGz6ofAAAEAxrw0fps+UIHwAA+Mva8EHlAwCAYFgbPujzAQBAMAgfhA8AAHxF+CB8AADgK+vDR47wAQCAr6wPH1Q+AADwl73hwymGD5qMAQDgK2vDRzhUqnx4AY8EAAC7WBs+3FLlg+wBAICvrA0fYZfKBwAAQbA2fLgulQ8AAIJgbfig8gEAQDCsDR9cagsAQDCsDx80GQMAwF8VhY/Ozk4tX75c8XhcjY2NWr16tfr7+yft84Mf/ECrVq1SXV2dHMfRkSNHZnO8s4bKBwAAwagofPT09CiZTGrPnj3q7u5WNptVW1ubhoaGyvsMDw/riiuu0Ne//vVZH+xsoskYAADBCFey87Zt2yZ9v3HjRjU2Nqq3t1etra2SpI6ODknSzp07Z2WAx0u5yVie8AEAgJ9mtOYjlUpJkhoaGmZlMH5yqXwAABCIiiofE3mep46ODq1cuVItLS3THkAmk1Emkyl/n06np32sSoRZ8wEAQCCmXflIJpPq6+vT5s2bZzSAzs5OJRKJ8tbU1DSj402VS/gAACAQ0wof7e3t2rp1q3bs2KHFixfPaADr1q1TKpUqb/v375/R8aaKygcAAMGoaNrFGKM1a9aoq6tLO3fuVHNz84wHEI1GFY1GZ3ycSoXcQu4ifAAA4K+KwkcymdSmTZu0ZcsWxeNxDQwMSJISiYRisZgkaWBgQAMDA3ruueckSb/5zW8Uj8e1ZMmSE2phaqhY86HJGAAA/qpo2mX9+vVKpVJatWqVFi5cWN7uu+++8j7f//73tWzZMt14442SpNbWVi1btkwPPfTQ7I58hkqVD4+rXQAA8FXF0y7v54477tAdd9wx3fH4ptRkjMoHAAD+svazXWgyBgBAMKwNHzQZAwAgGNaGDy61BQAgGNaGD5qMAQAQDGvDB5UPAACCYW34oPIBAEAwrA0fpcoHl9oCAOAva8NHqBg+aDIGAIC/7A0fNBkDACAQ1oaPcpMxzwt4JAAA2MXa8FFuMkblAwAAX1kbPrjUFgCAYFgbPrjUFgCAYFgbPqh8AAAQDGvDR7nywaW2AAD4ytrwUW4ylid8AADgJ2vDB03GAAAIhr3hgyZjAAAEwtrwUWoy5hE+AADwlbXhw6XyAQBAIKwNH2G38NS51BYAAH9ZGz6K2YPwAQCAz6wNH1Q+AAAIhrXhI1SqfHCpLQAAvrI4fBQrHzQZAwDAV/aGD4f26gAABMHe8BHiUlsAAIJgbfgofbYLTcYAAPCXteFjYpMxw9QLAAC+sTZ8lCofkkTxAwAA/1gbPtwJ4YNeHwAA+Mfa8BEmfAAAEAhrw0doYvhgzQcAAL4hfIhGYwAA+Mne8OFQ+QAAIAjWhg/XdVTKHznPC3YwAABYxNrwIU1sNBbwQAAAsIjV4WO80RjpAwAAv1gdPkqVDy61BQDAP1aHD5fwAQCA7yoKH52dnVq+fLni8bgaGxu1evVq9ff3T9pndHRUyWRSc+fOVW1tra655hodPHhwVgc9W6h8AADgv4rCR09Pj5LJpPbs2aPu7m5ls1m1tbVpaGiovM/atWv105/+VPfff796enp04MABfe5zn5v1gc+GUq8PLrUFAMA/4Up23rZt26TvN27cqMbGRvX29qq1tVWpVEr/8R//oU2bNulTn/qUJGnDhg0677zztGfPHn384x+fvZHPglL4yNFkDAAA38xozUcqlZIkNTQ0SJJ6e3uVzWZ1+eWXl/c599xztWTJEu3evfuYx8hkMkqn05M2v5QajXlUPgAA8M20w4fneero6NDKlSvV0tIiSRoYGFAkElF9ff2kfefPn6+BgYFjHqezs1OJRKK8NTU1TXdIFQuFSpfaEj4AAPDLtMNHMplUX1+fNm/ePKMBrFu3TqlUqrzt379/RserRLnyQfgAAMA3Fa35KGlvb9fWrVv1yCOPaPHixeXbFyxYoLGxMR05cmRS9ePgwYNasGDBMY8VjUYVjUanM4wZK6/5IHwAAOCbiiofxhi1t7erq6tL27dvV3Nz86T7L7roIlVVVenhhx8u39bf36+XX35ZK1asmJ0Rz6KwW3j6VD4AAPBPRZWPZDKpTZs2acuWLYrH4+V1HIlEQrFYTIlEQjfccINuu+02NTQ0qK6uTmvWrNGKFStOuCtdpPEmY1Q+AADwT0XhY/369ZKkVatWTbp9w4YNuv766yVJd999t1zX1TXXXKNMJqNPf/rT+t73vjcrg51tNBkDAMB/FYUPM4VLUqurq3XPPffonnvumfag/EJ7dQAA/Gf1Z7uEmXYBAMB3VocPmowBAOA/u8MHlQ8AAHxH+BCX2gIA4CfCh6h8AADgJ6vDR5jKBwAAvrM6fNBkDAAA/1kdPsabjHkBjwQAAHtYHT5oMgYAgP+sDh80GQMAwH9Whw+ajAEA4L+KPtvlVFO61Pa5Q0f1f/e8pP6BtPoHBpXzjM6cO6ewnV6jM+bO0aL6ap0+J1qeqgEAANND+JD048df0Y8ff2XSfftePvKO/SMhV/MTUS2si2lePKqaSKiwRcOqjYa1YulcLWuql+MQUAAAeDdWh48VS+fqgSde1dzaiM5dENc5C+p03sK4wq6rlw4P6fdvDOn3bw7rpTeHdGgwo7G8p/2HR7T/8Mi7HvP8hXX60sfP0NUfWaQ5UatPLwAAx+QYc2IteEin00okEkqlUqqrqzvuj5f3TLkC8l6yeU+HBjN67ciIDqRGdfhoRsPZvEbG8hrK5HUwParu3x3UWK5w2W48Gtb/+dB8zauNqjYaVrw6rNrqKtVEQqquclVdFVKsKqSaSFinzanSaTURVYWsXoIDADiJVfL6bf1b86kED0mqCrn6QH1MH6iPves+bw2N6Se9r+jevS/p928O64EnXq1oLPFoWKfNiSgccpT3jHJ5o2y+EGbm1kbVGC9udVElYlWqCrkKh1xFQo6qQq6qQq4iYVeR0tewW7zdUaR4f3XVePiJht33nSIqZVOmkgAAs8X6ysfx4HlGv3r+TfW+9JaOZrI6mslpcLSwjWTzGi1WTEZzharJkeExBXW1byTsKuQ4CrmOXKcQxvKeUa4YfsbynhxHqquu0mk1Vaqviei0mipFwq48Uwgnec/IM4Wrhkzxq2eMqkJuoeITDas2WqXaaEhVIVehkFN+zLDrKFwMSGHXVTg03nvFM+Ot7yNhV9HweKiKVYUUi4Q0JxJWTSSkSNjVWN5TNm80lvM0lvOUzXvKFP89lveU9zyF3PEwFg4VjlkTKVShqiMhVYdDhZ/hhOciSY4jOXIKX53C+h8CGQCMo/IRMNd19ImzT9cnzj59Svt7nlF6NKvDQ2M6PDSmnGdUFXIUct1yL5LXj2b0ejqjQ4OjOjSY0dHRnMbyXrk6UnjhLb3omvILbnbC7WO5wovxxL4mpWmi92KMlBrJKjWSld4cnt5JOcVUV7maOyeq02sjmlsbVSwS0uhYXkNjOY2M5TWSzcuRo6rweFUq7DrFUOcp7xll84Wfc7Q4BVcKVKXqVOk2x5GyxZ9zziuEvfpYlebWRjW3NqLTayOKhkMaHstreCxX/JovPE4xHOY9I0dSTTRcCGzRkGqqQhrJ5pUayerIcGHL5PI6rSaihjkRNdRGNHdORLXRcGFc4cLYwiFHw5m8BjNZDY7mdDSTk6RCgCs9h3BhCrEUUD1TCLa10bCqqwhugO0IHycA13VUXxNRfU1EH5x3/B8vl/c0mvM0MpbXWN6TV3xxyhsjzzNyXUdVrquqcKEaYWSUGs7qreGs3hoe05HhQkBynUK1xHEcOVKxeuLILVZRsnmvXPE5msnp6Giu+OJZeBH1PKNs8cU4lx//t+NIbrka48iYwgt1JpcvB6iRbOEFdqT4gl+q30XCrqIhV1UTpp+qQo4i4dCkF/9ssUqSyXkaGStUpCqpPo1mPb16ZESvHnn3xcc4NteR5kQKgUZS+WeSe1uVKxouVLQkFatxnvJ5IyO9Y4qxENbHf2dCbuF3ciKjUkWt8PtujBStKgSlWLH6Var8lf5/MKZQ5YoVr2yLVYUUCjnl38NM1tNoLq/hTE5HM3kNZXIaGssplzfltVwNcwpbVchVLl94HrniOMLueDitCjly5Mio8LhmwvlyHUdO8f83Y1Q+RunNx5tDY3p9MFPecp6nptNq1NRQ3E6LaU40XD6W6zgKTZiOrSpO3WZy+fIbjfRI4blUueMVx4nnYk40XLziL6xoeHzaN1r8mU2sOmZznowKQbT0vPKeKZ/HifuV3jRlcoW/TeHQ5GplIlalxnhU8+LR91zUP5rNK116LqO5cmW39PsRDbvl5zAnEpbrFv7WDI+Vfo55GWPUMCeiuuqq49ZmwfOMDg+PKTWSLVeYw++z/s8Yo5FsXoOjOTmOTso2EIQPC4VDrmpDrmoruBqnMV59HEc0M8YUponCrjPtd9TGFKaYRsc8yZn4B1/veEHIe0bpkazeOJrRm0fH9OZQRkOZvOZEQ4pFwppTfJGSNGkqKOd5hWkn1ylXtvKep5GxQpgayeY1MpbTaNYrTM0Vp+eMUTkIRsKuHElHhrN6c6jw2G8czWgs55Vf0Etfq0KFP7ThUCnEScNjhT+qpQpJdTik+poq1ddUKRGrUiTk6q3hQhXuzaExHR7KaDhTmCLM5iens2i4MK1WeFFzyhWfkWy+XFErBUlH452EPSMNZnIaLFZM3m54LD+tnyEmO5jO6PGX3gp6GMdVTSSk02oiklSeLvaM0XDx/6NKRItTt8daiOA60mk1EdXXFNbald5wSYW/CWO5fDlEZXLeeIAthlhJmhMJKV5dVZiGrg5rNJvX64MZvTk0NukjPhxHqo9VqWFORLFIqBhYTTmsDY0V3tBN/G8iYVeL62P6wGnj6xJLldChTF6ZXH7SG8OQ62hRIqbvfP7Cis7RbCJ84KTnOIUX85keIxoOKVpc8/F+ErEqNTXUzOgxTza54rvRbN5TTSRcrkocy7EWKnte4d3aUPEP4lAmJ9dxylWLsFs43lg+r9Fs4R1wJluohJWqGqVpyEnreUp/7CdMMb3985qMVK7Uld79Ok5h2rEU8krPLewW/kiH3rGPp5FsoaoRrXKLvy+Fr6VKwJxoSLXRsEKuUw6IbxWDXOFdfOE5hFxXrlMIZNli5a/wwmfGX9gmvLp5ZnxdlaTyNF6pKnBaTUTzitWAefGoXMfRK28N6+XDw9p/eESvvDWsTNYrr8fKm8IHapYet7RGKhoOKRErBNG6WFg1kXCxUlioEhYCeuFnWKoQDI/ly5WLY4kUx1h6Z+4Ufy9cR5MqKpFipatQuSxUOkKuMz7lmDfK5D2lhsd0aDBTnl4cHnv36qPrSHWxqvIbLa/8eyJlcoXxl35VMhOmoJ1idU6Sjhb3ebP4c5yu9GhO6dFjB26pcMHB0WIV961ipfn9lIodYzlPL7wxpBfeGJryeJbOmzPlfY8HwgeAKQkXy95TcawKlOs6xRfosBSf7dHh7T7SVO/r45WmR0shqhQqjtf6nqFMTq8PZvTW8Fh5KqW0xapCqotVKR4Nv+d0hDGFUFUKUdVVoUIFsypUHvdYztOR4TEdHi6uycuPTx2VwmJpMXy0uIVdd9J4jJGGxgpTz4ULELKKhF01xqs1Lx4tT8vlPaO3io/zxtGMMjlPVcVjlYJroXVDleLVhSmjnGc0kBrVK28VQuaBI6MKuSpXYWuihWmx0tqr0tRjKVwFhatdAADAjFXy+k1XKwAA4CvCBwAA8BXhAwAA+IrwAQAAfEX4AAAAviJ8AAAAXxE+AACArwgfAADAV4QPAADgK8IHAADwFeEDAAD4ivABAAB8RfgAAAC+CvYzdY+h9CG76XQ64JEAAICpKr1ul17H38sJFz4GBwclSU1NTQGPBAAAVGpwcFCJROI993HMVCKKjzzP04EDBxSPx+U4zqweO51Oq6mpSfv371ddXd2sHhuTca79w7n2D+faP5xr/8zWuTbGaHBwUIsWLZLrvveqjhOu8uG6rhYvXnxcH6Ouro5fZp9wrv3DufYP59o/nGv/zMa5fr+KRwkLTgEAgK8IHwAAwFdWhY9oNKpvfetbikajQQ/llMe59g/n2j+ca/9wrv0TxLk+4RacAgCAU5tVlQ8AABA8wgcAAPAV4QMAAPiK8AEAAHxlTfi45557dOaZZ6q6ulqXXHKJHn300aCHdNLr7OzU8uXLFY/H1djYqNWrV6u/v3/SPqOjo0omk5o7d65qa2t1zTXX6ODBgwGN+NRx1113yXEcdXR0lG/jXM+eV199VV/60pc0d+5cxWIxXXDBBXr88cfL9xtj9Ld/+7dauHChYrGYLr/8cj377LMBjvjklM/n9c1vflPNzc2KxWJaunSpvv3tb0/6bBDO9fQ98sgj+uxnP6tFixbJcRw9+OCDk+6fyrk9fPiwrr32WtXV1am+vl433HCDjh49OvPBGQts3rzZRCIR88Mf/tD87//+r7nxxhtNfX29OXjwYNBDO6l9+tOfNhs2bDB9fX3mySefNFdeeaVZsmSJOXr0aHmfm2++2TQ1NZmHH37YPP744+bjH/+4ufTSSwMc9cnv0UcfNWeeeab58Ic/bG699dby7Zzr2XH48GFzxhlnmOuvv97s3bvXvPDCC+YXv/iFee6558r73HXXXSaRSJgHH3zQ/PrXvzZ/+Id/aJqbm83IyEiAIz/53HnnnWbu3Llm69at5sUXXzT333+/qa2tNf/yL/9S3odzPX0/+9nPzDe+8Q3zwAMPGEmmq6tr0v1TObdXXHGFufDCC82ePXvML3/5S3PWWWeZL3zhCzMemxXh42Mf+5hJJpPl7/P5vFm0aJHp7OwMcFSnnkOHDhlJpqenxxhjzJEjR0xVVZW5//77y/v87ne/M5LM7t27gxrmSW1wcNCcffbZpru723zyk58shw/O9ez52te+Zj7xiU+86/2e55kFCxaY73znO+Xbjhw5YqLRqPmv//ovP4Z4yrjqqqvMn//5n0+67XOf+5y59tprjTGc69n09vAxlXP729/+1kgyjz32WHmfn//858ZxHPPqq6/OaDyn/LTL2NiYent7dfnll5dvc11Xl19+uXbv3h3gyE49qVRKktTQ0CBJ6u3tVTabnXTuzz33XC1ZsoRzP03JZFJXXXXVpHMqca5n00MPPaSLL75Yn//859XY2Khly5bp3//938v3v/jiixoYGJh0rhOJhC655BLOdYUuvfRSPfzww3rmmWckSb/+9a+1a9cufeYzn5HEuT6epnJud+/erfr6el188cXlfS6//HK5rqu9e/fO6PFPuA+Wm21vvPGG8vm85s+fP+n2+fPn6+mnnw5oVKcez/PU0dGhlStXqqWlRZI0MDCgSCSi+vr6SfvOnz9fAwMDAYzy5LZ582Y98cQTeuyxx95xH+d69rzwwgtav369brvtNn3961/XY489pltuuUWRSETXXXdd+Xwe628K57oyt99+u9LptM4991yFQiHl83ndeeeduvbaayWJc30cTeXcDgwMqLGxcdL94XBYDQ0NMz7/p3z4gD+SyaT6+vq0a9euoIdyStq/f79uvfVWdXd3q7q6OujhnNI8z9PFF1+sv//7v5ckLVu2TH19ffr+97+v6667LuDRnVp+/OMf695779WmTZv0oQ99SE8++aQ6Ojq0aNEizvUp7pSfdjn99NMVCoXeser/4MGDWrBgQUCjOrW0t7dr69at2rFjhxYvXly+fcGCBRobG9ORI0cm7c+5r1xvb68OHTqkj370owqHwwqHw+rp6dG//uu/KhwOa/78+ZzrWbJw4UKdf/75k24777zz9PLLL0tS+XzyN2XmvvrVr+r222/Xn/7pn+qCCy7Qn/3Zn2nt2rXq7OyUxLk+nqZybhcsWKBDhw5Nuj+Xy+nw4cMzPv+nfPiIRCK66KKL9PDDD5dv8zxPDz/8sFasWBHgyE5+xhi1t7erq6tL27dvV3Nz86T7L7roIlVVVU069/39/Xr55Zc59xW67LLL9Jvf/EZPPvlkebv44ot17bXXlv/NuZ4dK1eufMcl488884zOOOMMSVJzc7MWLFgw6Vyn02nt3buXc12h4eFhue7kl6FQKCTP8yRxro+nqZzbFStW6MiRI+rt7S3vs337dnmep0suuWRmA5jRctWTxObNm000GjUbN240v/3tb81NN91k6uvrzcDAQNBDO6n95V/+pUkkEmbnzp3mtddeK2/Dw8PlfW6++WazZMkSs337dvP444+bFStWmBUrVgQ46lPHxKtdjOFcz5ZHH33UhMNhc+edd5pnn33W3Hvvvaampsb86Ec/Ku9z1113mfr6erNlyxbz1FNPmauvvprLP6fhuuuuMx/4wAfKl9o+8MAD5vTTTzd//dd/Xd6Hcz19g4ODZt++fWbfvn1Gkvnnf/5ns2/fPvPSSy8ZY6Z2bq+44gqzbNkys3fvXrNr1y5z9tlnc6ltJf7t3/7NLFmyxEQiEfOxj33M7NmzJ+ghnfQkHXPbsGFDeZ+RkRHzV3/1V+a0004zNTU15o/+6I/Ma6+9FtygTyFvDx+c69nz05/+1LS0tJhoNGrOPfdc84Mf/GDS/Z7nmW9+85tm/vz5JhqNmssuu8z09/cHNNqTVzqdNrfeeqtZsmSJqa6uNh/84AfNN77xDZPJZMr7cK6nb8eOHcf8G33dddcZY6Z2bt98803zhS98wdTW1pq6ujrz5S9/2QwODs54bI4xE1rJAQAAHGen/JoPAABwYiF8AAAAXxE+AACArwgfAADAV4QPAADgK8IHAADwFeEDAAD4ivABAAB8RfgAAAC+InwAAABfET4AAICvCB8AAMBX/x858r5J2VlJ6AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_loader_original = DataLoader(SpeechCommandsDataset(subset='training'), batch_size=32, shuffle=True)\n",
    "\n",
    "enc = Encoder().to(device)\n",
    "dec = Decoder().to(device)\n",
    "loss_fn = nn.MSELoss()\n",
    "optimizer_enc = torch.optim.Adam(enc.parameters(), lr=1e-3)\n",
    "optimizer_dec = torch.optim.Adam(dec.parameters(), lr=1e-3)\n",
    "\n",
    "train_loss = []\n",
    "num_epochs = 100\n",
    "\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    train_epoch_loss = 0\n",
    "    for (x , _) in train_loader_original:\n",
    "        x = x.to(device)\n",
    "        x = x.flatten(1)\n",
    "        latents = enc(x)\n",
    "        output = dec(latents)\n",
    "        loss = loss_fn(output , x)\n",
    "        train_epoch_loss += loss.cpu().detach().numpy()\n",
    "        optimizer_enc.zero_grad()\n",
    "        optimizer_dec.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_enc.step()\n",
    "        optimizer_dec.step()\n",
    "    train_loss.append(train_epoch_loss)\n",
    "plt.plot(train_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "representation = None\n",
    "all_labels = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for (xs , labels) in train_loader_original:\n",
    "        xs = xs.to(device)\n",
    "        xs = xs.flatten(1)\n",
    "        all_labels.extend(list(labels.numpy()))\n",
    "        latents = enc(xs)\n",
    "        if representation is None:\n",
    "            representation = latents.cpu()\n",
    "        else:\n",
    "            representation = torch.vstack([representation , latents.cpu()])\n",
    "\n",
    "all_labels = np.array(all_labels)\n",
    "representation = representation.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_X_AE_class = []\n",
    "\n",
    "for class_ in range(num_class):\n",
    "    sampled_X_AE_list = []\n",
    "\n",
    "    rep = representation[np.argwhere(all_labels == class_)].squeeze()\n",
    "    # Fit a KDE to the theta values\n",
    "    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth_AE).fit(rep)\n",
    "\n",
    "    # Sample new data from the KDE\n",
    "    sampled_rep = kde.sample(n_samples=new_sample_size_per_class)\n",
    "    for i in range(new_sample_size_per_class):\n",
    "        pred = dec(torch.Tensor(sampled_rep[i])[None , ...].to(device)).cpu().detach().numpy()\n",
    "        sampled_X_AE_list.append(pred.flatten())\n",
    "\n",
    "    sampled_X_AE_class.append(sampled_X_AE_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classification Performance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precomputing features for training set...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 84843/84843 [03:59<00:00, 354.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precomputed 84843 waveforms\n",
      "Waveform shape: torch.Size([1, 4000])\n",
      "Precomputing features for testing set...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11005/11005 [00:30<00:00, 357.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precomputed 11005 waveforms\n",
      "Waveform shape: torch.Size([1, 4000])\n",
      "standard loader:\n",
      "Images shape: torch.Size([64, 1, 4000]), Labels shape: torch.Size([64])\n",
      "none loader:\n",
      "Images shape: torch.Size([64, 1, 4000]), Labels shape: torch.Size([64])\n",
      "PNL_augmented loader:\n",
      "Images shape: torch.Size([64, 1, 4000]), Labels shape: torch.Size([64])\n",
      "AE_augmented loader:\n",
      "Images shape: torch.Size([64, 1, 4000]), Labels shape: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "# Helper function to create a DataLoader for augmented datasets\n",
    "def prepare_augmented_dataset(original_dataset, augmented_data, augmented_labels, transform, batch_size, augmented_ratio=1.0):\n",
    "    # Ensure augmented data has the correct shape [N, 1, 4000]\n",
    "    if len(augmented_data.shape) == 1:\n",
    "        augmented_data = augmented_data.reshape(1, 1, -1)\n",
    "    elif len(augmented_data.shape) == 2:\n",
    "        augmented_data = augmented_data.reshape(augmented_data.shape[0], 1, -1)\n",
    "\n",
    "    # Reduce the size of augmented data according to the augmented_ratio\n",
    "    total_augmented = len(augmented_data)\n",
    "    num_selected = int(total_augmented * augmented_ratio)\n",
    "\n",
    "    if num_selected < total_augmented:  # Sample only if reducing\n",
    "        indices = np.random.choice(total_augmented, num_selected, replace=False)\n",
    "        augmented_data = augmented_data[indices]\n",
    "        augmented_labels = augmented_labels[indices]\n",
    "\n",
    "    # Create dataset from augmented data\n",
    "    augmented_dataset = NumpyDataset(features=augmented_data, labels=augmented_labels, transform=transform)\n",
    "\n",
    "    # Combine with the original dataset\n",
    "    combined_dataset = ConcatDataset([original_dataset, augmented_dataset])\n",
    "\n",
    "    # Create a DataLoader for the combined dataset\n",
    "    return DataLoader(dataset=combined_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "# Main function to prepare datasets and loaders\n",
    "def create_datasets_and_loaders(batch_size=64, augmented_ratio=1.0):\n",
    "    # Define transformations\n",
    "    train_dataset = SpeechCommandsDataset(subset='training')\n",
    "    test_dataset = SpeechCommandsDataset(subset='testing')\n",
    "\n",
    "    # Calculate normalization statistics from training data\n",
    "    mean = torch.mean(train_dataset.precomputed_waveforms)\n",
    "    std = torch.std(train_dataset.precomputed_waveforms)\n",
    "\n",
    "    transform_standard = CompositeTransform([\n",
    "        Normalize(mean, std),\n",
    "        AddGaussianNoise(mean=0.0, std=std.min()/4)  # Add Gaussian noise with 1/4 of the standard deviation\n",
    "    ])\n",
    "    transform_none = Normalize(mean, std)\n",
    "\n",
    "    train_standard = train_dataset\n",
    "    train_standard.transform = transform_standard\n",
    "    train_none = train_dataset\n",
    "    train_none.transform = transform_none\n",
    "\n",
    "    test_dataset.transform = transform_none\n",
    "\n",
    "    augmented_data_PNL, labels_PNL = [], []\n",
    "    augmented_data_AE, labels_AE = [], []\n",
    "\n",
    "    for class_ in range(num_class):\n",
    "        for data_PNL in sampled_X_BP_class[class_]:\n",
    "            augmented_data_PNL.append(data_PNL)\n",
    "            labels_PNL.append(class_)\n",
    "\n",
    "        for data_AE in sampled_X_AE_class[class_]:\n",
    "            augmented_data_AE.append(data_AE)\n",
    "            labels_AE.append(class_)\n",
    "\n",
    "    augmented_data_PNL = np.array(augmented_data_PNL)\n",
    "    labels_PNL = np.array(labels_PNL)\n",
    "    augmented_data_AE = np.array(augmented_data_AE)\n",
    "    labels_AE = np.array(labels_AE)\n",
    "\n",
    "    mean_LD = augmented_data_PNL.mean(axis=0)\n",
    "    std_LD = augmented_data_PNL.std(axis=0)\n",
    "    transform_none_LD = Normalize(mean_LD, std_LD)\n",
    "\n",
    "    mean_AE = augmented_data_AE.mean(axis=0)\n",
    "    std_AE = augmented_data_AE.std(axis=0)\n",
    "    transform_none_AE = Normalize(mean_AE, std_AE)\n",
    "\n",
    "    # Combine Speech Commands dataset with a subset of augmented data\n",
    "    train_loader_PNL = prepare_augmented_dataset(train_none, augmented_data_PNL, labels_PNL, transform_none_LD, batch_size, augmented_ratio)\n",
    "    train_loader_AE = prepare_augmented_dataset(train_none, augmented_data_AE, labels_AE, transform_none_AE, batch_size, augmented_ratio)\n",
    "\n",
    "    # DataLoader for Speech Commands only\n",
    "    train_loader_standard = DataLoader(train_standard, batch_size=batch_size, shuffle=True)\n",
    "    train_loader_none = DataLoader(train_none, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    return {\n",
    "        \"standard\": train_loader_standard,\n",
    "        \"none\": train_loader_none,\n",
    "        \"PNL_augmented\": train_loader_PNL,\n",
    "        \"AE_augmented\": train_loader_AE,\n",
    "    }, test_dataset\n",
    "\n",
    "# Create dictionaries of loaders for different datasets\n",
    "train_loader_standard = {}\n",
    "train_loader_none = {}\n",
    "train_loader_PNL = {}\n",
    "train_loader_AE = {}\n",
    "\n",
    "# Example Usage: Testing different dataset sizes\n",
    "for ratio in [0.0, 0.25, 0.5, 0.75, 1.0]:  # Testing different proportions of augmented data\n",
    "    print(f\"Testing with augmented_ratio = {ratio}\")\n",
    "    loaders, test_dataset = create_datasets_and_loaders(batch_size=64, augmented_ratio=ratio)\n",
    "\n",
    "    for name, loader in loaders.items():\n",
    "        print(f\"{name} loader:\")\n",
    "        for data, labels in loader:\n",
    "            print(f\"data shape: {data.shape}, Labels shape: {labels.shape}\")\n",
    "            break\n",
    "\n",
    "    train_loader_standard[ratio] = loaders[\"standard\"]\n",
    "    train_loader_none[ratio] = loaders[\"none\"]\n",
    "    train_loader_PNL[ratio] = loaders[\"PNL_augmented\"]\n",
    "    train_loader_AE[ratio] = loaders[\"AE_augmented\"]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ResNet18 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "class M5(nn.Module):\n",
    "    \"\"\"\n",
    "    M5 model architecture adapted for subsampled audio (4000 samples instead of 16000)\n",
    "    \"\"\"\n",
    "    def __init__(self, n_input=1, n_output=35, stride=4, n_channel=32):\n",
    "        super().__init__()\n",
    "\n",
    "        # Adjusted kernel sizes and strides for 4000-sample input\n",
    "        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=20, stride=stride).float()\n",
    "        self.bn1 = nn.BatchNorm1d(n_channel).float()\n",
    "        self.pool1 = nn.MaxPool1d(4)\n",
    "\n",
    "        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3).float()\n",
    "        self.bn2 = nn.BatchNorm1d(n_channel).float()\n",
    "        self.pool2 = nn.MaxPool1d(4)\n",
    "\n",
    "        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3).float()\n",
    "        self.bn3 = nn.BatchNorm1d(2 * n_channel).float()\n",
    "        self.pool3 = nn.MaxPool1d(4)\n",
    "\n",
    "        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3).float()\n",
    "        self.bn4 = nn.BatchNorm1d(2 * n_channel).float()\n",
    "        self.pool4 = nn.MaxPool1d(4)\n",
    "\n",
    "        self.fc1 = nn.Linear(2 * n_channel, n_output).float()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(self.bn1(x))\n",
    "        x = self.pool1(x)\n",
    "\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(self.bn2(x))\n",
    "        x = self.pool2(x)\n",
    "\n",
    "        x = self.conv3(x)\n",
    "        x = F.relu(self.bn3(x))\n",
    "        x = self.pool3(x)\n",
    "\n",
    "        x = self.conv4(x)\n",
    "        x = F.relu(self.bn4(x))\n",
    "        x = self.pool4(x)\n",
    "\n",
    "        x = F.avg_pool1d(x, x.shape[-1])\n",
    "        x = x.permute(0, 2, 1)\n",
    "        x = self.fc1(x)\n",
    "\n",
    "        return F.log_softmax(x, dim=2).squeeze()\n",
    "\n",
    "def train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs, device='cuda'):\n",
    "    train_loss = []\n",
    "    model.to(device)\n",
    "    for epoch in tqdm(range(num_epochs)):\n",
    "        model.train()\n",
    "        train_epoch_loss = 0.0\n",
    "\n",
    "        for inputs, labels in train_loader:\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            train_epoch_loss += loss.item()\n",
    "\n",
    "        scheduler.step()\n",
    "        train_loss.append(train_epoch_loss)\n",
    "\n",
    "    plt.plot(train_loss)\n",
    "\n",
    "def test_model(model, test_loader, device='cuda'):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for inputs, labels in test_loader:\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    accuracy = 100 * correct / total\n",
    "    return accuracy\n",
    "\n",
    "def bootstrapping(train_loader, test_dataset, num_epochs=100, learning_rate=1e-1, n_bootstrap=20, device='cuda'):\n",
    "    \"\"\"\n",
    "    Train the ResNet model on the training dataset, and evaluate it using bootstrapping on the test dataset.\n",
    "\n",
    "    Args:\n",
    "        train_loader: DataLoader for training data.\n",
    "        test_dataset: Dataset object for the test data.\n",
    "        num_epochs: Number of epochs for training.\n",
    "        learning_rate: Learning rate for the optimizer.\n",
    "        device: Device to run the training on ('cuda' or 'cpu').\n",
    "\n",
    "    Returns:\n",
    "        Prints the mean accuracy and 95% confidence interval after bootstrapping.\n",
    "    \"\"\"\n",
    "    # Initialize the model, loss, optimizer, and scheduler\n",
    "    model = M5().to(device)\n",
    "    model = model.float()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)  # Decay LR every 30 epochs\n",
    "\n",
    "    # Train the model\n",
    "    train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs=num_epochs, device=device)\n",
    "\n",
    "    # Perform bootstrapping\n",
    "    accuracies = []\n",
    "\n",
    "    num_test_samples = len(test_dataset) // 2\n",
    "    for i in range(n_bootstrap):\n",
    "        indices = torch.randint(len(test_dataset), size=(num_test_samples,))  # Sample 500 random indices\n",
    "        bootstrap_subset = Subset(test_dataset, indices)\n",
    "        bootstrap_loader = DataLoader(dataset=bootstrap_subset, batch_size=num_test_samples, shuffle=False)\n",
    "\n",
    "        accuracy = test_model(model, bootstrap_loader, device=device)\n",
    "        accuracies.append(accuracy)\n",
    "\n",
    "    # Compute statistics\n",
    "    mean_accuracy = np.mean(accuracies)\n",
    "    std_accuracy = np.std(accuracies)\n",
    "\n",
    "    print(f\"Mean accuracy: {mean_accuracy:.2f}%\")\n",
    "    print(f\"Standard deviation: {std_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/50], Loss: 1.5258\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [2/50], Loss: 0.9701\n",
      "Epoch [3/50], Loss: 0.8856\n",
      "Epoch [4/50], Loss: 0.8579\n",
      "Epoch [5/50], Loss: 0.8359\n",
      "Epoch [6/50], Loss: 0.8195\n",
      "Epoch [7/50], Loss: 0.8117\n",
      "Epoch [8/50], Loss: 0.8087\n",
      "Epoch [9/50], Loss: 0.7983\n",
      "Epoch [10/50], Loss: 0.7910\n",
      "Epoch [11/50], Loss: 0.7990\n",
      "Epoch [12/50], Loss: 0.7876\n",
      "Epoch [13/50], Loss: 0.7903\n",
      "Epoch [14/50], Loss: 0.7882\n",
      "Epoch [15/50], Loss: 0.7845\n",
      "Epoch [16/50], Loss: 0.7889\n",
      "Epoch [17/50], Loss: 0.7835\n",
      "Epoch [18/50], Loss: 0.7796\n",
      "Epoch [19/50], Loss: 0.7752\n",
      "Epoch [20/50], Loss: 0.7802\n",
      "Epoch [21/50], Loss: 0.7816\n",
      "Epoch [22/50], Loss: 0.7698\n",
      "Epoch [23/50], Loss: 0.7811\n",
      "Epoch [24/50], Loss: 0.7694\n",
      "Epoch [25/50], Loss: 0.7754\n",
      "Epoch [26/50], Loss: 0.7821\n",
      "Epoch [27/50], Loss: 0.7733\n",
      "Epoch [28/50], Loss: 0.7758\n",
      "Epoch [29/50], Loss: 0.7707\n",
      "Epoch [30/50], Loss: 0.7752\n",
      "Epoch [31/50], Loss: 0.5287\n",
      "Epoch [32/50], Loss: 0.4795\n",
      "Epoch [33/50], Loss: 0.4669\n",
      "Epoch [34/50], Loss: 0.4605\n",
      "Epoch [35/50], Loss: 0.4572\n",
      "Epoch [36/50], Loss: 0.4538\n",
      "Epoch [37/50], Loss: 0.4552\n",
      "Epoch [38/50], Loss: 0.4584\n",
      "Epoch [39/50], Loss: 0.4612\n",
      "Epoch [40/50], Loss: 0.4620\n",
      "Epoch [41/50], Loss: 0.4674\n",
      "Epoch [42/50], Loss: 0.4654\n",
      "Epoch [43/50], Loss: 0.4709\n",
      "Epoch [44/50], Loss: 0.4699\n",
      "Epoch [45/50], Loss: 0.4721\n",
      "Epoch [46/50], Loss: 0.4710\n",
      "Epoch [47/50], Loss: 0.4715\n",
      "Epoch [48/50], Loss: 0.4742\n",
      "Epoch [49/50], Loss: 0.4723\n",
      "Epoch [50/50], Loss: 0.4720\n",
      "Mean accuracy: 84.56%\n",
      "Standard deviation: 0.62\n"
     ]
    }
   ],
   "source": [
    "for ratio in [0.0, 0.25, 0.5, 0.75, 1.0]:\n",
    "    print(f\"Testing with augmented_ratio = {ratio}\")\n",
    "    print(\"Standard:\")\n",
    "    bootstrapping(train_loader_standard[ratio], test_dataset, num_epochs=50, learning_rate=0.1, device=device)\n",
    "    print(\"None:\")\n",
    "    bootstrapping(train_loader_none[ratio], test_dataset, num_epochs=50, learning_rate=0.1, device=device)\n",
    "    print(\"PNL:\")\n",
    "    bootstrapping(train_loader_PNL[ratio], test_dataset, num_epochs=50, learning_rate=0.1, device=device)\n",
    "    print(\"AE:\")\n",
    "    bootstrapping(train_loader_AE[ratio], test_dataset, num_epochs=50, learning_rate=0.1, device=device)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "PNL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
