{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "NUM_SEEDS = 20\n",
    "DIMS = [2, 5, 10, 20 ,30]\n",
    "DIMS = [5, 10, 20 ,30,50]\n",
    "\n",
    "LESLOGEIONLY = False\n",
    "\n",
    "WITHIN_MDL = False\n",
    "\n",
    "function = 'gpsample_gpytorchMDL'#,'square' #'gpsample_gpytorchMDL'\n",
    "#function = 'gpsample_within_mdl'\n",
    "function = 'gpsample_within_mdl_20250512'\n",
    "\n",
    "\n",
    "avrg_best_histories = []\n",
    "lower_quantiles_histories = []\n",
    "upper_quantiles_histories = []\n",
    "avrg_rank_histories = []\n",
    "significantly_worse_at_end_histories = []\n",
    "cum_avrg_best_histories = []\n",
    "cum_lower_quantiles_histories = []\n",
    "cum_upper_quantiles_histories = []\n",
    "cum_significantly_worse_at_end_histories = []\n",
    "cum_avrg_rank_histories = []\n",
    "\n",
    "\n",
    " \n",
    "function_names = [r'\\makecell{\\textbf{high}: $\\mathrm{p}(l) =$ \\\\   $\\mathrm{logn}(-2.5\\sqrt{2} + \\mathrm{log} \\sqrt{d},\\sqrt{3}/5)$ }', # \\\\ \\mathbb{E}[\\mathrm{p}(l)] =  \n",
    "                  r'\\makecell{\\textbf{medium}: $\\mathrm{p}(l) =$ \\\\   $\\mathrm{logn}(-2.0\\sqrt{2} + \\mathrm{log} \\sqrt{d},\\sqrt{3}/4)$ }',\n",
    "                  r'\\makecell{\\textbf{low}: $\\mathrm{p}(l) =$ \\\\   $\\mathrm{logn}(-1.0\\sqrt{2} + \\mathrm{log} \\sqrt{d},\\sqrt{3}/2)$ }',\n",
    "                 r'\\makecell{\\textbf{extremely low} \\cite{hvarfner2024vanilla}: $\\mathrm{p}(l) =$ \\\\   $\\mathrm{logn}(1.0\\sqrt{2} + \\mathrm{log} \\sqrt{d},\\sqrt{3})$ }']\n",
    "\n",
    "\n",
    "excpected_ls = np.array([[0.08, 0.11, 0.15, 0.19, 0.25],[0.16, 0.23,0.33,0.40,0.52],[0.83, 1.19, 1.67, 2.05, 2.65],[21.86,30.92,34.73,53.56,69.15]]) # medium, low, ext_low\n",
    "# wilson within model length scales: 0.56, 0.79, 1.18, 1.37,  1.77   \n",
    "if WITHIN_MDL:\n",
    "    functions = ['within_mdl_high/gpsample','within_mdl_medium/gpsample','within_mdl_low/gpsample','within_mdl_ext_low/gpsample']\n",
    "else:\n",
    "    #functions = ['oom_HVARFNER_HYPERPRIOR_very2_complex/gpsample','oom_HVARFNER_HYPERPRIOR_very_complex/gpsample','oom_HVARFNER_HYPERPRIOR_complex/gpsample','gpsample_oom_HVARFNER_HYPERPRIOR/gpsample']\n",
    "    functions = ['oom_high/gpsample','oom_medium/gpsample','oom_low/gpsample','oom_ext_low/gpsample']\n",
    "\n",
    "for CUM_RES in [True, False]:\n",
    "    for function in functions:\n",
    "        DISTR_PATH = \"./Data/\"+function+\"/local_optima_distribution/\"\n",
    "        BEST_HIST_PATH = \"./Data/\"+function+\"/optimizer_history/list_of_bests_\"\n",
    "        SAMPLED_POINTS_PATH = \"./Data/\"+function+\"/sampled_data/sampled_data_history_\"\n",
    "        LENGTH_SCALE_PATH = \"./Data/\"+function+\"/length_scale/length_scale_\"\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        if LESLOGEIONLY:\n",
    "            METHODS = ['les_250_8','logei','turbo', 'sobol']  \n",
    "        else:\n",
    "            if WITHIN_MDL:\n",
    "                METHODS = ['les_250_8','mes','logei','turbo','hci_gibo_09', 'sobol'] # final selection of algorithms #,'hci_gibo_09'\n",
    "            else:\n",
    "                METHODS = ['les_250_8','mes','logei','loghvarei','turbo','hci_gibo_09', 'sobol']\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        LABEL_NAMES = {'tracing':'With gradient tracing',\n",
    "        'mes':'MES',\n",
    "        'logei':'logEI',\n",
    "        'loghvarei':'logEI-DSP',\n",
    "        'turbo':'TuRBO',\n",
    "        'sobol':'Sobol random',\n",
    "        'std_gibo':'GIBO',\n",
    "        'hci_gibo':'HCI-GIBO',\n",
    "        'hci_gibo_09':'HCI-GIBO',\n",
    "        'les_20_8':'LES-ADAM: L = 20, P = 8',\n",
    "        'les_250_4':'LES: L = 250, P = 4',\n",
    "        'les_250_16':'LES: L = 250, P = 16',\n",
    "        'les_250_8':'LES (ours)',\n",
    "        'les_fp_wgrad':'LES-ADAM-Opt.Cond.: L = 250, P = 8 ',\n",
    "        'lesGD_250_8':'LES-GD: L = 250, P = 8',\n",
    "        'lesgradcond_20_8':'LES-ADAM-Grad.-Cond.: L = 20, P = 8',\n",
    "        'localTS':'Local Thompson Sampling',\n",
    "        'lesCMAES_20_8':'LES-CMAES: L = 20, P = 8'}\n",
    "\n",
    "        METHOD_TABLE_NAMES = []\n",
    "        for method in METHODS:\n",
    "            METHOD_TABLE_NAMES.append(LABEL_NAMES[method])\n",
    "\n",
    "\n",
    "        def decompress_gibo(df):\n",
    "            data = df[['y','n']].to_numpy(dtype=float)\n",
    "            repeats = np.diff(data[:,-1].astype(int))\n",
    "            repeats = np.insert(repeats, 0, data[0,-1])\n",
    "            repeated = np.repeat(data[:, :-1], repeats, axis=0)\n",
    "            return np.minimum.accumulate(repeated)\n",
    "\n",
    "        ######################## calculate median and quantiles of the best observed function value ##############\n",
    "\n",
    "        from matplotlib.lines import Line2D\n",
    "\n",
    "        avrg_best_history = []\n",
    "        std_best_history = []\n",
    "        lower_quantiles_history =[]\n",
    "        upper_quantiles_history =[]\n",
    "        avrg_rank_history = []\n",
    "        significantly_worse_at_end_history = []\n",
    "        avrg_time_deltas = []\n",
    "        std_time_deltas = []\n",
    "\n",
    "    \n",
    "\n",
    "        for dim in DIMS: \n",
    "            data_mean = []\n",
    "            stds = []\n",
    "\n",
    "            time_deltas_mean_per_method = []\n",
    "            time_deltas_std_per_method = []\n",
    "            lower_quantile =[]\n",
    "            upper_quantile =[]\n",
    "\n",
    "\n",
    "            num_objective_calls = min(20*dim, 400)\n",
    "\n",
    "            all_y_data_per_dim = np.zeros((NUM_SEEDS,len(METHODS),num_objective_calls))\n",
    "\n",
    "\n",
    "            for method in range(len(METHODS)):\n",
    "                y_data = np.zeros((0,0))\n",
    "                # Collect timestamp differences for all seeds having a timestamp column\n",
    "                time_arrays = []\n",
    "                for seed in range(NUM_SEEDS):\n",
    "                    \n",
    "\n",
    "                    \n",
    "                    file_identifier = f'{(seed+1):05d}_{dim}_{METHODS[method]}.csv'\n",
    "                    if CUM_RES:\n",
    "                        try: \n",
    "                            table = pd.read_csv(SAMPLED_POINTS_PATH + file_identifier) \n",
    "                        except: \n",
    "                            print(f'Unable to find file {SAMPLED_POINTS_PATH+file_identifier}.')\n",
    "                            #continue\n",
    "                    else:\n",
    "                        try: \n",
    "                            table = pd.read_csv(BEST_HIST_PATH + file_identifier) \n",
    "                        except: \n",
    "                            print(f'Unable to find file {BEST_HIST_PATH+file_identifier}.')\n",
    "                            #continue\n",
    "                        \n",
    "                    if METHODS[method] == 'std_gibo' or METHODS[method] == 'hci_gibo' or METHODS[method] == 'hci_gibo_09' and not CUM_RES :\n",
    "                        new_data = decompress_gibo(table.dropna())\n",
    "                    else:\n",
    "                        if CUM_RES:\n",
    "                            new_data = np.reshape(table['y'].to_numpy(), [-1,1])\n",
    "                            if np.isnan(new_data[-1]):\n",
    "                                new_data[-1] = new_data[-2]\n",
    "                        else:\n",
    "                            new_data = np.reshape(table['f'].to_numpy(), [-1,1])\n",
    "                    \n",
    "\n",
    "                    if CUM_RES:\n",
    "                        if y_data.shape[0] == 0:\n",
    "                            y_data = np.expand_dims(np.cumsum(new_data[:num_objective_calls, :]),axis = 1)\n",
    "                        else:\n",
    "                            try: \n",
    "                                new_data = np.expand_dims(np.cumsum(new_data[:num_objective_calls, :]),axis = 1)\n",
    "                                y_data = np.concatenate([y_data, new_data], axis=1)\n",
    "                            except:\n",
    "                                y_data = np.concatenate([y_data, np.expand_dims(y_data[:,-1],1)], axis=1)\n",
    "                                print(METHODS[method])\n",
    "                                print('data missing augmenting with exisiting data set')\n",
    "\n",
    "                    else:\n",
    "                        if y_data.shape[0] == 0:\n",
    "                            y_data = new_data[:num_objective_calls, :]\n",
    "                        else: \n",
    "                            new_data = new_data[:num_objective_calls, :]\n",
    "                            y_data = np.concatenate([y_data, new_data], axis=1)\n",
    "\n",
    "                    if 'timestamp' in table.columns:\n",
    "                        # Crop timestamps to match the # of objective calls\n",
    "                        timestamps = table['timestamp'].to_numpy()[:num_objective_calls]\n",
    "                        # Compute the consecutive differences\n",
    "                        if len(timestamps) > 1:\n",
    "                            diffs = np.diff(timestamps)\n",
    "                            time_arrays.append(diffs)\n",
    "\n",
    "\n",
    "                    \n",
    "                data_mean.append(np.median(y_data, axis=1))\n",
    "                stds.append(np.std(y_data, axis=1))\n",
    "                all_y_data_per_dim[:,method,0:y_data.shape[0]] = np.transpose(y_data)\n",
    "                for i_add in range(y_data.shape[0],num_objective_calls,1):\n",
    "                    all_y_data_per_dim[:,method,i_add] = all_y_data_per_dim[:,method,y_data.shape[0]-1] \n",
    "                try:\n",
    "                    lower_quantile.append(np.quantile(y_data, 0.25, axis=1))\n",
    "                    upper_quantile.append(np.quantile(y_data, 0.75, axis=1))\n",
    "                except:\n",
    "                    lower_quantile.append([])\n",
    "                    upper_quantile.append([])\n",
    "                    \n",
    "                stds.append(np.std(y_data, axis=1))\n",
    "\n",
    "                # Collect mean/std across seeds for the time-deltas (if available)\n",
    "                if len(time_arrays) > 0:\n",
    "                    time_arrays = np.array(time_arrays)  # shape: [num_seeds_with_timestamps, num_objective_calls-1]\n",
    "                    time_deltas_mean_per_method.append(np.mean(time_arrays, axis=0))\n",
    "                    time_deltas_std_per_method.append(np.std(time_arrays, axis=0))\n",
    "                else:\n",
    "                    # No timestamp data for this method in this dimension\n",
    "                    time_deltas_mean_per_method.append(None)\n",
    "                    time_deltas_std_per_method.append(None)\n",
    "\n",
    "\n",
    "            avrg_best_history.append(data_mean)\n",
    "            #avrg_best_history_normalized.append(data_mean_normalized)\n",
    "            std_best_history.append(stds)\n",
    "            #std_best_history_normalized.append(stds_normalized)\n",
    "            lower_quantiles_history.append(lower_quantile)\n",
    "            upper_quantiles_history.append(upper_quantile)\n",
    "\n",
    "            avrg_time_deltas.append(time_deltas_mean_per_method)\n",
    "            std_time_deltas.append(time_deltas_std_per_method)\n",
    "\n",
    "            # calculate ranks\n",
    "            tmp_sorted = np.argsort(all_y_data_per_dim,axis = 1)\n",
    "            ranks_per_dim = np.zeros(tmp_sorted.shape)\n",
    "            \n",
    "\n",
    "\n",
    "            for i_method in range(len(METHODS)):\n",
    "                res0,res1,res2 = np.where(tmp_sorted == i_method)\n",
    "                ranks_per_dim[res0,i_method,res2] = res1\n",
    "\n",
    "            ranks_per_dim += 1\n",
    "            # calculate average rank across seeds\n",
    "            avrg_rank_history.append(np.mean(ranks_per_dim,axis=0)) \n",
    "            lowest_average_rank_ind_per_dim = np.argmin(np.mean(ranks_per_dim,axis=0),axis=0)\n",
    "            \n",
    "    \n",
    "\n",
    "            \n",
    "\n",
    "\n",
    "            significantly_worse_at_end = np.zeros(len(METHODS))\n",
    "\n",
    "            for method in range(len(METHODS)): # check wether the respective algorithms are statistically significantly worse then the best one\n",
    "                if method == lowest_average_rank_ind_per_dim[-1]:\n",
    "                    significantly_worse_at_end[method] = 0\n",
    "                else:\n",
    "                    best_results = all_y_data_per_dim[:,lowest_average_rank_ind_per_dim[-1],-1]\n",
    "                    method_results = all_y_data_per_dim[:,method,-1] \n",
    "                    test_result = scipy.stats.wilcoxon(best_results, method_results, alternative='less', axis=0, nan_policy='propagate', keepdims=False)\n",
    "                    significantly_worse_at_end[method] = test_result.pvalue < 0.05 #95 % significance level\n",
    "            significantly_worse_at_end_history.append(copy.deepcopy(significantly_worse_at_end)) \n",
    "            \n",
    "\n",
    "        if CUM_RES:\n",
    "            cum_avrg_best_histories.append(copy.deepcopy(avrg_best_history))\n",
    "            cum_lower_quantiles_histories.append(copy.deepcopy(lower_quantiles_history))\n",
    "            cum_upper_quantiles_histories.append(copy.deepcopy(upper_quantiles_history))\n",
    "            cum_avrg_rank_histories.append(copy.deepcopy(avrg_rank_history))\n",
    "            cum_significantly_worse_at_end_histories.append(copy.deepcopy(significantly_worse_at_end_history))\n",
    "        else:\n",
    "            avrg_best_histories.append(copy.deepcopy(avrg_best_history))\n",
    "            lower_quantiles_histories.append(copy.deepcopy(lower_quantiles_history))\n",
    "            upper_quantiles_histories.append(copy.deepcopy(upper_quantiles_history))   \n",
    "            avrg_rank_histories.append(copy.deepcopy(avrg_rank_history))\n",
    "\n",
    "            significantly_worse_at_end_histories.append(copy.deepcopy(significantly_worse_at_end_history))\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_last_non_nan_value(data_array):\n",
    "    \"\"\"\n",
    "    Returns the last non-NaN entry of a 1D NumPy array.\n",
    "    If all entries are NaN, returns None.\n",
    "    \"\"\"\n",
    "    if data_array is None or len(data_array) == 0:\n",
    "        return None\n",
    "    not_nan_indices = np.where(~np.isnan(data_array))[0]\n",
    "    if len(not_nan_indices) == 0:\n",
    "        return None\n",
    "    return data_array[not_nan_indices[-1]]\n",
    "\n",
    "def gather_all_medians(median_results):\n",
    "    \"\"\"\n",
    "    Go through the entire median_results structure and collect all\n",
    "    last non-NaN values in one list (ignoring None).\n",
    "    This helps us find the min/max for color mapping.\n",
    "    \"\"\"\n",
    "    all_medians = []\n",
    "    for func_idx in range(len(median_results)):\n",
    "        for dim_idx in range(len(median_results[func_idx])):\n",
    "            for method_idx in range(len(median_results[func_idx][dim_idx])):\n",
    "                val = get_last_non_nan_value(median_results[func_idx][dim_idx][method_idx])\n",
    "                if val is not None:\n",
    "                    all_medians.append(val)\n",
    "    return all_medians\n",
    "\n",
    "def interpolate_color(value, vmin, vmax):\n",
    "    \"\"\"\n",
    "    Linearly maps 'value' in [vmin, vmax] to a color ranging from (0,0,1) [blue]\n",
    "    to (1,0.65,0) [orange]. If vmin == vmax, returns blue (0,0,1).\n",
    "\n",
    "    Return tuple (R, G, B) in [0, 1].\n",
    "    \"\"\"\n",
    "    if vmin == vmax:\n",
    "        # If there's only one median value in the entire dataset,\n",
    "        # just return blue.\n",
    "        return (0.0, 0.0, 1.0)\n",
    "\n",
    "    # Normalize to [0,1]\n",
    "    ratio = (value - vmin) / (vmax - vmin)\n",
    "    # (R,G,B) = (0,0,1) -> (1,0.65,0)\n",
    "    R = 0.0 + ratio * (1.0 - 0.0)       # 0 -> 1\n",
    "    G = 0.0 + ratio * (0.65 - 0.0)      # 0 -> 0.65\n",
    "    B = 1.0 - ratio * (1.0 - 0.0)       # 1 -> 0\n",
    "    return (R, G, B)\n",
    "\n",
    "def color_text_with_median(median_val, text, overall_min, overall_max):\n",
    "    \"\"\"\n",
    "    Wrap the given 'text' in a latex color command based on the median_val,\n",
    "    mapping from overall_min (blue) to overall_max (orange).\n",
    "\n",
    "    If median_val is None, returns \"-\" without color.\n",
    "    \"\"\"\n",
    "    if median_val is None:\n",
    "        return \"-\"\n",
    "\n",
    "    (r, g, b) = interpolate_color(median_val, overall_min, overall_max)\n",
    "    # Format color with 2 decimals in the rgb specification\n",
    "    return rf\"\\textcolor[rgb]{{{r:.2f},{g:.2f},{b:.2f}}}{{{text}}}\"\n",
    "\n",
    "###############################################################################\n",
    "# Main table creation function\n",
    "###############################################################################\n",
    "\n",
    "def create_latex_table(function_names, DIMS, METHODS,\n",
    "                       median_results_list,\n",
    "                       lower_quantiles_list,\n",
    "                       upper_quantiles_list,\n",
    "                       data_type_flag,\n",
    "                       use_multirow=False,\n",
    "                       significantly_worse_at_end = None,\n",
    "                       expected_ls = None):\n",
    "\n",
    "\n",
    "    # 1) Find global min/max among all medians for color mapping\n",
    "    all_medians = gather_all_medians(median_results_list)\n",
    "    if len(all_medians) == 0:\n",
    "        # If no medians at all, handle gracefully\n",
    "        global_min = 0\n",
    "        global_max = 1\n",
    "    else:\n",
    "        global_min = min(all_medians)\n",
    "        global_max = max(all_medians)\n",
    "\n",
    "    # find best performing algorithm for each dimension and complexity:\n",
    "\n",
    "    # Start building the table\n",
    "    latex_table = []\n",
    "    latex_table.append(r\"\\begin{table}\")\n",
    "    latex_table.append(r\"\\centering\")\n",
    "\n",
    "    # Define columns depending on use_multirow\n",
    "    if use_multirow:\n",
    "        # Function + Method + columns for each dimension\n",
    "        \n",
    "        col_spec = \"l l \" + \" \".join([\"c\"] * len(DIMS))\n",
    "        header = [\"Complexity\", \"Method\"] + [f\"$d = {d}$\" for d in DIMS]\n",
    "    else:\n",
    "        # Method + columns for each dimension; function name per \\multicolumn row\n",
    "        col_spec = \"l \" + \" \".join([\"c\"] * len(DIMS))\n",
    "        header = [\"Method\"] + [f\"Dim {d}\" for d in DIMS]\n",
    "\n",
    "    latex_table.append(r\"\\begin{tabular}{\" + col_spec + \"}\")\n",
    "    latex_table.append(r\"\\hline\")\n",
    "    latex_table.append(\" & \".join(header) + r\" \\\\\")\n",
    "    latex_table.append(r\"\\hline\")\n",
    "\n",
    "    # Helper to build the cell with median, upper, lower\n",
    "    def build_quantile_string(func_id, dim_id, method_id,data_type_flag):\n",
    "        # Last non-NaN from each array\n",
    "\n",
    "        if not significantly_worse_at_end == None:\n",
    "            significantly_worse = significantly_worse_at_end[func_id][dim_id][method_id] \n",
    "        else:\n",
    "            significantly_worse = True\n",
    "\n",
    "        val_median = get_last_non_nan_value(median_results_list[func_id][dim_id][method_id])\n",
    "        if not data_type_flag == 'rank':  \n",
    "            val_lower  = get_last_non_nan_value(lower_quantiles_list[func_id][dim_id][method_id])\n",
    "            val_upper  = get_last_non_nan_value(upper_quantiles_list[func_id][dim_id][method_id])\n",
    "\n",
    "        if val_median is None:\n",
    "            # No valid median => no data for this cell\n",
    "            return \"-\"\n",
    "\n",
    "        if data_type_flag == 'cumulative': \n",
    "            # Format all as 2 decimals\n",
    "            median_str = f\"{val_median:.0f}\"\n",
    "            lower_str  = \"-\" if (val_lower is None) else f\"{val_lower:.0f}\"\n",
    "            upper_str  = \"-\" if (val_upper is None) else f\"{val_upper:.0f}\"\n",
    "\n",
    "        elif data_type_flag == 'simple': \n",
    "            # Format all as 2 decimals\n",
    "            median_str = f\"{val_median:.2f}\"\n",
    "            lower_str  = \"-\" if (val_lower is None) else f\"{val_lower:.2f}\"\n",
    "            upper_str  = \"-\" if (val_upper is None) else f\"{val_upper:.2f}\"\n",
    "\n",
    "        if data_type_flag == 'rank':  \n",
    "            median_str = f\"{val_median:.1f}\"\n",
    "            if not significantly_worse:\n",
    "                cell_text = r\"$\\boldsymbol{\" + median_str + r\"}$\"\n",
    "            else:\n",
    "                cell_text = r\"$\" + median_str + r\"$\"\n",
    "        else:\n",
    "            # Combine into median^(upper)_(lower)\n",
    "            if not significantly_worse:\n",
    "                raise NotImplementedError\n",
    "            else:\n",
    "                cell_text = f\"${median_str}^{{{upper_str}}}_{{{lower_str}}}$\"\n",
    "        if data_type_flag == \"rank\":\n",
    "            colored_text = cell_text\n",
    "        else:\n",
    "            # Now color the entire cell text based on the median\n",
    "            colored_text = color_text_with_median(\n",
    "                val_median, cell_text, global_min, global_max\n",
    "            )\n",
    "        return colored_text\n",
    "\n",
    "    # Build the rows\n",
    "    for func_idx, func_name in enumerate(function_names):\n",
    "        if use_multirow:\n",
    "            # Each function name in a multirow for all methods\n",
    "            num_methods = len(METHODS)\n",
    "            \n",
    "            col_spec = \"l l \" + \" \".join([\"c\"] * len(DIMS))\n",
    "            header = [\"Complexity\", \"Method\"] + [f\"$d = {d}$\" for d in DIMS]\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "            for method_idx, method_name in enumerate(METHODS):\n",
    "                row_cells = []\n",
    "                for dim_idx in range(len(DIMS)):\n",
    "                    row_cells.append(build_quantile_string(func_idx, dim_idx, method_idx,data_type_flag))\n",
    "\n",
    "                if method_idx == 0:\n",
    "                    # First row => use multirow for the function name\n",
    "                    if expected_ls is None:\n",
    "                        row_items = [\n",
    "                            rf\"\\multirow{{{num_methods}}}{{*}}{{{func_name}}}\",\n",
    "                            method_name\n",
    "                        ] + row_cells\n",
    "                    else:\n",
    "                        row_items = [\n",
    "                            rf\"\\multirow{{{num_methods+1}}}{{*}}{{{func_name}}}\",\n",
    "                            method_name\n",
    "                        ] + row_cells                        \n",
    "                else:\n",
    "                    # Subsequent rows => empty cell for function name\n",
    "                    row_items = [\" \", method_name] + row_cells\n",
    "\n",
    "\n",
    "                latex_table.append(\" & \".join(row_items) + r\" \\\\\")\n",
    "            \n",
    "            if not expected_ls is None:\n",
    "                row_items = [\" \",r\"$\\mathbb{E}[p(l)]$\"]+[f\"${expected_ls[func_idx,d]}$\" for d in range(len(DIMS))]\n",
    "                latex_table.append(\" & \".join(row_items) + r\" \\\\\")    \n",
    "\n",
    "            latex_table.append(r\"\\hline\")\n",
    "\n",
    "        else:\n",
    "            # Put function name across the top via multicolumn (no bold)\n",
    "            latex_table.append(\n",
    "                rf\"\\multicolumn{{{len(DIMS) + 1}}}{{c}}{{{func_name}}} \\\\\"\n",
    "            )\n",
    "            latex_table.append(r\"\\hline\")\n",
    "\n",
    "            for method_idx, method_name in enumerate(METHODS):\n",
    "                row_entries = [method_name]\n",
    "                for dim_idx in range(len(DIMS)):\n",
    "                    row_entries.append(build_quantile_string(func_idx, dim_idx, method_idx))\n",
    "                latex_table.append(\" & \".join(row_entries) + r\" \\\\\")\n",
    "            latex_table.append(r\"\\hline\")\n",
    "\n",
    "    latex_table.append(r\"\\end{tabular}\")\n",
    "    if data_type_flag == 'rank':\n",
    "        latex_table.append(r\"\\caption{Your Caption Here}\")\n",
    "        latex_table.append(r\"\\label{tab:ranks_results_table}\")\n",
    "    else:\n",
    "        latex_table.append(r\"\\caption{Your Caption Here}\")\n",
    "        latex_table.append(r\"\\label{tab:your_label}\")\n",
    "    latex_table.append(r\"\\end{table}\")\n",
    "\n",
    "    return \"\\n\".join(latex_table)\n",
    "\n",
    "\n",
    "if LESLOGEIONLY:\n",
    "    prefix = \"LES_EI_\"\n",
    "else:\n",
    "    prefix = \"\"\n",
    "\n",
    "\n",
    "\n",
    "# Generate the LaTeX table code with multirow=True\n",
    "latex_code_multirow = create_latex_table(function_names, DIMS, METHOD_TABLE_NAMES, avrg_best_histories,[],[],\n",
    "                                         data_type_flag = \"rank\",\n",
    "                                         use_multirow=True,\n",
    "                                         significantly_worse_at_end = significantly_worse_at_end_histories)#,\n",
    "                                        \n",
    "if WITHIN_MDL:\n",
    "    with open(prefix+\"ranks_results_table_within_mdl.tex\", \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(latex_code_multirow)   \n",
    "else:\n",
    "    with open(prefix+\"ranks_results_table.tex\", \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(latex_code_multirow)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Generate the LaTeX table code with multirow=True\n",
    "latex_code_multirow = create_latex_table(function_names, DIMS, METHOD_TABLE_NAMES, cum_avrg_best_histories,[],[],\n",
    "                                         data_type_flag = \"rank\",\n",
    "                                         use_multirow=True,\n",
    "                                         significantly_worse_at_end = cum_significantly_worse_at_end_histories) #,\n",
    "                                         #expected_ls = excpected_ls)\n",
    "\n",
    "if WITHIN_MDL:\n",
    "    with open(prefix+\"cum_ranks_results_table_within_mdl.tex\", \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(latex_code_multirow)   \n",
    "else:\n",
    "    with open(prefix+\"cum_ranks_results_table_table.tex\", \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(latex_code_multirow)\n",
    "        \n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
