{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105ee81b-a62d-4021-80dd-ee6d534b9d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import pandas as pd\n",
    "\n",
    "# Initialize wandb\n",
    "wandb.init(project=\"your_project_name\")\n",
    "\n",
    "# Fetch runs from your project\n",
    "api = wandb.Api()\n",
    "runs = api.runs(\"multi_reward_feedback\", filters={\"display_name\": {\"$regex\": \"^RL_.*\"}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8516eb46-50d9-4978-955e-b6ab46c474ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a list to store data from filtered runs\n",
    "filtered_run_data = []\n",
    "\n",
    "# Iterate through the runs\n",
    "for run in runs:\n",
    "    # Check if the run name starts with \"ppo_\"\n",
    "    if run.name.startswith(\"RL_\") and \"noise\" not in run.name and \"ensemble\" not in run.name:\n",
    "        # Get the summary statistics (includes final values of metrics)\n",
    "        summary = run.summary._json_dict\n",
    "\n",
    "        # Get the history (includes all logged metrics)\n",
    "        history = run.history()\n",
    "\n",
    "        # Combine summary and history data\n",
    "        run_data = {\n",
    "            \"run_id\": run.id,\n",
    "            \"run_name\": run.name,\n",
    "            **summary,\n",
    "            **{f\"{k}_history\": v.tolist() for k, v in history.items()}\n",
    "        }\n",
    "\n",
    "        filtered_run_data.append(run_data)\n",
    "\n",
    "# Create a DataFrame from filtered run data\n",
    "orig_df = pd.DataFrame(filtered_run_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a602d039-ad57-4472-a062-59bda636ce3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict, OrderedDict\n",
    "from matplotlib.colors import hex2color, rgb2hex\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "\n",
    "def generate_muted_colors(hex_color, num_colors=10, mute_factor=0.5):\n",
    "    rgb = hex2color(hex_color)\n",
    "    muted_colors = []\n",
    "    for i in range(num_colors):\n",
    "        muted_rgb = tuple(c + (1 - c) * mute_factor for c in rgb)\n",
    "        muted_hex = rgb2hex(muted_rgb)\n",
    "        muted_colors.append(muted_hex)\n",
    "    return muted_colors\n",
    "\n",
    "df = orig_df.copy()\n",
    "\n",
    "score_field = \"eval/mean_reward_history\"\n",
    "\n",
    "# Function to extract environment and feedback type from run name\n",
    "def extract_info(run_name):\n",
    "    parts = run_name.split('_')\n",
    "    return parts[2], parts[-2]\n",
    "\n",
    "# Function to interpolate NaN values in a series\n",
    "def interpolate_nans(series):\n",
    "    return pd.Series(series).interpolate().ffill().bfill().values\n",
    "\n",
    "# Group runs by environment and feedback type\n",
    "grouped_runs = defaultdict(lambda: defaultdict(list))\n",
    "for _, row in df.iterrows():\n",
    "    env, feedback = extract_info(row['run_name'])\n",
    "    # Convert string \"nan\" to np.nan and other values to float\n",
    "    if isinstance(row[score_field], float):\n",
    "        print(\"WTF\")\n",
    "        continue\n",
    "    row[score_field] = [np.nan if x == \"nan\" else x for x in row[score_field]]\n",
    "    # Interpolate NaN values in val_loss_history\n",
    "    row[score_field] = interpolate_nans(row[score_field])\n",
    "    grouped_runs[env][feedback].append(row)\n",
    "\n",
    "# Read the CSV file with evaluation scores\n",
    "eval_df = pd.read_csv(\"../main/gt_agents/collected_results.csv\")\n",
    "\n",
    "# Define a color scale for feedback types\n",
    "color_scale = OrderedDict([\n",
    "    ('evaluative', '#1f77b4'),  # blue\n",
    "    ('comparative', '#ff7f0e'),  # orange\n",
    "    ('demonstrative', '#2ca02c'),  # green\n",
    "    ('corrective', '#d62728'),  # red\n",
    "    ('description', '#9467bd'),  # purple\n",
    "])\n",
    "\n",
    "# Plotting function\n",
    "def plot_environment(env, feedback_runs):\n",
    "    plt.figure(figsize=(15, 10))\n",
    "    \n",
    "    # Increase font size for all text elements\n",
    "    plt.rcParams.update({'font.size': 18})  # Adjust this value as needed\n",
    "    \n",
    "    for feedback in color_scale.keys():\n",
    "        if feedback not in feedback_runs:\n",
    "            continue  # Skip if this feedback type is not present for this environment\n",
    "        \n",
    "        runs = feedback_runs[feedback]\n",
    "        color = color_scale.get(feedback, '#7f7f7f')  # Default to gray if feedback type not in scale\n",
    "        muted_colors = generate_muted_colors(color)\n",
    "        \n",
    "        # Find the maximum length of steps\n",
    "        max_steps = max(len(run['global_step_history']) for run in runs)\n",
    "        \n",
    "        # Initialize arrays for losses and steps\n",
    "        all_losses = np.full((len(runs), max_steps), np.nan)\n",
    "        all_steps = np.full((len(runs), max_steps), np.nan)\n",
    "        \n",
    "        # Fill the arrays with available data\n",
    "        for i, run in enumerate(runs):\n",
    "            print(run[\"run_name\"])\n",
    "            length = len(run['global_step_history'])\n",
    "            all_losses[i, :length] = run[score_field]\n",
    "            all_steps[i, :length] = run['global_step_history']\n",
    "\n",
    "            smoothed_scores =  gaussian_filter1d(run[score_field], sigma=2)\n",
    "            plt.plot(run['global_step_history'], smoothed_scores, color=muted_colors[i], linewidth=1.5) \n",
    "\n",
    "            \n",
    "        \n",
    "        # Calculate statistics\n",
    "        mean_loss = np.nanmean(all_losses, axis=0)\n",
    "        min_loss = mean_loss - np.nanstd(all_losses, axis=0)\n",
    "        max_loss = mean_loss + np.nanstd(all_losses, axis=0)\n",
    "        \n",
    "        # Use the mean of steps for x-axis (ignoring NaNs)\n",
    "        steps = np.nanmean(all_steps, axis=0)\n",
    "        \n",
    "        # Remove NaN entries\n",
    "        valid = ~np.isnan(mean_loss)\n",
    "        steps = steps[valid]\n",
    "        mean_loss = mean_loss[valid]\n",
    "        min_loss = min_loss[valid]\n",
    "        max_loss = max_loss[valid]\n",
    "\n",
    "        display_feedback = feedback\n",
    "        if display_feedback == \"preference\":\n",
    "            display_feedback = \"descriptive pref.\"\n",
    "        elif display_feedback == \"descriptive\":\n",
    "            display_feedback = \"attention\"\n",
    "        elif display_feedback == \"description\":\n",
    "            display_feedback = \"Descriptive\"\n",
    "        plt.plot(steps, mean_loss, label=f\"{display_feedback.capitalize()}\", color=color, linewidth=3.0)\n",
    "        #plt.fill_between(steps, min_loss, max_loss, alpha=0.2, color=color)\n",
    "\n",
    "    # Filter eval scores for the current environment\n",
    "    env_eval_scores = eval_df[eval_df['env'] == env]\n",
    "    \n",
    "    # Sort and select the best four scores\n",
    "    best_scores = env_eval_scores.nlargest(4, 'eval_score')\n",
    "    \n",
    "    # Calculate statistics for the best scores\n",
    "    mean_score = best_scores['eval_score'].mean()\n",
    "    min_score = best_scores['eval_score'].min()\n",
    "    max_score = best_scores['eval_score'].max()\n",
    "    \n",
    "    # Plot evaluation scores as horizontal lines with updated styles\n",
    "    plt.axhline(y=mean_score, color='grey', linewidth=2.5)\n",
    "    plt.axhline(y=min_score, color='grey', linestyle='--', linewidth=2.0)\n",
    "    plt.axhline(y=max_score, color='grey', linestyle='--', linewidth=2.0)\n",
    "    \n",
    "    #plt.title(f\"Episode Reward for {env.capitalize()} Environment\", fontsize=20)\n",
    "    plt.xlabel(\"Env. Steps\", fontsize=18)\n",
    "    plt.ylabel(\"Episode Rew.\", fontsize=18)\n",
    "    plt.legend()\n",
    "    #plt.grid(True)\n",
    "    # Use log scale for y-axis if the range of values is large\n",
    "    #if np.nanmax(mean_loss) / np.nanmin(mean_loss[np.isfinite(mean_loss)]) > 100:\n",
    "    #    plt.yscale('log')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"rl_reward_curves_{env}.png\")\n",
    "    plt.close()\n",
    "    print(f\"Reward plot for {env} environment has been saved to rl_reward_curves_{env}.png\")\n",
    "\n",
    "# Create plots for each environment\n",
    "for env, feedback_runs in grouped_runs.items():\n",
    "    plot_environment(env, feedback_runs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rlhfblender",
   "language": "python",
   "name": "rlhfblender"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
