{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!-- ## Plots -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch\n",
    "import itertools\n",
    "\n",
    "# import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "pd.set_option('display.max_rows', 200)\n",
    "pd.set_option('display.max_columns', 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.style as style\n",
    "\n",
    "# Apply the ggplot style\n",
    "# style.use('ggplot')\n",
    "sns.set_style(\"whitegrid\")  # Use a white grid for a cleaner look\n",
    "# Create a color palette with as many colors as there are unique models\n",
    "\n",
    "paired_palette_hex = [\n",
    "    '#fb9a99',\n",
    "    '#ef5350',\n",
    "    '#e31a1c',\n",
    "    '#b2df8a',\n",
    "    '#7cb342',\n",
    "    '#33a02c',\n",
    "    '#a6cee3',\n",
    "    '#039be5',\n",
    "    '#1f78b4',\n",
    "    '#fdbf6f',\n",
    "    '#ff7f00',\n",
    "    '#cab2d6',\n",
    "    '#6a3d9a',\n",
    "    '#ffff99',\n",
    "    '#b15928'\n",
    "]\n",
    "\n",
    "\n",
    "map_model_name = {\n",
    "                    'mse': r'MSE ($\\mathcal{D}_\\text{O})$',  \n",
    "                    'mse_T':r'MSE ($\\mathcal{D}_\\text{T})$', \n",
    "                    'mse_is':r'MSE ($\\mathcal{D}_\\text{IR})$',\n",
    "                    'cspo+':r'SPO-RC+ ($\\mathcal{D}_\\text{O})$',\n",
    "                    'cspo+_T':r'SPO-RC+ ($\\mathcal{D}_\\text{T})$', \n",
    "                    'cspo+_is':r'SPO-RC+ ($\\mathcal{D}_\\text{IR})$', \n",
    "                    'cspo+_ws':r'SPO-RC+_WS($\\mathcal{D}_\\text{O})$', \n",
    "                    'cspo+_ws_T':r'SPO-RC+_WS ($\\mathcal{D}_\\text{T})$', \n",
    "                    'cspo+_ws_is':r'SPO-RC+_WS ($\\mathcal{D}_\\text{IR})$', \n",
    "                    'cspo+_mse': r'SPO-RC+_withMSE ($\\mathcal{D}_\\text{O})$',\n",
    "                    'cspo+_mse_T': r'SPO-RC+_withMSE ($\\mathcal{D}_\\text{T})$',\n",
    "                    'cspo+_mse_is':r'SPO-RC+_withMSE ($\\mathcal{D}_\\text{IR})$'\n",
    "                    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prep_df(dir_name,prefix_name ,arg_dict):\n",
    "    # Build list of parameter combinations\n",
    "    param_combos = [dict(zip(arg_dict.keys(), vals)) for vals in zip(*arg_dict.values())]\n",
    "\n",
    "    # Enumerate all parameter combinations (Cartesian product)\n",
    "    param_combos = [\n",
    "        dict(zip(arg_dict.keys(), vals))\n",
    "        for vals in itertools.product(*arg_dict.values())\n",
    "    ]\n",
    "\n",
    "\n",
    "    # Read each file, tag with its parameters, and collect dataframes\n",
    "    df_list = []\n",
    "    for params in param_combos:\n",
    "        # Build filename from params\n",
    "        param_str = \"_\".join(f\"{k}={v}\" for k, v in params.items())\n",
    "        fname = f\"{prefix_name}_{param_str}.txt\"\n",
    "        # print(f\"Loading {fname}\")\n",
    "        fpath = os.path.join(dir_name, fname)\n",
    "        if os.path.exists(fpath):\n",
    "            df = pd.read_csv(fpath, delimiter=\",\", header=0)\n",
    "            # Add parameter columns to df\n",
    "            for k, v in params.items():\n",
    "                df[k] = v\n",
    "            df_list.append(df)\n",
    "        else:\n",
    "            print(f\"File {fpath} does not exist\")\n",
    "\n",
    "    # Concatenate into one unified DataFrame and clean column names\n",
    "    if len(df_list) > 0:\n",
    "        results = pd.concat(df_list, ignore_index=True)\n",
    "\n",
    "    results.columns = results.columns.str.strip()\n",
    "    results.Model = results.Model.str.strip()\n",
    "    # Sort the DataFrame based on the model column\n",
    "    model_order = ['mse',\n",
    "                    'mse_T',\n",
    "                    'mse_is',\n",
    "                    'cspo+',\n",
    "                    'cspo+_T',\n",
    "                    'cspo+_is',\n",
    "                    'cspo+_ws',\n",
    "                    'cspo+_ws_T',\n",
    "                    'cspo+_ws_is',\n",
    "                    'cspo+_mse',\n",
    "                    'cspo+_mse_T',\n",
    "                    'cspo+_mse_is']\n",
    "    \n",
    "    results = results.sort_values(by='Model', key=lambda x: x.map({model: i for i, model in enumerate(model_order)}))\n",
    "\n",
    "    # print(f'Model list before applying the map {results.Model.unique()}')\n",
    "    results['Model'] = results['Model'].map(map_model_name)\n",
    "    # print(f'Model list After applying the map {results.Model.unique()}')\n",
    "    # Change infeasible count to percentage\n",
    "    infeasible_cols = ['nominal_infeas_count_pred', 'robust_nominal_infeas_count', 'robust_infeas_count']\n",
    "    infeasible_percent_cols = [col_name + \"_percent\" for col_name in infeasible_cols]\n",
    "    results[infeasible_percent_cols] = results[infeasible_cols].div(results['test_num_data'],axis=0)\n",
    "    results[infeasible_percent_cols] = (results[infeasible_percent_cols] * 100).round(2)\n",
    "    results['NormTest'] = (results['nominal_mean'] - results['robust_mean'])/ results['nominal_mean']\n",
    "    # results[infeasible_percent_cols]\n",
    "\n",
    "\n",
    "    # combine solve_ratio and lr into a single label column for plotting\n",
    "    # results['combo'] = results['solve_ratio'].astype(str) + '_' + results['lr'].astype(str)\n",
    "    results['Model_sr'] = results['Model'].astype(str) + '_' + results['solve_ratio'].astype(str)\n",
    "\n",
    "    filtered_results = results[results[\"epoch\"]== -1]\n",
    "    results = results[results['epoch'] != -1]\n",
    "    # Group by each parameter information\n",
    "    group_cols = [\"num_feat\",\"num_item\",\"weight_deg\",\"noise_width\",\"cost_deg\",\"Model\",'solve_ratio']\n",
    "    cols_of_interest = [\"alpha\",\"test_num_data\",\"test_regret\", 'nominal_mean',\n",
    "        'nominal_std', 'nominal_mean_pred', 'nominal_std_pred',\n",
    "        'nominal_infeas_count_pred', 'robust_nominal_mean',\n",
    "        'robust_nominal_std', 'robust_nominal_infeas_count', 'robust_mean',\n",
    "        'robust_std', 'robust_infeas_count'] + infeasible_percent_cols\n",
    "    grouped_results = filtered_results.groupby(group_cols)[cols_of_interest].mean()\n",
    "    \n",
    "\n",
    "    return results, filtered_results, grouped_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualizing_regret_results_aistats(best_results, weight_deg, alpha, model_color_dict, y_axis,log=False):\n",
    "\n",
    "    \"\"\"\n",
    "    Plotting functionfor different num_feat and capacity\n",
    "    \"\"\"\n",
    "\n",
    "    font_size = 50\n",
    "\n",
    "    # Filter results based on weight_deg\n",
    "    plot_results = best_results[(best_results[\"weight_deg\"] == weight_deg) & (best_results[\"alpha\"] == alpha)]\n",
    "    # Get unique values for training num\n",
    "    num_data_list = plot_results[\"num_data\"].unique().tolist()\n",
    "    \n",
    "    # Determine the number of subplots based on the number of unique num_feat and noise_width\n",
    "    num_subplots = len(num_data_list)\n",
    "    rows = int(num_subplots**0.5)\n",
    "    # cols = (num_subplots // rows) + (num_subplots % rows > 0)\n",
    "    # Set dpi to 300\n",
    "    plt.rcParams['figure.dpi'] = 300\n",
    "\n",
    "    # Create a figure and axes\n",
    "    fig, axes = plt.subplots(rows, 1, figsize=(20, 12))  # Adjust figsize as needed\n",
    "    if rows>1:\n",
    "        axes = axes.flatten()\n",
    "    else:\n",
    "        axes= [axes]\n",
    "    # Create a boxplot for each combination of num_feat and noise_width\n",
    "    for i, num_data in enumerate(num_data_list):\n",
    "        idx = i \n",
    "        ax = axes[idx]\n",
    "        \n",
    "        # Filter the DataFrame for the current num_feat and noise_width\n",
    "        boxplot_df = plot_results[(plot_results[\"num_data\"] == num_data)]\n",
    "        # Plot the boxplot\n",
    "        hue_order = all_models\n",
    "        \n",
    "        \n",
    "        # sns.boxplot(x='cost_deg', y=y_axis, hue='Model', data=boxplot_df, ax=ax, showfliers=False, palette=model_color_dict, saturation=0.75, width=0.6)\n",
    "        sns.boxplot(x='cost_deg', y=y_axis, hue='Model',hue_order=hue_order, data=boxplot_df, ax=ax, showfliers=False, palette=model_color_dict, saturation=0.75, width=0.6)\n",
    "        if log:\n",
    "            ax.set_yscale('log')\n",
    "        \n",
    "        # Set the title and labels\n",
    "        ax.set_title(f'Training Size: {num_data}, Test Size: {boxplot_df[\"test_num_data\"].unique()[0]}', fontsize=font_size)\n",
    "        ax.set_xlabel('Polynomial Degree', fontsize=font_size)\n",
    "        if y_axis == 'test_regret':\n",
    "            if log:\n",
    "                ax.set_ylabel('NormSPORCTest (log scale)', fontsize=font_size)\n",
    "            else:\n",
    "                ax.set_ylabel('NormSPORCTest ', fontsize=font_size)\n",
    "        elif y_axis == 'test_cspop_loss':\n",
    "            ax.set_ylabel('Test CSPO+ Loss (log scale)', fontsize=font_size)\n",
    "        elif y_axis == 'test_mse_loss':\n",
    "            ax.set_ylabel('Test MSE Loss (log scale)', fontsize=font_size)\n",
    "        else:\n",
    "            ax.set_ylabel(f'{y_axis} (log scale)', fontsize=font_size)\n",
    "        \n",
    "        # Increase the font size of xtick and ytick labels\n",
    "        ax.set_xticklabels(ax.get_xticklabels(), fontsize=font_size)\n",
    "        ax.set_yticklabels(ax.get_yticklabels(), fontsize=font_size)\n",
    "        ax.legend_.remove()\n",
    "        \n",
    "        # Turn off axis for any empty subplots\n",
    "        if idx >= num_subplots:\n",
    "            ax.axis('off')\n",
    "            \n",
    "    # Create a single legend for the entire figure with more space at the top\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "    \n",
    "    fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=len(hue_order)//2, fontsize=40, frameon=False)\n",
    "    \n",
    "    # Save figure with bbox_inches='tight' to ensure nothing is cut off\n",
    "    plt.savefig(f'{dir_name}/{prefix_name}_{y_axis}_metric.jpeg', bbox_inches='tight', dpi=300)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualizing_regret_results_per_sr(best_results, weight_deg, alpha,cost_deg, model_color_dict, y_axis):\n",
    "\n",
    "    \"\"\"\n",
    "    Plotting functionfor different num_feat and capacity\n",
    "    \"\"\"\n",
    "\n",
    "    font_size = 50\n",
    "\n",
    "    # Filter results based on weight_deg\n",
    "    plot_results = best_results[(best_results[\"weight_deg\"] == weight_deg) & (best_results[\"alpha\"] == alpha)& (best_results[\"cost_deg\"] == cost_deg)]\n",
    "    # Get unique values for training num\n",
    "    num_data_list = plot_results[\"num_data\"].unique().tolist()\n",
    "    model_list_here = plot_results.Model.unique()\n",
    "\n",
    "    # Determine the number of subplots based on the number of unique num_feat and noise_width\n",
    "    num_subplots = len(num_data_list)\n",
    "    rows = int(num_subplots**0.5)\n",
    "    # cols = (num_subplots // rows) + (num_subplots % rows > 0)\n",
    "    # Set dpi to 300\n",
    "    plt.rcParams['figure.dpi'] = 300\n",
    "\n",
    "    # Create a figure and axes\n",
    "    fig, axes = plt.subplots(rows, 1, figsize=(20, 12))  # Adjust figsize as needed\n",
    "    if rows>1:\n",
    "        axes = axes.flatten()\n",
    "    else:\n",
    "        axes= [axes]\n",
    "    # Create a boxplot for each combination of num_feat and noise_width\n",
    "    for i, num_data in enumerate(num_data_list):\n",
    "        idx = i \n",
    "        ax = axes[idx]\n",
    "        \n",
    "        # Filter the DataFrame for the current num_feat and noise_width\n",
    "        boxplot_df = plot_results[(plot_results[\"num_data\"] == num_data)]\n",
    "        \n",
    "        # sns.boxplot(x='cost_deg', y=y_axis, hue='Model', data=boxplot_df, ax=ax, showfliers=False, palette=model_color_dict, saturation=0.75, width=0.6)\n",
    "        sns.boxplot(x='solve_ratio', y=y_axis, hue='Model', data=boxplot_df, ax=ax, showfliers=False, palette=model_color_dict, saturation=0.75, width=0.6)\n",
    "        \n",
    "        # Set the title and labels\n",
    "        ax.set_title(f'Training Size: {num_data}, Test Size: {boxplot_df[\"test_num_data\"].unique()[0]}', fontsize=font_size)\n",
    "        ax.set_xlabel('Probability of Solving', fontsize=font_size)\n",
    "        if y_axis == 'test_regret':\n",
    "            ax.set_ylabel('NormSPORCTest', fontsize=font_size)\n",
    "        elif y_axis == 'test_cspop_loss':\n",
    "            ax.set_ylabel('Test CSPO+ Loss', fontsize=font_size)\n",
    "        elif y_axis == 'test_mse_loss':\n",
    "            ax.set_ylabel('Test MSE Loss', fontsize=font_size)\n",
    "        elif y_axis == 'training_time':\n",
    "            ax.set_ylabel('Training time', fontsize=font_size)\n",
    "        else:\n",
    "            ax.set_ylabel(y_axis, fontsize=font_size)\n",
    "        \n",
    "        # Increase the font size of xtick and ytick labels\n",
    "        ax.set_xticklabels(ax.get_xticklabels(), fontsize=font_size)\n",
    "        ax.set_yticklabels(ax.get_yticklabels(), fontsize=font_size)\n",
    "        ax.legend_.remove()\n",
    "        \n",
    "        # Turn off axis for any empty subplots\n",
    "        if idx >= num_subplots:\n",
    "            ax.axis('off')\n",
    "            \n",
    "    # Create a single legend for the entire figure with more space at the top\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "    \n",
    "    fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=len(model_list_here), fontsize=40, frameon=False)\n",
    "    # plt.title(' under l1-score confomity metric', pad=20)\n",
    "    \n",
    "    # Save figure with bbox_inches='tight' to ensure nothing is cut off\n",
    "    plt.savefig(f'{dir_name}/sr_{prefix_name}_{y_axis}_metric.jpeg', bbox_inches='tight', dpi=300)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pb_type = 'knapsack'\n",
    "dir_name = './results/'\n",
    "\n",
    "\n",
    "feat = 10\n",
    "var = 5\n",
    "const = 1\n",
    "\n",
    "arg_dict = {\n",
    "            'CSPO': [\n",
    "                'False',\n",
    "                'True'\n",
    "                ],\n",
    "            'solve_ratio': [\n",
    "                1.0 ]}\n",
    "\n",
    "prefix_name = f\"{pb_type}_results_num_data=1000_feat={feat}_num_var={var}_num_const={const}\"\n",
    "\n",
    "results, filtered_results, grouped_results = prep_df(dir_name,prefix_name ,arg_dict)\n",
    "\n",
    "# Define Color Pallette\n",
    "all_models = [\n",
    "                r'MSE ($\\mathcal{D}_\\text{O})$',  \n",
    "                r'MSE ($\\mathcal{D}_\\text{T})$', \n",
    "                r'MSE ($\\mathcal{D}_\\text{IR})$',\n",
    "                r'SPO-RC+ ($\\mathcal{D}_\\text{O})$',  \n",
    "                r'SPO-RC+ ($\\mathcal{D}_\\text{T})$',\n",
    "                r'SPO-RC+ ($\\mathcal{D}_\\text{IR})$',\n",
    "                    ]\n",
    "colors = paired_palette_hex[:len(all_models)]\n",
    "# Create a dictionary to map models to colors\n",
    "model_color_dict = dict(zip(all_models, colors))\n",
    "weight_deg = 2\n",
    "alpha = 0.2\n",
    "best_results_temp = filtered_results[filtered_results['Model'].isin(all_models)].copy()\n",
    "visualizing_regret_results_aistats(best_results_temp, weight_deg, alpha, model_color_dict= model_color_dict, y_axis='test_regret')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pb_type = 'cover'\n",
    "dir_name = './results/'\n",
    "\n",
    "\n",
    "feat = 10\n",
    "var = 5\n",
    "const = 2\n",
    "\n",
    "arg_dict = {\n",
    "            'CSPO': [\n",
    "                'False',\n",
    "                'True'\n",
    "                ],\n",
    "                'lr' : [1e-2],\n",
    "            'solve_ratio': [\n",
    "                1.0 ]}\n",
    "\n",
    "prefix_name = f\"{pb_type}_results_num_data=1000_feat={feat}_num_var={var}_num_const={const}\"\n",
    "\n",
    "results, filtered_results, grouped_results = prep_df(dir_name,prefix_name ,arg_dict)\n",
    "\n",
    "# Define Color Pallette\n",
    "all_models = [\n",
    "                r'MSE ($\\mathcal{D}_\\text{O})$',  \n",
    "                r'MSE ($\\mathcal{D}_\\text{T})$', \n",
    "                r'MSE ($\\mathcal{D}_\\text{IR})$',\n",
    "                r'SPO-RC+ ($\\mathcal{D}_\\text{O})$',  \n",
    "                r'SPO-RC+ ($\\mathcal{D}_\\text{T})$',\n",
    "                r'SPO-RC+ ($\\mathcal{D}_\\text{IR})$',\n",
    "                    ]\n",
    "colors = paired_palette_hex[:len(all_models)]\n",
    "# Create a dictionary to map models to colors\n",
    "model_color_dict = dict(zip(all_models, colors))\n",
    "weight_deg = 2\n",
    "alpha = 0.1\n",
    "best_results_temp = filtered_results[filtered_results['Model'].isin(all_models)].copy()\n",
    "visualizing_regret_results_aistats(best_results_temp, weight_deg, alpha, model_color_dict= model_color_dict, y_axis='test_regret')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "l2_conformity",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
