{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71551323",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append(os.path.abspath(\"../../..\")) \n",
    "\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import commons.semantic_loss as semloss\n",
    "\n",
    "\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{lmodern}'\n",
    "plt.rcParams['font.size'] = 20\n",
    "\n",
    "NUM_MAT = 100\n",
    "MAX_EPOCHS = 200\n",
    "MAX_N = 128\n",
    "# Total number of lines in a csv\n",
    "NUM_LINES = NUM_MAT * MAX_EPOCHS * MAX_N\n",
    "\n",
    "# [NUM_MAT] MUST be divisible for [N_THREADS]\n",
    "N_THREADS = 10\n",
    "\n",
    "# The loss that is currently considered and plotted\n",
    "experiment_loss = semloss.SemanticExperiment.PERIODIC_2\n",
    "\n",
    "if not os.path.exists(\"plots/\"):\n",
    "    os.makedirs(\"plots/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b298187",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(os.path.abspath(\"../..\")) \n",
    "from src.utils.data import simulate_material\n",
    "import src.utils.data as data\n",
    "\n",
    "\n",
    "TEST_SIZE = 100\n",
    "_, _, test_data = data.get_x_y_data(invd_steps=TEST_SIZE, device='cuda:0', val_split=None)\n",
    "x_test, y_test = test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a312ffb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_ranges(N, num_workers, delta):\n",
    "    \"\"\"\n",
    "    Create the row indices ranges to be assigned to each worker\n",
    "    \"\"\"\n",
    "    chunk_size = N // num_workers\n",
    "    ranges = []\n",
    "\n",
    "    for i in range(num_workers):\n",
    "        start = i * chunk_size\n",
    "        end = start + chunk_size if i < num_workers - 1 else N  \n",
    "        ranges.append((start * delta, end * delta))\n",
    "\n",
    "    return ranges\n",
    "\n",
    "def parse_array_column(s):\n",
    "    return np.fromstring(s.strip(\"[]\\n \"), sep=' ') if s != \"None\" else None\n",
    "\n",
    "\n",
    "def open_result_csv(worker_idx, csvfile, start, end) -> pd.DataFrame:\n",
    "    nrows = end - start\n",
    "\n",
    "    with open(csvfile, \"r\") as f:\n",
    "        header = f.readline().strip().split(\",\")\n",
    "\n",
    "    #print(header)\n",
    "\n",
    "    df = pd.read_csv(csvfile, \n",
    "        header = None,\n",
    "        names = header,\n",
    "        dtype = {\n",
    "            'Lr': np.float32,\n",
    "            'Mat_idx': np.uint16,\n",
    "            'Epochs': np.uint16,\n",
    "            'Trial': np.uint16,\n",
    "            'Simulator loss': np.float32,\n",
    "            'Semantic loss': np.float32,\n",
    "            'Onehot': np.float32,\n",
    "        },\n",
    "        converters = {\n",
    "            'Decoded mat': parse_array_column\n",
    "        },\n",
    "        skiprows = start + 1,\n",
    "        nrows = nrows\n",
    "    )\n",
    "\n",
    "    return df # type: ignore\n",
    "\n",
    "\n",
    "\n",
    "def process_df_srmse(worker_idx, n_mat, df, epochs, n_points, constraint_function):\n",
    "    # ['Lr', 'Mat_idx', 'Epochs', 'Point', 'Simulator loss', 'Semantic loss', 'Onehot', 'Decoded mat']\n",
    "\n",
    "    mat_idx = 0\n",
    "    cursor = 0\n",
    "    real_mat_idx = 0\n",
    "    #nancount = 0\n",
    "\n",
    "    data = []\n",
    "    while mat_idx < n_mat:\n",
    "        best_srmse = 99999999.0\n",
    "        best_overall_srmse = 99999999.0\n",
    "        best_onehot = -1.0\n",
    "        epoch_data = []\n",
    "        for epoch in range(epochs):\n",
    "            cursor = (mat_idx * MAX_EPOCHS * MAX_N) + ((epoch) * MAX_N)\n",
    "            real_mat_idx = df.iloc[cursor, 1]\n",
    "            \n",
    "            # Create intervals of N rows\n",
    "            indices = list(range(cursor, cursor + n_points))\n",
    "        \n",
    "            #print(f\"Worker {worker_idx} - {indices}\", flush=True)\n",
    "\n",
    "            # Take the subset of the dataframe which corresponds to material mat_idx, for N points at epoch [epoch]\n",
    "            sub_df = df.iloc[indices, :]\n",
    "            \n",
    "            #for idx, sloss in enumerate(sub_df.iloc[:, 5]):\n",
    "            #    if(str(sloss) == 'inf' or str(sloss) == \"nan\"):\n",
    "            #        nancount += 1\n",
    "                    \n",
    "            decoded_materials_ = torch.tensor(np.stack(sub_df.iloc[:, 7])) # type: ignore\n",
    "            count, mask = constraint_function(decoded_materials_)\n",
    "            \n",
    "            overall_l2 = sub_df.iloc[:, 4].min()\n",
    "            overall_srmse = overall_l2 * overall_l2\n",
    "            if overall_srmse < best_overall_srmse:\n",
    "                best_overall_srmse = overall_srmse\n",
    "            \n",
    "            dec_point = None\n",
    "            if count > 0:\n",
    "                # Get indices of satsfiable materials\n",
    "                constraint_indices = torch.where(mask)[0].numpy()\n",
    "                #srmse = sub_df.iloc[constraint_indices, 4].min() # type: ignore\n",
    "                constr_df = sub_df.iloc[constraint_indices, :]\n",
    "\n",
    "                srmse_idx = constr_df.iloc[:, 4].argmin()\n",
    "                l2 = constr_df.iloc[srmse_idx, 4]\n",
    "                srmse = l2*l2\n",
    "                onehot = constr_df.iloc[srmse_idx, 6]\n",
    "                dec_point = constr_df.iloc[srmse_idx, 7]\n",
    "\n",
    "                if srmse < best_srmse:\n",
    "                    best_srmse = srmse\n",
    "                    best_onehot = onehot\n",
    "\n",
    "            # if dec_point = None, it means that no satisfiable materials are present\n",
    "            epoch_data.append((\n",
    "                real_mat_idx, \n",
    "                epoch, \n",
    "                np.sqrt(best_overall_srmse / 2001), \n",
    "                count, \n",
    "                np.sqrt(best_srmse / 2001) if best_srmse < 99999999.0 else best_srmse, \n",
    "                best_onehot, \n",
    "                \"\" if epoch < epochs - 1 else str(dec_point)\n",
    "            ))\n",
    "\n",
    "\n",
    "        data.append(epoch_data)\n",
    "\n",
    "        mat_idx += 1\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d1ddcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from concurrent.futures import ProcessPoolExecutor\n",
    "\n",
    "def worker(worker_idx, filename, start, end):\n",
    "   df = open_result_csv(worker_idx, filename, start, end)\n",
    "   #print(f\"Worker [{worker_idx}] read dataframe {df.shape}\", flush=True)\n",
    "\n",
    "   n_mat_per_worker = NUM_MAT // N_THREADS\n",
    "   count_function = experiment_loss.get_count_function(num_lay=10, num_mat=7)\n",
    "\n",
    "   N_POINTS = [1, 2, 4, 8, 16, 32, 64, 128]\n",
    "   data_map = {}\n",
    "   for np_ in N_POINTS:\n",
    "      data = process_df_srmse(worker_idx, n_mat_per_worker, df, epochs=200, n_points=np_, constraint_function=count_function)\n",
    "      data_map[np_] = data\n",
    "     \n",
    "\n",
    "   return data_map\n",
    "\n",
    "\n",
    "model_used = \"Gidnet\"\n",
    "csv_name = \"logs/[{}][{}]results.csv\"\n",
    "plot_title = experiment_loss.get_log_filenames()[1]\n",
    "\n",
    "# Create ranges from material 0 to material [NUM_MAT]\n",
    "# Each material has [Epochs] X [N] lines associated\n",
    "chunks = split_ranges(NUM_MAT, num_workers = N_THREADS, delta = MAX_EPOCHS * MAX_N)\n",
    "\n",
    "results_noloss = {}\n",
    "results_semloss = {}\n",
    "N_POINTS = [1, 2, 4, 8, 16, 32, 64, 128]\n",
    "for np_ in N_POINTS:\n",
    "   results_noloss[np_] = []\n",
    "   results_semloss[np_] = []\n",
    "\n",
    "\n",
    "with ProcessPoolExecutor(max_workers=N_THREADS) as executor:\n",
    "   futures = [executor.submit(worker, i, csv_name.format(\"\", \"No_loss\"), chunks[i][0], chunks[i][1]) for i in range(len(chunks))]\n",
    "   futures_2 = [executor.submit(worker, i, csv_name.format(experiment_loss.get_log_filenames()[0], \"Sem_loss\"), chunks[i][0], chunks[i][1]) for i in range(len(chunks))]\n",
    "\n",
    "   for future in futures:\n",
    "      data_map = future.result()\n",
    "      for k,v in data_map.items():\n",
    "         results_noloss[k].extend(v)\n",
    "\n",
    "   for future in futures_2:\n",
    "      data_map = future.result()\n",
    "      for k,v in data_map.items():\n",
    "         results_semloss[k].extend(v)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "793985e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results_numpy(results, n_points, exclude_decoded_mat=True):\n",
    "    if not exclude_decoded_mat:\n",
    "        return np.stack(results[n_points])\n",
    "    else:\n",
    "        # Exclude decoded material from the array, in order to have equal data type\n",
    "        res = np.stack(results[n_points])\n",
    "        res = res[:, :, :6]\n",
    "        res = res.astype(np.float32)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66dca170",
   "metadata": {},
   "outputs": [],
   "source": [
    "semloss_ = get_results_numpy(results_semloss, n_points=128, exclude_decoded_mat=False)\n",
    "noloss_ = get_results_numpy(results_noloss, n_points=128, exclude_decoded_mat=False)\n",
    "\n",
    "# For each material, i have the best generated point, which satisfy the constraint, (out of T=128) based on reconstruction loss\n",
    "# Hence, i compute the real srmse on that point\n",
    "materials_sloss = []\n",
    "materials_noloss = []\n",
    "for idx, nn_srmse, mat in semloss_[:, 199, [0, 4, 6]]:\n",
    "    # mat == \"None\" if no points satisfy the constraint\n",
    "    if mat != \"None\" and mat != \"\":\n",
    "        materials_sloss.append((idx, torch.tensor(np.fromstring(mat.strip(\"[]\\n \"), sep=' ')), nn_srmse))\n",
    "\n",
    "for idx, nn_srmse, mat in noloss_[:, 199, [0, 4, 6]]:\n",
    "    if mat != \"None\" and mat != \"\":\n",
    "        materials_noloss.append((idx, torch.tensor(np.fromstring(mat.strip(\"[]\\n \"), sep=' ')), nn_srmse))\n",
    "\n",
    "\n",
    "# NO loss - Calculate real srmse of decoded materials\n",
    "spectra_noloss = [(mat_idx, simulate_material(decoded_mat), float(nn_srmse)) for mat_idx, decoded_mat, nn_srmse in materials_noloss]\n",
    "srmse_noloss = [(spectra - y_test[int(mat_idx)]).square().mean().cpu() for mat_idx, spectra, _ in spectra_noloss]\n",
    "\n",
    "# SEM loss - Calculate real srmse of decoded materials\n",
    "spectra_semloss = [(mat_idx, simulate_material(decoded_mat), float(nn_srmse)) for mat_idx, decoded_mat, nn_srmse in materials_sloss]\n",
    "srmse_semloss = [(spectra - y_test[int(mat_idx)]).square().mean().cpu() for mat_idx, spectra, _ in spectra_semloss]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "190cbf80",
   "metadata": {},
   "outputs": [],
   "source": [
    "semloss_ = get_results_numpy(results_semloss, n_points=128, exclude_decoded_mat=True)\n",
    "noloss_ = get_results_numpy(results_noloss, n_points=128, exclude_decoded_mat=True)\n",
    "\n",
    "perc_semloss = ((semloss_[:, 199, 3] >= 1).sum() / NUM_MAT) * 100\n",
    "perc_noloss = ((noloss_[:, 199, 3] >= 1).sum() / NUM_MAT) * 100\n",
    "\n",
    "# onehot is -1 when no feasible materials are found\n",
    "onehots_semloss = semloss_[:, 199, 5]\n",
    "onehots_semloss = onehots_semloss[onehots_semloss != -1]\n",
    "onehots_noloss = noloss_[:, 199, 5]\n",
    "onehots_noloss = onehots_noloss[onehots_noloss != -1]   \n",
    "\n",
    "print(f\"[Semloss] {plot_title}    -    MSE: {np.mean(srmse_semloss):.3f} +- ({np.std(srmse_semloss):.4f})   -   RMSE: {np.mean(np.sqrt(srmse_semloss)):.3f} +- ({np.std(np.sqrt(srmse_semloss)):.4f})   -   onehot: {onehots_semloss.mean():.3f} +- ({onehots_semloss.std():.3f}) - Sat: {perc_semloss:.2f}%\")\n",
    "print(f\"[Noloss] {plot_title}    -    MSE: {np.mean(srmse_noloss):.3f} +- ({np.std(srmse_noloss):.4f}) - RMSE: {np.mean(np.sqrt(srmse_noloss)):.3f} +- ({np.std(np.sqrt(srmse_noloss)):.4f})   -   onehot: {onehots_noloss.mean():.3f} +- ({onehots_noloss.std():.3f}) - Sat: {perc_noloss:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ceaa8c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.patches as mpatches\n",
    "\n",
    "all_colors = [\"#26A949FF\", \"#A92626FF\", \"#000000FF\", \"#B6B6B688\"]\n",
    "\n",
    "def get_scatter_x_y_data(noloss_, semloss_):\n",
    "    data_x = []\n",
    "    data_y = []\n",
    "    colors = []\n",
    "\n",
    "    for i in range(noloss_.shape[0]):\n",
    "        # Tuple (mat_idx, best srmse (overall), best srmse (constraint))\n",
    "        \n",
    "        srmse_noloss = noloss_[i, 2]\n",
    "        srmse_semloss = semloss_[i, 2]\n",
    "\n",
    "        px = 0\n",
    "        py = 0\n",
    "        c = \"\"\n",
    "        \n",
    "        if srmse_noloss >= 999999.0 and srmse_semloss < 999999.0:\n",
    "            c = all_colors[0]\n",
    "            px = noloss_[i, 1]\n",
    "            py = semloss_[i, 2]\n",
    "           \n",
    "        elif srmse_noloss < 999999.0 and srmse_semloss >= 999999.0:\n",
    "            c = all_colors[1]\n",
    "            px = noloss_[i, 2]\n",
    "            py = semloss_[i, 1]\n",
    "\n",
    "        elif srmse_noloss >= 999999.0 and srmse_semloss >= 999999.0:\n",
    "            c = all_colors[2]\n",
    "            px = noloss_[i, 1]\n",
    "            py = semloss_[i, 1]\n",
    "        else:\n",
    "            c = all_colors[3]\n",
    "            px = noloss_[i, 2]\n",
    "            py = semloss_[i, 2]\n",
    "\n",
    "        data_x.append(px)\n",
    "        data_y.append(py)\n",
    "        colors.append(c)\n",
    "\n",
    "\n",
    "    return np.array(data_x), np.array(data_y), np.array(colors)\n",
    "\n",
    "def plot_scatter(data1, data2, colors = None, min_val = None, max_val = None, log_scale = False, axis = None, x_axis = True, y_axis = True, xlabel = \"\", ylabel = \"\"):\n",
    "    #plt.figure(figsize=(8,8))\n",
    "\n",
    "    if axis == None:\n",
    "        _, axis = plt.subplots(1,1)\n",
    "\n",
    "    # Scatter plot\n",
    "    axis.scatter(data1, data2, linewidths=0.0001, c = colors)\n",
    "        \n",
    "    axis.plot([min_val, max_val], [min_val, max_val], 'r--') # type: ignore\n",
    "\n",
    "    # Labels and title\n",
    "    axis.grid(which=\"major\", alpha=0.6)\n",
    "    axis.grid(which=\"minor\", alpha=0.3)\n",
    "    axis.set_xlabel(xlabel)\n",
    "    axis.set_ylabel(ylabel)\n",
    "    axis.set_aspect(\"equal\", adjustable=\"box\")\n",
    "    \n",
    "    if log_scale:\n",
    "        axis.set_yscale(\"log\")\n",
    "        axis.set_xscale(\"log\")\n",
    "\n",
    "        axis.set_xlim([min_val, max_val]) # type: ignore\n",
    "        axis.set_ylim([min_val, max_val])  # type: ignore\n",
    "    else:\n",
    "        # Set limits\n",
    "        axis.set_xlim([min_val, max_val]) # type: ignore\n",
    "        axis.set_ylim([min_val, max_val]) # type: ignore\n",
    "\n",
    "    if not x_axis: \n",
    "        axis.set_xticklabels([])\n",
    "        axis.set_xlabel(\"\")\n",
    "\n",
    "    if not y_axis:\n",
    "        axis.set_yticklabels([])\n",
    "        axis.set_ylabel(\"\")\n",
    "       \n",
    "    #plt.show()\n",
    "\n",
    "\n",
    "#   ------ PLOT DATA -------------\n",
    "\n",
    "epochs = [10, 50, 100, 200]\n",
    "n_points = [1, 8, 16, 32, 64, 128]\n",
    "\n",
    "# fisize (width, height)\n",
    "fig, axes = plt.subplots(len(n_points), len(epochs), figsize=(20,30))\n",
    "\n",
    "glob_max = 0.0\n",
    "glob_min = 9999999.0\n",
    "\n",
    "# Construct a matrix of data: rows -> increasing N. columns -> increasing epochs\n",
    "data_matrix = [[None for j in range(len(epochs))] for i in range(len(n_points))]\n",
    "\n",
    "for i in range(len(n_points)):\n",
    "    np_ = n_points[i]\n",
    "\n",
    "    # Take the data corresponding to a certain NUM_POINTS\n",
    "    # NUM_MAT x EPOCHS x [Mat_idx, epoch, best SRMSE (overall), constraint count, best SRMSE (constrained)]\n",
    "    noloss_ = get_results_numpy(results_noloss, n_points=np_)\n",
    "    semloss_ = get_results_numpy(results_semloss, n_points=np_)\n",
    "    for j in range(len(epochs)):\n",
    "        ep_ = epochs[j] - 1\n",
    "\n",
    "        data_noloss = noloss_[:, ep_, [0, 2, 4]]\n",
    "        data_noloss = data_noloss[data_noloss[:, 0].argsort()]\n",
    "\n",
    "        data_sloss = semloss_[:, ep_, [0, 2, 4]]\n",
    "        data_sloss = data_sloss[data_sloss[:, 0].argsort()]\n",
    "\n",
    "        data_x, data_y, colors = get_scatter_x_y_data(data_noloss, data_sloss)\n",
    "        data_matrix[i][j] = (data_x, data_y, colors) # type: ignore\n",
    "\n",
    "        max_val = np.max([data_x.max(), data_y.max()])\n",
    "        min_val = np.min([data_x[data_x > 0].min(), data_y[data_y > 0].min()])\n",
    "\n",
    "        if max_val > glob_max:\n",
    "            glob_max = max_val\n",
    "        \n",
    "        if min_val < glob_min:\n",
    "            glob_min = min_val\n",
    "\n",
    "\n",
    "for i in range(len(n_points)):\n",
    "    np_ = n_points[i]\n",
    "    for j in range(len(epochs)):\n",
    "        ep_ = epochs[j] - 1\n",
    "\n",
    "        data_x, data_y, colors = data_matrix[i][j] # type: ignore\n",
    "        show_xaxis = i == len(n_points) -1\n",
    "        show_yaxis = j == 0\n",
    "\n",
    "        plot_scatter(data_x, data_y, colors, \n",
    "            min_val = 0.01, max_val = 1, \n",
    "            log_scale = True,\n",
    "            axis = axes[i][j], \n",
    "            x_axis = show_xaxis, y_axis=show_yaxis,\n",
    "            xlabel = rf\"$e$={ep_+1}\", ylabel = rf\"$T$={np_}\"\n",
    "        )\n",
    "\n",
    "fig.supxlabel(r\"SRMSE (No $L^s$)\", fontsize=18)\n",
    "fig.text(x=0.01, y=0.5, s=r\"SRMSE (With $L^s$)\", fontsize=18, rotation=90, va='center')\n",
    "fig.suptitle(f\"[{model_used}] {plot_title}\")\n",
    "\n",
    "legend_elements = [\n",
    "    mpatches.Patch(color = all_colors[0], label = r'SAT only with $L^s$'),\n",
    "    mpatches.Patch(color = all_colors[1], label = r'SAT only no $L^s$'),\n",
    "    mpatches.Patch(color = all_colors[2], label = r'No SAT materials'),\n",
    "    mpatches.Patch(color = all_colors[3], label = r'SAT with/no $L^s$')\n",
    "]\n",
    "\n",
    "fig.legend(\n",
    "    handles=legend_elements,\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, 0.96),\n",
    "    ncol=4,\n",
    "    columnspacing=2,\n",
    "    labelspacing=0.5,\n",
    "    fontsize=16,\n",
    ")\n",
    "\n",
    "fig.tight_layout(rect=[0.02, 0.01, 1, 0.96]) # type: ignore\n",
    "\n",
    "plt.savefig(f\"plots/scatterplot_{plot_title.lower().replace(\" \", \"_\")}_dataset10.svg\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21ff7dc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "def plot_grouped_histplot(data_noloss, data_sloss, axes, max_y, bins, x_axis = True, y_axis = True, xlabel = \"\", ylabel = \"\"):\n",
    "    _, srmse = data_noloss[:,0], data_noloss[:,1]\n",
    "    _, srmse_sloss = data_sloss[:,0], data_sloss[:,1]\n",
    "\n",
    "    df = pd.DataFrame({\n",
    "        \"Srmse\": np.concatenate([srmse_sloss, srmse]),\n",
    "        \"Class\": [r\"With $L^s$\"] * len(srmse_sloss) + [r\"No $L^s$\"] * len(srmse)\n",
    "    })\n",
    "\n",
    "    axes.set_xlabel(xlabel)\n",
    "    axes.set_ylabel(ylabel)\n",
    "    axes.set_ylim([0, max_y])\n",
    "    axes.grid(axis='y', alpha=0.5)\n",
    "\n",
    "    if not x_axis:\n",
    "        axes.set_xticklabels([])\n",
    "        axes.set_xlabel(\" \")\n",
    "    \n",
    "    if not y_axis:\n",
    "        axes.set_yticklabels([])\n",
    "        axes.set_ylabel(\" \")\n",
    "\n",
    "    if df.empty:\n",
    "        return\n",
    "\n",
    "    \n",
    "    h = sns.histplot(data=df, x=\"Srmse\", hue=\"Class\", palette=\"deep\", multiple=\"dodge\", bins=bins, element=\"step\", ax=axes)\n",
    "    h.get_legend().set_title(\"\")\n",
    "    plt.setp(h.get_legend().get_texts(), fontsize='14')  \n",
    "    plt.setp(h.get_legend().get_title(), fontsize='0')\n",
    "\n",
    "\n",
    "# ---- PLOT DATA --------\n",
    "\n",
    "epochs = [10, 50, 100, 200]\n",
    "n_points = [1, 8, 16, 32, 64, 128]\n",
    "\n",
    "# Take the maximum srmse obtained with N=128 and E=200\n",
    "noloss_ = get_results_numpy(results_noloss, n_points=2)\n",
    "data_noloss = noloss_[:, -1, [4]]\n",
    "# Filter out points with srmse > 9999999.0 (does not satisfy constraint)\n",
    "data_noloss = data_noloss[data_noloss[:, 0] < 9999999.0]\n",
    "\n",
    "semloss_ = get_results_numpy(results_semloss, n_points=2)\n",
    "data_semloss = semloss_[:, -1, [4]]\n",
    "data_semloss = data_semloss[data_semloss[:, 0] < 9999999.0]\n",
    "\n",
    "global_max_srmse = np.max([data_noloss.max() if data_noloss.size > 0 else 0, data_semloss.max() if data_semloss.size > 0 else 0])\n",
    "global_max_count = 0\n",
    "num_bins = 20\n",
    "bin_edges = np.linspace(0, global_max_srmse, num_bins + 1)\n",
    "\n",
    "# Construct a matrix of data: rows -> increasing N. columns -> increasing epochs\n",
    "data_matrix = [[None for j in range(len(epochs))] for i in range(len(n_points))]\n",
    "for i in range(len(n_points)):\n",
    "    np_ = n_points[i]\n",
    "\n",
    "    # Take the data corresponding to a certain NUM_POINTS\n",
    "    # NUM_MAT x EPOCHS x [Mat_idx, epoch, best SRMSE (overall), constraint count, best SRMSE (constrained)]\n",
    "    noloss_ = get_results_numpy(results_noloss, n_points=np_)\n",
    "    semloss_ = get_results_numpy(results_semloss, n_points=np_)\n",
    "    for j in range(len(epochs)):\n",
    "        ep_ = epochs[j] - 1\n",
    "\n",
    "        # Take data corresponding also to a certain EPOCH\n",
    "        # Filter based on srmse < 9999999.0 (no SAT material)\n",
    "        # Construct the histogram in order to take the maximum count\n",
    "        # [0, 4] = [mat_idx, best_constrained_srmse]\n",
    "        data_noloss = noloss_[:, ep_, [0, 4]]\n",
    "        data_noloss = data_noloss[data_noloss[:, 1] < 9999999.0]\n",
    "        counts_noloss, bin_edges_noloss = np.histogram(data_noloss, bins=bin_edges)\n",
    "\n",
    "        data_sloss = semloss_[:, ep_, [0, 4]]\n",
    "        data_sloss = data_sloss[data_sloss[:, 1] < 9999999.0]\n",
    "        counts_sloss, bin_edges_sloss = np.histogram(data_sloss, bins=bin_edges)\n",
    "\n",
    "        max_y = np.max([counts_noloss, counts_sloss])\n",
    "        if max_y > global_max_count:\n",
    "            global_max_count = max_y\n",
    "\n",
    "        # Store the computed data inside a data matrix\n",
    "        data_matrix[i][j] = (data_noloss, data_sloss) # type: ignore\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(len(n_points), len(epochs), figsize=(20,30))\n",
    "for i in range(len(n_points)):\n",
    "    np_ = n_points[i]\n",
    "\n",
    "    for j in range(len(epochs)):\n",
    "        ep_ = epochs[j] - 1\n",
    "\n",
    "        show_xaxis = i == len(n_points) -1\n",
    "        show_yaxis = j == 0\n",
    "\n",
    "\n",
    "        # Retrieve the data computed previously and plot it\n",
    "        plot_grouped_histplot(\n",
    "            # Data_noloss         data_semloss\n",
    "            data_matrix[i][j][0], data_matrix[i][j][1],   # type: ignore\n",
    "            axes[i][j], \n",
    "            max_y = global_max_count, \n",
    "            bins = bin_edges,\n",
    "            x_axis = show_xaxis, y_axis = show_yaxis,\n",
    "            xlabel = rf\"$e$={ep_+1}\", ylabel = rf\"$T$={np_}\"\n",
    "        )           \n",
    "\n",
    "fig.suptitle(f\"[{model_used}] {plot_title}\")\n",
    "fig.supylabel(\"Counts\")\n",
    "fig.supxlabel(\"SRMSE\", fontsize=14)\n",
    "fig.tight_layout(rect=[0.02, 0.01, 1, 0.98]) # type: ignore\n",
    "\n",
    "plt.savefig(f\"plots/histogram_{plot_title.lower().replace(\" \", \"_\")}_dataset10.svg\", dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4e0c1c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.stats as stats\n",
    "\n",
    "def compute_p_ci_plots(results_noloss, results_semloss, n_point, alpha=0.05, axes=None, y_axis = False, xlabel = \"\", ylabel = \"\"):\n",
    "    noloss_ = get_results_numpy(results_noloss, n_points=np_)\n",
    "    semloss_ = get_results_numpy(results_semloss, n_points=np_)\n",
    "\n",
    "    # For each material x_i, and for each epoch e_i_j\n",
    "    # Percentage of N_points that satisfy the constraint\n",
    "    p_semloss = semloss_[:, :, 3] / n_point\n",
    "    p_noloss = noloss_[:, :, 3] / n_point\n",
    "\n",
    "    # Calculate mean and standard error of p \n",
    "    mu_semloss_ = p_semloss.mean(axis = 0)\n",
    "    se_semloss_ = stats.sem(p_semloss, axis=0)\n",
    "\n",
    "    mu_noloss_ = p_noloss.mean(axis = 0)\n",
    "    se_noloss_ = stats.sem(p_noloss, axis=0)\n",
    "\n",
    "    # Caluclate margin of error given alpha\n",
    "    me_semloss = stats.t.ppf(1 - alpha/2, df=mu_semloss_.shape[0] - 1) * se_semloss_\n",
    "    me_noloss = stats.t.ppf(1 - alpha/2, df=mu_noloss_.shape[0] - 1) * se_noloss_\n",
    "\n",
    "\n",
    "    # Calulcate lower/upper bound of the confidence interval\n",
    "    ub_semloss = mu_semloss_ + me_semloss\n",
    "    lb_semloss = mu_semloss_ - me_semloss\n",
    "\n",
    "    ub_noloss = mu_noloss_ + me_noloss\n",
    "    lb_noloss = mu_noloss_ - me_noloss\n",
    "\n",
    "    epochs = np.arange(mu_semloss_.shape[0])\n",
    "\n",
    "    if axes == None:\n",
    "        _, axes = plt.subplots(1, 1)\n",
    "        \n",
    "    axes.plot(epochs, mu_semloss_ * 100, label=r'[With $L^s$] Mean proportion')\n",
    "    axes.fill_between(epochs, lb_semloss * 100, ub_semloss * 100, alpha=0.3) #, label=f'{int((1-alpha) * 100)}% CI')\n",
    "\n",
    "    axes.plot(epochs, mu_noloss_ * 100, label=r'[No $L^s$] Mean proportion')\n",
    "    axes.fill_between(epochs, lb_noloss * 100, ub_noloss * 100, alpha=0.3) #, label=f'{int((1-alpha) * 100)}% CI')\n",
    "\n",
    "    if not y_axis:\n",
    "        axes.set_yticklabels([])\n",
    "        axes.set_ylabel(\"\")\n",
    "\n",
    "    axes.set_xlabel(xlabel)\n",
    "\n",
    "    axes.grid(axis='y', alpha=0.5)\n",
    "    axes.legend()\n",
    "\n",
    "    # Return the minimum of lower bounds and the max of upper bounds\n",
    "    return np.min([lb_semloss.min(), lb_noloss.min()]), np.max([ub_semloss.max(), ub_noloss.max()])\n",
    "\n",
    "\n",
    "n_points = [1, 8, 16, 32, 64, 128]\n",
    "fig, axes = plt.subplots(1, len(n_points), figsize=(50, 10))\n",
    "global_max_p = 0\n",
    "global_min_p = 100\n",
    "\n",
    "for i in range(len(n_points)):\n",
    "    np_ = n_points[i]\n",
    "\n",
    "    min_p, max_p = compute_p_ci_plots(results_noloss, results_semloss, n_point=np_, alpha=0.05, axes=axes[i], y_axis = i == 0, xlabel = rf\"$T$={np_}\")\n",
    "\n",
    "    if max_p > global_max_p:\n",
    "        global_max_p = max_p\n",
    "\n",
    "    if min_p < global_min_p:\n",
    "        global_min_p = min_p\n",
    "\n",
    "for i in range(len(n_points)):\n",
    "    axes[i].set_ylim([global_min_p * 100, global_max_p * 100])\n",
    "    \n",
    "fig.suptitle(f\"[{model_used}] {plot_title}\")\n",
    "fig.supxlabel('Epoch')\n",
    "fig.supylabel('Mean proportion of SAT materials (\\\\%)')\n",
    "fig.tight_layout(rect=[0.02, 0.01, 1, 1]) # type: ignore\n",
    "\n",
    "plt.savefig(f\"plots/lineplot_p_ci95_{plot_title.lower().replace(\" \", \"_\")}_dataset10.svg\", dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
