{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "from train_cnfqi import run\n",
    "import seaborn as sns\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt \n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sanity Check: When Force_left = 0, contrastive FQI and FQI should perform the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg_successes = []\n",
    "fg_successes = []\n",
    "for i in range(2):\n",
    "    print(str(i))\n",
    "    printed_bg, printed_fg, performance, nfq_agent = run(verbose=True, is_contrastive=True, evaluations=1, force_left=0)\n",
    "    bg_successes.append(success_bg)\n",
    "    fg_successes.append(success_fg)\n",
    "    print(\"BG Succeeded: \" + str(printed_bg))\n",
    "    print(\"FG Succeeded: \" + str(printed_fg))\n",
    "    \n",
    "plt.title(\"Comparing BG and FG success for CFQI when force_left=0\")\n",
    "sns.stripplot(x=bg_successes, label='Background', color='blue')\n",
    "sns.stripplot(x=fg_successes, label='Foreground', color='red')\n",
    "plt.xlabel(\"# of successful runs (out of 10)\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nfq_agent._nfq_net.layers_fg[2].weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(\"Comparing BG and FG success for CFQI when force_left=0\")\n",
    "sns.stripplot(x=bg_successes, label='Background', color='blue')\n",
    "sns.stripplot(x=fg_successes, label='Foreground', color='red')\n",
    "plt.xlabel(\"# of successful runs (out of 10)\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/amandyam/.conda/envs/research/lib/python3.6/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fg trained after 129 epochs\n",
      "BG stayed up for steps:  [1000, 1000, 181, 1000, 166, 1000, 1000, 1000, 120, 1000]\n",
      "FG stayed up for steps:  [307, 1000, 1000, 1000, 1000, 246, 1000, 1000, 1000, 136]\n",
      "Fg trained after 351 epochs\n",
      "BG stayed up for steps:  [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]\n",
      "FG stayed up for steps:  [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "cfqi_success = []\n",
    "fqi_success = []\n",
    "for i in range(10):\n",
    "    print(str(i))\n",
    "    printed_bg, printed_fg, performance, nfq_agent = run(verbose=False, is_contrastive=True, evaluations=10, force_left=0)\n",
    "    cfqi_success.extend(performance)\n",
    "    printed_bg, printed_fg, performance, nfq_agent = run(verbose=False, is_contrastive=False, evaluations=10, force_left=0)\n",
    "    fqi_success.extend(performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(x=cfqi_success, label='CFQI', color='blue')\n",
    "sns.stripplot(x=fqi_success, label='FQI', color='red')\n",
    "plt.title(\"Force left = 0, CFQI and FQI performance\")\n",
    "plt.xlabel(\"Number of steps eval survived (out of 1000)\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Force_left vs success"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy.stats\n",
    "\n",
    "\n",
    "def mean_confidence_interval(data, confidence=0.95):\n",
    "    a = 1.0 * np.array(data)\n",
    "    n = len(a)\n",
    "    m, se = np.mean(a), scipy.stats.sem(a)\n",
    "    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return m, h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mean of success and confidence intervals\n",
    "c_success = []\n",
    "f_success = []\n",
    "c_errs = []\n",
    "f_errs = []\n",
    "# Looping over force left\n",
    "for i in range(10):\n",
    "    cfqi_success = []\n",
    "    fqi_success = []\n",
    "    # Iterations at this force left\n",
    "    for j in range(10):\n",
    "        print(str(i) + \" \" + str(j))\n",
    "        printed_bg, printed_fg, performance, nfq_agent = run(verbose=False, is_contrastive=True, evaluations=2, force_left=i)\n",
    "        cfqi_success.extend(performance)\n",
    "        printed_bg, printed_fg, performance, nfq_agent = run(verbose=False, is_contrastive=False, evaluations=2, force_left=i)\n",
    "        fqi_success.extend(performance)\n",
    "\n",
    "    c_success.append(np.mean(cfqi_success))\n",
    "    f_success.append(np.mean(fqi_success))\n",
    "    m, h = mean_confidence_interval(cfqi_success)\n",
    "    c_errs.append(h)\n",
    "    m, h = mean_confidence_interval(fqi_success)\n",
    "    f_errs.append(h)\n",
    "    \n",
    "    x = [k for k in range(i+1)]\n",
    "    sns.scatterplot(x, c_success, label='CFQI')\n",
    "    plt.errorbar(x, c_success ,yerr=c_errs, linestyle=\"None\")\n",
    "    sns.scatterplot(x, f_success, label='FQI')\n",
    "    plt.errorbar(x, f_success ,yerr=f_errs, linestyle=\"None\")\n",
    "    plt.title(\"Performance of CFQI and FQI when force on cart is modified\")\n",
    "    plt.xlabel(\"Force Left\")\n",
    "    plt.ylabel(\"Average steps the cartpole environment runs for\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interpretability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Foreground group\n",
    "heatmap = []\n",
    "for pos in range(-10, 11, 1):\n",
    "    pos /= 10\n",
    "    pos_values = []\n",
    "    vel = 0\n",
    "    ang = -4\n",
    "    p_vel = 0\n",
    "    state = np.asarray([pos, vel, ang, p_vel])\n",
    "    best_action = nfq_agent.get_best_action(state, np.array([0, 1]), 1)\n",
    "    pos_values.append(best_action)\n",
    "    \n",
    "    heatmap.append(pos_values)\n",
    "    \n",
    "heatmap = np.asarray(heatmap)\n",
    "xticklabels = [x/10 for x in range(-10, 10, 1)]\n",
    "yticklabels = [-4]\n",
    "ax = plt.figure(figsize=(10, 2))\n",
    "sns.heatmap(heatmap.T, xticklabels=xticklabels, yticklabels=yticklabels)\n",
    "plt.title(\"Foreground\")\n",
    "plt.xlabel(\"Cart Position\")\n",
    "plt.ylabel(\"Pole Angle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Background group\n",
    "heatmap = []\n",
    "for pos in range(-10, 11, 1):\n",
    "    pos /= 10\n",
    "    pos_values = []\n",
    "    vel = 0\n",
    "    ang = -4\n",
    "    p_vel = 0\n",
    "    state = np.asarray([pos, vel, ang, p_vel])\n",
    "    best_action = nfq_agent.get_best_action(state, np.array([0, 1]), 0)\n",
    "    pos_values.append(best_action)\n",
    "    \n",
    "    heatmap.append(pos_values)\n",
    "heatmap = np.asarray(heatmap)\n",
    "xticklabels = [x/10 for x in range(-10, 10, 1)]\n",
    "yticklabels = [-4]\n",
    "ax = plt.figure(figsize=(10, 2))\n",
    "sns.heatmap(heatmap.T, xticklabels=xticklabels, yticklabels=yticklabels)\n",
    "plt.title(\"Background\")\n",
    "plt.xlabel(\"Cart Position\")\n",
    "plt.ylabel(\"Pole Angle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nfq_agent._nfq_net.layers_fg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nc_success = []\n",
    "c_success = []\n",
    "for i in range(10):\n",
    "    print(str(i))\n",
    "    success, total = run(verbose=False, is_contrastive=False, evaluations=5)\n",
    "    nc_success.append(success/total)\n",
    "    \n",
    "    success, total = run(verbose=False, is_contrastive=True, evaluations=5)\n",
    "    c_success.append(success/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.boxplot(x=nc_success)\n",
    "sns.swarmplot(nc_success, label='FQI', color='.25')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.boxplot(x=c_success)\n",
    "sns.swarmplot(c_success, label='CFQI', color='.25')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.ylabel(\"Percentage of evaluations that were successful\")\n",
    "plt.boxplot(x=[nc_success, c_success], labels=['FQI', 'CFQI'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research [~/.conda/envs/research/]",
   "language": "python",
   "name": "conda_research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
