{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3c422fdb-9ae4-4d38-bff2-b9c8839cbaa1",
   "metadata": {},
   "source": [
    "# Demonstration of unsupervised learning on engineered features using Iterative NMF\n",
    "## The code is implemented in a class structure, which takes in the features previously extracted from the 4D-STEM datasets and now ready to perform unsupervised learning on. We show both the use of both PCA and Iterative NMF. PCA is used to select the initial number of components for the Iterative NMF process.\n",
    "\n",
    "### Last modified September 25th, 2022"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a212f0e7-8335-4c93-8f48-f9b310e87a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "import py4DSTEM\n",
    "from py4DSTEM.visualize import show_image_grid\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Rectangle\n",
    "import matplotlib.ticker as plticker\n",
    "from Featurization import Featurization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2c67c46-6881-4380-9589-d6b50a2119ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Creating colormap\n",
    "import matplotlib.cm as mplcm\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "NUM_COLORS = 200\n",
    "\n",
    "cm = plt.get_cmap('gist_rainbow')\n",
    "cmap = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f32556af-e067-498a-ae54-b0b1d6028518",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "975b0b3e-5d5d-4892-8c50-0cf6ad319966",
   "metadata": {},
   "outputs": [],
   "source": [
    "R_Nx = 100\n",
    "R_Ny = 100\n",
    "Q_Nx = 252\n",
    "Q_Ny = 252"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce5f717-df01-4114-8778-9d63da2628ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ag1 dataset features\n",
    "fp_BP_new = 'data/Ag1_BP.npy'\n",
    "fp_AA_new = 'data/Ag1_AA.npy'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aa4bdc1-75cf-4f21-84c8-a7704a318158",
   "metadata": {},
   "outputs": [],
   "source": [
    "BP_new = np.load(fp_BP_new)\n",
    "aa_new = np.load(fp_AA_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "060a2a25-a21b-4c7a-b57e-82f9ebe9db91",
   "metadata": {},
   "source": [
    "## Initialize classification class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c1aa7d2-0ced-4c3c-82c4-5a0888d0a4b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = ['BP', 'aa']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6d4aa7b-5108-4f78-87e0-616f40df8baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification = Featurization(keys, [BP_new, aa_new])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028cbf7b-3c14-4091-ade5-1d6f7db5be69",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification.MinMaxScaler(keys = ['BP', 'aa'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84f5469d-7821-402f-b9c7-1b54062c3b78",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Bragg Disks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3a4e9e8-427e-46e8-bd36-0074153fbd91",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Here, we will perform only 5 iterations to demonstrate the use of Iterative NMF and consensus clustering based on the parameters:\n",
    "1. Initial number of components = 60 (comps = 60) for Ag1, Ag2, and Ag3\n",
    "2. Iterations performed = 5 (iters_ = 5) -> number of times to run model with different random seeds\n",
    "3. merge threshold = 0.20 (_thresh = 0.20) for Ag1, 0.25 for Ag2, and 0.25 for Ag3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e8d7e34-9213-43b4-a81b-82e3aa678a14",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = ['BP_mms']\n",
    "comps = [60] * len(keys)\n",
    "iters_ = [5] * len(keys)\n",
    "_thresh = [0.20] * len(keys)\n",
    "max_components = dict(zip(keys, comps))\n",
    "merge_thresh = dict(zip(keys, _thresh))\n",
    "iters = dict(zip(keys, iters_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9535488c-0c93-4301-be0c-8748a49f7e1f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "classification.NMF_iterative(\n",
    "    keys = keys,\n",
    "    max_components = max_components,\n",
    "    merge_thresh = merge_thresh,\n",
    "    iters = iters,\n",
    "    return_all = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ebb6ef9-1e81-4082-85a7-11372411b16f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The length of this list should always be the number of iterations performed.\n",
    "print(len(classification.W['BP_mms']))\n",
    "print(classification.W['BP_mms'][0].shape)\n",
    "print(type(classification.W['BP_mms']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8229f65a-4fae-4b4f-a8ca-951431e91b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification.get_class_ims(keys, classification_method = ['nmf'], R_Nx=R_Nx, R_Ny=R_Ny)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a279e6e0-4384-4714-96f8-383577379715",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "## This cell will show the raw individual clusters associated with each iteration\n",
    "# for j in range(len(classification.W['BP_mms'])):\n",
    "#     fig, ax = show_image_grid(lambda i:classification.class_ims['BP_mms'][j][i]**0.5, 4,10, returnfig = True, cmap = 'inferno')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d493abd8-e662-4c21-847a-4e6fd8a7f010",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot raw cluster maps, no post-processing\n",
    "thresh = 0.01\n",
    "for j in range(len(classification.class_ims['BP_mms'])):\n",
    "    fig, (ax) = plt.subplots(figsize = (6,6))\n",
    "    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')\n",
    "    ax.axis('off')\n",
    "\n",
    "    if  len(classification.class_ims['BP_mms'][j]) > 0:\n",
    "        ival_1 = NUM_COLORS / len(classification.class_ims['BP_mms'][j])\n",
    "    else: ival_1 = 1\n",
    "    \n",
    "    for i in range(len(classification.class_ims['BP_mms'][j])):\n",
    "        iterval_1 = np.floor(ival_1 * i).astype(int)\n",
    "        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]\n",
    "        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)\n",
    "        ax.matshow(np.ma.array(\n",
    "            classification.class_ims['BP_mms'][j][i], \n",
    "            mask = classification.class_ims['BP_mms'][j][i]<thresh), cmap = cm)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2374c8c-34c0-4ec2-a158-9bb2e1d9c98f",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification.spatial_separation(keys, size = 25, threshold = thresh, method = 'yen', clean = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a241b199-203d-41f5-9330-de95431a12fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "## This cell will show the spatially separated and filtered individual clusters associated with each iteration\n",
    "# for j in range(len(classification.spatially_separated_ims['BP_mms'])):\n",
    "#     fig, ax = show_image_grid(lambda i:classification.spatially_separated_ims['BP_mms'][j][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84fce305-3c0d-4ce9-b043-852e73745ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot spatially separated and filtered cluster maps\n",
    "for j in range(len(classification.spatially_separated_ims['BP_mms'])):\n",
    "    fig, (ax) = plt.subplots(figsize = (6,6))\n",
    "    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')\n",
    "    ax.axis('off')\n",
    "\n",
    "    if  len(classification.spatially_separated_ims['BP_mms'][j]) > 0:\n",
    "        ival_1 = NUM_COLORS / len(classification.spatially_separated_ims['BP_mms'][j])\n",
    "    else: ival_1 = 1\n",
    "    \n",
    "    for i in range(len(classification.spatially_separated_ims['BP_mms'][j])):\n",
    "        iterval_1 = np.floor(ival_1 * i).astype(int)\n",
    "        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]\n",
    "        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)\n",
    "        ax.matshow(np.ma.array(\n",
    "            classification.spatially_separated_ims['BP_mms'][j][i], \n",
    "            mask = classification.spatially_separated_ims['BP_mms'][j][i]<thresh), cmap = cm)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71739523-3fd2-464c-8407-bafef75be5f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification.consensus(\n",
    "    keys=keys,\n",
    "    threshold = thresh,\n",
    "    location = 'spatially_separated_ims',\n",
    "    method = 'mean',\n",
    "    drop = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7e18db-f324-42fc-8c5e-6aa4c239ba9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "## This cell shows the first 8 matched clusters after performing label correspondence\n",
    "consensus_bins = list(classification.consensus_dict['BP_mms'].keys())\n",
    "for j in range(len(consensus_bins[0:8])):\n",
    "    fig, ax = show_image_grid(lambda i:classification.consensus_dict['BP_mms'][consensus_bins[j]][i]**0.5,\n",
    "                              1, 10, returnfig = True, cmap = 'inferno')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9757c030-3268-4dfc-91c2-a69294cbd5cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "## This cell will show the averaged consensus clusters from the bins in the cell above\n",
    "# fig, ax = show_image_grid(lambda i:classification.consensus_clusters['BP_mms'][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0953fd42-abc0-4ac0-88a1-44dae8a75857",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax) = plt.subplots(figsize = (6,6))\n",
    "ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')\n",
    "ax.axis('off')\n",
    "\n",
    "\n",
    "ival_1 = NUM_COLORS / len(classification.consensus_clusters['BP_mms'])\n",
    "\n",
    "for i in range(len(classification.consensus_clusters['BP_mms'])):\n",
    "    iterval_1 = np.floor(ival_1 * i).astype(int)\n",
    "    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]\n",
    "    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)\n",
    "    ax.matshow(np.ma.array(\n",
    "        classification.consensus_clusters['BP_mms'][i], \n",
    "        mask = classification.consensus_clusters['BP_mms'][i]<thresh), cmap = cm)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d376b6f-4429-4e36-b381-2a1b8523d4f7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Angular Average"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bb5f7f2-3458-42f9-a59f-af66ca6798a7",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Here, we will perform only 5 iterations to demonstrate the use of Iterative NMF and consensus clustering based on the parameters:\n",
    "1. Initial number of components = 50 (comps = 60) for Ag1, Ag2, and Ag3\n",
    "2. Iterations performed = 5 (iters_ = 5) -> number of times to run model with different random seeds\n",
    "3. merge threshold = 0.45 (_thresh = 0.45) for Ag1, 0.40 for Ag2, 0.40 for Ag3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a492b14-5370-46a2-be1d-6458fe815f49",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = ['aa_mms']\n",
    "comps = [50] * len(keys)\n",
    "iters_ = [5] * len(keys)\n",
    "_thresh = [0.45] * len(keys)\n",
    "max_components = dict(zip(keys, comps))\n",
    "merge_thresh = dict(zip(keys, _thresh))\n",
    "iters = dict(zip(keys, iters_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70292d73-d4fd-4a10-b457-8aba17c286f9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "classification.NMF_iterative(\n",
    "    keys = keys,\n",
    "    max_components = max_components,\n",
    "    merge_thresh = merge_thresh,\n",
    "    iters = iters,\n",
    "    return_all = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce110cb4-3d80-416d-a629-9c693eb93f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification.get_class_ims(keys, classification_method = ['nmf'], R_Nx=R_Nx, R_Ny=R_Ny)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db4a020a-678b-47a3-bc9a-dbdde8113303",
   "metadata": {},
   "outputs": [],
   "source": [
    "## This cell will show the raw individual clusters associated with each iteration\n",
    "# for j in range(len(classification.W['aa_mms'])):\n",
    "#     fig, ax = show_image_grid(lambda i:classification.class_ims['aa_mms'][j][i]**0.5, 4,10, returnfig = True, cmap = 'inferno')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3157bef-fd51-4760-8f38-a293e908549a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Plot raw cluster maps, no post-processing\n",
    "\n",
    "thresh = 0.01\n",
    "for j in range(len(classification.class_ims['aa_mms'])):\n",
    "    fig, (ax) = plt.subplots(figsize = (6,6))\n",
    "    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')\n",
    "    ax.axis('off')\n",
    "\n",
    "    if  len(classification.class_ims['aa_mms'][j]) > 0:\n",
    "        ival_1 = NUM_COLORS / len(classification.class_ims['aa_mms'][j])\n",
    "    else: ival_1 = 1\n",
    "    \n",
    "    for i in range(len(classification.class_ims['aa_mms'][j])):\n",
    "        iterval_1 = np.floor(ival_1 * i).astype(int)\n",
    "        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]\n",
    "        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)\n",
    "        ax.matshow(np.ma.array(\n",
    "            classification.class_ims['aa_mms'][j][i], \n",
    "            mask = classification.class_ims['aa_mms'][j][i]<thresh), cmap = cm)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b914ca-978a-49f3-bf79-e856397d38c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "classification.spatial_separation(keys, size = 25, threshold = thresh, method = 'yen', clean = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f8d5a3c-bc05-414e-bdcc-86685c5e9390",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "## This cell will show the spatially separated and filtered individual clusters associated with each iteration\n",
    "# for j in range(len(classification.spatially_separated_ims['aa_mms'])):\n",
    "#     fig, ax = show_image_grid(lambda i:classification.spatially_separated_ims['aa_mms'][j][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1f8528f-e332-4cf5-9fdd-962d2416f480",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for j in range(len(classification.spatially_separated_ims['aa_mms'])):\n",
    "    fig, (ax) = plt.subplots(figsize = (6,6))\n",
    "    ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')\n",
    "    ax.axis('off')\n",
    "\n",
    "    if  len(classification.spatially_separated_ims['aa_mms'][j]) > 0:\n",
    "        ival_1 = NUM_COLORS / len(classification.spatially_separated_ims['aa_mms'][j])\n",
    "    else: ival_1 = 1\n",
    "    \n",
    "    for i in range(len(classification.spatially_separated_ims['aa_mms'][j])):\n",
    "        iterval_1 = np.floor(ival_1 * i).astype(int)\n",
    "        c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]\n",
    "        cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)\n",
    "        ax.matshow(np.ma.array(\n",
    "            classification.spatially_separated_ims['aa_mms'][j][i], \n",
    "            mask = classification.spatially_separated_ims['aa_mms'][j][i]<thresh), cmap = cm)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf7703b4-2cae-498e-ac1c-6efcd2548222",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification.consensus(\n",
    "    keys=keys,\n",
    "    threshold = thresh,\n",
    "    location = 'spatially_separated_ims',\n",
    "    method = 'mean',\n",
    "    drop = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f0f7c42-7dca-4a11-b74e-7ffb846d4835",
   "metadata": {},
   "outputs": [],
   "source": [
    "## This cell shows the first 8 matched clusters after performing label correspondence\n",
    "consensus_bins = list(classification.consensus_dict['aa_mms'].keys())\n",
    "for j in range(len(consensus_bins[0:8])):\n",
    "    fig, ax = show_image_grid(lambda i:classification.consensus_dict['aa_mms'][consensus_bins[j]][i]**0.5,\n",
    "                              1, 10, returnfig = True, cmap = 'inferno')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5fc0d85-236c-4f61-bfa1-414d709823d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "## This cell will show the averaged consensus clusters from the bins in the cell above\n",
    "# fig, ax = show_image_grid(lambda i:classification.consensus_clusters['aa_mms'][i]**0.5, 5,10, returnfig = True, cmap = 'inferno')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7e1163-51a0-4b39-836a-1abab151b358",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax) = plt.subplots(figsize = (6,6))\n",
    "ax.matshow(np.zeros((R_Nx, R_Ny)), cmap = 'gray')\n",
    "ax.axis('off')\n",
    "\n",
    "\n",
    "ival_1 = NUM_COLORS / len(classification.consensus_clusters['aa_mms'])\n",
    "\n",
    "for i in range(len(classification.consensus_clusters['aa_mms'])):\n",
    "    iterval_1 = np.floor(ival_1 * i).astype(int)\n",
    "    c0, c1 = (cmap[iterval_1][0]*0.35,cmap[iterval_1][1]*0.35,cmap[iterval_1][2]*0.35,1), cmap[iterval_1]\n",
    "    cm = mpl.colors.LinearSegmentedColormap.from_list('cmap', [c0,c1], N = 10)\n",
    "    ax.matshow(np.ma.array(\n",
    "        classification.consensus_clusters['aa_mms'][i], \n",
    "        mask = classification.consensus_clusters['aa_mms'][i]<thresh), cmap = cm)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d18f8d5e-3f41-46c9-945a-1ea8454f9ae7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py4dstem_12.13",
   "language": "python",
   "name": "py4dstem_12.13"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
