{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b870371e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append(\"../../src\")\n",
    "sys.path.append(\"../../data\")\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "import cv2\n",
    "from IPython import display\n",
    "import pylab as pl\n",
    "\n",
    "from CorInfoMaxBSS import *\n",
    "from general_utils import *\n",
    "from visualization_utils import * \n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "notebook_name = 'Image Separation'\n",
    "\n",
    "# np.random.seed(250)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a8a7d2e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_paths = '../../data/TestImages'\n",
    "images = []\n",
    "for im_dir in os.listdir(image_paths):\n",
    "    try:\n",
    "        images.append(mpimg.imread(os.path.join(image_paths,im_dir)))\n",
    "    except:\n",
    "        pass\n",
    "images = np.array(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9c238871",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(454, 605, 3)\n",
      "(486, 648, 3)\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "could not broadcast input array from shape (944784,) into shape (824010,)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Input \u001b[0;32mIn [3]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m         small_img \u001b[38;5;241m=\u001b[39m cv2\u001b[38;5;241m.\u001b[39mresize(small_img, (\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m0\u001b[39m), fx \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m432\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m403\u001b[39m, fy \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m324\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m302\u001b[39m)\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;28mprint\u001b[39m(small_img\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m---> 12\u001b[0m     small_images[i] \u001b[38;5;241m=\u001b[39m small_img\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,)\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m255\u001b[39m\n\u001b[1;32m     14\u001b[0m small_images\u001b[38;5;241m.\u001b[39mshape\n",
      "\u001b[0;31mValueError\u001b[0m: could not broadcast input array from shape (944784,) into shape (824010,)"
     ]
    }
   ],
   "source": [
    "image_height_and_width = [454, 605]\n",
    "small_images = np.zeros((images.shape[0],image_height_and_width[0]* image_height_and_width[1]* 3))\n",
    "small_to_large_image_size_ratio = 0.15\n",
    "for i in range(images.shape[0]):\n",
    "    small_img = cv2.resize(images[i], # original image\n",
    "                           (0,0), # set fx and fy, not the final size\n",
    "                           fx=small_to_large_image_size_ratio, \n",
    "                           fy=small_to_large_image_size_ratio, \n",
    "                           interpolation=cv2.INTER_NEAREST)\n",
    "    if small_img.shape[0] == 302:\n",
    "        small_img = cv2.resize(small_img, (0,0), fx = 432/403, fy = 324/302)\n",
    "    print(small_img.shape)\n",
    "    small_images[i] = small_img.reshape(-1,)/255\n",
    "\n",
    "small_images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d4f2140",
   "metadata": {},
   "outputs": [],
   "source": [
    "S = small_images[[12,6,3]]\n",
    "Subplot_RGB_images(S, imsize = image_height_and_width, height = 4, width = 18)\n",
    "display_matrix(np.corrcoef(S))\n",
    "# plt.savefig('Original_Images.pdf', format = 'pdf', dpi = 1500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53b8cc6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(650)\n",
    "NumberofMixtures = 5\n",
    "NumberofSources = 3\n",
    "\n",
    "A = np.random.randn(NumberofMixtures,NumberofSources)\n",
    "X = np.dot(A,S)\n",
    "SNR = 40\n",
    "\n",
    "X, NoisePart = addWGN(X, SNR, return_noise=True)\n",
    "\n",
    "SNRinp = 10 * np.log10(\n",
    "    np.sum(np.mean((X - NoisePart) ** 2, axis=1))\n",
    "    / np.sum(np.mean(NoisePart**2, axis=1))\n",
    ")\n",
    "\n",
    "X_ = ZeroOneNormalizeColumns(X.T).T\n",
    "print(\"The following is the mixture matrix A\")\n",
    "display_matrix(A)\n",
    "print(\"Input SNR is : {}\".format(SNRinp))\n",
    "\n",
    "print(\"Row standard deviation of mixtures : {}\".format(X.std(1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3517060",
   "metadata": {},
   "outputs": [],
   "source": [
    "Subplot_RGB_images(X_, imsize = image_height_and_width, height = 3, width = 18)\n",
    "# plt.savefig('Mixture_Images.pdf', format = 'pdf', dpi = 1500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42be0f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subplot_1D_signals(S[:,1000:1200], title = 'Original Signals', figsize = (15.2,9), colorcode = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ae66aad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.figure()\n",
    "# plt.scatter(S[0,:], S[1,:])\n",
    "# plt.figure()\n",
    "# plt.scatter(S[1,:], S[2,:])\n",
    "# plt.figure()\n",
    "# plt.scatter(S[0,:], S[2,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31e09efc",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambday = 1 - 1e-1/15\n",
    "lambdae = 0.5\n",
    "\n",
    "s_dim = S.shape[0]\n",
    "x_dim = X.shape[0]\n",
    "\n",
    "n_samples = X.shape[1]\n",
    "n_iter = n_samples\n",
    "ppf = X.shape[1]\n",
    "Wf_arr = []\n",
    "\n",
    "# Inverse output covariance\n",
    "By = 1 * np.eye(s_dim)\n",
    "# Inverse error covariance\n",
    "Be = 100 * np.eye(s_dim)\n",
    "\n",
    "debug_iteration_point = 2500\n",
    "model = CorInfoMaxVideoSeparation(  s_dim = s_dim, x_dim = x_dim, muW = 50*1e-3, lambday = lambday,\n",
    "                                    lambdae = lambdae, By = By, Be = Be, neural_OUTPUT_COMP_TOL = 1e-6,\n",
    "                                    set_ground_truth = True, S = S, A = A)\n",
    "\n",
    "model.seperate_videos( Wf_list = Wf_arr, n_pixel_per_frame = ppf, X = X.reshape(1,NumberofMixtures, -1), \n",
    "                       n_iter = n_iter, neural_dynamic_iterations = 500,\n",
    "                       plot_in_jupyter = True, neural_lr_start = 0.5,\n",
    "                       neural_lr_stop = 0.001, debug_iteration_point = debug_iteration_point, \n",
    "                       shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0363945",
   "metadata": {},
   "outputs": [],
   "source": [
    "Wf = -1*model.compute_overall_mapping(return_mapping = True)\n",
    "OverallMatrix = Wf @ A\n",
    "perm = np.argmax(OverallMatrix, axis = 0)\n",
    "Y = (Wf @ X)[perm]\n",
    "Y = ZeroOneNormalizeColumns(Y.T).T\n",
    "PSNR_levels = []\n",
    "for kk in range(S.shape[0]):\n",
    "    PSNR_levels.append(psnr(S[kk], Y[kk]))\n",
    "\n",
    "SINR = 10*np.log10(CalculateSINRjit(Y, S, False)[0])\n",
    "print(\"Component PSNR Values : {}\\n\".format(PSNR_levels))\n",
    "print(\"Overall SINR : {}\".format(SINR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "906ca41e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Subplot_RGB_images(ZeroOneNormalizeColumns(Y.T).T, imsize = image_height_and_width, height = 4, width = 18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8fe6db8",
   "metadata": {},
   "outputs": [],
   "source": [
    "Subplot_RGB_images(np.clip(Y, 0,1), imsize = image_height_and_width, height = 4, width = 18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35821169",
   "metadata": {},
   "outputs": [],
   "source": [
    "Subplot_RGB_images(Y, imsize = image_height_and_width, height = 4, width = 18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "534cf670",
   "metadata": {},
   "outputs": [],
   "source": [
    "Subplot_RGB_images(S, imsize = image_height_and_width, height = 4, width = 18)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
