{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbd4178f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import cv2\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import Counter\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import pickle\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cf5d2b1",
   "metadata": {},
   "source": [
    "## Preprocess metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1173431",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read metadata\n",
    "path = 'somewhere/HAM10000/'\n",
    "\n",
    "demo_data = pd.read_csv(path + 'HAM10000_metadata.csv')\n",
    "demo_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95739fb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "Counter(demo_data['dataset'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a376a0aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add image path to the metadata\n",
    "pathlist = demo_data['image_id'].values.tolist()\n",
    "paths = ['HAM10000_images/' + i + '.jpg' for i in pathlist]\n",
    "demo_data['Path'] = paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c80a746b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove age/sex == null \n",
    "demo_data = demo_data[~demo_data['age'].isnull()]\n",
    "demo_data = demo_data[~demo_data['sex'].isnull()]\n",
    "demo_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e08809a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# unify the value of sensitive attributes\n",
    "sex = demo_data['sex'].values\n",
    "sex[sex == 'male'] = 'M'\n",
    "sex[sex == 'female'] = 'F'\n",
    "demo_data['Sex'] = sex\n",
    "demo_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39807da2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split subjects to different age groups\n",
    "demo_data['Age_multi'] = demo_data['age'].values.astype('int')\n",
    "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(-1,19), 0, demo_data['Age_multi'])\n",
    "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(20,39), 1, demo_data['Age_multi'])\n",
    "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(40,59), 2, demo_data['Age_multi'])\n",
    "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(60,79), 3, demo_data['Age_multi'])\n",
    "demo_data['Age_multi'] = np.where(demo_data['Age_multi']>=80, 4, demo_data['Age_multi'])\n",
    "\n",
    "demo_data['Age_binary'] = demo_data['age'].values.astype('int')\n",
    "demo_data['Age_binary'] = np.where(demo_data['Age_binary'].between(-1, 60), 0, demo_data['Age_binary'])\n",
    "demo_data['Age_binary'] = np.where(demo_data['Age_binary']>= 60, 1, demo_data['Age_binary'])\n",
    "demo_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6135d73",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to binary labels\n",
    "# benign: bcc, bkl, dermatofibroma, nv, vasc\n",
    "# maglinant: akiec, mel\n",
    "\n",
    "labels = demo_data['dx'].values.copy()\n",
    "labels[labels == 'akiec'] = '1'\n",
    "labels[labels == 'mel'] = '1'\n",
    "labels[labels != '1'] = '0'\n",
    "\n",
    "labels = labels.astype('int')\n",
    "\n",
    "demo_data['binaryLabel'] = labels\n",
    "demo_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b528c300",
   "metadata": {},
   "source": [
    "## Split train/val/test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4f757b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_811(all_meta, patient_ids):\n",
    "    sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.2, random_state=5)\n",
    "    sub_val, sub_test = train_test_split(sub_val_test, test_size=0.5, random_state=6)\n",
    "    train_meta = all_meta[all_meta.lesion_id.isin(sub_train)]\n",
    "    val_meta = all_meta[all_meta.lesion_id.isin(sub_val)]\n",
    "    test_meta = all_meta[all_meta.lesion_id.isin(sub_test)]\n",
    "    return train_meta, val_meta, test_meta\n",
    "\n",
    "sub_train, sub_val, sub_test = split_811(demo_data, np.unique(demo_data['lesion_id']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b91657ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_train.to_csv('somewhere/HAM10000/split/new_train.csv')\n",
    "sub_val.to_csv('somewhere/HAM10000/split/new_val.csv')\n",
    "sub_test.to_csv('somewhere/HAM10000/split/new_test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab951da8",
   "metadata": {},
   "source": [
    "## Save images into pickle files\n",
    "This is optional, but if you are training many models, this step can save a lot of time by reducing the data IO."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d667132",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_meta = pd.read_csv('/split/new_test.csv')\n",
    "\n",
    "path = 'save_path/'\n",
    "images = []\n",
    "start = time.time()\n",
    "for i in range(len(test_meta)):\n",
    "\n",
    "    img = cv2.imread(path + test_meta.iloc[i]['Path'])\n",
    "    # resize to the input size in advance to save time during training\n",
    "    img = cv2.resize(img, (256, 256))\n",
    "    images.append(img)\n",
    "    #print(img.shape)\n",
    "    \n",
    "\n",
    "end = time.time()\n",
    "end-start\n",
    "with open('your_path.pkl', 'wb') as f:\n",
    "    pickle.dump(images, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch11",
   "language": "python",
   "name": "torch11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
