{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c35fe35d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "import scipy.stats as stats\n",
    "import random\n",
    "import re\n",
    "import utilities as ut\n",
    "import modularised_utils as mut\n",
    "import networkx as nx\n",
    "\n",
    "from matplotlib.animation import FuncAnimation\n",
    "from IPython.display import HTML\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "seed = 0\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dbee5ff9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data loaded for 'slc'.\n"
     ]
    }
   ],
   "source": [
    "experiment = 'slc'\n",
    "setting    = 'gaussian'\n",
    "\n",
    "if setting == 'gaussian':\n",
    "    path = f\"data/{experiment}/results\"\n",
    "\n",
    "elif setting == 'empirical':\n",
    "    path = f\"data/{experiment}/results_empirical\"\n",
    "\n",
    "saved_folds = joblib.load(f\"data/{experiment}/cv_folds.pkl\")\n",
    "\n",
    "# Load the original data dictionary\n",
    "all_data      = ut.load_all_data(experiment)\n",
    "\n",
    "LLmodel       = all_data['LLmodel']\n",
    "HLmodel       = all_data['HLmodel']\n",
    "Dll_samples   = all_data['LLmodel']['data']\n",
    "Dhl_samples   = all_data['HLmodel']['data']\n",
    "ll_graph      = all_data['LLmodel']['graph']\n",
    "hl_graph      = all_data['HLmodel']['graph']\n",
    "I_ll_relevant = all_data['LLmodel']['intervention_set']\n",
    "ll_interventions = all_data['LLmodel']['intervention_set']\n",
    "hl_interventions = all_data['HLmodel']['intervention_set']\n",
    "omega         = all_data['abstraction_data']['omega']\n",
    "ll_var_names  = list(all_data['LLmodel']['graph'].nodes())\n",
    "hl_var_names  = list(all_data['HLmodel']['graph'].nodes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34efba9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dictionaries containing the results for each optimization method\n",
    "if setting == 'gaussian':\n",
    "    diroca_results = joblib.load(f\"{path}/diroca_cv_results.pkl\")\n",
    "    gradca_results = joblib.load(f\"{path}/gradca_cv_results.pkl\")\n",
    "    baryca_results = joblib.load(f\"{path}/baryca_cv_results.pkl\")\n",
    "\n",
    "elif setting == 'empirical':\n",
    "    diroca_results = joblib.load(f\"{path}/diroca_cv_results_empirical.pkl\")\n",
    "    gradca_results = joblib.load(f\"{path}/gradca_cv_results_empirical.pkl\")\n",
    "    baryca_results = joblib.load(f\"{path}/baryca_cv_results_empirical.pkl\")\n",
    "    abslingam_results = joblib.load(f\"{path}/abslingam_cv_results_empirical.pkl\")\n",
    "\n",
    "results_to_evaluate = {}\n",
    "\n",
    "if setting == 'empirical':\n",
    "    if abslingam_results:\n",
    "        first_fold_key = list(abslingam_results.keys())[0]\n",
    "        for style in abslingam_results[first_fold_key].keys():\n",
    "            method_name = f\"Abs-LiNGAM ({style})\"\n",
    "            new_abslingam_dict = {}\n",
    "            for fold_key, fold_results in abslingam_results.items():\n",
    "                if style in fold_results:\n",
    "                    new_abslingam_dict[fold_key] = {style: fold_results[style]}\n",
    "            results_to_evaluate[method_name] = new_abslingam_dict\n",
    "    \n",
    "    def create_diroca_label(run_id):\n",
    "        \"\"\"Parses a run_id and creates a simplified label if epsilon and delta are equal.\"\"\"\n",
    "        # Use regular expression to find numbers for epsilon and delta\n",
    "        matches = re.findall(r'(\\d+\\.?\\d*)', run_id)\n",
    "        if len(matches) == 2:\n",
    "            eps, delta = matches\n",
    "            if eps == delta:\n",
    "                # Handle integer conversion for clean labels like '1' instead of '1.0'\n",
    "                val = int(float(eps)) if float(eps).is_integer() else float(eps)\n",
    "                return f\"DIROCA (eps_delta_{val})\"\n",
    "        return f\"DIROCA ({run_id})\"\n",
    "\n",
    "    # Unpack each DIROCA hyperparameter run with the new, clean label\n",
    "    if diroca_results:\n",
    "        first_fold_key = list(diroca_results.keys())[0]\n",
    "        for run_id in diroca_results[first_fold_key].keys():\n",
    "            method_name = create_diroca_label(run_id) # Use the new helper to create the name\n",
    "            new_diroca_dict = {}\n",
    "            for fold_key, fold_results in diroca_results.items():\n",
    "                if run_id in fold_results:\n",
    "                    new_diroca_dict[fold_key] = {run_id: fold_results[run_id]}\n",
    "            results_to_evaluate[method_name] = new_diroca_dict\n",
    "\n",
    "    results_to_evaluate['GradCA'] = gradca_results\n",
    "    results_to_evaluate['BARYCA'] = baryca_results\n",
    "\n",
    "elif setting == 'gaussian':\n",
    "    results_to_evaluate['GradCA'] = gradca_results\n",
    "    results_to_evaluate['BARYCA'] = baryca_results\n",
    "\n",
    "    if diroca_results:\n",
    "        first_fold_key = list(diroca_results.keys())[0]\n",
    "        diroca_run_ids = list(diroca_results[first_fold_key].keys())\n",
    "\n",
    "        # create a separate entry for each DIROCA run\n",
    "        for run_id in diroca_run_ids:\n",
    "            method_name = f\"DIROCA ({run_id})\"\n",
    "            \n",
    "            new_diroca_dict = {}\n",
    "            for fold_key, fold_results in diroca_results.items():\n",
    "                # For each fold grab the data for the current run_id\n",
    "                if run_id in fold_results:\n",
    "                    new_diroca_dict[fold_key] = {run_id: fold_results[run_id]}\n",
    "            \n",
    "            results_to_evaluate[method_name] = new_diroca_dict\n",
    "\n",
    "label_map_gaussian = {\n",
    "                        'DIROCA (eps_delta_0.111)': 'DiRoCA_star',\n",
    "                        'DIROCA (eps_delta_1)': 'DIROCA_1',\n",
    "                        'DIROCA (eps_delta_2)': 'DIROCA_2',\n",
    "                        'DIROCA (eps_delta_4)': 'DIROCA_4',\n",
    "                        'DIROCA (eps_delta_8)': 'DIROCA_8',\n",
    "                        'GradCA': 'GradCA',\n",
    "                        'BARYCA': 'BARYCA'\n",
    "                    }\n",
    "\n",
    "label_map_empirical = {\n",
    "                        'DIROCA (eps_0.328_delta_0.107)': 'DiRoCA_star',\n",
    "                        'DIROCA (eps_delta_1)': 'DIROCA_1',\n",
    "                        'DIROCA (eps_delta_2)': 'DIROCA_2',\n",
    "                        'DIROCA (eps_delta_4)': 'DIROCA_4',\n",
    "                        'DIROCA (eps_delta_8)': 'DIROCA_8',\n",
    "                        'GradCA': 'GradCA',\n",
    "                        'BARYCA': 'BARYCA',\n",
    "                        'Abs-LiNGAM (Perfect)': 'Abslin_p',\n",
    "                        'Abs-LiNGAM (Noisy)': 'Abslin_n'\n",
    "                    }\n",
    "\n",
    "if setting == 'empirical':\n",
    "    results_to_evaluate = {label_map_empirical.get(key, key): value for key, value in results_to_evaluate.items()}\n",
    "\n",
    "elif setting == 'gaussian':\n",
    "    results_to_evaluate = {label_map_gaussian.get(key, key): value for key, value in results_to_evaluate.items()}\n",
    "\n",
    "print(\"\\nMethods available for evaluation:\")\n",
    "for key in results_to_evaluate.keys():\n",
    "    print(f\"  - {key}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d7384ef",
   "metadata": {},
   "source": [
    "# F-misspecification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43608d96",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_structural_contamination(\n",
    "    linear_data,\n",
    "    graph,\n",
    "    coeffs,\n",
    "    noise,\n",
    "    nonlinear_func=np.sin,\n",
    "    k=1\n",
    "):\n",
    "    \"\"\"\n",
    "    Applies structural contamination to SCM data with flexible reuse strategy.\n",
    "\n",
    "    Args:\n",
    "        linear_data (np.ndarray): Original SCM output (already includes noise).\n",
    "        graph (CausalBayesianNetwork): DAG structure.\n",
    "        coeffs (dict): Edge weights {(parent, child): weight}.\n",
    "        noise (np.ndarray): Exogenous noise (same shape as linear_data).\n",
    "        k (float): Contamination strength.\n",
    "        nonlinear_func (callable): Nonlinear function applied to parent values.\n",
    "\n",
    "    Returns:\n",
    "        np.ndarray: Contaminated data (n_samples x dim).\n",
    "    \"\"\"\n",
    "    n_samples, dim = linear_data.shape\n",
    "    topo_order = list(nx.topological_sort(graph))\n",
    "    var_index = {var: idx for idx, var in enumerate(topo_order)}\n",
    "\n",
    "    contaminated = np.zeros_like(noise)\n",
    "\n",
    "    for var in topo_order:\n",
    "        var_idx = var_index[var]\n",
    "        parents = list(graph.predecessors(var))\n",
    "\n",
    "        if not parents:\n",
    "            linear_part = np.zeros(n_samples)\n",
    "            nonlinear_part = np.zeros(n_samples)\n",
    "        else:\n",
    "            parent_indices = [var_index[p] for p in parents]\n",
    "            parent_vals = contaminated[:, parent_indices]\n",
    "            weights = np.array([coeffs.get((p, var), 0.0) for p in parents])\n",
    "            linear_part = parent_vals @ weights\n",
    "            nonlinear_part = k*nonlinear_func(parent_vals).sum(axis=1)\n",
    "\n",
    "        contaminated[:, var_idx] = nonlinear_part + noise[:, var_idx]\n",
    "\n",
    "    return contaminated\n",
    "\n",
    "def sin(x):\n",
    "    return np.sin(x)\n",
    "\n",
    "def tanh(x):\n",
    "    return np.tanh(x)\n",
    "\n",
    "def square(x):\n",
    "    return x**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6aea8c0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_trials              = 100\n",
    "nonlinear_func          = sin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "258478b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_spec_records = []\n",
    "for trial in range(num_trials):\n",
    "    for i, fold_info in enumerate(saved_folds):\n",
    "        for method_name, results_dict in results_to_evaluate.items():\n",
    "            fold_results = results_dict.get(f'fold_{i}', {})\n",
    "            for run_key, run_data in fold_results.items():\n",
    "                T_learned = run_data['T_matrix']\n",
    "                test_indices = run_data['test_indices']\n",
    "\n",
    "                errors_per_intervention = []\n",
    "\n",
    "                for iota in I_ll_relevant:\n",
    "                    # Prepare inputs\n",
    "                    Dll_clean = Dll_samples[iota][test_indices]\n",
    "                    Dhl_clean = Dhl_samples[omega[iota]][test_indices]\n",
    "\n",
    "                    noise_ll = LLmodel['noise'][iota][test_indices]\n",
    "                    noise_hl = HLmodel['noise'][omega[iota]][test_indices]\n",
    "\n",
    "                    Dll_cont = apply_structural_contamination(\n",
    "                        linear_data=Dll_clean,\n",
    "                        graph=ll_graph,\n",
    "                        coeffs=LLmodel['coeffs'],\n",
    "                        noise=noise_ll,\n",
    "                        nonlinear_func=nonlinear_func\n",
    "                    )\n",
    "\n",
    "                    Dhl_cont = apply_structural_contamination(\n",
    "                        linear_data=Dhl_clean,\n",
    "                        graph=hl_graph,\n",
    "                        coeffs=HLmodel['coeffs'],\n",
    "                        noise=noise_hl,\n",
    "                        nonlinear_func=nonlinear_func\n",
    "                    )\n",
    "\n",
    "                    if setting == 'gaussian':\n",
    "                        error = ut.calculate_abstraction_error(T_learned, Dll_cont, Dhl_cont)\n",
    "                    elif setting == 'empirical':\n",
    "                        error = ut.calculate_empirical_error(T_learned, Dll_cont, Dhl_cont)\n",
    "                    else:\n",
    "                        raise ValueError(f\"Unknown setting: {setting}\")\n",
    "\n",
    "                    if not np.isnan(error):\n",
    "                        errors_per_intervention.append(error)\n",
    "\n",
    "                avg_error = np.mean(errors_per_intervention) if errors_per_intervention else np.nan\n",
    "                f_spec_records.append({\n",
    "                    'method': method_name,\n",
    "                    'trial': trial,\n",
    "                    'fold': i,\n",
    "                    'error': avg_error\n",
    "                })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e29b9f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "f_spec_df = pd.DataFrame(f_spec_records)\n",
    "\n",
    "print(\"\\n--- F-Misspecification Evaluation Complete ---\")\n",
    "print(\"=\"*65)\n",
    "print(\"Overall Performance (Averaged Across All Nonlinearity Strengths)\")\n",
    "print(\"=\"*65)\n",
    "print(f\"{'Method/Run':<35} | {'Mean ± Std'}\")\n",
    "print(\"=\"*65)\n",
    "\n",
    "summary = f_spec_df.groupby('method')['error'].agg(['mean', 'std', 'count'])\n",
    "summary['sem'] = summary['std']  \n",
    "# summary['sem'] = summary['std'] / np.sqrt(summary['count'])\n",
    "summary['ci95'] = summary['sem']  \n",
    "\n",
    "for method_name, row in summary.sort_values('mean').iterrows():\n",
    "    print(f\"{method_name:<35} | {row['mean']:.4f} ± {row['ci95']:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0474993b",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map_empirical = {\n",
    "    'DIROCA (eps_0.328_delta_0.107)': 'DiRoCA_star',\n",
    "    'DIROCA (eps_delta_1)': 'DIROCA_1',\n",
    "    'DIROCA (eps_delta_2)': 'DIROCA_2',\n",
    "    'DIROCA (eps_delta_4)': 'DIROCA_4',\n",
    "    'DIROCA (eps_delta_8)': 'DIROCA_8',\n",
    "    'GradCA': 'GradCA',\n",
    "    'BARYCA': 'BARYCA',\n",
    "    'Abs-LiNGAM (Perfect)': 'Abslin_p',\n",
    "    'Abs-LiNGAM (Noisy)': 'Abslin_n'\n",
    "}\n",
    "label_map_gaussian = {\n",
    "    'DIROCA (eps_delta_0.111)': 'DiRoCA_star',\n",
    "    'DIROCA (eps_delta_1)': 'DIROCA_1',\n",
    "    'DIROCA (eps_delta_2)': 'DIROCA_2',\n",
    "    'DIROCA (eps_delta_4)': 'DIROCA_4',\n",
    "    'DIROCA (eps_delta_8)': 'DIROCA_8',\n",
    "    'GradCA': 'GradCA',\n",
    "    'BARYCA': 'BARYCA'\n",
    "}\n",
    "\n",
    "print_label_map  = {\n",
    "    'DiRoCA_star':  r'DiRoCA$_{\\epsilon_\\ell^*, \\epsilon_h^*}$',\n",
    "    'DIROCA_1':     r'DiRoCA$_{1,1}$',\n",
    "    'DIROCA_2':     r'DiRoCA$_{2,2}$',\n",
    "    'DIROCA_4':     r'DiRoCA$_{4,4}$',\n",
    "    'DIROCA_8':     r'DiRoCA$_{8,8}$',\n",
    "    'GradCA':       r'GRAD$_{(\\tau, \\omega)}$',\n",
    "    'BARYCA':       r'BARY$_{(\\tau, \\omega)}$',\n",
    "    'Abslin_p':     r'AbsLin$_{\\text{p}}$',\n",
    "    'Abslin_n':     r'AbsLin$_{\\text{n}}$'\n",
    "}\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": False,\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.serif\": [\"Computer Modern Roman\", \"CMU Serif\", \"DejaVu Serif\"],\n",
    "    \"mathtext.fontset\": \"cm\",\n",
    "    \"mathtext.rm\": \"serif\"\n",
    "})\n",
    "\n",
    "methods_to_plot = ['DiRoCA_star', 'DIROCA_1', 'DIROCA_2', 'DIROCA_4', 'DIROCA_8', 'GradCA', 'BARYCA', 'Abslin_p', 'Abslin_n']\n",
    "display_names = [print_label_map[m] for m in methods_to_plot]\n",
    "\n",
    "color_map = {\n",
    "    r'DiRoCA$_{\\epsilon_\\ell^*, \\epsilon_h^*}$': '#1f77b4',\n",
    "    r'DiRoCA$_{1,1}$': 'gold',\n",
    "    r'DiRoCA$_{2,2}$': 'darkorange',\n",
    "    r'DiRoCA$_{4,4}$': 'lightskyblue',\n",
    "    r'DiRoCA$_{8,8}$': 'violet',\n",
    "    r'GRAD$_{(\\tau, \\omega)}$': '#2ca02c',\n",
    "    r'BARY$_{(\\tau, \\omega)}$': '#d62728',\n",
    "    r'AbsLin$_{\\text{p}}$': '#9467bd',\n",
    "    r'AbsLin$_{\\text{n}}$': '#8c564b'\n",
    "}\n",
    "\n",
    "def _create_diroca_label(run_id):\n",
    "    \"\"\"If epsilon==delta in run_id string, compress to eps_delta_v form.\"\"\"\n",
    "    matches = re.findall(r'(\\d+\\.?\\d*)', run_id)\n",
    "    if len(matches) == 2:\n",
    "        eps, delta = matches\n",
    "        if eps == delta:\n",
    "            val = int(float(eps)) if float(eps).is_integer() else float(eps)\n",
    "            return f\"DIROCA (eps_delta_{val})\"\n",
    "    return f\"DIROCA ({run_id})\"\n",
    "\n",
    "def build_results_to_evaluate(experiment, setting):\n",
    "    \"\"\"Rebuild results_to_evaluate for a given (experiment, setting).\"\"\"\n",
    "    if setting == 'gaussian':\n",
    "        path = f\"data/{experiment}/results\"\n",
    "        diroca_results = joblib.load(f\"{path}/diroca_cv_results.pkl\")\n",
    "        gradca_results = joblib.load(f\"{path}/gradca_cv_results.pkl\")\n",
    "        baryca_results = joblib.load(f\"{path}/baryca_cv_results.pkl\")\n",
    "        label_map = label_map_gaussian\n",
    "    else:\n",
    "        path = f\"data/{experiment}/results_empirical\"\n",
    "        diroca_results = joblib.load(f\"{path}/diroca_cv_results_empirical.pkl\")\n",
    "        gradca_results = joblib.load(f\"{path}/gradca_cv_results_empirical.pkl\")\n",
    "        baryca_results = joblib.load(f\"{path}/baryca_cv_results_empirical.pkl\")\n",
    "        abslingam_results = joblib.load(f\"{path}/abslingam_cv_results_empirical.pkl\")\n",
    "        label_map = label_map_empirical\n",
    "\n",
    "    results_to_evaluate = {}\n",
    "\n",
    "    # DIROCA variants\n",
    "    if diroca_results:\n",
    "        first_fold_key = list(diroca_results.keys())[0]\n",
    "        diroca_run_ids = list(diroca_results[first_fold_key].keys())\n",
    "        for run_id in diroca_run_ids:\n",
    "            if setting == 'empirical':\n",
    "                method_name = _create_diroca_label(run_id)\n",
    "            else:\n",
    "                method_name = f\"DIROCA ({run_id})\"\n",
    "\n",
    "            new_diroca = {}\n",
    "            for fold_key, fold_res in diroca_results.items():\n",
    "                if run_id in fold_res:\n",
    "                    new_diroca[fold_key] = {run_id: fold_res[run_id]}\n",
    "            results_to_evaluate[method_name] = new_diroca\n",
    "\n",
    "    # Baselines\n",
    "    results_to_evaluate['GradCA'] = gradca_results\n",
    "    results_to_evaluate['BARYCA'] = baryca_results\n",
    "\n",
    "    if setting == 'empirical':\n",
    "        if abslingam_results:\n",
    "            first_fold_key = list(abslingam_results.keys())[0]\n",
    "            for style in abslingam_results[first_fold_key].keys():\n",
    "                method_name = f\"Abs-LiNGAM ({style})\"\n",
    "                new_abs = {}\n",
    "                for fold_key, fold_res in abslingam_results.items():\n",
    "                    if style in fold_res:\n",
    "                        new_abs[fold_key] = {style: fold_res[style]}\n",
    "                results_to_evaluate[method_name] = new_abs\n",
    "\n",
    "    # Map to canonical internal keys\n",
    "    results_to_evaluate = {label_map.get(k, k): v for k, v in results_to_evaluate.items()}\n",
    "    return results_to_evaluate\n",
    "\n",
    "def run_k_sweep(experiment, setting, k_values, num_trials=5, strength=1, scaled=True, nonlinear_func=None):\n",
    "    \"\"\"Compute f_spec_df for (experiment, setting) over k_values.\"\"\"\n",
    "    # Load folds & data\n",
    "    folds_path = f\"data/{experiment}/cv_folds.pkl\"\n",
    "    saved_folds = joblib.load(folds_path)\n",
    "    all_data = ut.load_all_data(experiment)\n",
    "\n",
    "    Dll_samples   = all_data['LLmodel']['data']\n",
    "    Dhl_samples   = all_data['HLmodel']['data']\n",
    "    LLmodel       = all_data['LLmodel']\n",
    "    HLmodel       = all_data['HLmodel']\n",
    "    ll_graph      = all_data['LLmodel']['graph']\n",
    "    hl_graph      = all_data['HLmodel']['graph']\n",
    "    I_ll_relevant = all_data['LLmodel']['intervention_set']\n",
    "    omega         = all_data['abstraction_data']['omega']\n",
    "\n",
    "    results_to_evaluate = build_results_to_evaluate(experiment, setting)\n",
    "\n",
    "    # Compute records\n",
    "    records = []\n",
    "    for k in k_values:\n",
    "        for trial in range(num_trials):\n",
    "            for i, fold_info in enumerate(saved_folds):\n",
    "                for method_name, results_dict in results_to_evaluate.items():\n",
    "                    fold_results = results_dict.get(f'fold_{i}', {})\n",
    "                    for run_key, run_data in fold_results.items():\n",
    "                        T_learned = run_data['T_matrix']\n",
    "                        test_indices = run_data['test_indices']\n",
    "\n",
    "                        errors = []\n",
    "                        for iota in I_ll_relevant:\n",
    "                            Dll_clean = Dll_samples[iota][test_indices]\n",
    "                            Dhl_clean = Dhl_samples[omega[iota]][test_indices]\n",
    "\n",
    "                            noise_ll = LLmodel['noise'][iota][test_indices]\n",
    "                            noise_hl = HLmodel['noise'][omega[iota]][test_indices]\n",
    "\n",
    "                            Dll_cont = apply_structural_contamination(\n",
    "                                linear_data=Dll_clean,\n",
    "                                graph=ll_graph,\n",
    "                                coeffs=LLmodel['coeffs'],\n",
    "                                noise=noise_ll,\n",
    "                                nonlinear_func=nonlinear_func,\n",
    "                                k=k\n",
    "                            )\n",
    "                            Dhl_cont = apply_structural_contamination(\n",
    "                                linear_data=Dhl_clean,\n",
    "                                graph=hl_graph,\n",
    "                                coeffs=HLmodel['coeffs'],\n",
    "                                noise=noise_hl,\n",
    "                                nonlinear_func=nonlinear_func,\n",
    "                                k=k\n",
    "                            )\n",
    "\n",
    "                            if setting == 'gaussian':\n",
    "                                err = ut.calculate_abstraction_error(T_learned, Dll_cont, Dhl_cont)\n",
    "                            else:\n",
    "                                err = ut.calculate_empirical_error(T_learned, Dll_cont, Dhl_cont)\n",
    "\n",
    "                            if not np.isnan(err):\n",
    "                                errors.append(err)\n",
    "\n",
    "                        avg_error = np.mean(errors) if errors else np.nan\n",
    "                        records.append({\n",
    "                            'method': method_name,\n",
    "                            'k_value': k,\n",
    "                            'trial': trial,\n",
    "                            'fold': i,\n",
    "                            'error': avg_error\n",
    "                        })\n",
    "\n",
    "    df = pd.DataFrame(records)\n",
    "    present_methods = sorted(set(df['method']))\n",
    "    keep = [m for m in methods_to_plot if m in present_methods]\n",
    "    df = df[df['method'].isin(keep)].copy()\n",
    "    df['display_name'] = df['method'].map(print_label_map)\n",
    "    return df, keep\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b8d3f67",
   "metadata": {},
   "outputs": [],
   "source": [
    "k_values = np.linspace(0, 100, 15)\n",
    "num_trials = 1\n",
    "nonlinear_func = tanh  \n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(18, 12), sharey=False)\n",
    "axes = np.array(axes).reshape(2, 2)\n",
    "\n",
    "# (row, col) -> (experiment, setting)\n",
    "panels = [\n",
    "    ((0, 0), ('slc', 'gaussian'),   \"Gaussian\"),\n",
    "    ((0, 1), ('slc', 'empirical'),  \"Empirical\"),\n",
    "    ((1, 0), ('lilucas', 'gaussian'),  \"Gaussian\"),\n",
    "    ((1, 1), ('lilucas', 'empirical'), \"Empirical\"),\n",
    "]\n",
    "\n",
    "present_disp_names_global = []\n",
    "\n",
    "for (r, c), (experiment, setting), title_str in panels:\n",
    "    ax = axes[r, c]\n",
    "\n",
    "    df_panel, keep_methods = run_k_sweep(\n",
    "        experiment, setting, k_values,\n",
    "        num_trials=num_trials, nonlinear_func=nonlinear_func\n",
    "    )\n",
    "\n",
    "    disp_order = [print_label_map[m] for m in methods_to_plot if m in keep_methods]\n",
    "\n",
    "    if not df_panel.empty and disp_order:\n",
    "        sns.lineplot(\n",
    "            data=df_panel,\n",
    "            x='k_value',\n",
    "            y='error',\n",
    "            hue='display_name',\n",
    "            hue_order=disp_order,\n",
    "            palette=color_map,\n",
    "            marker='o',\n",
    "            linewidth=2.5,\n",
    "            markersize=8,\n",
    "            errorbar='sd',\n",
    "            ax=ax,\n",
    "            legend=False  \n",
    "        )\n",
    "\n",
    "        present_disp_names_global.extend(disp_order)\n",
    "\n",
    "    ax.set_title(title_str, fontsize=30)\n",
    "    ax.set_xlabel(r'$k$', fontsize=32)\n",
    "    if c == 0:\n",
    "        ax.set_ylabel('Abstraction Error', fontsize=32)\n",
    "    else:\n",
    "        ax.set_ylabel('')\n",
    "    ax.tick_params(axis='both', labelsize=18)\n",
    "    ax.grid(True, linestyle='--', alpha=0.7)\n",
    "\n",
    "\n",
    "present_disp_names_global = [dn for dn in display_names if dn in set(present_disp_names_global)]\n",
    "\n",
    "ordered_handles = [\n",
    "    plt.Line2D([], [], linestyle='-', linewidth=6,\n",
    "               label=dn, color=color_map.get(dn, '#000000'))\n",
    "    for dn in present_disp_names_global\n",
    "]\n",
    "\n",
    "fig.legend(\n",
    "    ordered_handles,\n",
    "    present_disp_names_global,\n",
    "    loc='lower center',\n",
    "    ncol=min(6, len(present_disp_names_global)),\n",
    "    fontsize=20,\n",
    "    frameon=False\n",
    ")\n",
    "\n",
    "fig.tight_layout(rect=[0, 0.12, 1, 0.97])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f620c121",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "30cf506e",
   "metadata": {},
   "source": [
    "# ω-misspecification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8619c1be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def contaminate_omega_semantic(original_omega, ll_interventions, hl_interventions, \n",
    "                                        num_misalignments, seed=None, delta=None, return_changed=False):\n",
    "    rng = random.Random(seed)\n",
    "    contaminated_omega = dict(original_omega)\n",
    "\n",
    "    eligible_ll = [ll for ll in original_omega if ll is not None]\n",
    "    to_corrupt = rng.sample(eligible_ll, k=min(num_misalignments, len(eligible_ll)))\n",
    "\n",
    "    changed = 0\n",
    "\n",
    "    for ll_intervention in to_corrupt:\n",
    "        original_target = original_omega[ll_intervention]\n",
    "        ll_complexity = 0 if ll_intervention is None else len(ll_intervention.vv())\n",
    "\n",
    "        # 1) same complexity, different target\n",
    "        same = [hl for hl in hl_interventions\n",
    "                if hl is not None and hl != original_target and len(hl.vv()) == ll_complexity]\n",
    "\n",
    "        candidates = same\n",
    "\n",
    "        # 2) fallback: nearest complexity (if needed)\n",
    "        if not candidates:\n",
    "            pairs = [(hl, abs(len(hl.vv()) - ll_complexity))\n",
    "                     for hl in hl_interventions if hl is not None and hl != original_target]\n",
    "            if pairs:\n",
    "                min_diff = min(diff for _, diff in pairs)\n",
    "                # enforce a cap if provided (delta)\n",
    "                if delta is not None and min_diff > delta:\n",
    "                    continue  # skip this ll; no near-enough HL target\n",
    "                candidates = [hl for hl, diff in pairs if diff == min_diff]\n",
    "            else:\n",
    "                continue  # no alternative HL at all\n",
    "\n",
    "        # pick a new target and set\n",
    "        new_target = rng.choice(candidates)\n",
    "        if new_target != original_target:\n",
    "            contaminated_omega[ll_intervention] = new_target\n",
    "            changed += 1\n",
    "\n",
    "    if return_changed:\n",
    "        return contaminated_omega, changed\n",
    "    return contaminated_omega\n",
    "\n",
    "\n",
    "def evaluate_omega_contamination(original_omega, ll_interventions, hl_interventions, total_interventions, results_to_evaluate, saved_folds, \n",
    "                                Dll_samples, Dhl_samples, setting, num_trials, delta):\n",
    "    \"\"\"Evaluates omega contamination across different levels.\"\"\"\n",
    "\n",
    "    contamination_levels = [int(total_interventions * 1.0)]\n",
    "    omega_contamination_records = []\n",
    "    \n",
    "    for num_misalignments in contamination_levels:\n",
    "        for trial in range(num_trials):\n",
    "            contaminated_omega = contaminate_omega_semantic(\n",
    "                original_omega, ll_interventions, hl_interventions, num_misalignments, \n",
    "                seed=seed, delta=delta, return_changed=False\n",
    "            )\n",
    "            \n",
    "            for fold_id, fold_info in enumerate(saved_folds):\n",
    "                for method_name, results_dict in results_to_evaluate.items():\n",
    "                    fold_results = results_dict.get(f'fold_{fold_id}', {})\n",
    "                    \n",
    "                    for run_key, run_data in fold_results.items():\n",
    "                        T_learned = run_data['T_matrix']\n",
    "                        test_indices = run_data['test_indices']\n",
    "                        \n",
    "                        errors_per_intervention = []\n",
    "                        \n",
    "                        for ll_intervention in ll_interventions:\n",
    "                            if ll_intervention is None:\n",
    "                                continue\n",
    "                                \n",
    "                            contaminated_hl_intervention = contaminated_omega[ll_intervention]\n",
    "                            if contaminated_hl_intervention is None:\n",
    "                                continue\n",
    "                            \n",
    "                            Dll_test = Dll_samples[ll_intervention][test_indices]\n",
    "                            Dhl_test = Dhl_samples[contaminated_hl_intervention][test_indices]\n",
    "                            \n",
    "                            if setting == 'gaussian':\n",
    "                                error = ut.calculate_abstraction_error(T_learned, Dll_test, Dhl_test)\n",
    "                            elif setting == 'empirical':\n",
    "                                error = ut.calculate_empirical_error(T_learned, Dll_test, Dhl_test)\n",
    "                            \n",
    "                            if not np.isnan(error):\n",
    "                                errors_per_intervention.append(error)\n",
    "                        \n",
    "                        avg_error = np.mean(errors_per_intervention) if errors_per_intervention else np.nan\n",
    "                        \n",
    "                        omega_contamination_records.append({\n",
    "                            'method': method_name,\n",
    "                            'num_misalignments': num_misalignments,\n",
    "                            'trial': trial,\n",
    "                            'fold': fold_id,\n",
    "                            'error': avg_error\n",
    "                        })\n",
    "    \n",
    "    return omega_contamination_records"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1b63c0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get total number of interventions\n",
    "total_interventions  = len([ll for ll in omega if ll is not None])\n",
    "num_trials = 100\n",
    "deltas = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebcb599d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for delta in deltas:\n",
    "    omega_contamination_records = evaluate_omega_contamination(\n",
    "        omega, ll_interventions, hl_interventions, total_interventions, \n",
    "        results_to_evaluate, saved_folds, Dll_samples, Dhl_samples, setting, num_trials, delta\n",
    "    )\n",
    "\n",
    "    omega_contamination_df = pd.DataFrame(omega_contamination_records)\n",
    "\n",
    "    print(\"=== OMEGA CONTAMINATION EVALUATION RESULTS ===\")\n",
    "    print(\"=\"*60)\n",
    "\n",
    "    # Overall performance summary\n",
    "    overall_performance = omega_contamination_df.groupby('method')['error'].agg(['mean', 'std']).round(2)\n",
    "    overall_performance = overall_performance.sort_values('mean')\n",
    "    overall_performance.columns = ['Mean Error', 'Std Error']\n",
    "\n",
    "    print(\"Overall Performance (All Contamination Levels):\")\n",
    "    print(\"=\"*60)\n",
    "    for method, row in overall_performance.iterrows():\n",
    "        print(f\"{method:<15} | {row['Mean Error']:6.2f} ± {row['Std Error']:5.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f93008a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "erica",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
