{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105ee81b-a62d-4021-80dd-ee6d534b9d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import pandas as pd\n",
    "\n",
    "# Initialize wandb\n",
    "wandb.init(project=\"analysis_runs\")\n",
    "\n",
    "# Fetch runs from your project\n",
    "api = wandb.Api()\n",
    "runs = api.runs(\"multi_reward_feedback\")\n",
    "\n",
    "# Create a list to store data from filtered runs\n",
    "filtered_run_data = []\n",
    "\n",
    "# Iterate through the runs\n",
    "for run in runs:\n",
    "    # Check if the run name starts with \"ppo_\"\n",
    "    if run.name.startswith(\"ppo_\") and \"noise\" not in run.name:\n",
    "        # Get the summary statistics (includes final values of metrics)\n",
    "        summary = run.summary._json_dict\n",
    "\n",
    "        # Get the history (includes all logged metrics)\n",
    "        history = run.history()\n",
    "\n",
    "        # Combine summary and history data\n",
    "        run_data = {\n",
    "            \"run_id\": run.id,\n",
    "            \"run_name\": run.name,\n",
    "            **summary,\n",
    "            **{f\"{k}_history\": v.tolist() for k, v in history.items()}\n",
    "        }\n",
    "\n",
    "        filtered_run_data.append(run_data)\n",
    "\n",
    "# Create a DataFrame from filtered run data\n",
    "df = pd.DataFrame(filtered_run_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a602d039-ad57-4472-a062-59bda636ce3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "\n",
    "# Function to extract environment and feedback type from run name\n",
    "def extract_info(run_name):\n",
    "    parts = run_name.split('_')\n",
    "    return parts[1], parts[-2]\n",
    "\n",
    "# Function to interpolate NaN values in a series\n",
    "def interpolate_nans(series):\n",
    "    return pd.Series(series).interpolate().ffill().bfill().values\n",
    "\n",
    "# Group runs by feedback type and environment\n",
    "grouped_runs = defaultdict(lambda: defaultdict(list))\n",
    "for _, row in df.iterrows():\n",
    "    env, feedback = extract_info(row['run_name'])\n",
    "    # Convert string \"nan\" to np.nan and other values to float\n",
    "    if isinstance(row['val_loss_history'], float):\n",
    "        continue\n",
    "    row['val_loss_history'] = [np.nan if x == \"nan\" else x for x in row['val_loss_history']]\n",
    "    # Interpolate NaN values in val_loss_history\n",
    "    row['val_loss_history'] = interpolate_nans(row['val_loss_history'])\n",
    "    grouped_runs[feedback][env].append(row)\n",
    "\n",
    "# Plotting function\n",
    "def plot_feedback_type(feedback, env_runs):\n",
    "    plt.figure(figsize=(15, 10))\n",
    "    \n",
    "    # Increase font size for all text elements\n",
    "    plt.rcParams.update({'font.size': 16})  # Adjust this value as needed\n",
    "    \n",
    "    for env, runs in env_runs.items():\n",
    "        # Find the maximum length of steps\n",
    "        max_steps = max(len(run['_step_history']) for run in runs)\n",
    "        \n",
    "        # Initialize arrays for losses and steps\n",
    "        all_losses = np.full((len(runs), max_steps), np.nan)\n",
    "        all_steps = np.full((len(runs), max_steps), np.nan)\n",
    "        \n",
    "        # Fill the arrays with available data\n",
    "        for i, run in enumerate(runs):\n",
    "            length = len(run['_step_history'])\n",
    "            all_losses[i, :length] = run['val_loss_history']\n",
    "            all_steps[i, :length] = run['_step_history']\n",
    "        \n",
    "        # Calculate statistics\n",
    "        mean_loss = np.nanmean(all_losses, axis=0)\n",
    "        min_loss = np.nanmin(all_losses, axis=0)\n",
    "        max_loss = np.nanmax(all_losses, axis=0)\n",
    "        \n",
    "        # Use the mean of steps for x-axis (ignoring NaNs)\n",
    "        steps = np.nanmean(all_steps, axis=0)\n",
    "        \n",
    "        # Remove NaN entries\n",
    "        valid = ~np.isnan(mean_loss)\n",
    "        steps = steps[valid]\n",
    "        mean_loss = mean_loss[valid]\n",
    "        min_loss = min_loss[valid]\n",
    "        max_loss = max_loss[valid]\n",
    "        \n",
    "        plt.plot(steps, mean_loss, label=f\"{env}\")\n",
    "        plt.fill_between(steps, min_loss, max_loss, alpha=0.2)\n",
    "\n",
    "    display_feedback = feedback\n",
    "    if display_feedback == \"preference\":\n",
    "        display_feedback = \"descriptive pref.\"\n",
    "    elif display_feedback == \"descriptive\":\n",
    "        display_feedback = \"attention\"\n",
    "    elif display_feedback == \"description\":\n",
    "        display_feedback = \"cluster description\"\n",
    "    plt.title(f\"Validation Loss Curves for {display_feedback.capitalize()} Feedback\", fontsize=20)\n",
    "    plt.xlabel(\"Steps\", fontsize=18)\n",
    "    plt.ylabel(\"Validation Loss\", fontsize=18)\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "\n",
    "    # Use log scale for y-axis if the range of values is large\n",
    "    if np.nanmax(mean_loss) / np.nanmin(mean_loss[np.isfinite(mean_loss)]) > 100:\n",
    "        plt.yscale('log')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"loss_curves_{feedback}.png\")\n",
    "    plt.close()\n",
    "\n",
    "    print(f\"Loss curves for {feedback} feedback have been saved to loss_curves_{feedback}.png\")\n",
    "\n",
    "# Create plots for each feedback type\n",
    "for feedback, env_runs in grouped_runs.items():\n",
    "    plot_feedback_type(feedback, env_runs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rlhfblender",
   "language": "python",
   "name": "rlhfblender"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
