{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "obtain_subtypes.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Subtyping pysustain notebook\n",
        "--\n",
        "Built to replicate paper results in supbtyping given available data\n",
        "\n",
        "---\n",
        "Heavily based on tutorial code: https://github.com/ucl-pond/pySuStaIn/blob/master/notebooks/SuStaInWorkshop.ipynb\n",
        "\n",
        "---\n",
        "Important: file names were based on local environments, before running notebook please ensure z-scores for regions of interest are in the local directory and file name is accurate. File output name can also be modified as desired for running multiple times\n"
      ],
      "metadata": {
        "id": "9FsZwXm_IaGs"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xLtQFHaeN1Sl"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import shutil\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pickle\n",
        "from pathlib import Path\n",
        "import sklearn.model_selection\n",
        "import pandas as pd\n",
        "import pylab\n",
        "import sys"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install git+https://github.com/ucl-pond/pySuStaIn"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rAxzP7QpP6_u",
        "outputId": "ecaf7839-cf5f-4183-8bbd-4dd7d30713a7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting git+https://github.com/ucl-pond/pySuStaIn\n",
            "  Cloning https://github.com/ucl-pond/pySuStaIn to /tmp/pip-req-build-52lcawnb\n",
            "  Running command git clone -q https://github.com/ucl-pond/pySuStaIn /tmp/pip-req-build-52lcawnb\n",
            "Collecting awkde@ git+https://github.com/noxtoby/awkde.git\n",
            "  Cloning https://github.com/noxtoby/awkde.git to /tmp/pip-install-l5fh3stw/awkde_d59315f5aca147c4ba762bbdf863c9ee\n",
            "  Running command git clone -q https://github.com/noxtoby/awkde.git /tmp/pip-install-l5fh3stw/awkde_d59315f5aca147c4ba762bbdf863c9ee\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting kde_ebm@ git+https://github.com/ucl-pond/kde_ebm.git\n",
            "  Cloning https://github.com/ucl-pond/kde_ebm.git to /tmp/pip-install-l5fh3stw/kde-ebm_7bc75a565ea5447bb40a4b235f487832\n",
            "  Running command git clone -q https://github.com/ucl-pond/kde_ebm.git /tmp/pip-install-l5fh3stw/kde-ebm_7bc75a565ea5447bb40a4b235f487832\n",
            "Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from pySuStaIn==0.1) (1.19.5)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from pySuStaIn==0.1) (1.4.1)\n",
            "Requirement already satisfied: matplotlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from pySuStaIn==0.1) (3.2.2)\n",
            "Collecting pathos\n",
            "  Downloading pathos-0.2.8-py2.py3-none-any.whl (81 kB)\n",
            "\u001b[K     |████████████████████████████████| 81 kB 7.9 MB/s \n",
            "\u001b[?25hRequirement already satisfied: sklearn in /usr/local/lib/python3.7/dist-packages (from pySuStaIn==0.1) (0.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from pySuStaIn==0.1) (1.1.5)\n",
            "Collecting pybind11\n",
            "  Using cached pybind11-2.9.0-py2.py3-none-any.whl (210 kB)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from pySuStaIn==0.1) (4.62.3)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from awkde@ git+https://github.com/noxtoby/awkde.git->pySuStaIn==0.1) (1.0.2)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from awkde@ git+https://github.com/noxtoby/awkde.git->pySuStaIn==0.1) (0.16.0)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.0.0->pySuStaIn==0.1) (2.8.2)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.0.0->pySuStaIn==0.1) (0.11.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.0.0->pySuStaIn==0.1) (1.3.2)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.0.0->pySuStaIn==0.1) (3.0.6)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib>=3.0.0->pySuStaIn==0.1) (1.15.0)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->pySuStaIn==0.1) (2018.9)\n",
            "Collecting ppft>=1.6.6.4\n",
            "  Downloading ppft-1.6.6.4-py3-none-any.whl (65 kB)\n",
            "\u001b[K     |████████████████████████████████| 65 kB 3.7 MB/s \n",
            "\u001b[?25hRequirement already satisfied: multiprocess>=0.70.12 in /usr/local/lib/python3.7/dist-packages (from pathos->pySuStaIn==0.1) (0.70.12.2)\n",
            "Collecting pox>=0.3.0\n",
            "  Downloading pox-0.3.0-py2.py3-none-any.whl (30 kB)\n",
            "Requirement already satisfied: dill>=0.3.4 in /usr/local/lib/python3.7/dist-packages (from pathos->pySuStaIn==0.1) (0.3.4)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->awkde@ git+https://github.com/noxtoby/awkde.git->pySuStaIn==0.1) (1.1.0)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->awkde@ git+https://github.com/noxtoby/awkde.git->pySuStaIn==0.1) (3.0.0)\n",
            "Building wheels for collected packages: pySuStaIn, awkde, kde-ebm\n",
            "  Building wheel for pySuStaIn (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pySuStaIn: filename=pySuStaIn-0.1-py3-none-any.whl size=56671 sha256=e7912126b4e74f332388d09b76465b5d166ebec9753500698f02ef6c6982add0\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-q6tq19_6/wheels/f6/20/a4/80afe31cb7283ac3f2aaa5ce62332686ed34bda7d33a17292e\n",
            "  Building wheel for awkde (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for awkde: filename=awkde-0.1-cp37-cp37m-linux_x86_64.whl size=65958 sha256=9a2157d3d6786f25d1f6c7bdd8edbec90c901bb9e154b1b2fa05836ca6a535aa\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-q6tq19_6/wheels/42/d8/a4/66c5c60a3b2e00fb8b51f6d96935a2844b4370ac118bf8ea4f\n",
            "  Building wheel for kde-ebm (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for kde-ebm: filename=kde_ebm-0.0.2-py3-none-any.whl size=79080 sha256=38c37a9215e7414457a5acfc96e25f3f2b3b215af8732a1ce73e36f3c41f8c6d\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-q6tq19_6/wheels/4f/d5/fe/d58d72ed0b987048a9a9627fe4f12892402dddd67a8e8e0bec\n",
            "Successfully built pySuStaIn awkde kde-ebm\n",
            "Installing collected packages: pybind11, ppft, pox, pathos, kde-ebm, awkde, pySuStaIn\n",
            "Successfully installed awkde-0.1 kde-ebm-0.0.2 pathos-0.2.8 pox-0.3.0 ppft-1.6.6.4 pySuStaIn-0.1 pybind11-2.9.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import pySuStaIn"
      ],
      "metadata": {
        "id": "y9Jn1y6EkfN7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_filename = \"rois_zscores.csv\" #put relative path here, designed in colab\n",
        "dataset_name = 'ADSubtypes'\n",
        "output_folder = os.path.join(os.getcwd(), dataset_name)"
      ],
      "metadata": {
        "id": "5vwFS1XGTam8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data = pd.read_csv(data_filename, index_col= 0)\n",
        "data.head()\n",
        "data[data<0] = 0.01 #made to adjust for data requirements"
      ],
      "metadata": {
        "id": "ODQ-j5LMTRjL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "biomarkers = data.columns"
      ],
      "metadata": {
        "id": "KSdUTDDrTo_v",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9e1a1534-ce7a-44ba-df83-e79d78192f47"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Index(['Temporal LH', 'Temporal RH', 'Parietal LH', 'Parietal RH',\n",
            "       'Occipital LH', 'Occipital RH', 'Frontal LH', 'Frontal RH', 'MTL LH',\n",
            "       'MTL RH'],\n",
            "      dtype='object')\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "N = len(biomarkers) #allows for adjustment of ROIs if needed\n",
        "SuStaInLabels = biomarkers\n",
        "Z_vals = np.array([[2,5,10]]*N)     # Z-scores for each biomarker\n",
        "Z_max  = np.array([25]*N)"
      ],
      "metadata": {
        "id": "XI7ek9Z0TvZA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "N_startpoints = 10\n",
        "N_S_max = 4\n",
        "N_iterations_MCMC = int(1e4)\n",
        "\n",
        "\n",
        "# Initiate the SuStaIn object\n",
        "sustain_input = pySuStaIn.ZscoreSustain(\n",
        "                              data.values,\n",
        "                              Z_vals,\n",
        "                              Z_max,\n",
        "                              SuStaInLabels,\n",
        "                              N_startpoints,\n",
        "                              N_S_max, \n",
        "                              N_iterations_MCMC, \n",
        "                              output_folder, \n",
        "                              dataset_name, \n",
        "                              True)"
      ],
      "metadata": {
        "id": "HpIcPygIU4fm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "if os.path.exists(output_folder):\n",
        "    shutil.rmtree(output_folder)\n",
        "if not os.path.isdir(output_folder): #allows results to be overwritten as needed\n",
        "    os.mkdir(output_folder)"
      ],
      "metadata": {
        "id": "GpGgxuBmVWo_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# runs the sustain algorithm with the inputs set in sustain_input above\n",
        "samples_sequence,   \\\n",
        "samples_f,          \\\n",
        "ml_subtype,         \\\n",
        "prob_ml_subtype,    \\\n",
        "ml_stage,           \\\n",
        "prob_ml_stage,      \\\n",
        "prob_subtype_stage  = sustain_input.run_sustain_algorithm()"
      ],
      "metadata": {
        "id": "Ozo9I35EVe7Y",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "062bd82a-a551-4ac9-cb03-217f43e8213d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Found pickle file: /content/drive/MyDrive/ProjectX/pickle_files/SubtypesAtLast_subtype0.pickle. Using pickled variables for 0 subtype.\n",
            "Found pickle file: /content/drive/MyDrive/ProjectX/pickle_files/SubtypesAtLast_subtype1.pickle. Using pickled variables for 1 subtype.\n",
            "Found pickle file: /content/drive/MyDrive/ProjectX/pickle_files/SubtypesAtLast_subtype2.pickle. Using pickled variables for 2 subtype.\n",
            "Found pickle file: /content/drive/MyDrive/ProjectX/pickle_files/SubtypesAtLast_subtype3.pickle. Using pickled variables for 3 subtype.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Quick and Dirty: just applying subtypes from paper"
      ],
      "metadata": {
        "id": "fxGvBVylpVz2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "s = 0\n",
        "pickle_filename_s = output_folder + '/pickle_files/' + dataset_name + '_subtype' + str(s) + '.pickle'\n",
        "pk = pd.read_pickle(pickle_filename_s)\n",
        "\n",
        "for variable in ['ml_subtype', # the assigned subtype\n",
        "                 'prob_ml_subtype', # the probability of the assigned subtype\n",
        "                 'ml_stage', # the assigned stage \n",
        "                 'prob_ml_stage',]: # the probability of the assigned stage\n",
        "    \n",
        "    # add SuStaIn output to dataframe\n",
        "    data.loc[:,variable] = pk[variable] \n",
        "\n",
        "# let's also add the probability for each subject of being each subtype\n",
        "for i in range(s):\n",
        "    data.loc[:,'prob_S%s'%i] = pk['prob_subtype'][:,i]\n",
        "data.head()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "40_2Wnq3pVOH",
        "outputId": "cfdfe05b-7c0f-484b-8e79-90b6f1b23096"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-aae2c7b4-87b4-4a7f-a863-d56ee5e7d700\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Temporal LH</th>\n",
              "      <th>Temporal RH</th>\n",
              "      <th>Parietal LH</th>\n",
              "      <th>Parietal RH</th>\n",
              "      <th>Occipital LH</th>\n",
              "      <th>Occipital RH</th>\n",
              "      <th>Frontal LH</th>\n",
              "      <th>Frontal RH</th>\n",
              "      <th>MTL LH</th>\n",
              "      <th>MTL RH</th>\n",
              "      <th>ml_subtype</th>\n",
              "      <th>prob_ml_subtype</th>\n",
              "      <th>ml_stage</th>\n",
              "      <th>prob_ml_stage</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>21</th>\n",
              "      <td>1.247968</td>\n",
              "      <td>1.325007</td>\n",
              "      <td>0.354052</td>\n",
              "      <td>0.205143</td>\n",
              "      <td>1.389530</td>\n",
              "      <td>1.084380</td>\n",
              "      <td>1.154372</td>\n",
              "      <td>0.990390</td>\n",
              "      <td>0.647129</td>\n",
              "      <td>0.599295</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.420074</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>31</th>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.195678</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.107418</td>\n",
              "      <td>0.104943</td>\n",
              "      <td>0.689582</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.950795</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>31</th>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.010000</td>\n",
              "      <td>0.203146</td>\n",
              "      <td>0.438188</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.961475</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>56</th>\n",
              "      <td>0.768775</td>\n",
              "      <td>0.877256</td>\n",
              "      <td>0.803005</td>\n",
              "      <td>0.916798</td>\n",
              "      <td>1.003151</td>\n",
              "      <td>0.916630</td>\n",
              "      <td>0.936634</td>\n",
              "      <td>1.065935</td>\n",
              "      <td>1.073632</td>\n",
              "      <td>1.131283</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.488860</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>56</th>\n",
              "      <td>0.470495</td>\n",
              "      <td>0.566332</td>\n",
              "      <td>0.992443</td>\n",
              "      <td>1.117978</td>\n",
              "      <td>1.243071</td>\n",
              "      <td>1.066782</td>\n",
              "      <td>0.739452</td>\n",
              "      <td>0.871078</td>\n",
              "      <td>0.711645</td>\n",
              "      <td>0.680519</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.453924</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-aae2c7b4-87b4-4a7f-a863-d56ee5e7d700')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-aae2c7b4-87b4-4a7f-a863-d56ee5e7d700 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-aae2c7b4-87b4-4a7f-a863-d56ee5e7d700');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "    Temporal LH  Temporal RH  ...  ml_stage  prob_ml_stage\n",
              "21     1.247968     1.325007  ...       1.0       0.420074\n",
              "31     0.010000     0.010000  ...       0.0       0.950795\n",
              "31     0.010000     0.010000  ...       0.0       0.961475\n",
              "56     0.768775     0.877256  ...       1.0       0.488860\n",
              "56     0.470495     0.566332  ...       0.0       0.453924\n",
              "\n",
              "[5 rows x 14 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# make current subtypes correspond to 1-4\n",
        "data.loc[:,'ml_subtype'] += 1\n",
        "\n",
        "# convert \"Stage 0\" subjects to subtype 0, conistent with no disease\n",
        "data.loc[data.ml_stage==0,'ml_subtype'] = 0"
      ],
      "metadata": {
        "id": "Ytq8l33hqnJT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#can replace with desired output path and file name\n",
        "data.to_csv(\"subtypes.csv)"
      ],
      "metadata": {
        "id": "8otzBovTq-zP"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}