{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "60fa5139",
   "metadata": {},
   "source": [
    "# Import Required Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a691d44b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"../../src\")\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "import pylab as pl\n",
    "\n",
    "from CorInfoMaxBSS import *\n",
    "from general_utils import *\n",
    "from visualization_utils import *\n",
    "from polytope_utils import *\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "notebook_name = \"General_Polytope\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2c14087",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ = np.random.randint(500000)\n",
    "print(seed_)\n",
    "np.random.seed(seed_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a9bb1aa",
   "metadata": {},
   "source": [
    "# Source Generation and Mixing Scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16f9ff13",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 3\n",
    "N = 500000\n",
    "NumberofSources = dim\n",
    "NumberofMixtures = 6\n",
    "\n",
    "signed_dims = np.array([0, 1])\n",
    "nn_dims = np.array([2])\n",
    "sparse_dims_list = [np.array([0, 1]), np.array([1, 2])]\n",
    "(A, b), V = generate_practical_polytope(dim, signed_dims, nn_dims, sparse_dims_list)\n",
    "S = generate_uniform_points_in_polytope(V, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9f5a88e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 5))\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.scatter(S[0, :], S[1, :])\n",
    "plt.xlabel(\"$S_{:,1}$\", fontsize=25)\n",
    "plt.ylabel(\"$S_{:,2}$\", fontsize=25)\n",
    "plt.grid(linewidth=0.2)\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.scatter(S[0, :], S[2, :])\n",
    "plt.xlabel(\"$S_{:,1}$\", fontsize=25)\n",
    "plt.ylabel(\"$S_{:,3}$\", fontsize=25)\n",
    "plt.grid(linewidth=0.2)\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.scatter(S[1, :], S[2, :])\n",
    "plt.xlabel(\"$S_{:,2}$\", fontsize=25)\n",
    "plt.ylabel(\"$S_{:,3}$\", fontsize=25)\n",
    "plt.grid(linewidth=0.2)\n",
    "\n",
    "plt.suptitle(\"Scatter Plot of Source Components\", fontsize=30)\n",
    "# plt.savefig('Pex_source_components.pdf', format='pdf', dpi = 1200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40568ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Mxr random mixing from i.i.d N(0,1)\n",
    "A = np.random.randn(NumberofMixtures, NumberofSources)\n",
    "X = np.dot(A, S)\n",
    "\n",
    "SNR = 30\n",
    "X, NoisePart = addWGN(X, SNR, return_noise=True)\n",
    "\n",
    "SNRinp = 10 * np.log10(\n",
    "    np.sum(np.mean((X - NoisePart) ** 2, axis=1))\n",
    "    / np.sum(np.mean(NoisePart**2, axis=1))\n",
    ")\n",
    "print(\"The following is the mixture matrix A\")\n",
    "display_matrix(A)\n",
    "print(\"Input SNR is : {}\".format(SNRinp))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8d7386a",
   "metadata": {},
   "source": [
    "# Visualize Generated Sources and Mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09c069e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "subplot_1D_signals(\n",
    "    S[:, 0:100], title=\"Original Signals\", figsize=(15.2, 9), colorcode=None\n",
    ")\n",
    "subplot_1D_signals(\n",
    "    X[:, 0:100], title=\"Mixture Signals\", figsize=(15, 18), colorcode=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4626081",
   "metadata": {},
   "source": [
    "# Algorithm Hyperparameter Selection and Weight Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7319fdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambday = 1 - 1e-1 / 10\n",
    "lambdae = 1 - 1e-1 / 10\n",
    "s_dim = S.shape[0]\n",
    "x_dim = X.shape[0]\n",
    "\n",
    "# Inverse output covariance\n",
    "By = 5 * np.eye(s_dim)\n",
    "# Inverse error covariance\n",
    "Be = 1000 * np.eye(s_dim)\n",
    "\n",
    "debug_iteration_point = 25000\n",
    "model = OnlineCorInfoMax(\n",
    "    s_dim=s_dim,\n",
    "    x_dim=x_dim,\n",
    "    muW=50 * 1e-3,\n",
    "    lambday=lambday,\n",
    "    lambdae=lambdae,\n",
    "    By=By,\n",
    "    Be=Be,\n",
    "    neural_OUTPUT_COMP_TOL=1e-6,\n",
    "    set_ground_truth=True,\n",
    "    S=S,\n",
    "    A=A,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be406768",
   "metadata": {},
   "source": [
    "# Run CorInfoMax Algorithm on Mixture Signals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c178c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with Timer() as t:\n",
    "    model.fit_batch_general_polytope(\n",
    "        X=X,\n",
    "        signed_dims=signed_dims,\n",
    "        nn_dims=nn_dims,\n",
    "        sparse_dims_list=sparse_dims_list,\n",
    "        n_epochs=1,\n",
    "        neural_dynamic_iterations=500,\n",
    "        plot_in_jupyter=True,\n",
    "        neural_lr_start=0.1, # Usually 0.1 works very well\n",
    "        neural_lr_stop=1e-10,\n",
    "        debug_iteration_point=debug_iteration_point,\n",
    "        shuffle=False,\n",
    "    )\n",
    "print(\"Algorithm took %f sec.\" % (t.interval))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab8dd4e5",
   "metadata": {},
   "source": [
    "# Visualize SINR Convergence "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6964847",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams[\"xtick.labelsize\"] = 18\n",
    "mpl.rcParams[\"ytick.labelsize\"] = 18\n",
    "plot_convergence_plot(\n",
    "    model.SIR_list,\n",
    "    xlabel=\"Number of Iterations / {}\".format(debug_iteration_point),\n",
    "    ylabel=\"SINR (dB)\",\n",
    "    title=\"SINR Convergence Plot\",\n",
    "    colorcode=None,\n",
    "    linewidth=1.8,\n",
    ")\n",
    "\n",
    "print(\"Final SIR: {}\".format(np.array(model.SIR_list[-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0613f3a",
   "metadata": {},
   "source": [
    "# Calculate Resulting Component SNRs and Overall SINR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc37de94",
   "metadata": {},
   "outputs": [],
   "source": [
    "Szeromean = S - S.mean(axis=1).reshape(-1, 1)\n",
    "Wf = model.compute_overall_mapping(return_mapping=True)\n",
    "Y_ = Wf @ X\n",
    "Yzeromean = Y_ - Y_.mean(axis=1).reshape(-1, 1)\n",
    "Y_ = model.signed_and_permutation_corrected_sources(Szeromean, Yzeromean)\n",
    "coef_ = ((Y_ * Szeromean).sum(axis=1) / (Y_ * Y_).sum(axis=1)).reshape(-1, 1)\n",
    "Y_ = coef_ * Y_\n",
    "\n",
    "print(\"Component SNR Values : {}\\n\".format(snr_jit(Szeromean, Y_)))\n",
    "\n",
    "SINR = 10 * np.log10(CalculateSINRjit(Y_, Szeromean, False)[0])\n",
    "\n",
    "print(\"Overall SINR : {}\".format(SINR))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88de065e",
   "metadata": {},
   "source": [
    "# Vizualize Extracted Signals Compared to Original Sources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c895ba1",
   "metadata": {},
   "outputs": [],
   "source": [
    "subplot_1D_signals(\n",
    "    Y_[:, 0:100],\n",
    "    title=\"Extracted Signals (Sign and Permutation Corrected)\",\n",
    "    figsize=(15.2, 9),\n",
    "    colorcode=None,\n",
    ")\n",
    "subplot_1D_signals(\n",
    "    Szeromean[:, 0:100], title=\"Original Signals\", figsize=(15.2, 9), colorcode=None\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
