{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook is for preprocessing PTBXL, CPSC2018, and CSN datasets for finetuning tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import wfdb\n",
    "import os\n",
    "import ast\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "from pprint import pprint\n",
    "from tqdm import tqdm\n",
    "from scipy.ndimage import zoom\n",
    "from scipy.io import loadmat\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the split file path to store your processed csv file\n",
    "split_path = ''\n",
    "# set the meta path for the raw ecg you download\n",
    "meta_path = ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing PTB-XL dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Since PTB-XL provide the offical split, we will use the offical split for the finetune dataset.\n",
    "The offical preprocess code is shown in the orignal paper: https://www.nature.com/articles/s41597-020-0495-6\n",
    "We also list the preprocessed csv file in MERL/finetune/data_split/ptbxl\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing CPSC2018 Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "This dataset provide raw file in .mat format.\n",
    "We first convert the .mat file to .hea and .dat file using the wfdb package.\n",
    "Then we downsample the data to 100Hz and 500Hz.\n",
    "All information of this dataset can be found in: http://2018.icbeb.org/Challenge.html\n",
    "'''\n",
    "\n",
    "# here is your original data folder, you should download the data from the website\n",
    "ori_data_folder = os.path.join(meta_path, 'icbeb2018')\n",
    "\n",
    "# here is the output folder to store the preprocessed data\n",
    "output_folder = os.path.join(meta_path, 'icbeb2018')\n",
    "output_datafolder_100 = output_folder+ '/records100/'\n",
    "output_datafolder_500 = output_folder+ '/records500/'\n",
    "if not os.path.exists(output_folder):\n",
    "    os.makedirs(output_folder)\n",
    "else:\n",
    "    print('The folder already exists')\n",
    "if not os.path.exists(output_datafolder_100):\n",
    "    os.makedirs(output_datafolder_100)\n",
    "else:\n",
    "    print('The folder already exists')\n",
    "if not os.path.exists(output_datafolder_500):\n",
    "    os.makedirs(output_datafolder_500)\n",
    "else:\n",
    "    print('The folder already exists')\n",
    "\n",
    "# function to store 12 leads ECG data as wfdb format\n",
    "def store_as_wfdb(signame, data, sigfolder, fs):\n",
    "    channel_itos=['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']\n",
    "    wfdb.wrsamp(signame,\n",
    "                fs=fs,\n",
    "                sig_name=channel_itos, \n",
    "                p_signal=data,\n",
    "                units=['mV']*len(channel_itos),\n",
    "                fmt = ['16']*len(channel_itos), \n",
    "                write_dir=sigfolder)  \n",
    "\n",
    "# load the reference csv file\n",
    "reference_path = os.path.join(output_folder, 'REFERENCE.csv')\n",
    "df_reference = pd.read_csv(reference_path)\n",
    "\n",
    "# define the label dictionary\n",
    "# label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD_', 9:'STE_'}\n",
    "label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD', 9:'STE'}\n",
    "\n",
    "data = {'ecg_id':[], 'filename':[], 'validation':[], 'age':[], 'sex':[], 'scp_codes':[]}\n",
    "\n",
    "# read all .mat files from the folder then convert to .hea and .dat files\n",
    "ecg_counter = 0\n",
    "for folder in ['all_data']:\n",
    "    filenames = os.listdir(os.path.join(ori_data_folder, folder))\n",
    "    for filename in tqdm(filenames):\n",
    "        if filename.split('.')[1] == 'mat':\n",
    "            ecg_counter += 1\n",
    "            name = filename.split('.')[0]\n",
    "\n",
    "            sex, age, sig = loadmat(ori_data_folder + '/' + folder + '/' + filename)['ECG'][0][0]\n",
    "            data['ecg_id'].append(ecg_counter)\n",
    "            data['filename'].append(name)\n",
    "            data['validation'].append(False)\n",
    "            data['age'].append(age[0][0])\n",
    "            data['sex'].append(1 if sex[0] == 'Male' else 0)\n",
    "            labels = df_reference[df_reference.Recording == name][['First_label' ,'Second_label' ,'Third_label']].values.flatten()\n",
    "            labels = labels[~np.isnan(labels)].astype(int)\n",
    "            data['scp_codes'].append({label_dict[key]:1 for key in labels})\n",
    "\n",
    "            # # resample to 500 hz data\n",
    "            # store_as_wfdb(str(ecg_counter), sig.T, output_datafolder_500, 500)\n",
    "            # # resample to 100 hz data\n",
    "            # down_sig = np.array([zoom(channel, .2) for channel in sig])\n",
    "            # store_as_wfdb(str(ecg_counter), down_sig.T, output_datafolder_100, 100)\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "df['patient_id'] = df.ecg_id\n",
    "# df = stratisfy_df(df, 'strat_fold')\n",
    "# df.to_csv(output_folder+'icbeb_database.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make the patient_id column the first column\n",
    "cols = list(df.columns)\n",
    "cols = [cols[-1]] + cols[:-1]\n",
    "switched_df = df[cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract all unique labels from the 'scp_codes' column\n",
    "# all_labels = set()\n",
    "# for item in switched_df['scp_codes']:\n",
    "#     all_labels.update(item.keys())\n",
    "\n",
    "all_labels = ['AFIB', 'VPC', 'NORM', '1AVB', 'CRBBB', 'STE', 'PAC', 'CLBBB', 'STD']\n",
    "\n",
    "\n",
    "# # Create new columns for each label\n",
    "for label in all_labels:\n",
    "    switched_df[label] = switched_df['scp_codes'].apply(lambda x: x.get(label, 0))\n",
    "\n",
    "cols = list(switched_df.columns)\n",
    "print(cols)\n",
    "# cols[-1] = 'STD'\n",
    "# cols[-4] = 'STE'\n",
    "# # replace columns name\n",
    "# switched_df.columns = cols\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split train test val\n",
    "train_df, test_df = train_test_split(switched_df, test_size=0.2, random_state=42)\n",
    "train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)\n",
    "\n",
    "print(f'train_df shape: {train_df.shape}')\n",
    "print(f'val_df shape: {val_df.shape}')\n",
    "print(f'test_df shape: {test_df.shape}')\n",
    "\n",
    "# save the csv files\n",
    "# train_df.to_csv(split_path+'icbeb_train.csv', index=False)\n",
    "# val_df.to_csv(split_path+'icbeb_val.csv', index=False)\n",
    "# test_df.to_csv(split_path+'icbeb_test.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing CSN Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "For all details of the dataset, please refer to: https://physionet.org/content/ecg-arrhythmia/1.0.0/\n",
    "'''\n",
    "\n",
    "your_path = meta_path\n",
    "\n",
    "data_path = f'{your_path}chapman/WFDBRecords'\n",
    "folders = os.listdir(data_path)\n",
    "num_folders = len(folders)\n",
    "folders = sorted(folders)\n",
    "folders = [os.path.join(data_path, f) for f in folders]\n",
    "folders = [f for f in folders if os.path.isdir(f)]\n",
    "\n",
    "dict_with_empty_lists = {f\"{i:02d}\": [] for i in range(1, 47)}\n",
    "for i, folder in enumerate(folders):\n",
    "    subfolders = os.listdir(folder)\n",
    "    subfolders = sorted(subfolders)\n",
    "    subfolders = [os.path.join(folder, f) for f in subfolders]\n",
    "    subfolders = [f for f in subfolders if os.path.isdir(f)]\n",
    "    dict_with_empty_lists[f\"{i+1:02d}\"] = subfolders\n",
    "\n",
    "\n",
    "# place this '/raid/cl522/ecg-text/downstream' with your own path\n",
    "for key in dict_with_empty_lists.keys():\n",
    "    dict_with_empty_lists[key] = [x.replace(f'{your_path}', '') for x in dict_with_empty_lists[key]]\n",
    "\n",
    "def read_header_file(file_path):\n",
    "    with open(file_path, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "        header_info = [line.strip() for line in lines]\n",
    "    return header_info\n",
    "\n",
    "df = {'ecg_path': [], \n",
    "      'age': [], \n",
    "      'diagnose': []}\n",
    "\n",
    "ref = pd.read_csv(f'{your_path}chapman/ConditionNames_SNOMED-CT.csv')\n",
    "ref['Snomed_CT'] = ref['Snomed_CT'].astype(str)\n",
    "\n",
    "# count the number of mat file in each folder\n",
    "total_files = 0\n",
    "for key in tqdm(dict_with_empty_lists.keys()):\n",
    "    for folder in dict_with_empty_lists[key]:\n",
    "        files = os.listdir(f'{your_path}'+folder)\n",
    "        mat_files = [f for f in files if f.endswith('.mat')]\n",
    "        hea_files = [f for f in files if f.endswith('.hea')]\n",
    "        \n",
    "        mat_files_path = [os.path.join(f'{your_path}', folder, f) for f in mat_files]\n",
    "        hea_files_path = [os.path.join(f'{your_path}', folder, f) for f in hea_files]\n",
    "        mat_files_path = sorted(mat_files_path)\n",
    "        hea_files_path = sorted(hea_files_path)\n",
    "\n",
    "        for file, hea_file in zip(mat_files_path, hea_files_path):\n",
    "            mat = loadmat(file)\n",
    "            ecg = mat['val']\n",
    "            hea = read_header_file(hea_file)\n",
    "            \n",
    "            df['ecg_path'].append(file)\n",
    "            df['age'].append(hea[0].split()[1])\n",
    "            \n",
    "            try:\n",
    "                diagnose_str = []\n",
    "                Dx_idx = [i for i, s in enumerate(hea) if 'Dx' in s]\n",
    "                diagnose_code = hea[Dx_idx[0]].split()[1]\n",
    "                diagnose_code = diagnose_code.split(',')\n",
    "                for i in range(len(diagnose_code)):\n",
    "                    diagnose = ref[ref['Snomed_CT'] == diagnose_code[i]]['Acronym Name']\n",
    "                    diagnose = diagnose.values[0]\n",
    "                    diagnose_str.append(diagnose)\n",
    "                diagnose_str = ','.join(diagnose_str)\n",
    "                df['diagnose'].append(diagnose_str)\n",
    "            except:\n",
    "                df['diagnose'].append('Unknown')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df = pd.DataFrame(df)\n",
    "new_df = new_df[new_df['diagnose'] != 'Unknown']\n",
    "new_df.reset_index(inplace=True, drop=True)\n",
    "\n",
    "unique_labels = []\n",
    "for labels in new_df['diagnose']:\n",
    "    labels = labels.split(',')\n",
    "    unique_labels.extend(labels)\n",
    "\n",
    "unique_labels = list(set(unique_labels))\n",
    "# Create new columns for each unique label\n",
    "for label in unique_labels:\n",
    "    new_df[label] = new_df['diagnose'].apply(lambda x: 1 if label in x else 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# count the number of sample for each label\n",
    "label_count = {}\n",
    "for label in unique_labels:\n",
    "    label_count[label] = new_df[label].sum()\n",
    "# sort the label_count dictionary\n",
    "label_count = dict(sorted(label_count.items(), key=lambda item: item[1], reverse=True))\n",
    "# drop the label with less than 10 samples\n",
    "for key in list(label_count.keys()):\n",
    "    if label_count[key] < 10:\n",
    "        del label_count[key]\n",
    "# drop the columns not in label_count\n",
    "for key in list(new_df.columns):\n",
    "    if key not in label_count.keys():\n",
    "        new_df.drop(key, axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split train test val\n",
    "train_df, test_df = train_test_split(new_df, test_size=0.2, random_state=42)\n",
    "train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)\n",
    "train_df.reset_index(inplace=True, drop=True)\n",
    "val_df.reset_index(inplace=True, drop=True)\n",
    "test_df.reset_index(inplace=True, drop=True)\n",
    "\n",
    "print(f'train_df shape: {train_df.shape}')\n",
    "print(f'val_df shape: {val_df.shape}')\n",
    "print(f'test_df shape: {test_df.shape}')\n",
    "\n",
    "# save the csv files\n",
    "# train_df.to_csv(f'{split_path}chapman/'+'chapman_train.csv', index=False)\n",
    "# val_df.to_csv(f'{split_path}chapman/'+'chapman_val.csv', index=False)\n",
    "# test_df.to_csv(f'{split_path}chapman/'+'chapman_test.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "medvlp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
