{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a25c2be1",
   "metadata": {},
   "source": [
    "# Import Required Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d94d44",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "os.chdir(\"..\")\n",
    "os.chdir(\"..\")\n",
    "os.chdir(\"./src\")\n",
    "# sys.path.append(\"./src\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96d233ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "import pylab as pl\n",
    "\n",
    "from PMF import *\n",
    "from general_utils import *\n",
    "from visualization_utils import *\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "notebook_name = \"Nonnegative_Antisparse_Copula\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d0fcdf8",
   "metadata": {},
   "source": [
    "# Source Generation and Mixing Scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f347c550",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10000\n",
    "NumberofSources = 5\n",
    "NumberofMixtures = 10\n",
    "Sgt = generate_correlated_copula_sources(\n",
    "    rho=0.7,\n",
    "    df=4,\n",
    "    n_sources=NumberofSources,\n",
    "    size_sources=N,\n",
    "    decreasing_correlation=True,\n",
    ")\n",
    "\n",
    "print(\"The following is the correlation matrix of sources\")\n",
    "display_matrix(np.corrcoef(Sgt))\n",
    "\n",
    "# Generate Mxr random mixing from i.i.d N(0,1)\n",
    "Agt = np.random.randn(NumberofMixtures, NumberofSources)\n",
    "Y = np.dot(Agt, Sgt)\n",
    "\n",
    "SNR = 30\n",
    "Y, NoisePart = addWGN(Y, SNR, return_noise=True)\n",
    "\n",
    "SNRinp = 10 * np.log10(\n",
    "    np.sum(np.mean((Y - NoisePart) ** 2, axis=1))\n",
    "    / np.sum(np.mean(NoisePart**2, axis=1))\n",
    ")\n",
    "print(\"The following is the mixture matrix A\")\n",
    "display_matrix(Agt)\n",
    "print(\"Input SNR is : {}\".format(SNRinp))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f161ed1",
   "metadata": {},
   "source": [
    "# Visualize Generated Sources and Mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fac953e",
   "metadata": {},
   "outputs": [],
   "source": [
    "subplot_1D_signals(\n",
    "    Sgt[:, 0:100], title=\"Original Signals\", figsize=(15.2, 9), colorcode=None\n",
    ")\n",
    "subplot_1D_signals(\n",
    "    Y[:, 0:100], title=\"Mixture Signals\", figsize=(15, 18), colorcode=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08db0906",
   "metadata": {},
   "source": [
    "# Algorithm Hyperparameter Selection and Weight Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83583247",
   "metadata": {},
   "outputs": [],
   "source": [
    "s_dim = Sgt.shape[0]\n",
    "y_dim = Y.shape[0]\n",
    "debug_iteration_point = 25000\n",
    "model = PMFv2(s_dim=s_dim, y_dim=y_dim, set_ground_truth=True, Sgt=Sgt, Agt=Agt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e93d0f26",
   "metadata": {},
   "source": [
    "# Run PMF Algorithm on Mixture Signals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb6210bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit_batch_nnantisparse(\n",
    "    Y,\n",
    "    n_iterations=100000,\n",
    "    step_size_scale=100,\n",
    "    debug_iteration_point=debug_iteration_point,\n",
    "    plot_in_jupyter=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47b79106",
   "metadata": {},
   "source": [
    "# Calculate Resulting Component SNRs and Overall SINR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "072a1a94",
   "metadata": {},
   "outputs": [],
   "source": [
    "Spmf = model.S\n",
    "Spmf = signed_and_permutation_corrected_sources(Sgt, Spmf)\n",
    "coef_ = ((Spmf * Sgt).sum(axis=1) / (Spmf * Spmf).sum(axis=1)).reshape(-1, 1)\n",
    "Spmf = coef_ * Spmf\n",
    "\n",
    "print(\"Component SNR Values : {}\\n\".format(snr_jit(Sgt, Spmf)))\n",
    "\n",
    "SINR = 10 * np.log10(CalculateSINRjit(Spmf, Sgt)[0])\n",
    "\n",
    "print(\"Overall SINR : {}\".format(SINR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5560b1df",
   "metadata": {},
   "outputs": [],
   "source": [
    "WfPMF = model.W\n",
    "Spmf = WfPMF @ Y\n",
    "Spmf = signed_and_permutation_corrected_sources(Sgt, Spmf)\n",
    "coef_ = ((Spmf * Sgt).sum(axis=1) / (Spmf * Spmf).sum(axis=1)).reshape(-1, 1)\n",
    "Spmf = coef_ * Spmf\n",
    "\n",
    "print(\"Component SNR Values : {}\\n\".format(snr_jit(Sgt, Spmf)))\n",
    "\n",
    "SINR = 10 * np.log10(CalculateSINRjit(Spmf, Sgt)[0])\n",
    "\n",
    "print(\"Overall SINR : {}\".format(SINR))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99e8c11e",
   "metadata": {},
   "source": [
    "# Vizualize Extracted Signals Compared to Original Sources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a10fd9e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "subplot_1D_signals(\n",
    "    Spmf[:, 0:100],\n",
    "    title=\"Extracted Signals (Sign and Permutation Corrected)\",\n",
    "    figsize=(15.2, 9),\n",
    "    colorcode=None,\n",
    ")\n",
    "subplot_1D_signals(\n",
    "    Sgt[:, 0:100], title=\"Original Signals\", figsize=(15.2, 9), colorcode=None\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
