{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install geopandas==0.14.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.patches as patch\n",
    "import json\n",
    "import geopandas as gpd\n",
    "from pathlib import Path\n",
    "\n",
    "# suppress warnings\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "path_to_xlxs = \"./xlxs\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code to generate Figure 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = Path('./metadata/')\n",
    "\n",
    "sns.set(style=\"whitegrid\")\n",
    "\n",
    "sex = {}\n",
    "age = {}\n",
    "length = {}\n",
    "labels = {}\n",
    "geographic_origin = {}\n",
    "fns = root.glob('*.json')\n",
    "for fn in fns:\n",
    "    with open(fn) as f:\n",
    "        js = json.load(f)\n",
    "    sex[fn.stem[:-9]] = js['sex']\n",
    "    age[fn.stem[:-9]] = js['age']\n",
    "    length[fn.stem[:-9]] = js['length']\n",
    "    labels[fn.stem[:-9]] = js['labels']\n",
    "    geographic_origin[fn.stem[:-9]] = js['geographic_origin']\n",
    "    \n",
    "df_sex = pd.DataFrame(sex)\n",
    "df_sex = df_sex.reindex(sorted(df_sex.columns), axis=1)\n",
    "df_age = pd.DataFrame(age)\n",
    "df_age = df_age.reindex(sorted(df_age.columns), axis=1)\n",
    "df_length = pd.DataFrame(length)\n",
    "df_length = df_length.reindex(sorted(df_length.columns), axis=1)\n",
    "df_labels = pd.DataFrame(labels)\n",
    "df_labels = df_labels.reindex(sorted(df_labels.columns), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "ax = df_sex.T.plot(kind='bar', stacked=True, colormap='viridis', edgecolor='black')\n",
    "plt.xticks(rotation=45, ha='right', fontsize=12)\n",
    "\n",
    "plt.title('Sex Divisions Across Datasets', fontsize=20)\n",
    "plt.xlabel('Datasets', fontsize=15)\n",
    "plt.ylabel('Percentage', fontsize=15)\n",
    "\n",
    "plt.legend(title='Sex', title_fontsize='13', fontsize='11')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 8))\n",
    "\n",
    "colors = sns.color_palette(\"viridis\", n_colors=len(df_length.columns))\n",
    "plt.bar(df_length.sum(axis=0).index, np.log10(df_length.sum(axis=0)), alpha=0.7, color=colors)\n",
    "plt.xticks(rotation=45, ha='right', fontsize=12)\n",
    "\n",
    "plt.title('Samples Across Dataset', fontsize=20)\n",
    "plt.xlabel('Dataset', fontsize=15)\n",
    "plt.ylabel('Count (log10 scale)', fontsize=15)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('sample_count_across_datasets.pdf', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_sums = df_length.sum(axis=0)\n",
    "\n",
    "colors = sns.color_palette(\"viridis\", n_colors=len(df_length.columns))\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 10), facecolor='white')\n",
    "\n",
    "plt.title('Sample Distribution Across Datasets', fontsize=22, weight='bold', pad=20)\n",
    "\n",
    "wedges, texts = ax.pie(\n",
    "    dataset_sums, \n",
    "    colors=colors, \n",
    "    startangle=140, \n",
    "    wedgeprops={'edgecolor': 'white', 'linewidth': 2}\n",
    ")\n",
    "\n",
    "centre_circle = plt.Circle((0, 0), 0.70, color='white', fc='white', linewidth=1.25)\n",
    "fig.gca().add_artist(centre_circle)\n",
    "\n",
    "ax.axis('equal')  \n",
    "\n",
    "\n",
    "center_text = \"\\n\".join([f\"{name}: {count}\" for name, count in zip(dataset_sums.index, dataset_sums.values)])\n",
    "ax.text(0, 0, center_text, ha='center', va='center', fontsize=14, weight='bold', color='black')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('size_datasets_pie.pdf', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_age_norm = df_age / df_age.sum(axis=0)\n",
    "\n",
    "bar_width = 0.1\n",
    "x = np.arange(len(df_age_norm.index)) \n",
    "\n",
    "plt.figure(figsize=(12, 8))\n",
    "for i, dataset in enumerate(df_age_norm.columns):\n",
    "    plt.bar(x + i * bar_width, df_age_norm[dataset], width=bar_width, label=dataset, alpha=0.7, color=colors[i])\n",
    "\n",
    "plt.title('Age Distribution Across Datasets', fontsize=20)\n",
    "plt.xlabel('Age Range', fontsize=15)\n",
    "plt.ylabel('Count', fontsize=15)\n",
    "plt.xticks(x + bar_width * (len(df_age_norm.columns) / 2), df_age_norm.index)\n",
    "\n",
    "plt.legend(title='Datasets', title_fontsize='13', fontsize='11')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('age_distribution_across_datasets.pdf', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_length.loc['>40'] = df_length.loc['40-45':].sum(axis=0)\n",
    "\n",
    "df_length = df_length.drop(df_length.index[df_length.index.get_loc('40-45'):df_length.index.get_loc('>40')])\n",
    "df_length = df_length.drop(['0-5','5-10'])\n",
    "\n",
    " \n",
    "df_length_norm = df_length / df_length.sum(axis=0)\n",
    "\n",
    "bar_width = 0.1\n",
    "x = np.arange(len(df_length_norm.index)) \n",
    "plt.figure(figsize=(10, 8))\n",
    "for i, dataset in enumerate(df_length_norm.columns):\n",
    "    plt.bar(x + i * bar_width, df_length_norm[dataset], width=bar_width, label=dataset, alpha=0.7, color=colors[i])\n",
    "\n",
    "plt.title('Lenght Distribution Across Datasets', fontsize=20)\n",
    "plt.xlabel('Lenght Range', fontsize=15)\n",
    "plt.ylabel('Count', fontsize=15)\n",
    "plt.xticks(x + bar_width * (len(df_length_norm.columns) / 2), df_length_norm.index)\n",
    "\n",
    "plt.legend(title='Datasets', title_fontsize='13', fontsize='11')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('lenght_distribution_across_datasets.pdf', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_labels.rename({'CHAGAS' : 'SAMITROP-DEATH'}, inplace=True)\n",
    "\n",
    "# Mapping each class to a specific color\n",
    "class_colors = {\n",
    "    2: 'tomato',   # # Group 2: Red color for \"ECG is not used in clinic: prediction of CVE\"\n",
    "    1: 'skyblue',   # Group 1: Yellow color for \"ECG is a Supportive Diagnostic Tool\"\n",
    "    0: 'lightgreen'     # Group 0: Green color for \"ECG is the Primary Diagnostic Tool\"\n",
    "}\n",
    "\n",
    "classes_map = { \n",
    "    'SAMITROP-DEATH': 2,\n",
    "    'TIA': 2,\n",
    "    'EAMI':1,\n",
    "    'IPLMI':1,\n",
    "    'PMI':1,\n",
    "    'ILMI':1,\n",
    "    'IPMI':1,\n",
    "    'ALMI':1,\n",
    "    'LMI':1,\n",
    "    'ASMI':1,\n",
    "    'IMI':1,\n",
    "    'INJLA':1,\n",
    "    'LAA':1,\n",
    "    'LAH':1,\n",
    "    'RAAB':1,\n",
    "    'LVH':1,\n",
    "    'SEHYP':1,\n",
    "    'RVH':1,\n",
    "    'AH':1,\n",
    "    'VH':1,\n",
    "    'CHD':1,\n",
    "    'CMIS':1,\n",
    "    'HF':1,\n",
    "    'HVD':1,\n",
    "    'LVS':1,\n",
    "    'PACE':1,\n",
    "    'ISC_':1,\n",
    "    'HYP': 1,\n",
    "}\n",
    "\n",
    "plt.figure(figsize=(18, 6)) \n",
    "sns.set(style=\"whitegrid\")\n",
    "\n",
    "colors = sns.color_palette(\"viridis\", n_colors=len(df_labels.columns))\n",
    "for i, dataset in enumerate(df_labels.columns):\n",
    "    plt.bar(df_labels.index, np.log10(df_labels[dataset]), label=dataset, color=colors[i], alpha=0.7)\n",
    "\n",
    "plt.xticks(rotation=45, ha='right', fontsize=6)\n",
    "\n",
    "for label in df_labels.index:\n",
    "    class_group = classes_map.get(label, 0)  \n",
    "    bar_color = class_colors[class_group]\n",
    "    \n",
    "    plt.bar(label, -0.25, width=0.8, color=bar_color, align='center')  \n",
    "\n",
    "plt.title('Label Distribution Across Datasets', fontsize=20)\n",
    "plt.xlabel('Label', fontsize=15)\n",
    "plt.ylabel('Count (log10 scale)', fontsize=15)\n",
    "plt.grid(visible=False)\n",
    "\n",
    "plt.legend(title='Datasets', title_fontsize='13', fontsize='11')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('label_distribution_across_datasets_with_classes.pdf', dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the geographical origin of the datasets\n",
    "dataset_countries = {'mimic': 'United States of America',\n",
    " 'ribeiroLabled': 'Brazil',\n",
    " 'samitrop': 'Brazil',\n",
    " 'chapman': 'China',\n",
    " 'georgia': 'United States of America',\n",
    " 'ptb': 'Germany',\n",
    " 'ningbo': 'China',\n",
    " 'ribeiroUnlabled': 'Brazil',\n",
    " 'cpscExtra': 'China',\n",
    " 'hefei': 'China',\n",
    " 'cpsc': 'China',\n",
    " 'sph': 'China',\n",
    " 'ptbxl': 'Germany'}\n",
    "\n",
    "# World map data\n",
    "world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))\n",
    "#world = gpd.read_file(\"path/to/your/ne_110m_admin_0_countries.shp\")\n",
    "\n",
    "# Extract relevant countries from the dataset_countries mapping\n",
    "countries = set(dataset_countries.values())\n",
    "highlight_countries = world[world['name'].isin(countries)]\n",
    "\n",
    "# Plotting the world map and highlighting the relevant countries\n",
    "plt.figure(figsize=(10, 10))\n",
    "world.plot(ax=plt.gca(), color='lightgray')\n",
    "highlight_countries.plot(ax=plt.gca(), color='skyblue')\n",
    "\n",
    "# Adding labels for each dataset\n",
    "for dataset, country in dataset_countries.items():\n",
    "    country_data = highlight_countries[highlight_countries['name'] == country]\n",
    "    #plt.text(country_data.geometry.centroid.x.values[0], \n",
    "    #         country_data.geometry.centroid.y.values[0], \n",
    "    #         dataset, fontsize=12, ha='center')\n",
    "\n",
    "# Adding title\n",
    "plt.title('Geographical Origin of Datasets', fontsize=20)\n",
    "plt.savefig('geographical_origin_datasets.pdf', dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code to generate Figure 2b, 2c and 2d in the paper, that is, the label-wise performance obtained by HuBERT-ECG SMALL, BASE and LARGE across tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = [\"Figure 2b.xlsx\", \"Figure 2c.xlsx\", \"Figure 2d.xlsx\"]\n",
    "sizes = ['small', 'base', 'large']\n",
    "\n",
    "labels_abbreviations = pd.read_csv(os.path.join(path_to_xlxs, \"labels_abbreviations.csv\"), sep=';')\n",
    "labels_abbreviations = labels_abbreviations[[\"Abbreviation\",\n",
    "                                            \"Gruppo 1 (ECG is the Primary Diagnostic Tool)\",\n",
    "                                            \"Gruppo 2 (ECG is a Supportive, Not Primary, Diagnostic Tool)\",\n",
    "                                            \"Gruppo 3 (prediction of CVE)\"]]\n",
    "\n",
    "labels_abbreviations.rename(columns={\"Gruppo 1 (ECG is the Primary Diagnostic Tool)\": \"Gruppo 1\",\n",
    "                                    \"Gruppo 2 (ECG is a Supportive, Not Primary, Diagnostic Tool)\": \"Gruppo 2\",\n",
    "                                    \"Gruppo 3 (prediction of CVE)\" : \"Gruppo 3\"}, inplace=True)\n",
    "\n",
    "labels_abbreviations.set_index(\"Abbreviation\", inplace=True)\n",
    "labels_abbreviations.fillna(0, inplace=True)\n",
    "\n",
    "for size, path in zip(sizes, paths):\n",
    "    performance = pd.read_excel(os.path.join(path_to_xlxs, path), index_col=0)\n",
    "\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(40, 10))\n",
    "\n",
    "    # Generate the color palette\n",
    "    colors = sns.color_palette('tab20', len(performance.index))\n",
    "\n",
    "    # Plot each label with its corresponding color\n",
    "    for i, label in enumerate(performance.index):\n",
    "        ax.scatter(performance.columns, performance.iloc[i], label=label, color=colors[i])\n",
    "\n",
    "    ax.set_xlabel('CONDITIONS', fontsize=15)\n",
    "    ax.set_xticklabels(performance.columns, rotation=90, fontsize=15)\n",
    "    ax.set_yticklabels(np.round(ax.get_yticks(), 1), fontsize=15)\n",
    "    ax.set_xlim(-1, len(performance.columns) + 0.01)\n",
    "    ax.set_ylabel('AUROC', fontsize=15)\n",
    "    fig.suptitle(f'HuBERT-ECG {size.upper()} label-wise performance', fontsize=23, y=1.03)\n",
    "\n",
    "    # Create the legend with the same colors\n",
    "    legend = ax.legend(ncol=len(performance.index)//3, \n",
    "                    handles=[patch.Patch(color=colors[i], label=performance.index[i]) for i in range(len(performance.index))],\n",
    "                    title='Tasks',\n",
    "                    loc=(0.42, 1.02),\n",
    "                    fontsize=18, \n",
    "                    title_fontsize=15)\n",
    "\n",
    "    # Add colored patches under the x labels to show the group of the label\n",
    "    for i, label in enumerate(performance.columns):\n",
    "        group = labels_abbreviations.loc[label].values.sum()\n",
    "        if group == 1:\n",
    "            color = 'green'\n",
    "        elif group == 2:\n",
    "            color = 'blue'\n",
    "        elif group == 3:\n",
    "            color = 'red'\n",
    "        plt.axvspan(i-0.5, i+0.5, color=color, alpha=0.2)\n",
    "\n",
    "    group_legend = ax.legend(handles=[patch.Patch(color='green', label='ECG is the primary diagnostic tool', alpha=0.2),\n",
    "                                    patch.Patch(color='blue', label='ECG is a supportive, not primary, diagnostic tool', alpha=0.2),\n",
    "                                    patch.Patch(color='red', label='Prediction of CVE', alpha=0.2)],\n",
    "                            loc=(0.0, 1.07),\n",
    "                            ncols=3,\n",
    "                            title='ECG diagnostic role-based classes',\n",
    "                            fontsize=18, \n",
    "                            title_fontsize=15) \n",
    "    fig.add_artist(legend)\n",
    "    fig.add_artist(group_legend)\n",
    "    plt.tight_layout()  \n",
    "    plt.hlines(0.9, 0-0.5, len(performance.columns)-0.5, linestyles='dashed', color='black', linewidth=0.5)\n",
    "    fig.savefig(\"./label_wise_performance_\" + size + \".svg\", bbox_inches = \"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code to generate Supplementary Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "supplementary_figure_1 = pd.read_excel(os.path.join(path_to_xlxs, \"Supplementary figure 1.xlsx\"), index_col=0)\n",
    "\n",
    "supplementary_figure_1\n",
    "\n",
    "c = supplementary_figure_1.index\n",
    "\n",
    "time_freq = supplementary_figure_1[\"time-freq\"]\n",
    "mixed = supplementary_figure_1[\"mixed\"]\n",
    "mfccs = supplementary_figure_1[\"mfccs\"]\n",
    "\n",
    "db_time_freq = 1.225\n",
    "db_mixed = 1.516\n",
    "db_mfcc = 1.213\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(c, time_freq, '--*', label=f\"time_freq (n=16, DB_C100={db_time_freq})\", color='tab:orange')\n",
    "plt.plot(c, mixed, '--x', label=f\"mixed (n=29, DB_C100={db_mixed})\", color='tab:blue')\n",
    "plt.plot(c, mfccs, '--*', label=f\"mfcc (n=39, DB_C100={db_mfcc})\", color='tab:green')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.xlabel(\"Number of clusters - C\")\n",
    "plt.ylabel(\"SSE\")\n",
    "plt.show()\n",
    "plt.savefig(\"./SSE_vs_C.svg\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "supplementary_figure_2 = pd.read_excel(os.path.join(path_to_xlxs, \"Supplementary Figure 2.xlsx\"))\n",
    "supplementary_figure_2.index = supplementary_figure_2.index + 1\n",
    "\n",
    "fig, imgs = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "img1, img2 = imgs\n",
    "img1.plot(supplementary_figure_2['Pre-training 100 Hz Validation Loss'], '-o', label='100 Hz')\n",
    "img1.plot(supplementary_figure_2['Pre-training 50 Hz Validation Loss'], '-o', label='50 Hz')\n",
    "img1.grid(True)\n",
    "img1.set_xlabel(\"Steps x 2500\")\n",
    "img1.set_ylabel(\"Validation loss\")\n",
    "img1.set_xticks([1, 10, 20, 30, 40, 50])\n",
    "img1.legend(loc=\"upper right\")\n",
    "img1.set_title(\"(a)\")\n",
    "\n",
    "img2.plot(supplementary_figure_2['Macro-avg AUROC Linear Evaluation 100 Hz'], '-o', label='100 Hz')\n",
    "img2.plot(supplementary_figure_2['Macro-avg AUROC Linear Evaluation 50 Hz'], '-o', label='50 Hz')\n",
    "img2.grid(True)\n",
    "img2.set_xlabel(\"Steps x 5000\")\n",
    "img2.set_ylabel(\"Macro-avg AUROC\")\n",
    "img2.set_xticks([1, 3, 5, 7, 9, 11, 13])\n",
    "img2.legend(loc=\"lower right\")\n",
    "img2.set_title(\"(b)\")\n",
    "\n",
    "plt.show()\n",
    "plt.savefig(\"upstream_downstream_performance_varying_samp_rate.svg\", bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "supplementary_figure_3 = pd.read_excel(os.path.join(path_to_xlxs, \"Supplementary Figure 3.xlsx\"))\n",
    "\n",
    "sse_500 = supplementary_figure_3[\"SSE C 500\"]\n",
    "sse_1000 = supplementary_figure_3[\"SSE C 1000\"]\n",
    "db_500 = supplementary_figure_3[\"DB C 500\"]\n",
    "db_1000 = supplementary_figure_3[\"DB C 1000\"]\n",
    "ch_500 = supplementary_figure_3[\"CH C 500\"]\n",
    "ch_1000 = supplementary_figure_3[\"CH C 1000\"]\n",
    "\n",
    "sse_500_it1 = sse_500.iloc[:6]\n",
    "sse_500_it2 = sse_500.iloc[6:]\n",
    "sse_1000_it1 = sse_1000.iloc[:6]\n",
    "sse_1000_it2 = sse_1000.iloc[6:]\n",
    "\n",
    "db_500_it1 = db_500.iloc[:6]\n",
    "db_500_it2 = db_500.iloc[6:]\n",
    "db_1000_it1 = db_1000.iloc[:6]\n",
    "db_1000_it2 = db_1000.iloc[6:]\n",
    "\n",
    "ch_500_it1 = ch_500.iloc[:6]\n",
    "ch_500_it2 = ch_500.iloc[6:]\n",
    "ch_1000_it1 = ch_1000.iloc[:6]\n",
    "ch_1000_it2 = ch_1000.iloc[6:]\n",
    "\n",
    "layers = [5, 6, 7, 8, 9, 10]\n",
    "                                       \n",
    "                            \n",
    "\n",
    "fig, imgs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))\n",
    "img1, img2, img3 = imgs\n",
    "img1.plot(layers, sse_500_it1, '-s', label='C = 500 (it1)', color='tab:blue')\n",
    "img1.plot(layers, sse_1000_it1, '-D', label='C = 1000 (it1)', color='tab:red')\n",
    "img1.plot(layers, sse_500_it2, '-s', label='C = 500 (it2)', color='tab:green')\n",
    "img1.plot(layers, sse_1000_it2, '-D', label='C = 1000 (it2)', color='tab:purple')\n",
    "img1.grid(True)\n",
    "img1.set_xticks(np.arange(1, 13))\n",
    "img1.fill_between(np.arange(1, 5, 0.1), 0, 4300, color='tab:grey', alpha=0.5)\n",
    "img1.fill_between(np.arange(10.1, 12.1, 0.1), 0, 4300, color='tab:grey', alpha=0.5)\n",
    "img1.set_xlabel(\"Encoding layers\")\n",
    "img1.set_ylabel(\"SSE ←\")\n",
    "img1.legend(loc=\"lower left\")\n",
    "\n",
    "img2.plot(layers, db_500_it1, '-s', label='C = 500 (it1)', color='tab:blue')\n",
    "img2.plot(layers, db_1000_it1, '-D', label='C = 1000 (it1)', color='tab:red')\n",
    "img2.plot(layers, db_500_it2, '-s', label='C = 500 (it2)', color='tab:green')\n",
    "img2.plot(layers, db_1000_it2, '-D', label='C = 1000 (it2)', color='tab:purple')\n",
    "img2.grid(True)\n",
    "img2.set_xlabel(\"Encoding layers\")\n",
    "img2.set_ylabel(\"Davies-Bouldin ←\")\n",
    "img2.legend(loc=\"lower left\")\n",
    "img2.set_xticks(np.arange(1, 13))\n",
    "img2.fill_between(np.arange(1, 5, 0.1), 0, 3, color='tab:grey', alpha=0.5)\n",
    "img2.fill_between(np.arange(10.1, 12.1, 0.1), 0, 3, color='tab:grey', alpha=0.5)\n",
    "\n",
    "img3.plot(layers, ch_500_it1, '-s', label='C = 500 (it1)', color='tab:blue')\n",
    "img3.plot(layers, ch_1000_it1, '-D', label='C = 1000 (it1)', color='tab:red')\n",
    "img3.plot(layers, ch_500_it2, '-s', label='C = 500 (it2)', color='tab:green')\n",
    "img3.plot(layers, ch_1000_it2, '-D', label='C = 1000 (it2)', color='tab:purple')\n",
    "img3.grid(True)\n",
    "img3.set_xlabel(\"Encoding layers\")\n",
    "img3.set_ylabel(\"Calinsky-Harabasz →\")\n",
    "img3.legend(loc=\"upper left\")\n",
    "img3.set_xticks(np.arange(1, 13))\n",
    "img3.fill_between(np.arange(1, 5, 0.1), 0, 65, color='tab:grey', alpha=0.5)\n",
    "img3.fill_between(np.arange(10.1, 12.1, 0.1), 0, 65, color='tab:grey', alpha=0.5)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "plt.savefig(\"clustering_quality_across_iterations_and_layers.svg\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "supplementary_figure_4 = pd.read_excel(os.path.join(path_to_xlxs, \"Supplementary Figure 4.xlsx\"), index_col=0)\n",
    "\n",
    "plt.figure(figsize=(12, 8))\n",
    "plt.plot(supplementary_figure_4['Linear Evaluation BASE it1'], '-o', color='tab:orange', label='BASE it1')\n",
    "plt.plot(supplementary_figure_4['Linear Evaluation BASE it2'], '-o', color='tab:blue', label='BASE it2')\n",
    "plt.plot(supplementary_figure_4['Linear Evaluation SMALL'], '-o', color='tab:green', label='SMALL')\n",
    "plt.plot(supplementary_figure_4['Linear Evaluation LARGE'], '-o', color='tab:red', label='LARGE')\n",
    "#plt.axvline(13, linestyle='--', color='tab:grey', label='plateau BASE it1')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.xlabel(\"Steps\")\n",
    "plt.ylabel(\"Macro-avg AUROC\")\n",
    "plt.xticks(supplementary_figure_4.index)\n",
    "plt.show()\n",
    "plt.savefig(\"linear_eval_varying_model_sizes.svg\", bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "supplementary_figure_5 = pd.read_excel(os.path.join(path_to_xlxs, \"Supplementary Figure 5.xlsx\"), index_col=0)\n",
    "\n",
    "p = supplementary_figure_5['masking p']\n",
    "aucs = supplementary_figure_5['Linear Evaluation AUROC']\n",
    "plt.figure(figsize=(12, 8))\n",
    "plt.plot(p, aucs, '-o')\n",
    "plt.axvline(0.21, linestyle='--', color='tab:red')\n",
    "plt.xlabel(\"Masking percentage p = Percentage of masked embeddings\")\n",
    "plt.xticks(p)\n",
    "plt.ylabel(\"Macro-averaged AUROC\")\n",
    "plt.grid()\n",
    "plt.savefig(\"masking_p.svg\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
