{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94cc36d7-46e9-41bf-b3b8-22b50c22c579",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(r'..')\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import seaborn as sns\n",
    "import math\n",
    "import statistics\n",
    "import pickle as pkl\n",
    "import torch\n",
    "\n",
    "from expression_tree import simplify_equation, calc_r_squared, NMSE_reward_func\n",
    "from scipy import optimize\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51602c80-c33e-4c55-97ce-0d013e3343de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def smooth_data(data, window_size=5):\n",
    "    # return [max(data[i:i+window_size]) for i in range(len(data)-window_size+1)]\n",
    "    return [sum(data[i:i+window_size])/window_size for i in range(len(data)-window_size+1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b7d8a55-5948-46b2-bda2-874e0e3d4d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loading the dictionary\n",
    "with open(os.getcwd() + \"\\\\..\\\\run_data\\\\nguyen_4_2321_drop_dct_0.pkl\", 'rb') as file:\n",
    "        results = pkl.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e77237a1-f46c-4583-b64b-ab97ebc525dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results[\"Hyper-parameters\"].keys():\n",
    "    print(f\"{key}: {results['Hyper-parameters'][key]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737ae111-2eef-4342-a1ae-a80dd6a45da6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results['Training Cycle'][0]['parameters'].keys():\n",
    "    print(f\"{key}: {results['Training Cycle'][0]['parameters'][key]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b8bfa7b-05ed-41f7-a0e1-c3066528b28a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results['Best Function'].keys(): \n",
    "    print(f'{key}: {results[\"Best Function\"][key]}')\n",
    "time_keys = ['Sample Time', 'Opt Time', 'Reward', 'Prediction']\n",
    "print(f\"Run Time: {round(sum([sum(results['Training Cycle'][0]['Timings'][key]) for key in time_keys])/3600, 2)} hours\")\n",
    "id_epoch = np.argmax(np.array([max(r) for r in results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]]))\n",
    "print(f\"Discovered Epoch: {id_epoch}\")\n",
    "equ = results[\"Best Function\"][\"Equation\"]\n",
    "if len(results[\"Best Function\"][\"Constants\"]) != 0:\n",
    "    if isinstance(results[\"Best Function\"][\"Constants\"][0], float):\n",
    "        results[\"Best Function\"][\"Constants\"] = [[r] for r in results[\"Best Function\"][\"Constants\"]]\n",
    "    for i, c in enumerate(results[\"Best Function\"][\"Constants\"]):\n",
    "        equ = equ.replace(f\"c[{i}]\", str(round(c[0], 3)))\n",
    "equ = equ.replace(\"np.\", \"torch.\")\n",
    "simplified_equation = simplify_equation(equ, 0, 25)\n",
    "print(f\"Final Equation: {simplified_equation}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7cea996-098b-44fa-9d84-de29d7a96711",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "x = np.array(results['Training Cycle'][0][\"Training Data\"][0])\n",
    "c = np.array(results[\"Best Function\"][\"Constants\"])\n",
    "y_pred = torch.tensor(eval(results[\"Best Function\"][\"Equation\"].replace(\"torch.\", \"np.\")))\n",
    "y_real = torch.tensor(results['Training Cycle'][0][\"Training Data\"][1])\n",
    "v = torch.mean((y_pred - y_real) ** 2)\n",
    "sample_size = len(y_real)\n",
    "MSE = torch.mean((y_pred - y_real)**2)\n",
    "r_squared = 1 - torch.sum((y_pred - y_real)**2)/torch.sum((y_real - torch.mean(y_real))**2)\n",
    "reward = NMSE_reward_func(y_pred.numpy(), y_real.numpy(), y_real.std())\n",
    "\n",
    "print(\"Training Data\")\n",
    "print(f\"R2: {r_squared}\")\n",
    "print(f\"Reward: {reward}\")\n",
    "print(f\"MSE: {MSE}\")\n",
    "print(f\"NMSE: {MSE/(y_real**2).mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1e86611-cd9d-4020-ba2f-c84985bcd694",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_x = results[\"Test Data\"][0]\n",
    "x = test_x\n",
    "c = np.array(results[\"Best Function\"][\"Constants\"])\n",
    "y_pred = torch.tensor(eval(results[\"Best Function\"][\"Equation\"]))\n",
    "y_real = torch.tensor(results[\"Test Data\"][1])\n",
    "\n",
    "v = torch.mean((y_pred - y_real) ** 2)\n",
    "sample_size = len(y_real)\n",
    "MSE = torch.mean((y_pred - y_real)**2)\n",
    "r_squared = 1 - torch.sum((y_pred - y_real)**2)/torch.sum((y_real - torch.mean(y_real))**2)\n",
    "reward = NMSE_reward_func(y_pred.numpy(), y_real.numpy(), y_real.std())\n",
    "\n",
    "print(\"Test Data\")\n",
    "print(f\"R2: {r_squared}\")\n",
    "print(f\"Reward: {reward}\")\n",
    "print(f\"MSE: {MSE}\")\n",
    "print(f\"NMSE: {MSE/(y_real**2).mean()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221a64f3-9b12-421a-81ff-1a0571035790",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.title(\"Comparison between Best Function and Ground Truth\")\n",
    "plt.scatter(y_pred, y_real, label=\"Observed Data\", s=20)\n",
    "x = np.linspace(min(y_real), max(y_real), 100)\n",
    "X = np.linspace(min(y_real), max(y_real), 100)\n",
    "plt.plot(X, X, color=\"black\")\n",
    "color = 'tab:red'\n",
    "# plt.fill_between(X, X - np.sqrt(results[\"Best Function\"][\"Noise\"]),\n",
    "#                  X +  np.sqrt(results[\"Best Function\"][\"Noise\"]), color=color, alpha=0.2)\n",
    "plt.xlabel(\"Modeled Y\")\n",
    "plt.ylabel(\"Actaul Y\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b35a1a7-708d-4495-8555-ff2baf3b82d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "window_size = math.ceil(len(results['Training Cycle'][0][\"Iteration Info\"][\"Loss\"])/50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d561a99e-6bfc-495f-8a96-b892bb92d20c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, iter in enumerate(results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]):\n",
    "    results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"][i] = [j for j in iter if j > -1E32]\n",
    "results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"] = [np.array(r)[~np.isnan(r)] for r in results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]]\n",
    "results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"] = [np.array(r)[np.isfinite(r)] for r in results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f721f014-1ced-4312-a709-7b2c9e97ce78",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,1)\n",
    "ax.plot(smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"New Equations\"], 1))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e944fdcf-698e-4f3f-9d24-b29ce3a91a33",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,1)\n",
    "tmp = [-0.005 * temp.to(torch.device('cpu')) for temp in results['Training Cycle'][0][\"Iteration Info\"][\"Full Entropy\"]]\n",
    "\n",
    "smoothed_loss = smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"Loss\"], 1)\n",
    "ax.plot(range(len(smoothed_loss)), smoothed_loss, label=\"Total Loss\")\n",
    "ax.plot(range(len(smoothed_loss)), smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"Policy Loss\"], 1), label=\"Policy Loss\")\n",
    "ax.plot(range(len(smoothed_loss)), smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"Entropy Loss\"], 1), label=\"Entropy Loss\")\n",
    "ax.plot(range(len(smoothed_loss)), smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"KL Loss\"], 1), label=\"KL Loss\")\n",
    "\n",
    "\n",
    "# Adding labels and title\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Smoothed Loss\")\n",
    "ax.set_title(\"Different Losses\")\n",
    "# ax.set_xlim(0, 20)\n",
    "\n",
    "ax.legend(title=\"Losses\", loc=\"center left\", bbox_to_anchor=(1, 0, 0.5, 1))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a421263b-5181-447c-87a1-6a9d5c38c60a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1)\n",
    "node_means = [i.float().mean().to(torch.device(\"cpu\")) for i in results['Training Cycle'][0][\"Iteration Info\"][\"Node Counts\"]]\n",
    "node_median = [i.median().to(torch.device(\"cpu\")) for i in results['Training Cycle'][0][\"Iteration Info\"][\"Node Counts\"]]\n",
    "node_mode = [torch.mode(i.to(torch.device(\"cpu\")))[0] for i in results['Training Cycle'][0][\"Iteration Info\"][\"Node Counts\"]]\n",
    "node_stds =  [i.float().std().to(torch.device(\"cpu\")) for i in results['Training Cycle'][0][\"Iteration Info\"][\"Node Counts\"]]\n",
    "\n",
    "# Plotting the Best Reward with smoothing\n",
    "smoothed_node_means = smooth_data(node_means, window_size)\n",
    "smoothed_node_median = smooth_data(node_median, window_size)\n",
    "smoothed_node_mode = smooth_data(node_mode, window_size)\n",
    "smoothed_node_stds = smooth_data(node_stds, window_size)\n",
    "\n",
    "x = range(len(smoothed_node_means))\n",
    "color = 'tab:red'\n",
    "ax.plot(x, smoothed_node_means, label=\"Node Means\")\n",
    "ax.plot(x, smoothed_node_median, label=\"Node Medians\")\n",
    "ax.plot(x, smoothed_node_mode, label=\"Node Modes\")\n",
    "ax.fill_between(x, np.array(smoothed_node_means) - np.array(smoothed_node_stds),\n",
    "                 np.array(smoothed_node_means) + np.array(smoothed_node_stds), color=color, alpha=0.2)\n",
    "\n",
    "# Adding labels and title\n",
    "ax.set_xlabel(\"Iteration\")\n",
    "ax.set_ylabel(\"Node Counts\")\n",
    "ax.set_title(\"Smoothed Average Node Numbers per Expression Tree\")\n",
    "\n",
    "# Adding legend\n",
    "ax.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1318bd3-1130-4ffc-a047-eb20490d2d47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming 'results' is your data dictionary\n",
    "node_nums = [i.to(torch.device(\"cpu\")).numpy() for i in results['Training Cycle'][0][\"Iteration Info\"][\"Node Counts\"]]\n",
    "num_epochs = len(node_nums)\n",
    "\n",
    "# Create a continuous distribution plot with lines for each epoch\n",
    "plt.figure(figsize=(12, 8))\n",
    "\n",
    "# Define a single color from the 'viridis' colormap\n",
    "base_color = mpl.colormaps[\"Blues\"]\n",
    "    \n",
    "for epoch_number, node_num_epoch in enumerate(node_nums):\n",
    "    node_num_epoch = sorted(node_num_epoch, reverse=True)\n",
    "    sns.kdeplot(node_num_epoch, color=base_color(epoch_number/num_epochs), fill=False, label='', alpha=0.5, warn_singular=False)\n",
    "plt.xlabel(\"Node Number\")\n",
    "plt.ylabel(\"Density\")\n",
    "# plt.ylim(0, 15)\n",
    "plt.title(\"Continuous Distribution Plot for Expression Size Across Epochs\")\n",
    "\n",
    "# Create a continuous legend using a color bar outside the plot\n",
    "cax = plt.axes([0.92, 0.1, 0.02, 0.8])  # Adjust the position and size of the color bar\n",
    "cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=base_color), cax=cax)\n",
    "\n",
    "cbar.set_label('Epoch Number')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6757f895-a2c4-4463-9424-2e6823ed237f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2)\n",
    "fig.set_size_inches(14, 7)\n",
    "\n",
    "# Plotting the Best Reward with smoothing\n",
    "window_size = 1\n",
    "smoothed_best_reward = smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"Best Reward\"], window_size)\n",
    "x = range(len(smoothed_best_reward))\n",
    "# smoothed_best_reward = 1 / (1 + np.exp(-np.array(smoothed_best_reward)/500))\n",
    "ax[0].plot(x, smoothed_best_reward, label=\"Best\")\n",
    "\n",
    "# Plotting the Median Reward with smoothing\n",
    "smoothed_median_reward = smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"Median Reward\"], window_size)\n",
    "# smoothed_median_reward = 1 / (1 + np.exp(-np.array(smoothed_median_reward)/500))\n",
    "ax[0].plot(x, smoothed_median_reward, label=r'Median Top $\\alpha\\%$')\n",
    "\n",
    "# Plotting the Baseline Reward with smoothing\n",
    "smoothed_baseline_reward = smooth_data(results['Training Cycle'][0][\"Iteration Info\"][\"Baseline Reward\"], window_size)\n",
    "# smoothed_baseline_reward = 1 / (1 + np.exp(-np.array(smoothed_baseline_reward)/500))\n",
    "ax[0].plot(x, smoothed_baseline_reward, label=r'$R_\\alpha$')\n",
    "\n",
    "smoothed_min_reward = smooth_data([np.mean(np.array(i)) for i in results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]], window_size)\n",
    "# smoothed_min_reward = 1 / (1 + np.exp(-np.array(smoothed_min_reward)/500))\n",
    "ax[0].plot(x, smoothed_min_reward, label=\"Median\")\n",
    "\n",
    "# Adding labels and title\n",
    "ax[0].set_xlabel(\"Epoch\", fontsize=\"14\")\n",
    "ax[0].set_ylabel(\"Reward\", fontsize=\"14\")\n",
    "ax[0].set_title(\"(a)\", fontsize=\"16\", y=-0.25)\n",
    "# ax.set_xlim(0, 50)\n",
    "# ax[0].set_ylim(-700000, -600000)\n",
    "# ax[0].set_ylim(-10000, -3000)\n",
    "# Adding legend\n",
    "\n",
    "ax[0].legend(loc=\"lower left\", fontsize=\"10\")\n",
    "\n",
    "expressions_pruned = []\n",
    "for i in range(len(results['Training Cycle'][0][\"Iteration Info\"][\"Best Reward\"])):\n",
    "    current_rewards = results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"][i]\n",
    "    epsilon = results['Training Cycle'][0][\"Iteration Info\"][\"Baseline Reward\"][i]\n",
    "    current_rewards = [entry for entry in current_rewards if entry > epsilon]\n",
    "    percent = 100 * (np.abs(current_rewards-epsilon) < 0.001).mean()\n",
    "    expressions_pruned.append(percent)\n",
    "smoothed_percent = smooth_data(expressions_pruned, window_size)\n",
    "\n",
    "ax[1].plot(x, smoothed_percent)\n",
    "ax[1].set_xlabel(\"Epoch\", fontsize=\"14\")\n",
    "ax[1].set_ylabel(\"Percentage of Top $\\\\alpha$ % Expressions Pruned\", fontsize=\"14\")\n",
    "ax[1].set_title(\"(b)\", fontsize=\"16\", y=-0.25)\n",
    "# ax[1].set_ylim(-5, 40)\n",
    "\n",
    "plt.show()\n",
    "# fig.savefig(\"NMSE_unbiased_seed_314.pdf\", dpi=500, bbox_inches='tight', pad_inches=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8ff913-7b38-4ce5-91fe-8f2762f7b487",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1)\n",
    "fig.set_size_inches(7, 7)\n",
    "\n",
    "\n",
    "cum_best = [-np.inf]\n",
    "for i, r in enumerate(results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]):\n",
    "    if max(r) > cum_best[-1]:\n",
    "        cum_best.append(max(r))\n",
    "    else:\n",
    "        cum_best.append(cum_best[-1])\n",
    "\n",
    "cum_epsilon_m = [-np.inf]\n",
    "for i, r in enumerate(results['Training Cycle'][0][\"Iteration Info\"][\"Median Reward\"]):\n",
    "    if r > cum_epsilon_m[-1]:\n",
    "        cum_epsilon_m.append(r)\n",
    "    else:\n",
    "        cum_epsilon_m.append(cum_epsilon_m[-1])\n",
    "        \n",
    "cum_epsilon = [-np.inf]\n",
    "for i, r in enumerate(results['Training Cycle'][0][\"Iteration Info\"][\"Baseline Reward\"]):\n",
    "    if r > cum_epsilon[-1]:\n",
    "        cum_epsilon.append(r)\n",
    "    else:\n",
    "        cum_epsilon.append(cum_epsilon[-1])\n",
    "        \n",
    "cum_median = [-np.inf]\n",
    "for i, r in enumerate(results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]):\n",
    "    if np.median(r) > cum_median[-1]:\n",
    "        cum_median.append(np.median(r))\n",
    "    else:\n",
    "        cum_median.append(cum_median[-1])\n",
    "        \n",
    "# Plotting the Best Reward with smoothing\n",
    "window_size = 1\n",
    "smoothed_best_reward = smooth_data(cum_best, window_size)\n",
    "x = range(len(smoothed_best_reward))\n",
    "# smoothed_best_reward = 1 / (1 + np.exp(-np.array(smoothed_best_reward)/500))\n",
    "\n",
    "ax.plot(x, smoothed_best_reward, label=\"Best\")\n",
    "\n",
    "# Plotting the Epsilon Reward with smoothing\n",
    "ax.plot(x, smooth_data(cum_epsilon_m, window_size), label=r'Median Top $\\alpha\\%$')\n",
    "\n",
    "# Plotting the Epsilon Reward with smoothing\n",
    "ax.plot(x, smooth_data(cum_epsilon, window_size), label=r'$R_\\alpha$')\n",
    "\n",
    "# Plotting the Median Reward with smoothing\n",
    "ax.plot(x, smooth_data(cum_median, window_size), label=\"Median\")\n",
    "\n",
    "\n",
    "# Adding labels and title\n",
    "ax.set_xlabel(\"Epoch\", fontsize=\"14\")\n",
    "ax.set_ylabel(\"NMSE Reward\", fontsize=\"14\")\n",
    "ax.set_title(\"Cummlative Performance\", fontsize=\"16\") # Cumulative Performance by BIC Loss\n",
    "# ax.set_xlim(0,200)\n",
    "# ax.set_ylim(79000, 82000)\n",
    "# Adding legend\n",
    "ax.legend( loc=\"center right\")\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n",
    "# fig.savefig(\"Mutations.png\", dpi=500, bbox_inches='tight', pad_inches=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8de0844d-4ca7-44c7-80e5-1374664eb32e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming 'results' is your data dictionary\n",
    "np.seterr(over='ignore')\n",
    "rewards_per_epoch = results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]\n",
    "num_epochs = len(rewards_per_epoch)\n",
    "\n",
    "# Create a continuous distribution plot with lines for each epoch\n",
    "plt.figure(figsize=(12, 8))\n",
    "\n",
    "# Define a single color from the 'viridis' colormap\n",
    "base_color = mpl.colormaps[\"Blues\"]\n",
    "    \n",
    "for epoch_number, rewards_for_single_epoch in enumerate(rewards_per_epoch):\n",
    "    rewards_for_single_epoch = rewards_for_single_epoch[rewards_for_single_epoch != 0.0]\n",
    "    rewards_for_single_epoch = sorted(rewards_for_single_epoch, reverse=True)\n",
    "    # rewards_for_single_epoch = 1 / (1 + np.exp(-np.array(rewards_for_single_epoch)/1E4))\n",
    "    sns.kdeplot(rewards_for_single_epoch, color=base_color(epoch_number/num_epochs), fill=False, label='', alpha=0.5, warn_singular=False)\n",
    "plt.xlabel(\"Rewards\")\n",
    "plt.ylabel(\"Density\")\n",
    "plt.title(\"Continuous Distribution Plot for Rewards Across Epochs\")\n",
    "\n",
    "# Create a continuous legend using a color bar outside the plot\n",
    "cax = plt.axes([0.92, 0.1, 0.02, 0.8])  # Adjust the position and size of the color bar\n",
    "cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=base_color), cax=cax)\n",
    "\n",
    "cbar.set_label('Epoch Number')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cd09501-f250-4a7a-a11c-3b0ef5f44298",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming 'results' is your data dictionary\n",
    "rewards_per_epoch = results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]\n",
    "num_epochs = len(rewards_per_epoch)\n",
    "window_size = 1  # Set your desired window size\n",
    "\n",
    "# Create a continuous distribution plot with lines for each epoch\n",
    "plt.figure(figsize=(12, 8))\n",
    "\n",
    "# Define a single color from the 'Blues' colormap\n",
    "base_color = mpl.colormaps[\"Blues\"]\n",
    "\n",
    "for epoch_number, rewards_for_single_epoch in enumerate(rewards_per_epoch):\n",
    "    rewards_for_single_epoch = sorted(rewards_for_single_epoch, reverse=True)[:50]\n",
    "    # rewards_for_single_epoch = 1 / (1 + np.exp(-np.array(rewards_for_single_epoch)/1E4))\n",
    "    sns.kdeplot(rewards_for_single_epoch, color=base_color(epoch_number / num_epochs), fill=False, label=f'Epoch {epoch_number}', alpha=0.5, warn_singular=False)\n",
    "\n",
    "plt.xlabel(\"Rewards\")\n",
    "plt.ylabel(\"Density\")\n",
    "plt.ylim(0, 15)\n",
    "plt.title(\"Continuous Distribution Plot for Top 50 Rewards Across Epochs\")\n",
    "\n",
    "# Create a continuous legend using a color bar outside the plot\n",
    "cax = plt.axes([0.92, 0.1, 0.02, 0.8])  # Adjust the position and size of the color bar\n",
    "cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=base_color), cax=cax, ticks=[0, 1])\n",
    "cbar.set_label('Epoch Number')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e26a536-4d14-423c-a183-b888d24f8223",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming 'results' is your data dictionary\n",
    "temp_rewards = []\n",
    "for i, iter in enumerate(results['Training Cycle'][0][\"Iteration Info\"][\"Rewards\"]):\n",
    "    temp_rewards.append([j for j in iter if j != 0])\n",
    "rewards_entries = [len(entries) for entries in temp_rewards]\n",
    "iterations = range(1, len(rewards_entries) + 1)\n",
    "\n",
    "smoothed_entries = smooth_data(rewards_entries, window_size=1)\n",
    "\n",
    "# Create a line plot with smoothing and without markers\n",
    "plt.plot(iterations[len(iterations) - len(smoothed_entries):], smoothed_entries, linestyle='-', color='blue')\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Smoothed Number of Entries in Rewards\")\n",
    "plt.title(\"Smoothed Number of Entries in Rewards Over Iterations\")\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f383952-5524-4859-ba3b-e36612e12018",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming 'results' is your data dictionary\n",
    "time_keys = ['Sample Time', 'Opt Time', 'Reward', 'Prediction']\n",
    "\n",
    "# Example data (replace this with your actual data)\n",
    "time_values = [sum(results['Training Cycle'][0]['Timings'][key]) for key in time_keys]\n",
    "print(time_values)\n",
    "time_values = [100 * v/sum(time_values) for v in time_values]\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "\n",
    "# Creating a pie chart with black lines separating sections\n",
    "wedges, texts, autotexts = ax.pie(time_values, labels=None, autopct='', startangle=90, wedgeprops=dict(linewidth=2, edgecolor='black'))\n",
    "\n",
    "# Creating a legend with labeled percents\n",
    "legend_labels = [f\"{key}: {time:.1f}%\" for key, time in zip(time_keys, time_values)]\n",
    "ax.legend(wedges, legend_labels, title=\"Time Components\", loc=\"center left\", bbox_to_anchor=(1, 0, 0.5, 1))\n",
    "\n",
    "# Adding title\n",
    "ax.set_title(\"Overall Time Breakdown per Epoch\")\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74122f1-c3ac-4848-994f-ae009a69cb0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming 'results' is your data dictionary\n",
    "time_keys = results['Training Cycle'][0]['Timings'][\"Sample Time In-depth\"][0].keys()\n",
    "sample_timings = {}\n",
    "for key in time_keys:\n",
    "    sample_timings[key] = sum([results['Training Cycle'][0]['Timings'][\"Sample Time In-depth\"][k][key] for k in range(len(results['Training Cycle'][0]['Timings'][\"Sample Time In-depth\"]))])\n",
    "sample_timings = (sample_timings.values())\n",
    "sample_timings = [100 * v/sum(sample_timings) for v in sample_timings]\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "\n",
    "# Creating a pie chart with black lines separating sections\n",
    "wedges, texts, autotexts = ax.pie(sample_timings, labels=None, autopct='', startangle=90, wedgeprops=dict(linewidth=2, edgecolor='black'))\n",
    "\n",
    "# Creating a legend with labeled percents\n",
    "legend_labels = [f\"{key}: {time:.1f}%\" for key, time in zip(time_keys, sample_timings)]\n",
    "ax.legend(wedges, legend_labels, title=\"Time Components\", loc=\"center left\", bbox_to_anchor=(1, 0, 0.5, 1))\n",
    "\n",
    "# Adding title\n",
    "ax.set_title(\"Time Breakdown for Sampling Step\")\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23daff1f-f186-4802-868f-255c166f9d23",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1743db5-fe94-4bef-8114-54c48c7286b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b021da5-4f51-4d82-bc4f-6ea293673f0c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "301f181d-e6f9-4cd7-afcd-7fb299a90fce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c3f1fd5-f70a-470c-99dd-9ebf0018c17a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "739cd869-f1c2-4e8b-b195-3bdf749ab75c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0742b8c5-4a26-4685-9f62-b82bd0169657",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ccaeca2-0654-4606-b493-d1d46a67691c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b348c39b-62b8-4cba-b72e-e3a160e6eaa0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
