{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8018a642",
   "metadata": {},
   "source": [
    "# Import Required Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "760582cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "os.chdir(\"..\")\n",
    "os.chdir(\"..\")\n",
    "os.chdir(\"./src\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e77b841",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "import pylab as pl\n",
    "\n",
    "from ICA import *\n",
    "from general_utils import *\n",
    "from visualization_utils import *\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "notebook_name = \"FastICA\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0071175",
   "metadata": {},
   "source": [
    "# Source Generation and Mixing Scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1351ebbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the sources and mixtures\n",
    "S, X, A = generate_synthetic_data_SMICA(seed=np.random.randint(2500))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf42a1cf",
   "metadata": {},
   "source": [
    "# Visualize Generated Sources and Mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f0df841",
   "metadata": {},
   "outputs": [],
   "source": [
    "subplot_1D_signals(\n",
    "    S[:, 0:100], title=\"Original Signals\", figsize=(15.2, 9), colorcode=None\n",
    ")\n",
    "subplot_1D_signals(\n",
    "    X[:, 0:100], title=\"Mixture Signals\", figsize=(15, 18), colorcode=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca4c8d4d",
   "metadata": {},
   "source": [
    "# Run FastICA Algorithm on Mixture Signals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3d08673",
   "metadata": {},
   "outputs": [],
   "source": [
    "NumberofSources = 4\n",
    "NumberofMixtures = NumberofSources\n",
    "model = FastICA(s_dim=NumberofSources, x_dim=NumberofMixtures)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77bb46e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = model.fit_transform(X, n_epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0693ec6",
   "metadata": {},
   "source": [
    "# Calculate Resulting Component SNRs and Overall SINR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea4ae2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_ = signed_and_permutation_corrected_sources(S, Y)\n",
    "coef_ = ((Y_ * S).sum(axis=1) / (Y_ * Y_).sum(axis=1)).reshape(-1, 1)\n",
    "Y_ = coef_ * Y_\n",
    "\n",
    "print(\"Component SNR Values : {}\\n\".format(snr(S, Y_)))\n",
    "\n",
    "SINRwsm = 10 * np.log10(CalculateSINR(Y_, S)[0])\n",
    "\n",
    "print(\"Overall SINR : {}\".format(SINRwsm))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f30e80a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = model.predict(X)\n",
    "Y_ = signed_and_permutation_corrected_sources(S, Y)\n",
    "coef_ = ((Y_ * S).sum(axis=1) / (Y_ * Y_).sum(axis=1)).reshape(-1, 1)\n",
    "Y_ = coef_ * Y_\n",
    "\n",
    "print(\"Component SNR Values : {}\\n\".format(snr(S, Y_)))\n",
    "\n",
    "SINRwsm = 10 * np.log10(CalculateSINR(Y_, S)[0])\n",
    "\n",
    "print(\"Overall SINR : {}\".format(SINRwsm))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e431ba6c",
   "metadata": {},
   "source": [
    "# Vizualize Extracted Signals Compared to Original Sources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "345a80da",
   "metadata": {},
   "outputs": [],
   "source": [
    "subplot_1D_signals(\n",
    "    Y_[:, 0:100],\n",
    "    title=\"Extracted Signals (Sign and Permutation Corrected)\",\n",
    "    figsize=(15.2, 9),\n",
    "    colorcode=None,\n",
    ")\n",
    "subplot_1D_signals(\n",
    "    S[:, 0:100], title=\"Original Signals\", figsize=(15.2, 9), colorcode=None\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
