{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Paper Analysis Consolidation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup and Common Initializations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "import paper_analysis_utils as pau\n",
    "\n",
    "# --- Configuration ---\n",
    "\n",
    "# Path to the experiments provided in SnS_experiments on Zenodo\n",
    "SNS_METADATA_PATH = \"path/to/SnS_multiexp_dirs.json\" \n",
    "# Full path to the distance_params.json in demo\n",
    "DIST_PARAMS_PATH = 'path/to/demo/distance_params.json' \n",
    "\n",
    "# Full path to human accuracy data provided on Zenodo\n",
    "HUMAN_DATA_PATH = 'path/to/humanAcc.csv' # For MultiNet Correlation\n",
    "\n",
    "# Full path to the MultiNet ICLR accuracies directory provided on Zenodo\n",
    "MULTINET_ICLR_DIR = \"/path/to/multintACC_ICLR\"\n",
    "\n",
    "# Full path to the VITs accuracies directory provided on Zenodo\n",
    "VIT_ACC_DIR = \"/path/to/multinetACC_VIT\"\n",
    "\n",
    "\n",
    "BASE_OUTPUT_DIR = \"./paper_analysis_outputs\"\n",
    "os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {DEVICE}\")\n",
    "\n",
    "# ---Data Loading ---\n",
    "try:\n",
    "    experiments_metadata = pau.SnS_metadata.from_json(SNS_METADATA_PATH, recalculate=False)\n",
    "    print(f\"Successfully loaded SnS_metadata from {SNS_METADATA_PATH}\")\n",
    "except Exception as e:\n",
    "    print(f\"Error loading SnS_metadata: {e}. Some analyses might fail.\")\n",
    "    experiments_metadata = None\n",
    "\n",
    "try:\n",
    "    prms_dist = pau.read_json(DIST_PARAMS_PATH)\n",
    "    print(f\"Successfully loaded distance parameters from {DIST_PARAMS_PATH}\")\n",
    "except Exception as e:\n",
    "    print(f\"Error loading distance parameters: {e}. Distance analysis might fail.\")\n",
    "    prms_dist = None\n",
    "\n",
    "# Initialize a general-purpose generator (e.g., fc7)\n",
    "try:\n",
    "    # Using WEIGHTS from pau, which might be a placeholder if experiment.utils.args is not found\n",
    "    generator_fc7 = pau.DeePSiMGenerator(root=str(pau.WEIGHTS), variant='fc7').to(DEVICE)\n",
    "    print(\"Initialized DeePSiMGenerator (fc7 variant).\")\n",
    "except Exception as e:\n",
    "    print(f\"Error initializing DeePSiMGenerator (fc7): {e}\")\n",
    "    generator_fc7 = None\n",
    "\n",
    "# Initialize a reference TorchNetworkSubject for SVC (standard resnet50, '00_input_01')\n",
    "rec_ly_svc = '00_input_01'\n",
    "try:\n",
    "    probe_svc = pau.RecordingProbe(target={rec_ly_svc: []})\n",
    "    # Using torch_load from pau (pxdream.utils.torch_net_load_functs)\n",
    "    repr_net_svc_resnet50_vanilla = pau.TorchNetworkSubject(\n",
    "        record_probe=probe_svc, \n",
    "        network_name='resnet50', \n",
    "        t_net_loading=pau.torch_load, \n",
    "        custom_weights_path=''\n",
    "    ).eval().to(DEVICE)\n",
    "    print(\"Initialized TorchNetworkSubject for SVC (ResNet50 vanilla, '00_input_01').\")\n",
    "except Exception as e:\n",
    "    print(f\"Error initializing TorchNetworkSubject for SVC: {e}\")\n",
    "    repr_net_svc_resnet50_vanilla = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Centroid Plot + Activation Percentiles Boxplot "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from  snslib.metaexperiment.metaexp_functs import nat_percentiles\n",
    "import importlib\n",
    "import  snslib.metaexperiment.plots\n",
    "importlib.reload( snslib.metaexperiment.plots)\n",
    "from  snslib.metaexperiment.plots import SnS_scatterplot\n",
    "\n",
    "experiments_metadata.apply_analysis(queries=[['resnet50', '56_linear_01','robust_l2', '00_input_01', '500', 'invariance 2.0'],['resnet50', '56_linear_01','robust_l2', '00_input_01', '500', 'adversarial 1.0']],\n",
    "    callables=[nat_percentiles, SnS_scatterplot])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if experiments_metadata:\n",
    "    # Define experiment queries \n",
    "    exp_std_inv_pix = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','vanilla', '00_input_01', '500', 'invariance 2.0']])\n",
    "    exp_std_inv_25  = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','vanilla', '26_conv_25', '500', 'invariance 2.0']])\n",
    "    exp_std_inv_51  = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','vanilla', '52_conv_51', '500', 'invariance 2.0']])\n",
    "    exp_std_adv_pix = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','vanilla', '00_input_01', '500', 'adversarial 1.0']])\n",
    "\n",
    "    exp_r_inv_pix  = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','robust_l2', '00_input_01', '500', 'invariance 2.0']])\n",
    "    exp_r_inv_25    = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','robust_l2', '26_conv_25', '500', 'invariance 2.0']])\n",
    "    exp_r_inv_51    = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','robust_l2', '52_conv_51', '500', 'invariance 2.0']])\n",
    "    exp_r_adv_pix   = experiments_metadata.get_experiments(queries=[['resnet50', '56_linear_01','robust_l2', '00_input_01', '500', 'adversarial 1.0']])\n",
    "\n",
    "    # Calculate activation percentiles \n",
    "    print(\"Calculating activation percentiles...\")\n",
    "    nat_stats_std_inv_pix = pau.calculate_activation_percentiles(exps=exp_std_inv_pix)\n",
    "    nat_stats_r_inv_pix   = pau.calculate_activation_percentiles(exps=exp_r_inv_pix)\n",
    "    nat_stats_std_inv_25  = pau.calculate_activation_percentiles(exps=exp_std_inv_25)\n",
    "    nat_stats_r_inv_25    = pau.calculate_activation_percentiles(exps=exp_r_inv_25)\n",
    "    nat_stats_std_inv_51  = pau.calculate_activation_percentiles(exps=exp_std_inv_51)\n",
    "    nat_stats_r_inv_51    = pau.calculate_activation_percentiles(exps=exp_r_inv_51)\n",
    "    nat_stats_adv_std     = pau.calculate_activation_percentiles(exps=exp_std_adv_pix)\n",
    "    nat_stats_adv_rob     = pau.calculate_activation_percentiles(exps=exp_r_adv_pix)\n",
    "    \n",
    "\n",
    "    # Generate the boxplot\n",
    "    print(\"\\nGenerating Plot 1 (Standard Invariance and Adversarial)...\")\n",
    "    sns_experiments_plot1 = [\n",
    "        # Order matters for visual grouping \n",
    "        (\"Adv Pixel space\", nat_stats_adv_std, nat_stats_adv_rob),\n",
    "        (\"Inv Pixel space\", nat_stats_std_inv_pix, nat_stats_r_inv_pix),\n",
    "        (\"Inv Layer3_conv1\", nat_stats_std_inv_25, nat_stats_r_inv_25),\n",
    "        (\"Inv Layer4_conv7\", nat_stats_std_inv_51, nat_stats_r_inv_51)\n",
    "    ]\n",
    "    \n",
    "    pau.plot_activation_percentiles_boxplot(\n",
    "        general_stats_std=nat_stats_std_inv_pix, # For 'Same cat.', 'Other cats.'\n",
    "        general_stats_rob=nat_stats_r_inv_pix,\n",
    "        sns_experiment_data=sns_experiments_plot1,\n",
    "        save_dir=os.path.join(BASE_OUTPUT_DIR, \"activation_boxplot_outputs\"),\n",
    "        filename=\"activation_percentiles_plot1_inv_adv.png\"\n",
    "    )\n",
    "else:\n",
    "    print(\"SnS_metadata not loaded. Skipping Activation Percentiles Boxplot.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "percent_convergence_niter={}\n",
    "median_nat_perc = {}\n",
    "for exp in [exp_std_inv_pix,exp_std_inv_25,exp_std_inv_51,exp_std_adv_pix,exp_r_inv_pix,exp_r_inv_25,exp_r_inv_51,exp_r_adv_pix]:\n",
    "    exp_name= list(exp.keys())[0]\n",
    "    exp_df = exp[exp_name]['df']\n",
    "    percent_convergence_niter[exp_name] = (exp_df['nat_stat_early_stopping_p1_n_it'].isna().sum()/exp_df['nat_stat_early_stopping_p1_n_it'].shape[-1])\n",
    "    median_nat_perc[exp_name] = exp_df['nat_percentiles'].median()\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "median_nat_perc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convergence rate for target units in readout layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "percent_convergence_niter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convergence rate for target units in mid and high levels of the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_std_mid_trg25 = experiments_metadata.get_experiments(queries=[['resnet50', '26_conv_25','vanilla', '00_input_01', '500', 'invariance 2.0']])\n",
    "exp_std_high_trg51 = experiments_metadata.get_experiments(queries=[['resnet50', '52_conv_51','vanilla', '00_input_01', '500', 'invariance 2.0']])\n",
    "\n",
    "\n",
    "exp_r_mid_trg25 = experiments_metadata.get_experiments(queries=[['resnet50', '26_conv_25','robust_l2', '00_input_01', '500', 'invariance 2.0']])\n",
    "exp_r_high_trg51 = experiments_metadata.get_experiments(queries=[['resnet50', '52_conv_51','robust_l2', '00_input_01', '500', 'invariance 2.0']])\n",
    "\n",
    "\n",
    "percent_convergence_niter_diffTRG={}\n",
    "median_nat_perc_diffTRG = {}\n",
    "for exp in [exp_std_mid_trg25,exp_std_high_trg51,exp_r_mid_trg25,exp_r_high_trg51]:\n",
    "    exp_name= list(exp.keys())[0]\n",
    "    exp_df = exp[exp_name]['df']\n",
    "    percent_convergence_niter_diffTRG[exp_name] = 1-(exp_df['nat_stat_early_stopping_p1_n_it'].isna().sum()/exp_df['nat_stat_early_stopping_p1_n_it'].shape[-1])\n",
    "    \n",
    "nat_stats_std_mid_trg25 = pau.calculate_activation_percentiles(exps=exp_std_mid_trg25)\n",
    "nat_stats_r_mid_trg25   = pau.calculate_activation_percentiles(exps=exp_r_mid_trg25)\n",
    "nat_stats_std_high_trg51  = pau.calculate_activation_percentiles(exps=exp_std_high_trg51)\n",
    "nat_stats_r_high_trg51    = pau.calculate_activation_percentiles(exps=exp_r_high_trg51)\n",
    "\n",
    "if 'nat_stats_std_mid_trg25' in locals() and 'nat_stats_std_high_trg51' in locals():\n",
    "    sns_experiments_plot2 = [\n",
    "        (\"Mid Level Targets \", nat_stats_std_mid_trg25, nat_stats_r_mid_trg25),\n",
    "        (\"High Level Targets \", nat_stats_std_high_trg51, nat_stats_r_high_trg51)\n",
    "        \n",
    "    ]\n",
    "\n",
    "   \n",
    "    pau.plot_activation_percentiles_boxplot(\n",
    "        general_stats_std=nat_stats_std_mid_trg25, \n",
    "        general_stats_rob=nat_stats_r_mid_trg25,\n",
    "        sns_experiment_data=sns_experiments_plot2,\n",
    "        save_dir=os.path.join(BASE_OUTPUT_DIR, \"activation_boxplot_outputs\"),\n",
    "        filename=\"activation_percentiles_plot2_mid_high.png\"\n",
    "    )\n",
    "else:\n",
    "    print(\"Data for Plot 2 (Mid/High Target) not available. Skipping.\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "percent_convergence_niter_diffTRG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. SVC Analysis "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if experiments_metadata and generator_fc7 and repr_net_svc_resnet50_vanilla:\n",
    "    print(\"Starting SVC Analysis...\")\n",
    "    # Get experiment data for 'vanilla' and 'robust_l2'\n",
    "    queries_vanilla_svc = [\n",
    "        ['resnet50', '56_linear_01','vanilla', '00_input_01', '500', 'invariance 2.0'],\n",
    "        ['resnet50', '56_linear_01','vanilla', '26_conv_25', '500', 'invariance 2.0'],\n",
    "        ['resnet50', '56_linear_01','vanilla', '52_conv_51', '500', 'invariance 2.0']\n",
    "    ]\n",
    "    exp_vanilla_svc = experiments_metadata.get_experiments(queries=queries_vanilla_svc)\n",
    "\n",
    "    queries_robust_l2_svc = [\n",
    "        ['resnet50', '56_linear_01','robust_l2', '00_input_01', '500', 'invariance 2.0'],\n",
    "        ['resnet50', '56_linear_01','robust_l2', '26_conv_25', '500', 'invariance 2.0'],\n",
    "        ['resnet50', '56_linear_01','robust_l2', '52_conv_51', '500', 'invariance 2.0']\n",
    "    ]\n",
    "    exp_robust_l2_svc = experiments_metadata.get_experiments(queries=queries_robust_l2_svc)\n",
    "    \n",
    "    # Process data for 'vanilla' SVC\n",
    "    print(\"\\nProcessing 'vanilla' experiments for SVC...\")\n",
    "    reprs_vanilla = pau.process_experiment_data_for_svc(\n",
    "        exp_vanilla_svc, p1='end_p1_idxs', rec_ly=rec_ly_svc, \n",
    "        generator=generator_fc7, repr_net=repr_net_svc_resnet50_vanilla, DEVICE=DEVICE\n",
    "    )\n",
    "    final_result_vanilla = pau.unify_representations(reprs_vanilla)\n",
    "    \n",
    "    npcs_values_svc = [2, 10, 25, 50, 100]\n",
    "    van_results_svc = []\n",
    "    if final_result_vanilla and rec_ly_svc in final_result_vanilla[0] and final_result_vanilla[0][rec_ly_svc].size > 0:\n",
    "        van_results_svc = pau.calculate_classwise_scores_svc(\n",
    "            final_result_vanilla[0][rec_ly_svc], \n",
    "            np.array(final_result_vanilla[1]), \n",
    "            npcs_values_svc\n",
    "        )\n",
    "        print(\"Completed 'vanilla' SVC processing.\")\n",
    "    else:\n",
    "        print(\"Warning: No data for 'vanilla' SVC after processing. Skipping score calculation.\")\n",
    "\n",
    "    # Process data for 'robust_l2' SVC\n",
    "    print(\"\\nProcessing 'robust_l2' experiments for SVC...\")\n",
    "    reprs_robust = pau.process_experiment_data_for_svc(\n",
    "        exp_robust_l2_svc, p1='end_p1_idxs', rec_ly=rec_ly_svc, \n",
    "        generator=generator_fc7, repr_net=repr_net_svc_resnet50_vanilla, DEVICE=DEVICE # Using same standard ResNet50 for representation\n",
    "    )\n",
    "    final_result_robust = pau.unify_representations(reprs_robust)\n",
    "    \n",
    "    rob_results_svc = []\n",
    "    if final_result_robust and rec_ly_svc in final_result_robust[0] and final_result_robust[0][rec_ly_svc].size > 0:\n",
    "        rob_results_svc = pau.calculate_classwise_scores_svc(\n",
    "            final_result_robust[0][rec_ly_svc], \n",
    "            np.array(final_result_robust[1]), \n",
    "            npcs_values_svc\n",
    "        )\n",
    "        print(\"Completed 'robust_l2' SVC processing.\")\n",
    "    else:\n",
    "        print(\"Warning: No data for 'robust_l2' SVC after processing. Skipping score calculation.\")\n",
    "\n",
    "    # Plot SVC results\n",
    "    if van_results_svc or rob_results_svc: # Check if there are any results to plot\n",
    "        print(\"\\nPlotting SVC results...\")\n",
    "        display_names_map_svc = {\n",
    "            '00_input_01': 'Pixel Space',\n",
    "            '26_conv_25':  'Layer3_conv1',\n",
    "            '52_conv_51':  'Layer4_conv7'\n",
    "        }\n",
    "        # Use keys from one of the experiment dicts for layer labels in the plot legend\n",
    "        # The specific choice (vanilla or robust) for keys mainly affects label generation if formats differ.\n",
    "        # Assuming consistent formatting, either should work. Original used the last one (robust).\n",
    "        exp_keys_for_plot = list(exp_robust_l2_svc.keys()) if exp_robust_l2_svc else list(exp_vanilla_svc.keys())\n",
    "        \n",
    "        pau.plot_svc_accuracy(\n",
    "            van_results_svc, rob_results_svc, \n",
    "            exp_keys_source=exp_keys_for_plot, \n",
    "            display_names_map=display_names_map_svc, \n",
    "            save_dir=os.path.join(BASE_OUTPUT_DIR, \"svc_analysis_outputs\")\n",
    "        )\n",
    "    else:\n",
    "        print(\"No results to plot for SVC analysis.\")\n",
    "else:\n",
    "    print(\"SnS_metadata, generator, or SVC reference network not loaded. Skipping SVC Analysis.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Distance Analysis "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if prms_dist and SNS_METADATA_PATH: # Ensure prms_dist is loaded\n",
    "    print(\"Starting Distance Analysis...\")\n",
    "    da_results = pau.run_full_distance_analysis(\n",
    "        prms_dist,\n",
    "        SNS_METADATA_PATH, # Pass path, SnS_metadata will be loaded inside run_full_distance_analysis\n",
    "        DEVICE,\n",
    "        BASE_OUTPUT_DIR\n",
    "    )\n",
    "    print(\"\\nDistance Analysis complete.\")\n",
    "    print(f\"Results saved in: {da_results['analysis_dir']}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Multi-Network Accuracy Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure these paths as needed\n",
    "ICLR_ROOT = \"/path/to/multintACC_ICLR\"  # root containing Obs_* folders\n",
    "SAVE_DIR = \"/path\"  # None to skip saving\n",
    "FSZ_TITLE   = 29\n",
    "FSZ_TICKS   = 26\n",
    "FSZ_LEGEND  = 21\n",
    "LINE_WIDTH  = 3\n",
    "MARKER_SIZE = 11\n",
    "from pathlib import Path\n",
    "import os, re, json\n",
    "from matplotlib import patches\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "# Categories/order to mimic plot_multinet_iclr_averages\n",
    "categories = [\n",
    "    \"Natural\", \"Spacing1\", \"Robust MEI\", \"Robust Pixel space\",\n",
    "    \"Robust Layer3_conv1\", \"Robust Layer4_conv7\",\n",
    "    \"Spacing2\", \"Standard MEI\", \"Standard Pixel space\",\n",
    "    \"Standard Layer3_conv1\", \"Standard Layer4_conv7\",\n",
    "]\n",
    "\n",
    "cat2lbl = {\n",
    "    \"Robust MEI\": \"MEI\",\n",
    "    \"Robust Pixel space\":  \"Pixel space\",\n",
    "    \"Robust Layer3_conv1\": \"Layer3_conv1\",\n",
    "    \"Robust Layer4_conv7\": \"Layer4_conv7\",\n",
    "    \"Standard MEI\": \"MEI\",\n",
    "    \"Standard Pixel space\":  \"Pixel space\",\n",
    "    \"Standard Layer3_conv1\": \"Layer3_conv1\",\n",
    "    \"Standard Layer4_conv7\": \"Layer4_conv7\",\n",
    "}\n",
    "\n",
    "def parse_folder(entry_name):\n",
    "    # Expect: Obs_<obs>_(rob|std)_Gen_<gen>\n",
    "    m = re.match(r\"^Obs_(?P<obs>.+?)_(?P<obs_type>rob|std)_Gen_(?P<gen>.+)$\", entry_name, flags=re.IGNORECASE)\n",
    "    if not m:\n",
    "        return None\n",
    "    obs = m.group(\"obs\")\n",
    "    obs_type = m.group(\"obs_type\").lower()  # 'rob' or 'std'\n",
    "    gen = m.group(\"gen\")\n",
    "    return obs, obs_type, gen\n",
    "\n",
    "def build_category_series(acc_dict):\n",
    "    # Determine common subkeys (neuron IDs) across all cases for comparable averaging\n",
    "    subkey_sets = [set(v.keys()) for v in acc_dict.values() if isinstance(v, dict)]\n",
    "    common_subkeys = set.intersection(*subkey_sets) if subkey_sets else set()\n",
    "\n",
    "    # Infer two internal locations from case names (excluding input)\n",
    "    locs = []\n",
    "    for k in acc_dict.keys():\n",
    "        if k in (\"nat_refs\", \"mXDREAM - vanilla\", \"mXDREAM - l2robust\"):\n",
    "            continue\n",
    "        parts = k.split(\"#\")\n",
    "        if len(parts) >= 4:\n",
    "            locs.append(parts[3])\n",
    "    uniq_locs = sorted(set([l for l in locs if not str(l).startswith(\"00_input\")]))\n",
    "\n",
    "    loc_map = {}\n",
    "    if len(uniq_locs) >= 1:\n",
    "        loc_map[uniq_locs[0]] = \"Layer3_conv1\"\n",
    "    if len(uniq_locs) >= 2:\n",
    "        loc_map[uniq_locs[1]] = \"Layer4_conv7\"\n",
    "\n",
    "    def avg_vals(d):\n",
    "        if not isinstance(d, dict) or len(d) == 0:\n",
    "            return np.nan\n",
    "        keys = common_subkeys if common_subkeys else d.keys()\n",
    "        vals = [d[k] for k in keys if k in d]\n",
    "        return float(np.mean(vals)) if vals else np.nan\n",
    "\n",
    "    out = {c: np.nan for c in categories}  # keep spacings as NaN to break lines\n",
    "    for case_name, vals in acc_dict.items():\n",
    "        if case_name == \"nat_refs\":\n",
    "            out[\"Natural\"] = avg_vals(vals)\n",
    "            continue\n",
    "\n",
    "        if case_name.startswith(\"mXDREAM - \"):\n",
    "            tag = \"Robust\" if \"l2robust\" in case_name.lower() else \"Standard\"\n",
    "            out[f\"{tag} MEI\"] = avg_vals(vals)\n",
    "            continue\n",
    "\n",
    "        parts = case_name.split(\"#\")\n",
    "        if len(parts) < 4:\n",
    "            continue\n",
    "        tag = \"Robust\" if \"robust\" in parts[2].lower() else \"Standard\"\n",
    "        loc_code = parts[3]\n",
    "        if str(loc_code).startswith(\"00_input\"):\n",
    "            lbl = f\"{tag} Pixel space\"\n",
    "        else:\n",
    "            loc_lbl = loc_map.get(loc_code)\n",
    "            if not loc_lbl:\n",
    "                continue\n",
    "            lbl = f\"{tag} {loc_lbl}\"\n",
    "        out[lbl] = avg_vals(vals)\n",
    "\n",
    "    return pd.Series([out[c] for c in categories], index=categories)\n",
    "\n",
    "# Collect series per observer for generator Resnet50 only\n",
    "by_obs_type = {\"rob\": {}, \"std\": {}}\n",
    "\n",
    "if not os.path.isdir(ICLR_ROOT):\n",
    "    raise FileNotFoundError(f\"Provided iclr_root_dir does not exist: {ICLR_ROOT}\")\n",
    "\n",
    "for entry in sorted(os.listdir(ICLR_ROOT)):\n",
    "    subdir = os.path.join(ICLR_ROOT, entry)\n",
    "    if not os.path.isdir(subdir):\n",
    "        continue\n",
    "    parsed = parse_folder(entry)\n",
    "    if not parsed:\n",
    "        continue\n",
    "    obs_name, obs_type, gen_name = parsed\n",
    "    # Filter: only *_Gen_Resnet50\n",
    "    if gen_name != \"Resnet50\":\n",
    "        continue\n",
    "\n",
    "    acc_fp = os.path.join(subdir, \"accuracy.json\")\n",
    "    if not os.path.exists(acc_fp):\n",
    "        continue\n",
    "    try:\n",
    "        with open(acc_fp, \"r\") as f:\n",
    "            acc_dict = json.load(f)\n",
    "    except Exception as e:\n",
    "        print(f\"Could not read {acc_fp}: {e}\")\n",
    "        continue\n",
    "\n",
    "    s = build_category_series(acc_dict)\n",
    "    # Legend wants observer net from folder name; keep `obs_name`\n",
    "    label = obs_name\n",
    "    by_obs_type[obs_type][label] = s\n",
    "\n",
    "# Build DataFrames for plotting\n",
    "def to_df(group_dict):\n",
    "    if not group_dict:\n",
    "        return pd.DataFrame(index=categories)\n",
    "    df = pd.DataFrame(group_dict)  # index=categories, columns=observer labels\n",
    "    df = df.reindex(categories)\n",
    "    return df\n",
    "\n",
    "df_rob = to_df(by_obs_type[\"rob\"])\n",
    "df_std = to_df(by_obs_type[\"std\"])\n",
    "\n",
    "# If no data, abort gracefully\n",
    "if df_rob.empty and df_std.empty:\n",
    "    print(\"No data found for *_Gen_Resnet50 under the provided root.\")\n",
    "else:\n",
    "    # Color map per observer (consistent within subplot)\n",
    "    def build_color_map(df):\n",
    "        cols = list(df.columns)\n",
    "        cmap = plt.cm.tab20 if len(cols) <= 20 else plt.cm.gist_ncar\n",
    "        colors = cmap(np.linspace(0, 1, max(2, len(cols))))\n",
    "        return {c: colors[i % len(colors)] for i, c in enumerate(cols)}\n",
    "\n",
    "    colors_rob = build_color_map(df_rob) if not df_rob.empty else {}\n",
    "    colors_std = build_color_map(df_std) if not df_std.empty else {}\n",
    "    color_map = {\n",
    "        'densenet161':       np.array([0.190631, 0.407061, 0.556089, 1.]),\n",
    "        'resnet18':              np.array([0.565498, 0.84243,  0.262877, 1.]),\n",
    "        'resnet50':              np.array([0.993248, 0.906157, 0.143936, 1.]),\n",
    "        'resnext50_32x4d':      np.array([0.127568, 0.566949, 0.550556, 1.]),\n",
    "        'wide_resnet50_2':           np.array([0.267968, 0.84243, 0.512008, 1.]),\n",
    "        'shufflenet_v2_x1_0':    np.array([0.267004, 0.004874, 0.329415, 1.]),\n",
    "        'vgg16_bn':              np.array([0.267968, 0.223549, 0.512008, 1.]),\n",
    "    }\n",
    "    colors_rob = color_map\n",
    "    colors_std = color_map\n",
    "    \n",
    "    fig, axes = plt.subplots(1, 2, figsize=(15, 10), sharey=True)\n",
    "    idx_map = {lbl: i for i, lbl in enumerate(categories)}\n",
    "    x_pos = list(range(len(categories)))\n",
    "\n",
    "    # Indices for segment styling (same structure as original plot)\n",
    "    def idx(lbl): return idx_map.get(lbl)\n",
    "    i_rm, i_rp, i_r3, i_r4 = idx(\"Robust MEI\"), idx(\"Robust Pixel space\"), idx(\"Robust Layer3_conv1\"), idx(\"Robust Layer4_conv7\")\n",
    "    i_sm, i_sp, i_s3, i_s4 = idx(\"Standard MEI\"), idx(\"Standard Pixel space\"), idx(\"Standard Layer3_conv1\"), idx(\"Standard Layer4_conv7\")\n",
    "\n",
    "    plot_ticks = [i for i, lbl in enumerate(categories) if not lbl.startswith(\"Spacing\")]\n",
    "    plot_tick_labels = [lbl for lbl in categories if not lbl.startswith(\"Spacing\")]\n",
    "\n",
    "    def plot_group(ax, df_group, title, color_map, legend=True):\n",
    "        if df_group.empty:\n",
    "            ax.set_title(title)\n",
    "            ax.text(0.5, 0.5, \"No data\", ha='center', va='center', transform=ax.transAxes)\n",
    "            return\n",
    "\n",
    "        legend_handles, legend_labels = [], []\n",
    "\n",
    "        def draw_series(y_vals, color, label):\n",
    "            # Markers at all available points\n",
    "            idx_non_nan = [i for i, v in enumerate(y_vals) if not np.isnan(v)]\n",
    "            if idx_non_nan:\n",
    "                ax.plot([x_pos[i] for i in idx_non_nan], [y_vals[i] for i in idx_non_nan],\n",
    "                        linestyle='None', marker='o', markersize=MARKER_SIZE, color=color, alpha=1.0)\n",
    "\n",
    "            def seg(i, j, ls, lw):\n",
    "                if i is None or j is None:\n",
    "                    return\n",
    "                if np.isnan(y_vals[i]) or np.isnan(y_vals[j]):\n",
    "                    return\n",
    "                ax.plot([x_pos[i], x_pos[j]], [y_vals[i], y_vals[j]],\n",
    "                        linestyle=ls, color=color, linewidth=lw, alpha=1.0)\n",
    "\n",
    "            # Robust block: dotted between MEI->Pixel, solid afterwards\n",
    "            seg(i_rm, i_rp, ':', 2.0)\n",
    "            seg(i_rp, i_r3, '-', 2.0)\n",
    "            seg(i_r3, i_r4, '-', 2.0)\n",
    "\n",
    "            # Standard block: dotted between MEI->Pixel, solid afterwards\n",
    "            seg(i_sm, i_sp, ':', 2.0)\n",
    "            seg(i_sp, i_s3, '-', 2.0)\n",
    "            seg(i_s3, i_s4, '-', 2.0)\n",
    "\n",
    "            # Legend entry (line sample)\n",
    "            legend_handles.append(Line2D([0], [0], color=color, linewidth=LINE_WIDTH, linestyle='-'))\n",
    "            legend_labels.append(label)\n",
    "\n",
    "        for col in df_group.columns:\n",
    "            y_vals = df_group.loc[categories, col].values\n",
    "            draw_series(y_vals=y_vals, color=color_map.get(col.lower(), 'gray'), label=col)\n",
    "\n",
    "        ax.spines['bottom'].set_position(('data', 0))\n",
    "        ax.set_xticks(plot_ticks)\n",
    "        ax.set_xticklabels([cat2lbl[pt] if pt in cat2lbl else pt for pt in plot_tick_labels], rotation=45, ha='right', fontsize=FSZ_TICKS)\n",
    "        ax.set_title(title, fontsize=FSZ_TITLE, fontweight='bold')\n",
    "        line_thickness = LINE_WIDTH\n",
    "        ax.tick_params(axis='both', which='major', length=10, width=line_thickness)\n",
    "        ax.spines['left'].set_linewidth(line_thickness)\n",
    "        ax.spines['bottom'].set_linewidth(line_thickness)       \n",
    "        # remove top and right margins (spines)\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        ax.spines['bottom'].set_bounds(ax.get_xticks()[0], ax.get_xticks()[-1])\n",
    "        plt.rcParams['hatch.linewidth'] = 10  # default is 0.5\n",
    "\n",
    "        c_std = (1, 0, 0, 0.1)\n",
    "        c_rob = (0, 0.8, 0.6, 0.1)\n",
    "        add_alpha = (0, 0, 0, 0.4)\n",
    "        rect1 = patches.Rectangle(\n",
    "            (1.5, 0),   # bottom-left corner\n",
    "            4,                       # width (4.5-0.5)\n",
    "            1,  # full y-range\n",
    "            facecolor=c_rob,  # RGBA: red with alpha\n",
    "            hatch='/',                  # 45° stripes\n",
    "            edgecolor='white',    # hatch color\n",
    "        )\n",
    "\n",
    "        rect2 = patches.Rectangle(\n",
    "            (6.5, 0),   # bottom-left corner\n",
    "            8,                       # width (4.5-0.5)\n",
    "            1,  # full y-range\n",
    "            facecolor=c_std,  # RGBA: red with alpha\n",
    "            hatch='/',                  # 45° stripes\n",
    "            edgecolor='white',    # hatch color\n",
    "        )\n",
    "        # Set x spine at y=0\n",
    "        \n",
    "        ax.add_patch(rect1)\n",
    "        ax.add_patch(rect2)\n",
    "        ax.text(x=8.5, y=0.01, s=\"Standard\", ha='center', va='bottom', fontsize=FSZ_TITLE, color=tuple(np.array(c_std) + np.array(add_alpha)), fontweight='bold')\n",
    "        ax.text(x=3.5, y=0.01, s=\"Robust\", ha='center', va='bottom', fontsize=FSZ_TITLE, color=tuple(np.array(c_rob) + np.array(add_alpha)), fontweight='bold')\n",
    "\n",
    "        ax.axhline(y=1/12, color='k', linestyle='--', linewidth=LINE_WIDTH)\n",
    "        ax.text(x=9.5, y=1/12+ 0.01, s=\"Chance\", ha='center', va='bottom', fontsize=FSZ_TICKS-2, color='black')\n",
    "\n",
    "\n",
    "        # Create legend with line markers\n",
    "        if legend_handles and legend:\n",
    "            ax.legend(\n",
    "                legend_handles,\n",
    "                legend_labels,\n",
    "                fontsize=FSZ_LEGEND,               # match the second legend\n",
    "                loc=\"lower left\",          # match second legend\n",
    "                bbox_to_anchor=(0, 0.12),  # shift relative to axes\n",
    "                frameon=True,              # no frame\n",
    "                ncol=1,\n",
    "                title=\"Observer nets\",\n",
    "                title_fontproperties={'weight':'bold', 'size': FSZ_LEGEND}\n",
    "            )\n",
    "                \n",
    "\n",
    "    plot_group(axes[0], df_rob, \"Robust observers\", colors_rob, legend=True)\n",
    "    plot_group(axes[1], df_std, \"Standard observers\", colors_std, legend=False)\n",
    "    \n",
    "    axes[0].set_ylabel(\"Accuracy\", fontsize=FSZ_TICKS)\n",
    "    axes[0].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])\n",
    "    axes[0].spines['left'].set_bounds(axes[0].get_yticks()[0], axes[0].get_yticks()[-1])\n",
    "    axes[1].spines['left'].set_bounds(axes[1].get_yticks()[0], axes[1].get_yticks()[-1])\n",
    "\n",
    "    axes[0].tick_params(axis='y', labelsize=FSZ_TICKS)\n",
    "    axes[1].tick_params(axis='y', labelsize=FSZ_TICKS)\n",
    "\n",
    "    axes[0].grid(False)\n",
    "    axes[1].grid(False)\n",
    "\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    if SAVE_DIR:\n",
    "        os.makedirs(SAVE_DIR, exist_ok=True)\n",
    "        out_fp = os.path.join(SAVE_DIR, \"multinet_iclr_observers_Resnet50_noavg.svg\")\n",
    "        print(Path(out_fp).suffix.lstrip('.'))\n",
    "        plt.savefig(out_fp, format=Path(out_fp).suffix.lstrip('.'), bbox_inches='tight')\n",
    "\n",
    "        print(f\"Saved multi-net VIT accuracy plot to {out_fp}\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "def perform_correlation_analysis_with_human(human_data_path, df_robust, df_standard, color_map,\n",
    "                                            save_dir=None, filename=\"model_human_correlations.png\"):\n",
    "    \"\"\"\n",
    "    Performs correlation analysis between model accuracies (robust and standard) and human data.\n",
    "    It plots adjacent bars for robust and standard model correlations for each model architecture.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        df_human = pd.read_csv(human_data_path, index_col=0)\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: Human data file not found at {human_data_path}. Skipping correlation analysis.\")\n",
    "        return None, None\n",
    "\n",
    "    # --- Data Cleaning ---\n",
    "    df_human.index = df_human.index.str.replace(\"_group\", \"\").str.replace(\"_\", \" \").str.replace(\" conv\", \"_conv\").str.replace(\"Space\", \"space\")\n",
    "    df_human.index = df_human.index.str.replace('natural', 'Natural', case=False)\n",
    "    df_human.rename(columns={'participant_accuracy': 'Avg_human'}, inplace=True)\n",
    "\n",
    "    def calculate_correlations(model_df):\n",
    "        \"\"\"Helper function to merge and compute correlations for a single model dataframe.\"\"\"\n",
    "        common_indices = model_df.index.intersection(df_human.index)\n",
    "        if common_indices.empty:\n",
    "            return {}, {}\n",
    "\n",
    "        merged_df = pd.concat([df_human.loc[common_indices], model_df.loc[common_indices]], axis=1)\n",
    "        merged_df.dropna(subset=['Avg_human'], inplace=True)\n",
    "        merged_df = merged_df.dropna(axis=1, how='all')\n",
    "\n",
    "        model_cols = [c for c in merged_df.columns if c != \"Avg_human\"]\n",
    "        if not model_cols:\n",
    "            return {}, {}\n",
    "\n",
    "        results = {col: pearsonr(merged_df[\"Avg_human\"].dropna(), merged_df[col].dropna())\n",
    "                   for col in model_cols if merged_df[col].notna().sum() > 1}\n",
    "        \n",
    "        valid_results = {k: v for k, v in results.items() if not (np.isnan(v[0]) or np.isnan(v[1]))}\n",
    "        \n",
    "        correlations = {model: stats[0] for model, stats in valid_results.items()}\n",
    "        p_values = {model: stats[1] for model, stats in valid_results.items()}\n",
    "        return correlations, p_values\n",
    "\n",
    "    # --- Calculate Correlations for Both DataFrames ---\n",
    "    corr_rob, p_val_rob = calculate_correlations(df_robust)\n",
    "    corr_std, p_val_std = calculate_correlations(df_standard)\n",
    "\n",
    "    if not corr_rob and not corr_std:\n",
    "        print(\"Warning: No valid correlations could be computed for either robust or standard models.\")\n",
    "        return None, None\n",
    "\n",
    "    # --- Prepare for Plotting ---\n",
    "    # Use the union of model names, sorted by the standard observer correlation values\n",
    "    all_models_set = set(corr_rob.keys()) | set(corr_std.keys())\n",
    "    all_models = sorted(list(all_models_set), key=lambda model: corr_std.get(model, 0), reverse=True)\n",
    "    \n",
    "    # Get correlation values, using 0 if a model is missing from a group\n",
    "    r_values_rob = [corr_rob.get(model, 0) for model in all_models]\n",
    "    r_values_std = [corr_std.get(model, 0) for model in all_models]\n",
    "\n",
    "    # --- Plotting ---\n",
    "    x = np.arange(len(all_models))\n",
    "    width = 0.35\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(15, 10))\n",
    "    \n",
    "    # Prepare colors with different alpha values\n",
    "    colors_rob = []\n",
    "    colors_std = []\n",
    "    for model in all_models:\n",
    "        base_color = color_map.get(model.lower(), (0.5, 0.5, 0.5, 1.0))\n",
    "        \n",
    "        # Color for robust (alpha=1.0)\n",
    "        color_rob = list(base_color)\n",
    "        color_rob[3] = 1.0\n",
    "        colors_rob.append(tuple(color_rob))\n",
    "        \n",
    "        # Color for standard (alpha=0.5)\n",
    "        color_std = list(base_color)\n",
    "        color_std[3] = 0.5\n",
    "        colors_std.append(tuple(color_std))\n",
    "\n",
    "    rects1 = ax.bar(x - width/2, r_values_rob, width, label='Robust Observers', color=colors_rob)\n",
    "    rects2 = ax.bar(x + width/2, r_values_std, width, label='Standard Observers', color=colors_std)\n",
    "    \n",
    "    if len(rects1) > 0:\n",
    "        rect = rects1[0]\n",
    "        height = rect.get_height()\n",
    "        if height > 0:\n",
    "            ax.text(rect.get_x() + rect.get_width() / 2.0, \n",
    "                    0.35 + (height - 0.35) / 2.0, 'Robust', \n",
    "                    ha='center', va='center', color='white', rotation=90, fontsize=30)\n",
    "\n",
    "    if len(rects2) > 0:\n",
    "        rect = rects2[0]\n",
    "        height = rect.get_height()\n",
    "        if height > 0:\n",
    "            ax.text(rect.get_x() + rect.get_width() / 2.0, \n",
    "                    0.35 + (height - 0.35) / 2.0, 'Standard', \n",
    "                    ha='center', va='center', color='white', rotation=90, fontsize=30)\n",
    "\n",
    "    fonts = 30\n",
    "    ax.set_ylabel(\"Pearson corr.coef with \\navg. human\", fontsize=fonts)\n",
    "\n",
    "    ax.set_ylim(0.35, 1)\n",
    "    line_thickness = 2\n",
    "    ax.tick_params(axis='both', which='major', length=10, width=line_thickness)\n",
    "    ax.spines['left'].set_linewidth(line_thickness)\n",
    "    ax.spines['bottom'].set_linewidth(line_thickness)       \n",
    "    # remove top and right margins (spines)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['bottom'].set_position(('data', 0.35))\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(all_models, rotation=45, ha='right', fontsize=fonts)\n",
    "\n",
    "    ax.axhline(0, color='grey', linewidth=0.8)\n",
    "    ax.grid(False)\n",
    "    ax.set_yticks([0.4, 0.6, 0.8, 1.0])\n",
    "    ax.tick_params(axis='y', labelsize=fonts)\n",
    "    ax.spines['left'].set_bounds(ax.get_yticks()[0], ax.get_yticks()[-1])\n",
    "    ax.spines['bottom'].set_bounds(ax.get_xticks()[0], ax.get_xticks()[-1])\n",
    "    fig.tight_layout()\n",
    "\n",
    "\n",
    "    if save_dir:\n",
    "        os.makedirs(save_dir, exist_ok=True)\n",
    "        out_fp = os.path.join(save_dir, filename)\n",
    "        print(Path(out_fp).suffix.lstrip('.'))\n",
    "        plt.savefig(out_fp, format=Path(out_fp).suffix.lstrip('.'), bbox_inches='tight')\n",
    "\n",
    "        print(f\"Saved model-human correlation plot to {out_fp}\")\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "    # --- Print Results ---\n",
    "    print(\"Correlations (r-value, p-value):\")\n",
    "    print(\"\\n--- Robust Observers ---\")\n",
    "    # Print in the new sorted order for consistency with the plot\n",
    "    for model in all_models:\n",
    "        if model in corr_rob:\n",
    "            print(f\"  {model}: r = {corr_rob[model]:.3f}, p = {p_val_rob[model]:.3g}\")\n",
    "    \n",
    "    print(\"\\n--- Standard Observers ---\")\n",
    "    # Print in the new sorted order for consistency with the plot\n",
    "    for model in all_models:\n",
    "        if model in corr_std:\n",
    "            print(f\"  {model}: r = {corr_std[model]:.3f}, p = {p_val_std[model]:.3g}\")\n",
    "\n",
    "    return (corr_rob, p_val_rob), (corr_std, p_val_std)\n",
    "\n",
    "\n",
    "# --- Function Call ---\n",
    "HUMAN_DATA_PATH = '/path/to/means_by_group.csv'\n",
    "BASE_OUTPUT_DIR = \"./paper_analysis_outputs\"\n",
    "\n",
    "(corr_rob, p_val_rob), (corr_std, p_val_std) = perform_correlation_analysis_with_human(\n",
    "    human_data_path=HUMAN_DATA_PATH,\n",
    "    df_robust=df_rob,\n",
    "    df_standard=df_std,\n",
    "    color_map=color_map,\n",
    "    save_dir=os.path.join(BASE_OUTPUT_DIR, \"multinet_iclr_analysis_outputs\"),\n",
    "    filename=\"multinet_iclr_human_correlation_sorted_by_standard.svg\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import paper_analysis_utils as pau\n",
    "\n",
    "\n",
    "print(\"Starting Multi-Network ICLR Accuracy Analysis...\")\n",
    "pau.plot_multinet_iclr_averages(MULTINET_ICLR_DIR, save_dir=BASE_OUTPUT_DIR, filename=\"multinet_iclr_averages.svg\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df, df_orig, color_map, ordered_labels = pau.load_and_process_vit_multinet_accuracies(VIT_ACC_DIR)\n",
    "if df.empty:\n",
    "    print(\"No VIT compare data found.\")\n",
    "else:\n",
    "    pau.plot_vit_multinet_accuracy_trends(\n",
    "        df, ordered_labels, color_map,\n",
    "        save_dir=os.path.join(BASE_OUTPUT_DIR, \"multinet_vit_outputs\"),\n",
    "        filename=\"multinet_accuracy_trends_VIT.svg\"\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zdream",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
