{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "df0065c0",
   "metadata": {},
   "source": [
    "# Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "633c1b7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f539669f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "def rename_label_columns(input_dir: str = \"./raw\", output_dir: str = \"./preprocessed\"):\n",
    "    targets = ['death', 'STAY_DAYS', 'readmit_30d']\n",
    "    in_dir = Path(input_dir)\n",
    "    out_dir = Path(output_dir)\n",
    "    out_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    outputs = []\n",
    "    for csv_path in sorted(in_dir.glob(\"*.csv\")):\n",
    "        df = pd.read_csv(csv_path)\n",
    "        rename_map = {c: 'label' for c in df.columns if c in targets}\n",
    "        if rename_map:\n",
    "            df = df.rename(columns=rename_map)\n",
    "            # Convert True/False to 1/0 for death column\n",
    "            if 'death' in rename_map and 'label' in df.columns:\n",
    "                df['label'] = df['label'].map({True: 1, False: 0})\n",
    "        out_path = out_dir / csv_path.name\n",
    "        df.to_csv(out_path, index=False)\n",
    "        outputs.append(str(out_path))\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6cdee8af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['preprocessed/mortality_cato.csv',\n",
       " 'preprocessed/mortality_ours.csv',\n",
       " 'preprocessed/mortality_vanilla.csv',\n",
       " 'preprocessed/readmission_cato.csv',\n",
       " 'preprocessed/readmission_ours.csv',\n",
       " 'preprocessed/readmission_vanilla.csv',\n",
       " 'preprocessed/stay_periods_cato.csv',\n",
       " 'preprocessed/stay_periods_ours.csv',\n",
       " 'preprocessed/stay_periods_vanilla.csv']"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rename_label_columns()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1754f8c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cleaned and saved: preprocessed/mortality_cato.csv\n",
      "Cleaned and saved: preprocessed/mortality_ours.csv\n",
      "Cleaned and saved: preprocessed/mortality_vanilla.csv\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import re\n",
    "\n",
    "files = [\n",
    "    \"preprocessed/mortality_cato.csv\",\n",
    "    \"preprocessed/mortality_ours.csv\",\n",
    "    \"preprocessed/mortality_vanilla.csv\",\n",
    "]\n",
    "\n",
    "def clean_notes(text: str) -> str:\n",
    "    if pd.isna(text):\n",
    "        return text\n",
    "    s = str(text)\n",
    "\n",
    "    # Remove segments like \" | DISCHARGE_LOCATION=HOME\" up to next pipe\n",
    "    s = re.sub(r'\\s*\\|\\s*DISCHARGE_LOCATION=[^|]*', '', s, flags=re.IGNORECASE)\n",
    "    # Also handle cases without pipes: \"DISCHARGE_LOCATION=HOME\"\n",
    "    s = re.sub(r'DISCHARGE_LOCATION=[^\\s|,;]*', '', s, flags=re.IGNORECASE)\n",
    "\n",
    "    # Remove mentions of the field name \"discharge_location\" (leave the value)\n",
    "    s = re.sub(r'\\*{0,2}\\s*discharge_location\\s*\\*{0,2}', '', s, flags=re.IGNORECASE)\n",
    "\n",
    "    # Remove death-related words (case-insensitive), keeping only the words themselves\n",
    "    s = re.sub(r'\\b(dead|death|expired)\\b', '', s, flags=re.IGNORECASE)\n",
    "\n",
    "    # Tidy up dangling separators created by removals\n",
    "    s = re.sub(r'\\s*\\|\\s*\\|\\s*', ' | ', s)          # \" |  | \" -> \" | \"\n",
    "    s = re.sub(r'\\s*\\|\\s*[=:]\\s*', ' | ', s)        # \" | = \" or \" | : \" -> \" | \"\n",
    "    s = re.sub(r'^[=:]\\s*', '', s)                  # leading \"= \" or \": \" -> \"\"\n",
    "    s = re.sub(r'\\s{2,}', ' ', s).strip()           # collapse spaces\n",
    "    return s\n",
    "\n",
    "for path in files:\n",
    "    df = pd.read_csv(path)\n",
    "    for col in (\"original_note\", \"augmented_note\"):\n",
    "        if col in df.columns:\n",
    "            df[col] = df[col].apply(clean_notes)\n",
    "    df.to_csv(path, index=False)\n",
    "    print(f\"Cleaned and saved: {path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "41fecc2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>original_note</th>\n",
       "      <th>augmented_note</th>\n",
       "      <th>label</th>\n",
       "      <th>method</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>DIAGNOSIS=NEWBORN | EDOUTTIME= | EDREGTIME= | ...</td>\n",
       "      <td>**diagnosis:** NEWBORN **edouttime:** | **edre...</td>\n",
       "      <td>0</td>\n",
       "      <td>cato</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>DIAGNOSIS=HYPOTENSION | EDOUTTIME=2101-10-20 1...</td>\n",
       "      <td>**Clinical Note:** **Diagnosis:** HYPOTENSION ...</td>\n",
       "      <td>0</td>\n",
       "      <td>cato</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>DIAGNOSIS=FEVER,DEHYDRATION,FAILURE TO THRIVE ...</td>\n",
       "      <td>**Clinical Note:** **Diagnosis:** Fever, dehyd...</td>\n",
       "      <td>0</td>\n",
       "      <td>cato</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>DIAGNOSIS=NEWBORN | EDOUTTIME= | EDREGTIME= | ...</td>\n",
       "      <td>**Diagnosis:** NEWBORN **EDOUTTIME:** | **EDRE...</td>\n",
       "      <td>0</td>\n",
       "      <td>cato</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>DIAGNOSIS=CHRONIC RENAL FAILURE/SDA | EDOUTTIM...</td>\n",
       "      <td>**Clinical Note:** **Diagnosis:** CHRONIC RENA...</td>\n",
       "      <td>0</td>\n",
       "      <td>cato</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       original_note  \\\n",
       "0  DIAGNOSIS=NEWBORN | EDOUTTIME= | EDREGTIME= | ...   \n",
       "1  DIAGNOSIS=HYPOTENSION | EDOUTTIME=2101-10-20 1...   \n",
       "2  DIAGNOSIS=FEVER,DEHYDRATION,FAILURE TO THRIVE ...   \n",
       "3  DIAGNOSIS=NEWBORN | EDOUTTIME= | EDREGTIME= | ...   \n",
       "4  DIAGNOSIS=CHRONIC RENAL FAILURE/SDA | EDOUTTIM...   \n",
       "\n",
       "                                      augmented_note  label method  \n",
       "0  **diagnosis:** NEWBORN **edouttime:** | **edre...      0   cato  \n",
       "1  **Clinical Note:** **Diagnosis:** HYPOTENSION ...      0   cato  \n",
       "2  **Clinical Note:** **Diagnosis:** Fever, dehyd...      0   cato  \n",
       "3  **Diagnosis:** NEWBORN **EDOUTTIME:** | **EDRE...      0   cato  \n",
       "4  **Clinical Note:** **Diagnosis:** CHRONIC RENA...      0   cato  "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df= pd.read_csv(\"preprocessed/mortality_cato.csv\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "509cd4a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'DIAGNOSIS=CORONARY ARTERY DISEASE\\\\CATH | EDOUTTIME= | EDREGTIME= | ETHNICITY=UNKNOWN/NOT SPECIFIED | MARITAL_STATUS=MARRIED | RELIGION=CATHOLIC| ADMISSION_LOCATION=CLINIC REFERRAL/PREMATURE | ADMISSION_TYPE=EMERGENCY | DISCHTIME=2192-11-27 14:35:00 | ADMITTIME=2192-11-19 18:14:00 | TEXT=[**2192-11-19**] 4:16 PM CAROTID SERIES COMPLETE Clip # [**Clip Number (Radiology) 37261**] Reason: r/o stenosis. Pre-op for CABG. ______________________________________________________________________________ [**Hospital 2**] MEDICAL CONDITION: 80 year old man with CAD. S/P bilateral CEA in [**2179**]. REASON FOR THIS EXAMINATION: r/o stenosis. Pre-op for CABG. ______________________________________________________________________________ FINAL REPORT STUDY: Carotid series complete. REASON: Preop CABG. Status post bilateral carotid endarterectomy in [**2179**]. FINDINGS: Duplex evaluation was performed of bilateral carotid arteries. There is no significant plaque noted in the right or left carotid arteries. On the right, peak systolic velocities are 73, 48, and 68 cm per second in the ICA, CCA, and ECA respectively. The ICA end-diastolic velocity is 26 cm per second. The ICA:CCA ratio is 1.5. This is consistent with no significant stenosis. On the left, peak systolic velocities are 64, 85, and 62 cm per second in the ICA, CCA, and ECA respectively. The ICA end-diastolic velocity is 25 cm per second. The ICA:CCA ratio is 0.8. This is consistent with no significant stenosis. There is antegrade vertebral artery flow bilaterally. IMPRESSION: No significant stenosis of the right or left ICA.'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['original_note'][44]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c8535ad",
   "metadata": {},
   "source": [
    "# TO SFT Format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3c9fcfe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "#readmission\n",
    "def convert_to_sft_readmission(csv_file_path, output_file_path, augment=True):\n",
    "    \"\"\"Convert CSV format to SFT format (json) using chosen responses\"\"\"\n",
    "    try:\n",
    "        # Read CSV file\n",
    "        df = pd.read_csv(csv_file_path)\n",
    "        \n",
    "        sft_data = []\n",
    "        for index, row in df.iterrows():\n",
    "            readmission_prompt = f\"\"\"You will be given clinical notes of a patient. Your task is to predict whether the patient will readmit to the hospital in 30 Days.\n",
    "            Label: 1 (The patient will readmit within 30 days), 0 (The patient will not readmit within 30 days).\n",
    "            Clinical Note: {row['original_note']}\n",
    "            \"\"\"\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": \"You are a clinical expert classifier. Respond with only 1 or 0.\"},\n",
    "                {\"role\": \"user\", \"content\": readmission_prompt},\n",
    "                {\"role\": \"assistant\", \"content\": str(row['label'])}\n",
    "            ]\n",
    "            sft_data.append({\"messages\": messages})\n",
    "\n",
    "            if augment and 'augmented_note' in df.columns:\n",
    "                augmented_prompt = f\"\"\"You will be given clinical notes of a patient. Your task is to predict whether the patient will readmit to the hospital in 30 Days.\n",
    "                Label: 1 (The patient will readmit within 30 days), 0 (The patient will not readmit within 30 days).\n",
    "                Clinical Note: {row['augmented_note']}\n",
    "                \"\"\"\n",
    "                augmented_messages = [\n",
    "                    {\"role\": \"system\", \"content\": \"You are a clinical expert classifier. Respond with only 1 or 0.\"},\n",
    "                    {\"role\": \"user\", \"content\": augmented_prompt},\n",
    "                    {\"role\": \"assistant\", \"content\": str(row['label'])}\n",
    "                ]\n",
    "                sft_data.append({\"messages\": augmented_messages})                  \n",
    "        \n",
    "        # Create output directory if it doesn't exist\n",
    "        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)\n",
    "        \n",
    "        with open(output_file_path, 'w') as f:\n",
    "            json.dump(sft_data, f, indent=2)\n",
    "        \n",
    "        print(f\"Converted {len(sft_data)} samples to SFT format and saved to {output_file_path}\")\n",
    "        \n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File '{csv_file_path}' not found.\")\n",
    "    except KeyError as e:\n",
    "        print(f\"Error: Missing required column '{e}' in the CSV file.\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "31ab64d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "#mortality\n",
    "def convert_to_sft_mortality(csv_file_path, output_file_path, augment=True):\n",
    "    \"\"\"Convert CSV format to SFT format (json) using chosen responses\"\"\"\n",
    "    try:\n",
    "        # Read CSV file\n",
    "        df = pd.read_csv(csv_file_path)\n",
    "        \n",
    "        sft_data = []\n",
    "        for index, row in df.iterrows():\n",
    "            readmission_prompt = f\"\"\"You will be given clinical notes of a patient. Your task is to predict whether the patient will readmit to the hospital in 30 Days.\n",
    "                Label: 1 (The patient has died within their admission), 0 (The patient did not die during their admission).\n",
    "            Clinical Note: {row['original_note']}\n",
    "            \"\"\"\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": \"You are a clinical expert classifier. Respond with only 1 or 0.\"},\n",
    "                {\"role\": \"user\", \"content\": readmission_prompt},\n",
    "                {\"role\": \"assistant\", \"content\": str(row['label'])}\n",
    "            ]\n",
    "            sft_data.append({\"messages\": messages})\n",
    "\n",
    "            if augment and 'augmented_note' in df.columns:\n",
    "                augmented_prompt = f\"\"\"You will be given clinical notes of a patient. Your task is to predict whether the patient will readmit to the hospital in 30 Days.\n",
    "                Label: 1 (The patient has died within their admission), 0 (The patient did not die during their admission).\n",
    "                Clinical Note: {row['augmented_note']}\n",
    "                \"\"\"\n",
    "                augmented_messages = [\n",
    "                    {\"role\": \"system\", \"content\": \"You are a clinical expert classifier. Respond with only 1 or 0.\"},\n",
    "                    {\"role\": \"user\", \"content\": augmented_prompt},\n",
    "                    {\"role\": \"assistant\", \"content\": str(row['label'])}\n",
    "                ]\n",
    "                sft_data.append({\"messages\": augmented_messages})                  \n",
    "        \n",
    "        # Create output directory if it doesn't exist\n",
    "        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)\n",
    "        \n",
    "        with open(output_file_path, 'w') as f:\n",
    "            json.dump(sft_data, f, indent=2)\n",
    "        \n",
    "        print(f\"Converted {len(sft_data)} samples to SFT format and saved to {output_file_path}\")\n",
    "        \n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File '{csv_file_path}' not found.\")\n",
    "    except KeyError as e:\n",
    "        print(f\"Error: Missing required column '{e}' in the CSV file.\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "fcbd73bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "#stay\n",
    "def convert_to_sft_period(csv_file_path, output_file_path, augment=True):\n",
    "    \"\"\"Convert CSV format to SFT format (json) using chosen responses\"\"\"\n",
    "    try:\n",
    "        # Read CSV file\n",
    "        df = pd.read_csv(csv_file_path)\n",
    "        \n",
    "        sft_data = []\n",
    "        for index, row in df.iterrows():\n",
    "            readmission_prompt = f\"\"\"You will be given clinical notes of a patient. Your task is to predict how many days (float) the patient will stay.\n",
    "            Clinical Note: {row['original_note']}\n",
    "            \"\"\"\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": \"You are a clinical expert. Respond only with the number of days (e.g., XX.X)\"},\n",
    "                {\"role\": \"user\", \"content\": readmission_prompt},\n",
    "                {\"role\": \"assistant\", \"content\": str(row['label'])}\n",
    "            ]\n",
    "            sft_data.append({\"messages\": messages})\n",
    "\n",
    "            if augment and 'augmented_note' in df.columns:\n",
    "                augmented_prompt = f\"\"\"You will be given clinical notes of a patient. Your task is to predict how many days (float) the patient will stay.\n",
    "                Clinical Note: {row['augmented_note']}\n",
    "                \"\"\"\n",
    "                augmented_messages = [\n",
    "                    {\"role\": \"system\", \"content\": \"You are a clinical expert. Respond only with the number of days (e.g., XX.X)\"},\n",
    "                    {\"role\": \"user\", \"content\": augmented_prompt},\n",
    "                    {\"role\": \"assistant\", \"content\": str(row['label'])}\n",
    "                ]\n",
    "                sft_data.append({\"messages\": augmented_messages})                  \n",
    "        \n",
    "        # Create output directory if it doesn't exist\n",
    "        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)\n",
    "        \n",
    "        with open(output_file_path, 'w') as f:\n",
    "            json.dump(sft_data, f, indent=2)\n",
    "        \n",
    "        print(f\"Converted {len(sft_data)} samples to SFT format and saved to {output_file_path}\")\n",
    "        \n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File '{csv_file_path}' not found.\")\n",
    "    except KeyError as e:\n",
    "        print(f\"Error: Missing required column '{e}' in the CSV file.\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "589df740",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 8951 samples to SFT format and saved to ./data/readmission.json\n",
      "Converted 11060 samples to SFT format and saved to ./data/mortality.json\n",
      "Converted 9532 samples to SFT format and saved to ./data/stay_periods.json\n",
      "Converted 17902 samples to SFT format and saved to ./data/readmission_vanilla.json\n",
      "Converted 16234 samples to SFT format and saved to ./data/readmission_cato.json\n",
      "Converted 14966 samples to SFT format and saved to ./data/readmission_ours.json\n",
      "Converted 22120 samples to SFT format and saved to ./data/mortality_vanilla.json\n",
      "Converted 19482 samples to SFT format and saved to ./data/mortality_cato.json\n",
      "Converted 13906 samples to SFT format and saved to ./data/mortality_ours.json\n",
      "Converted 19064 samples to SFT format and saved to ./data/stay_periods_vanilla.json\n",
      "Converted 21258 samples to SFT format and saved to ./data/stay_periods_cato.json\n",
      "Converted 17218 samples to SFT format and saved to ./data/stay_periods_ours.json\n"
     ]
    }
   ],
   "source": [
    "#no augmentation\n",
    "convert_to_sft_readmission(\"./preprocessed/readmission_vanilla.csv\", \"./data/readmission.json\",augment=False)\n",
    "convert_to_sft_mortality(\"./preprocessed/mortality_vanilla.csv\", \"./data/mortality.json\",augment=False)\n",
    "convert_to_sft_period(\"./preprocessed/stay_periods_vanilla.csv\", \"./data/stay_periods.json\",augment=False)\n",
    "\n",
    "#readmission\n",
    "convert_to_sft_readmission(\"./preprocessed/readmission_vanilla.csv\", \"./data/readmission_vanilla.json\")\n",
    "convert_to_sft_readmission(\"./preprocessed/readmission_cato.csv\", \"./data/readmission_cato.json\")\n",
    "convert_to_sft_readmission(\"./preprocessed/readmission_ours.csv\", \"./data/readmission_ours.json\")\n",
    "\n",
    "#mortality\n",
    "convert_to_sft_mortality(\"./preprocessed/mortality_vanilla.csv\", \"./data/mortality_vanilla.json\")\n",
    "convert_to_sft_mortality(\"./preprocessed/mortality_cato.csv\", \"./data/mortality_cato.json\")\n",
    "convert_to_sft_mortality(\"./preprocessed/mortality_ours.csv\", \"./data/mortality_ours.json\")\n",
    "\n",
    "#period\n",
    "convert_to_sft_period(\"./preprocessed/stay_periods_vanilla.csv\", \"./data/stay_periods_vanilla.json\")\n",
    "convert_to_sft_period(\"./preprocessed/stay_periods_cato.csv\", \"./data/stay_periods_cato.json\")\n",
    "convert_to_sft_period(\"./preprocessed/stay_periods_ours.csv\", \"./data/stay_periods_ours.json\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0c3bcdd",
   "metadata": {},
   "source": [
    "# Balance Data for Evaluation \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "91c6e682",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Balanced to 14966 samples:\n",
      "./data/readmission_vanilla_balanced.json\n",
      "./data/readmission_cato_balanced.json\n",
      "./data/readmission_ours_balanced.json\n",
      "Balanced to 13906 samples:\n",
      "./data/mortality_vanilla_balanced.json\n",
      "./data/mortality_cato_balanced.json\n",
      "./data/mortality_ours_balanced.json\n",
      "Balanced to 17218 samples:\n",
      "./data/stay_periods_vanilla_balanced.json\n",
      "./data/stay_periods_cato_balanced.json\n",
      "./data/stay_periods_ours_balanced.json\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['./data/stay_periods_vanilla_balanced.json',\n",
       " './data/stay_periods_cato_balanced.json',\n",
       " './data/stay_periods_ours_balanced.json']"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json, os, random\n",
    "\n",
    "\n",
    "def balance_json_group(file_paths, output_suffix=\"_balanced\", seed=42, output_dir=None):\n",
    "    \"\"\"Downsample all JSON datasets in file_paths to the same (minimum) length.\n",
    "    Saves new files with the given suffix in the same directory by default.\n",
    "    \"\"\"\n",
    "    datasets = []\n",
    "    for path in file_paths:\n",
    "        with open(path, \"r\") as f:\n",
    "            datasets.append(json.load(f))\n",
    "\n",
    "    min_len = min(len(d) for d in datasets)\n",
    "\n",
    "    random.seed(seed)\n",
    "    balanced_datasets = []\n",
    "    for data in datasets:\n",
    "        if len(data) > min_len:\n",
    "            sampled = random.sample(data, min_len)\n",
    "        else:\n",
    "            sampled = data\n",
    "        balanced_datasets.append(sampled)\n",
    "\n",
    "    out_paths = []\n",
    "    for in_path, data in zip(file_paths, balanced_datasets):\n",
    "        directory = os.path.dirname(in_path) if output_dir is None else output_dir\n",
    "        base = os.path.basename(in_path)\n",
    "        if base.endswith(\".json\"):\n",
    "            base = base[:-5]\n",
    "        out_path = os.path.join(directory, f\"{base}{output_suffix}.json\")\n",
    "        with open(out_path, \"w\") as f:\n",
    "            json.dump(data, f, indent=2)\n",
    "        out_paths.append(out_path)\n",
    "\n",
    "    print(f\"Balanced to {min_len} samples:\")\n",
    "    for p in out_paths:\n",
    "        print(p)\n",
    "    return out_paths\n",
    "\n",
    "\n",
    "# Balance readmission variants\n",
    "balance_json_group([\n",
    "    \"./data/readmission_vanilla.json\",\n",
    "    \"./data/readmission_cato.json\",\n",
    "    \"./data/readmission_ours.json\",\n",
    "])\n",
    "\n",
    "# Balance mortality variants\n",
    "balance_json_group([\n",
    "    \"./data/mortality_vanilla.json\",\n",
    "    \"./data/mortality_cato.json\",\n",
    "    \"./data/mortality_ours.json\",\n",
    "])\n",
    "\n",
    "# Balance stay period variants\n",
    "balance_json_group([\n",
    "    \"./data/stay_periods_vanilla.json\",\n",
    "    \"./data/stay_periods_cato.json\",\n",
    "    \"./data/stay_periods_ours.json\",\n",
    "])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02ff8138",
   "metadata": {},
   "source": [
    "# DPO Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "08bbd982",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating DPO datasets...\n",
      "==================================================\n",
      "Creating DPO dataset for readmission:\n",
      "  Vanilla samples: 8951\n",
      "  Cato samples: 8117\n",
      "  Ours samples: 7483\n",
      "  Created 14966 DPO pairs\n",
      "  Saved to: ./data/readmission_dpo.json\n",
      "Creating DPO dataset for mortality:\n",
      "  Vanilla samples: 11060\n",
      "  Cato samples: 9741\n",
      "  Ours samples: 6953\n",
      "  Created 13906 DPO pairs\n",
      "  Saved to: ./data/mortality_dpo.json\n",
      "Creating DPO dataset for stay_periods:\n",
      "  Vanilla samples: 9532\n",
      "  Cato samples: 10629\n",
      "  Ours samples: 8609\n",
      "  Created 17218 DPO pairs\n",
      "  Saved to: ./data/stay_periods_dpo.json\n",
      "==================================================\n",
      "Total DPO pairs created:\n",
      "  Readmission: 14966\n",
      "  Mortality: 13906\n",
      "  Stay Periods: 17218\n",
      "  Grand Total: 46090\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "def create_dpo_dataset(task_name, vanilla_path, cato_path, ours_path, output_path):\n",
    "    \"\"\"\n",
    "    Create DPO dataset by pairing ours (chosen) with vanilla/cato (rejected) augmented notes.\n",
    "    \n",
    "    Args:\n",
    "        task_name: Name of the task (readmission, mortality, stay_periods)\n",
    "        vanilla_path: Path to vanilla CSV file\n",
    "        cato_path: Path to cato CSV file  \n",
    "        ours_path: Path to ours CSV file\n",
    "        output_path: Path to save the DPO JSON file\n",
    "    \"\"\"\n",
    "    \n",
    "    # Read the CSV files\n",
    "    df_vanilla = pd.read_csv(vanilla_path)\n",
    "    df_cato = pd.read_csv(cato_path)\n",
    "    df_ours = pd.read_csv(ours_path)\n",
    "    \n",
    "    print(f\"Creating DPO dataset for {task_name}:\")\n",
    "    print(f\"  Vanilla samples: {len(df_vanilla)}\")\n",
    "    print(f\"  Cato samples: {len(df_cato)}\")\n",
    "    print(f\"  Ours samples: {len(df_ours)}\")\n",
    "    \n",
    "    dpo_data = []\n",
    "    \n",
    "    # Create DPO pairs: ours vs vanilla\n",
    "    for idx, row in df_ours.iterrows():\n",
    "        if idx < len(df_vanilla):\n",
    "            vanilla_row = df_vanilla.iloc[idx]\n",
    "            \n",
    "            # Ensure we have the same original note (for alignment)\n",
    "            if row['original_note'] == vanilla_row['original_note']:\n",
    "                dpo_entry = {\n",
    "                    \"prompt\": [\n",
    "                        {\n",
    "                            \"role\": \"user\", \n",
    "                            \"content\": f\"You are an AI assistant. Your task is to rewrite the given clinical note in a different writing style while maintaining the same medical information and clinical accuracy. You can change sentence structure, word choice, and writing flow, but preserve all medical facts and details.\\n\\nClinical Note: {row['original_note']}\"\n",
    "                        }\n",
    "                    ],\n",
    "                    \"rejected\": [\n",
    "                        {\n",
    "                            \"role\": \"assistant\", \n",
    "                            \"content\": vanilla_row['augmented_note']\n",
    "                        }\n",
    "                    ],\n",
    "                    \"chosen\": [\n",
    "                        {\n",
    "                            \"role\": \"assistant\", \n",
    "                            \"content\": row['augmented_note']\n",
    "                        }\n",
    "                    ]\n",
    "                }\n",
    "                dpo_data.append(dpo_entry)\n",
    "    \n",
    "    # Create DPO pairs: ours vs cato\n",
    "    for idx, row in df_ours.iterrows():\n",
    "        if idx < len(df_cato):\n",
    "            cato_row = df_cato.iloc[idx]\n",
    "            \n",
    "            # Ensure we have the same original note (for alignment)\n",
    "            if row['original_note'] == cato_row['original_note']:\n",
    "                dpo_entry = {\n",
    "                    \"prompt\": [\n",
    "                        {\n",
    "                            \"role\": \"user\", \n",
    "                            \"content\": f\"You are an AI assistant. Your task is to rewrite the given clinical note in a different writing style while maintaining the same medical information and clinical accuracy. You can change sentence structure, word choice, and writing flow, but preserve all medical facts and details.\\n\\nClinical Note: {row['original_note']}\"\n",
    "                        }\n",
    "                    ],\n",
    "                    \"rejected\": [\n",
    "                        {\n",
    "                            \"role\": \"assistant\", \n",
    "                            \"content\": cato_row['augmented_note']\n",
    "                        }\n",
    "                    ],\n",
    "                    \"chosen\": [\n",
    "                        {\n",
    "                            \"role\": \"assistant\", \n",
    "                            \"content\": row['augmented_note']\n",
    "                        }\n",
    "                    ]\n",
    "                }\n",
    "                dpo_data.append(dpo_entry)\n",
    "    \n",
    "    # Create output directory if it doesn't exist\n",
    "    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
    "    \n",
    "    # Save DPO dataset\n",
    "    with open(output_path, 'w') as f:\n",
    "        json.dump(dpo_data, f, indent=2)\n",
    "    \n",
    "    print(f\"  Created {len(dpo_data)} DPO pairs\")\n",
    "    print(f\"  Saved to: {output_path}\")\n",
    "    \n",
    "    return len(dpo_data)\n",
    "\n",
    "# Create DPO datasets for all three tasks\n",
    "print(\"Creating DPO datasets...\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "# Readmission task\n",
    "readmission_pairs = create_dpo_dataset(\n",
    "    \"readmission\",\n",
    "    \"./preprocessed/readmission_vanilla.csv\",\n",
    "    \"./preprocessed/readmission_cato.csv\", \n",
    "    \"./preprocessed/readmission_ours.csv\",\n",
    "    \"./data/readmission_dpo.json\"\n",
    ")\n",
    "\n",
    "# Mortality task  \n",
    "mortality_pairs = create_dpo_dataset(\n",
    "    \"mortality\",\n",
    "    \"./preprocessed/mortality_vanilla.csv\",\n",
    "    \"./preprocessed/mortality_cato.csv\",\n",
    "    \"./preprocessed/mortality_ours.csv\", \n",
    "    \"./data/mortality_dpo.json\"\n",
    ")\n",
    "\n",
    "# Stay periods task\n",
    "stay_periods_pairs = create_dpo_dataset(\n",
    "    \"stay_periods\",\n",
    "    \"./preprocessed/stay_periods_vanilla.csv\",\n",
    "    \"./preprocessed/stay_periods_cato.csv\",\n",
    "    \"./preprocessed/stay_periods_ours.csv\",\n",
    "    \"./data/stay_periods_dpo.json\"\n",
    ")\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(f\"Total DPO pairs created:\")\n",
    "print(f\"  Readmission: {readmission_pairs}\")\n",
    "print(f\"  Mortality: {mortality_pairs}\")\n",
    "print(f\"  Stay Periods: {stay_periods_pairs}\")\n",
    "print(f\"  Grand Total: {readmission_pairs + mortality_pairs + stay_periods_pairs}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad365e37",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3d163b0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a8be551",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e44a03",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76a2e63",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97df2bb4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
