{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "deadly-broadcast",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os, sys\n",
    "sys.path.append('environments/')\n",
    "from generate_pendulum_tuples import tuples\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import pickle, os, csv, math, time, joblib\n",
    "from joblib import Parallel, delayed\n",
    "import datetime as dt\n",
    "from datetime import date, datetime, timedelta\n",
    "from collections import Counter\n",
    "import copy as cp\n",
    "import tqdm\n",
    "from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier\n",
    "from lightgbm import LGBMRegressor, LGBMClassifier\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
    "from sklearn.metrics import log_loss, f1_score, precision_score, recall_score, accuracy_score\n",
    "#import matplotlib.pyplot as plt\n",
    "#import matplotlib.ticker as ticker\n",
    "import collections \n",
    "#import shap\n",
    "import seaborn as sns\n",
    "import random\n",
    "from sklearn.linear_model import LinearRegression\n",
    "np.seterr(all=\"ignore\")\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "import math\n",
    "import statsmodels.api as sm\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import numpy as np\n",
    "import json\n",
    "import util as util_fqi\n",
    "import sys\n",
    "sys.path.append('models/')\n",
    "from lmmfqi import LMMFQIagent\n",
    "from fqi import FQIagent\n",
    "from cfqi import CFQIagent\n",
    "import gym\n",
    "from gym import spaces\n",
    "from gym.utils import seeding\n",
    "import numpy as np\n",
    "from os import path\n",
    "from os.path import join as pjoin\n",
    "from pendulum import PendulumEnv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "injured-election",
   "metadata": {},
   "source": [
    "# Generate pendulum data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "desirable-squad",
   "metadata": {},
   "outputs": [],
   "source": [
    "bg_tuples, fg_tuples = tuples(n_trajectories=10)\n",
    "all_tuples = bg_tuples + fg_tuples\n",
    "random.shuffle(all_tuples)\n",
    "split = 0.8\n",
    "train_tuples = all_tuples[:int(split*len(all_tuples))]\n",
    "test_tuples = all_tuples[int(split*len(all_tuples)):]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "continuous-firewall",
   "metadata": {},
   "source": [
    "# Train Agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ordinary-ocean",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N actions:  5\n",
      "Learning policy\n",
      "Run 0 :\n",
      "Initialize: get batch, set initial Q\n"
     ]
    }
   ],
   "source": [
    "lmm_agent = LMMFQIagent(train_tuples=train_tuples, test_tuples=test_tuples, estimator='gbm', gamma=0.99, state_dim=3, batch_size=1600, iters=1000)\n",
    "# The two policies are in lmm_agent.piE_foreground and lmm_agent.piE_background\n",
    "Q_dist = lmm_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"LMMFQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "innovative-bachelor",
   "metadata": {},
   "outputs": [],
   "source": [
    "fqi_agent = FQIagent(train_tuples=train_tuples, test_tuples=test_tuples, state_dim=3, gamma=0.5, batch_size=1600, iters=200, estimator='gbm')\n",
    "Q_dist = fqi_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"FQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "objective-slovakia",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.heatmap(lmm_agent.q_est.coefs_shared)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "flush-scotland",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.heatmap(lmm_agent.q_est.coefs_fg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "equal-abortion",
   "metadata": {},
   "source": [
    "# Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "emotional-deployment",
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "def validate_agent(pend, ds='foreground', plot='action'):\n",
    "        algos = ['fqi', 'lmmfqi', 'oracle', 'random']\n",
    "        val_rewards = {}\n",
    "        alg_actions = {}\n",
    "        for alg in algos:\n",
    "            val_rewards[alg] = []\n",
    "            alg_actions[alg] = []\n",
    "        \n",
    "        for alg in algos:\n",
    "            print(\"Alg: \", alg)\n",
    "            pend.reset()\n",
    "            \n",
    "            for i in range(100):\n",
    "                s = pend._get_obs().reshape((1, -1))\n",
    "#                 print(pend.state)\n",
    "                if alg == 'fqi':\n",
    "                    # FQI agent\\n\",\n",
    "                    fqi_action = fqi_agent.piE.predict(s)\n",
    "                    #fqi_action = np.rint(fqi_action)\n",
    "                    alg_actions['fqi'].append(fqi_action[0])\n",
    "                    ns, cost, _, _ = pend.step(fqi_action)\n",
    "                    val_rewards['fqi'].append(cost/10)\n",
    "                    s = ns\n",
    "                elif alg == 'lmmfqi':\n",
    "                    # LMMFQI agent\n",
    "                    if ds == 'foreground':\n",
    "                        group = [1]\n",
    "                        lmmfqi_action = lmm_agent.piE_foreground.predict(s, group)\n",
    "                    else:\n",
    "                        group = [0]\n",
    "                        lmmfqi_action = lmm_agent.piE_background.predict(s, group)\n",
    "                    alg_actions['lmmfqi'].append(lmmfqi_action[0])\n",
    "                    ns, cost, _, _ = pend.step(lmmfqi_action)\n",
    "                    val_rewards['lmmfqi'].append(cost/10)\n",
    "                elif alg == 'random':\n",
    "                   # Random action\n",
    "                    random_action = pend.action_space.sample()\n",
    "                    random_action = np.rint(random_action)\n",
    "                    alg_actions['random'].append(random_action[0])\n",
    "                    ns, cost, _, _ = pend.step(random_action)\n",
    "                    val_rewards['random'].append(cost)\n",
    "                elif alg == 'oracle':\n",
    "                    # Oracle\n",
    "                    best_reward = None\n",
    "                    best_ns = None\n",
    "                    best_action = None\n",
    "                    actions = [-2, -1, 0, 1, 2]\n",
    "                    for j, a in enumerate(actions):\n",
    "                        a = np.asarray([a])\n",
    "                        ns, cost, _, _ = pend.step(a)\n",
    "                        if best_reward is None or cost > best_reward:\n",
    "                            best_reward = cost\n",
    "                            best_ns = ns\n",
    "                            best_action = a\n",
    "                    alg_actions['oracle'].append(best_action[0])\n",
    "                    val_rewards['oracle'].append(best_reward)\n",
    "                    ns = best_ns\n",
    "                else:\n",
    "                    raise Exception(\"Invalid algorithm selected\")\n",
    "\n",
    "        #plt.title(\\\"Cumulative Reward for ds: \\\" + str(ds))\\n\",\n",
    "        plt.xlabel(\"Step\")\n",
    "        \n",
    "        x = [i for i in range(100)]\n",
    "        rewards_fqi = util_fqi.cumulative_reward(val_rewards['fqi'])\n",
    "        rewards_lmmfqi = util_fqi.cumulative_reward(val_rewards['lmmfqi'])\n",
    "        rewards_oracle = util_fqi.cumulative_reward(val_rewards['oracle'])\n",
    "        rewards_random = util_fqi.cumulative_reward(val_rewards['random'])\n",
    "        if plot == 'reward':\n",
    "            plt.plot(x, rewards_fqi, label=\"FQI\", alpha=0.7)\n",
    "            plt.plot(x, rewards_lmmfqi, label='LMMFQI', alpha=0.7)\n",
    "#             plt.plot(x, rewards_oracle, label='Oracle')\n",
    "#             plt.plot(x, rewards_random, label='Random', alpha=0.7)\n",
    "            plt.ylabel(\"Cumulative Reward\")\n",
    "        else:\n",
    "            plt.plot(x, alg_actions['fqi'], label=\"FQI\", alpha=0.7)\n",
    "            plt.plot(x, alg_actions['lmmfqi'], label='LMMFQI', alpha=0.7)\n",
    "#             plt.plot(x, alg_actions['oracle'], label='Oracle')\n",
    "#             plt.plot(x, alg_actions['random'], label='Random', alpha=0.7)\n",
    "            plt.ylabel(\"Action\")\n",
    "        plt.legend()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "present-letter",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Foreground Environment\n",
    "pend = PendulumEnv(m=5.0)\n",
    "validate_agent(pend, ds='foreground', plot='reward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "exact-revision",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Background Environment\n",
    "pend = PendulumEnv(m=1.0)\n",
    "validate_agent(pend, ds='background', plot='reward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "inner-rebate",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pressed-glass",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "numeric-geology",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research",
   "language": "python",
   "name": "research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
