{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, glob\n",
    "import pandas\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from evaluate import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mean_std(ax, x, y, label, color):\n",
    "    mean = smooth(np.mean(y, axis=0))\n",
    "    std = smooth(np.std(y, axis=0))\n",
    "    ci = 0.15*std\n",
    "    # sns.set_style(\"darkgrid\")\n",
    "    ax.plot(x, mean, label=label, color=color)\n",
    "    ax.fill_between(x, mean-ci, mean+ci, facecolor=color, alpha=0.3)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rpe2ut(rpe1, rpe2):\n",
    "    r1 = np.array([rpe1,rpe2])\n",
    "    r1 = np.transpose(r1, axes=[1,2,0])\n",
    "    ut = avg_utility(r1)\n",
    "    return ut\n",
    "\n",
    "def rpe2pd(rpe11, rpe12, rpe21, rpe22):\n",
    "    r1 = np.array([rpe11,rpe12]) \n",
    "    r2 = np.array([rpe21,rpe22])\n",
    "    r1 = np.transpose(r1, axes=[1,2,0])\n",
    "    r2 = np.transpose(r2, axes=[1,2,0])\n",
    "    pd = pareto_dominance(r1, r2)\n",
    "    return pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_ut(envs, dir_name, params, verbose=0):\n",
    "    UT = dict()\n",
    "    for env in envs:\n",
    "        Ut = dict()\n",
    "        for key in params:\n",
    "            nA = params[key][\"nA\"]\n",
    "            eps = params[key][\"eps\"]\n",
    "            ud = params[key][\"ud\"]\n",
    "            frames, rpe1, rpe2, epsilon, ploss = synthesize(env, dir_name, eps=eps, nA=nA, ud=ud)\n",
    "            ut = rpe2ut(rpe1, rpe2)\n",
    "            ut_mean = np.mean(ut)\n",
    "            ut_std = np.std(ut)\n",
    "            Ut[key] = ut\n",
    "            eps_name = \"\".join(str(eps))\n",
    "            if verbose:\n",
    "                out = f\"Env = {env}, eps = {eps_name}, adaptive = {not nA}, ud = {ud}, ut_mean = {ut_mean}, ut_std={ut_std}\"\n",
    "                print(out)\n",
    "        UT[env] = Ut\n",
    "    return UT, frames\n",
    "\n",
    "def calc_pd(envs, dir_name, params, verbose=0):\n",
    "    PD = dict()\n",
    "    for env in envs:\n",
    "        pd1 = dict()\n",
    "        for method1 in params.keys():\n",
    "            pd2 = []\n",
    "            for method2 in params.keys():\n",
    "                if method1 != method2:\n",
    "                    eps1, nA1, ud1 = params[method1][\"eps\"], params[method1][\"nA\"], params[method1][\"ud\"]\n",
    "                    eps2, nA2, ud2 = params[method2][\"eps\"], params[method2][\"nA\"], params[method2][\"ud\"]\n",
    "                    frames, rpe11, rpe12, epsilon, ploss = synthesize(env, dir_name, eps=eps1, nA=nA1, ud=ud1)\n",
    "                    frames, rpe21, rpe22, epsilon, ploss = synthesize(env, dir_name, eps=eps2, nA=nA2, ud=ud2)\n",
    "                    pd = rpe2pd(rpe11, rpe12, rpe21, rpe22)\n",
    "                    pd_mean = np.mean(pd)\n",
    "                    pd_std = np.std(pd)\n",
    "                    if verbose:\n",
    "                        out = f\"Env = {env}, {method1} -> {method2}, pd_mean = {pd_mean}, pd_std={pd_std}\"\n",
    "                        print(out)\n",
    "                    pd2.append(pd)\n",
    "            if pd2:\n",
    "                pd2 = np.concatenate(pd2)\n",
    "            pd1[method1] = pd2\n",
    "        PD[env] = pd1 \n",
    "    return PD, frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set globals\n",
    "dir_name = \"mogw_dst_final\"\n",
    "envs = [\"Rover\", \"DST\"]\n",
    "params = {\n",
    "    \"ppa\": {\"eps\": None, \"nA\": False, \"ud\": False, \"label\": \"PPA\"}, \n",
    "    # \"ppaud\": {\"eps\": None, \"nA\": False, \"ud\": True, \"label\": \"up-PPA\"},\n",
    "    \"nappa\": {\"eps\": None, \"nA\": True, \"ud\": False, \"label\": \"na-PPA\"},\n",
    "    \"fp1\": {\"eps\": [0.5,0.5], \"nA\": False, \"ud\": False, \"label\": \"$\\omega_1=0.5$\"},\n",
    "    \"fp2\": {\"eps\": [0.75,0.25], \"nA\": False, \"ud\": False, \"label\": \"$\\omega_1=0.75$\"},\n",
    "    \"fp3\": {\"eps\": [0.25,0.75], \"nA\": False, \"ud\": False, \"label\": \"$\\omega_1=0.25$\"},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calc metrics\n",
    "UT, _ = calc_ut(envs, dir_name, params)\n",
    "PD, frames = calc_pd(envs, dir_name, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot reward curves\n",
    "fig, ax = plt.subplots(2,len(envs), figsize=(10,8))\n",
    "clrs = sns.color_palette(\"husl\", len(params))\n",
    "for i, env in enumerate(envs):\n",
    "    for j, key in enumerate(params):\n",
    "        nA = params[key][\"nA\"]\n",
    "        eps = params[key][\"eps\"]\n",
    "        ud = params[key][\"ud\"]\n",
    "        label = params[key][\"label\"]\n",
    "        frames, rpe1, rpe2, epsilon, ploss = synthesize(env, dir_name, eps=eps, ud=ud, nA=nA)\n",
    "        plot_mean_std(ax=ax[i,0], x=frames[0], y=rpe1, label=label, color=clrs[j])\n",
    "        plot_mean_std(ax=ax[i,1], x=frames[0], y=rpe2, label=label, color=clrs[j])\n",
    "        ax[i,0].set_xlabel(\"frames\")\n",
    "        ax[i,0].set_ylabel(\"episode return\")\n",
    "        ax[i,1].set_xlabel(\"frames\")\n",
    "        ax[i,1].set_ylabel(\"time penalty\")\n",
    "\n",
    "ax[0,0].set_title(\"Multi-Objective GridWorld Episode Reward\")\n",
    "ax[0,1].set_title(\"Multi-Objective GridWorld Episode Time Penalty\")\n",
    "\n",
    "ax[1,0].set_title(\"Deep Sea Treasure Episode Reward\")\n",
    "ax[1,1].set_title(\"Deep Sea Treasure Episode Time Penalty\")\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.legend(bbox_to_anchor = (0.58,-0.156), ncol=5)\n",
    "plt.savefig(\"plots/mogw_dst.pdf\", bbox_inches = 'tight',\n",
    "    pad_inches = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot metrics \n",
    "fig, ax = plt.subplots(2,len(envs), figsize=(14,7))\n",
    "clrs = sns.color_palette(\"husl\", len(params))\n",
    "per = 0.5 \n",
    "for i,env in enumerate(envs): \n",
    "    for j,key in enumerate(params):\n",
    "        nA = params[key][\"nA\"]\n",
    "        eps = params[key][\"eps\"]\n",
    "        ud = params[key][\"ud\"]\n",
    "        label = params[key][\"label\"]\n",
    "        vals = PD[env][key]\n",
    "        end = int(0.5*vals.shape[1])\n",
    "        plot_mean_std(ax=ax[0,i], x=2*frames[0,:end], y=vals[:,:end], color=clrs[j], label=label)\n",
    "        vals = UT[env][key]\n",
    "        plot_mean_std(ax=ax[1,i], x=2*frames[0,:end], y=vals[:,:end], color=clrs[j], label=label)\n",
    "        ax[0,i].set_xlabel(\"frames\")\n",
    "        ax[0,i].set_ylabel(\"pareto dominance\")\n",
    "        ax[0,i].set_title(env)\n",
    "        ax[1,i].set_xlabel(\"frames\")\n",
    "        ax[1,i].set_ylabel(\"average utility\")\n",
    "        ax[1,i].set_title(env)\n",
    "fig.tight_layout()\n",
    "plt.legend(bbox_to_anchor = (-0.3,-0.2), ncol=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from env_rover import RoverEnv\n",
    "from dst import DeepSea\n",
    "rover = RoverEnv()\n",
    "rover.reset()\n",
    "rover_img = rover.render(highlight=False,mode='rgb')\n",
    "dst = DeepSea()\n",
    "dst.reset()\n",
    "dst_image = dst.render(highlight=False, mode='rgb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "from dst import DeepSea\n",
    "img = DeepSea().sea_map\n",
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "im=ax.imshow(img)\n",
    "fig.colorbar(im)\n",
    "ax.set_title(\"DST Enviroment Reward Map\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"dst.pdf\",bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "from env_rover import RoverEnv\n",
    "env = RoverEnv()\n",
    "env.reset()\n",
    "img = env.render(highlight=False, mode='rgb')\n",
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "im=ax.imshow(img)\n",
    "ax.set_title(\"MOGW Envriroment\")\n",
    "plt.savefig(\"mogw.pdf\",bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python392jvsc74a57bd0a61fc779d08deaee957108e0966e05422a283b1c762ee0e088310fcc993d6c21",
   "display_name": "Python 3.9.2 64-bit ('torch': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "metadata": {
   "interpreter": {
    "hash": "a61fc779d08deaee957108e0966e05422a283b1c762ee0e088310fcc993d6c21"
   }
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}