{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Warning: Flow failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n",
      "No module named 'flow'\n",
      "/home/tarun/orl/lib/python3.10/site-packages/glfw/__init__.py:914: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'\n",
      "  warnings.warn(message, GLFWError)\n",
      "Warning: CARLA failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n",
      "No module named 'carla'\n",
      "pybullet build time: Nov 28 2023 23:45:17\n",
      "/home/tarun/orl/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "import d4rl \n",
    "import gym \n",
    "from pathlib import Path\n",
    "from tabulate import tabulate\n",
    "import train_behavior_model as train_behavior_model\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_with_shaded_error_simple(data_dict, title='Plot Title', xlabel='X-axis', ylabel='Y-axis'):\n",
    "    \"\"\"\n",
    "    Plots multiple lines with shaded areas, assuming that only y-values are provided.\n",
    "\n",
    "    Parameters:\n",
    "    - data_dict: dict\n",
    "        A dictionary where keys are labels (for the legend) and values are lists of y-values.\n",
    "    - title: str\n",
    "        The title of the plot.\n",
    "    - xlabel: str\n",
    "        The label for the x-axis.\n",
    "    - ylabel: str\n",
    "        The label for the y-axis.\n",
    "    \"\"\"\n",
    "    plt.figure(figsize=(8, 6))\n",
    "\n",
    "    for label, y_values in data_dict.items():\n",
    "        x_values = np.arange(len(y_values))  # Use indices as x-values\n",
    "        y_values = np.array(y_values)\n",
    "        y_mean = np.mean(y_values)\n",
    "        y_std = np.std(y_values)\n",
    "        \n",
    "        # Plot mean line\n",
    "        plt.plot(x_values, y_values, label=label)\n",
    "        \n",
    "        # Plot shaded area as ± standard deviation\n",
    "        plt.fill_between(x_values, y_values - y_std, y_values + y_std, alpha=0.2)\n",
    "\n",
    "    # Customize the plot\n",
    "    plt.title(title)\n",
    "    plt.xlabel(xlabel)\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.legend()\n",
    "\n",
    "    # Show the plot\n",
    "    \n",
    "    local_folder = Path('local_plots/')\n",
    "    local_folder.mkdir(parents = True, exist_ok= True)\n",
    "    plt.savefig(f'{local_folder}/{title}.png', bbox_inches='tight')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load(path):\n",
    "    with open(path, 'r') as f:\n",
    "        data = json.load(f)\n",
    "    return data['true_pfms'], data['sim_pfms']\n",
    "\n",
    "def behavior_clone(env, seed, data_num):\n",
    "    behavior_model = train_behavior_model.behavior_model(env_name=env, ac_kwargs=dict(hidden_sizes=[64] * 3))\n",
    "    behavior_path = '../behavior/%s_%d_%d' % (env, data_num, seed)\n",
    "    if not os.path.isfile(behavior_path):\n",
    "        raise Exception('behavior clone not done')\n",
    "    else:\n",
    "        behavior_model.action_nn.load_state_dict(torch.load(behavior_path))\n",
    "    test_result = behavior_model.test()\n",
    "    return test_result\n",
    "\n",
    "def visualize_results(env_name, data_num):\n",
    "    results = []\n",
    "    to_plot = {}\n",
    "    env = gym.make(env_name)\n",
    "    #clone', 'baseline', 'behavior', 'brac', 'orl', 'oracle', 'rm'\n",
    "    for learn in ['clone']:\n",
    "        if learn == 'clone':\n",
    "            res = [env.get_normalized_score(behavior_clone(env_name, seed, data_num)) for seed in range(3)]\n",
    "            results.append([learn, round(np.mean(res) * 100.0, 1), np.std(res) * 100])\n",
    "            \n",
    "        else:\n",
    "            res = []\n",
    "            res2 = []\n",
    "            for seed in range(3):\n",
    "                file = Path(f'/home/tarun/PBRL/rlhf/result/{learn}_{env_name}/outputs_{seed}.json')\n",
    "                \n",
    "                if not file.exists():\n",
    "                    continue\n",
    "                result_tup = load(file)\n",
    "                \n",
    "                result = [env.get_normalized_score(eval_score) for eval_score in result_tup[0]]\n",
    "                \n",
    "                res.append(np.mean(result[-10:]))\n",
    "            results.append([learn, round(np.mean(sorted(res, reverse = True)) * 100.0, 1), np.std(res) * 100])\n",
    "            print(np.std(res) * 100)\n",
    "            #to_plot[learn] = result\n",
    "            \n",
    "    \n",
    "    print(f'\\n Overall Results for {env_name}: \\n')\n",
    "    print(tabulate(results, headers=['Learn', 'Normalized Performance', 'Std'], tablefmt=\"pipe\"))\n",
    "    #plot_with_shaded_error_simple(to_plot, title= f'{env_name}', xlabel= 'Epochs', ylabel = 'D4RL Normalized Score')\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_results2(env_name):\n",
    "    results = []\n",
    "    to_plot = {}\n",
    "    env = gym.make(env_name)\n",
    "    #clone', 'baseline', 'behavior', 'brac', 'orl', 'oracle', 'rm'\n",
    "    for learn in ['naive-behavior-ppo', 'brac-ppo', 'behavior', 'rm', 'oracle']:\n",
    "        if learn == 'clone':\n",
    "            res = [env.get_normalized_score(behavior_clone(env_name, seed, data_num)) for seed in range(3)]\n",
    "            results.append([learn, round(np.mean(res) * 100.0, 1)])\n",
    "        else:\n",
    "            res = []\n",
    "            res2 = []\n",
    "            for seed in range(3):\n",
    "                file = Path(f'/home/tarun/PBRL/rlhf/result/{learn}_{env_name}/outputs_{seed}.json')\n",
    "                \n",
    "                if not file.exists():\n",
    "                    continue\n",
    "                result_tup = load(file)\n",
    "                \n",
    "                result = [env.get_normalized_score(eval_score) for eval_score in result_tup[0]]\n",
    "                \n",
    "                res.append(np.max(result))\n",
    "            results.append([learn, round(np.mean(sorted(res, reverse = True)) * 100.0, 1), np.std(res) * 100])\n",
    "            print(res)\n",
    "            #to_plot[learn] = result\n",
    "            \n",
    "    \n",
    "    print(f'\\n Overall Results for {env_name}: \\n')\n",
    "    print(tabulate(results, headers=['Learn', 'Normalized Performance', 'Margain of Error'], tablefmt=\"pipe\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.8875825736969482, 0.91301352332615, 1.0020923989755741]\n",
      "[0.9436912244399447, 0.8980798365671756]\n",
      "[1.1125094711737722, 1.1003188846137049, 1.0979854281504753]\n",
      "[0.9603249626814807, 1.1002420082155961, 1.1110339602092405]\n",
      "[1.109254079048405, 1.1185356208238026, 1.1184250393226272]\n",
      "\n",
      " Overall Results for walker2d-medium-expert-v2: \n",
      "\n",
      "| Learn              |   Normalized Performance |   Margain of Error |\n",
      "|:-------------------|-------------------------:|-------------------:|\n",
      "| naive-behavior-ppo |                     93.4 |           4.90966  |\n",
      "| brac-ppo           |                     92.1 |           2.28057  |\n",
      "| behavior           |                    110.4 |           0.636835 |\n",
      "| rm                 |                    105.7 |           6.86428  |\n",
      "| oracle             |                    111.5 |           0.434953 |\n"
     ]
    }
   ],
   "source": [
    "visualize_results2(f'walker2d-medium-expert-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tarun/orl/lib/python3.10/site-packages/gym/spaces/box.py:84: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5049.980164970453\n",
      "4957.315729038743\n",
      "4844.279422655241\n",
      "\n",
      " Overall Results for halfcheetah-medium-v2: \n",
      "\n",
      "| Learn   |   Normalized Performance |      Std |\n",
      "|:--------|-------------------------:|---------:|\n",
      "| clone   |                     42.1 | 0.677511 |\n"
     ]
    }
   ],
   "source": [
    "visualize_results(f'halfcheetah-medium-v2', 24975)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1650.536575403404\n",
      "1803.743429622819\n",
      "1495.280841818763\n",
      "\n",
      " Overall Results for hopper-medium-v2: \n",
      "\n",
      "| Learn   |   Normalized Performance |     Std |\n",
      "|:--------|-------------------------:|--------:|\n",
      "| clone   |                     51.3 | 3.86933 |\n"
     ]
    }
   ],
   "source": [
    "visualize_results(f'hopper-medium-v2', 23958)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1194.4640148386238\n",
      "1447.8712732305974\n",
      "2040.5825880092586\n",
      "\n",
      " Overall Results for hopper-medium-expert-v2: \n",
      "\n",
      "| Learn   |   Normalized Performance |     Std |\n",
      "|:--------|-------------------------:|--------:|\n",
      "| clone   |                     48.6 | 10.8943 |\n"
     ]
    }
   ],
   "source": [
    "visualize_results(f'hopper-medium-expert-v2', 48892)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "321.5357316757302\n",
      "343.447265118752\n",
      "386.90099612356977\n",
      "\n",
      " Overall Results for walker2d-medium-replay-v2: \n",
      "\n",
      "| Learn   |   Normalized Performance |      Std |\n",
      "|:--------|-------------------------:|---------:|\n",
      "| clone   |                      7.6 | 0.591722 |\n"
     ]
    }
   ],
   "source": [
    "visualize_results(f'walker2d-medium-replay-v2', 7172)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Overall Results for walker2d-expert-v2: \n",
      "\n",
      "| Learn              |   Normalized Performance |\n",
      "|:-------------------|-------------------------:|\n",
      "| behavior           |                    109.7 |\n",
      "| orl                |                    110.7 |\n",
      "| oracle             |                    111.2 |\n",
      "| rm                 |                    110   |\n",
      "| naive-behavior-ppo |                    108.7 |\n"
     ]
    }
   ],
   "source": [
    "visualize_results(f'walker2d-expert-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Overall Results for hopper-medium-v2: \n",
      "\n",
      "| Learn              |   Normalized Performance |\n",
      "|:-------------------|-------------------------:|\n",
      "| behavior           |                     67.2 |\n",
      "| orl                |                    nan   |\n",
      "| oracle             |                     76.3 |\n",
      "| rm                 |                     67.2 |\n",
      "| naive-behavior-ppo |                     68.9 |\n",
      "\n",
      " Overall Results for hopper-medium-expert-v2: \n",
      "\n",
      "| Learn              |   Normalized Performance |\n",
      "|:-------------------|-------------------------:|\n",
      "| behavior           |                     86.4 |\n",
      "| orl                |                    nan   |\n",
      "| oracle             |                    113.1 |\n",
      "| rm                 |                     97.1 |\n",
      "| naive-behavior-ppo |                     65.5 |\n",
      "\n",
      " Overall Results for hopper-expert-v2: \n",
      "\n",
      "| Learn              |   Normalized Performance |\n",
      "|:-------------------|-------------------------:|\n",
      "| behavior           |                    106.7 |\n",
      "| orl                |                    nan   |\n",
      "| oracle             |                    113.2 |\n",
      "| rm                 |                    113.2 |\n",
      "| naive-behavior-ppo |                     48.4 |\n"
     ]
    }
   ],
   "source": [
    "for traj_type in ['medium', 'medium-expert', 'expert']:\n",
    "    for env_name in ['hopper']:\n",
    "        visualize_results(f'{env_name}-{traj_type}-v2')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['true_pfms', 'sim_pfms'])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "import pandas as pd\n",
    "res = json.load(open('/home/tarun/PBRL/rlhf/result/brac_hopper-medium-v2/outputs_1.json', 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_csv(file_path, save_name, col):\n",
    "    res = json.load(open(file_path, 'r'))\n",
    "    res = {'step': list(range(len(res[col]))), save_name: res[col]}\n",
    "    df = pd.DataFrame.from_dict(res)\n",
    "    df.to_csv(f'{save_name}.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "convert_to_csv('/home/tarun/PBRL/rlhf/result/brac_hopper-medium-v2/outputs_1.json', 'initalsetup_learn=baseline_env=hopper-medium-v2_alg=sac_seed=1_', 'sim_pfms')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "orl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
