{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3e5183",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "import random\n",
    "import sys\n",
    "sys.path.append(os.path.dirname(os.getcwd()))\n",
    "from constants import *\n",
    "\n",
    "seed = 0\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692d1d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "root = BRSET_root\n",
    "tgt_root = BRSET_tgt_root"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d235fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(os.path.join(root, 'labels_brset.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f32a129",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['quality'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85eccabc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df[df['quality'] == 'Adequate']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3a052ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df['patient_id'].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "099e25a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a8b4a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns[20:33]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b90b2a5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['abnormal'] = df.iloc[:, 20:33].sum(axis=1) != 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26d9b799",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_normal = df[df['abnormal'] == 0]\n",
    "print(len(df_normal))\n",
    "df_abnormal = df[df['abnormal'] != 0]\n",
    "print(len(df_abnormal))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89ce96ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_abnormal_no_others = df_abnormal[(df_abnormal['other'] == 0)]\n",
    "print(len(df_abnormal_no_others))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce753862",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df_normal, df_abnormal_no_others])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692d1634",
   "metadata": {},
   "outputs": [],
   "source": [
    "# randomly select 200 patient_id as validation set and 500 patient_id as test set\n",
    "# Step 1: Get unique patient IDs\n",
    "unique_patient_ids = df['patient_id'].unique()\n",
    "print(f\"Total unique patients: {len(unique_patient_ids)}\")\n",
    "\n",
    "# Step 2: Randomly select patient IDs\n",
    "test_patient_ids = np.random.choice(unique_patient_ids, size=500, replace=False)\n",
    "\n",
    "# Remove test patient IDs from the pool\n",
    "remaining_patient_ids = np.setdiff1d(unique_patient_ids, test_patient_ids)\n",
    "\n",
    "# Select 200 patient IDs for validation\n",
    "validation_patient_ids = np.random.choice(remaining_patient_ids, size=200, replace=False)\n",
    "\n",
    "# Remaining patient IDs are for training\n",
    "train_patient_ids = np.setdiff1d(remaining_patient_ids, validation_patient_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e32f14f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = df[df['patient_id'].isin(train_patient_ids)]\n",
    "validation_set = df[df['patient_id'].isin(validation_patient_ids)]\n",
    "test_set = df[df['patient_id'].isin(test_patient_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db766dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set['abnormal'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "565de86e",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set['abnormal'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3073ee1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set['abnormal'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91224449",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set.to_csv(os.path.join(tgt_root, 'train.csv'), index=False)\n",
    "validation_set.to_csv(os.path.join(tgt_root, 'val.csv'), index=False)\n",
    "test_set.to_csv(os.path.join(tgt_root, 'test.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9668edd7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
