{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "altered-south",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os, sys\n",
    "sys.path.append('environments/')\n",
    "from cartpole_regulator import CartPoleRegulatorEnv\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import pickle, os, csv, math, time, joblib\n",
    "from joblib import Parallel, delayed\n",
    "import datetime as dt\n",
    "from datetime import date, datetime, timedelta\n",
    "from collections import Counter\n",
    "import copy as cp\n",
    "import tqdm\n",
    "from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier\n",
    "from lightgbm import LGBMRegressor, LGBMClassifier\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
    "from sklearn.metrics import log_loss, f1_score, precision_score, recall_score, accuracy_score\n",
    "#import matplotlib.pyplot as plt\n",
    "#import matplotlib.ticker as \"ticker\n",
    "import collections \n",
    "#import shap\n",
    "import seaborn as sns\n",
    "import random\n",
    "from sklearn.linear_model import LinearRegression\n",
    "np.seterr(all=\"ignore\")\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "import math\n",
    "import statsmodels.api as sm\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import numpy as np\n",
    "import json\n",
    "import util as util_fqi\n",
    "import sys\n",
    "sys.path.append('models/')\n",
    "from lmmfqi import LMMFQIagent\n",
    "from fqi import FQIagent\n",
    "from cfqi import CFQIagent\n",
    "import gym\n",
    "from gym import spaces\n",
    "from gym.utils import seeding\n",
    "import numpy as np\n",
    "from os import path\n",
    "from os.path import join as pjoin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "proved-notification",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/aishwaryamandyam/anaconda3/envs/research/lib/python3.7/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n"
     ]
    }
   ],
   "source": [
    "cartpole = CartPoleRegulatorEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "responsible-british",
   "metadata": {},
   "outputs": [],
   "source": [
    "fg_tuples = cartpole.generate_tuples(group='foreground', n_iter=10000)\n",
    "bg_tuples = cartpole.generate_tuples(group='background', n_iter=10000)\n",
    "all_tuples = bg_tuples + fg_tuples\n",
    "random.shuffle(all_tuples)\n",
    "split = 0.8\n",
    "train_tuples = all_tuples[:int(split*len(all_tuples))]\n",
    "test_tuples = all_tuples[int(split*len(all_tuples)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "brutal-acquisition",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_tuples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "elementary-algeria",
   "metadata": {},
   "source": [
    "# Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "differential-hardware",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N actions:  2\n",
      "Learning policy\n",
      "Run 0 :\n",
      "Initialize: get batch, set initial Q\n",
      "Learn policy\n",
      "Opta:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "This solver needs samples of at least 2 classes in the data, but the data contains only one class: 0",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-38d8b4d160b6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mfqi_agent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mFQIagent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_tuples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_tuples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_tuples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest_tuples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m64\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'gbm'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mQ_dist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfqi_agent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrunFQI\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrepeats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mQ_dist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0;34m\"FQI\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxlabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Iteration\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mylabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Q Estimate\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Research/BEE/contrastive-rl/contrastive-rl/simulated_fqi/models/fqi.py\u001b[0m in \u001b[0;36mrunFQI\u001b[0;34m(self, repeats)\u001b[0m\n\u001b[1;32m    138\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mQtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeanQtable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    139\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Learn policy'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetPi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeanQtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    141\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mQdist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Research/BEE/contrastive-rl/contrastive-rl/simulated_fqi/models/fqi.py\u001b[0m in \u001b[0;36mgetPi\u001b[0;34m(self, Qtable)\u001b[0m\n\u001b[1;32m    155\u001b[0m         \u001b[0;31m#print(\"Optimal actions: \", optA)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    156\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptA\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptA\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpiE\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_set\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m's'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptA\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    158\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Fit score: \"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpiE\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_set\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m's'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptA\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    159\u001b[0m         \u001b[0;31m#print(\"Done Fitting\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/research/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m   1374\u001b[0m             raise ValueError(\"This solver needs samples of at least 2 classes\"\n\u001b[1;32m   1375\u001b[0m                              \u001b[0;34m\" in the data, but the data contains only one\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1376\u001b[0;31m                              \" class: %r\" % classes_[0])\n\u001b[0m\u001b[1;32m   1377\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1378\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclasses_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: This solver needs samples of at least 2 classes in the data, but the data contains only one class: 0"
     ]
    }
   ],
   "source": [
    "fqi_agent = FQIagent(train_tuples=train_tuples, test_tuples=test_tuples, state_dim=4, gamma=0.5, batch_size=64, iters=1000, estimator='gbm')\n",
    "Q_dist = fqi_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"FQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "sacred-economics",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.0637384 , 0.03365227],\n",
       "       [0.09662369, 0.06749495],\n",
       "       [0.09105657, 0.06192782],\n",
       "       [0.09662369, 0.06749495],\n",
       "       [0.04782297, 0.01773684],\n",
       "       [0.05847077, 0.02838463],\n",
       "       [0.03132398, 0.00123785],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.04782297, 0.01773684],\n",
       "       [0.04222579, 0.01213966],\n",
       "       [0.05114537, 0.02105924],\n",
       "       [0.09105657, 0.06192782],\n",
       "       [0.0759316 , 0.04584547],\n",
       "       [0.05847077, 0.02838463],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.14906228, 0.11993354],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.14906228, 0.11993354],\n",
       "       [0.05394114, 0.023855  ],\n",
       "       [0.07905213, 0.04896599],\n",
       "       [0.05114537, 0.02105924],\n",
       "       [0.07905213, 0.04896599],\n",
       "       [0.04782297, 0.01773684],\n",
       "       [0.06140751, 0.03132137],\n",
       "       [0.05114537, 0.02105924],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.06140751, 0.03132137],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.05394114, 0.023855  ],\n",
       "       [0.0637384 , 0.03365227],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.03132398, 0.00123785],\n",
       "       [0.05847077, 0.02838463],\n",
       "       [0.09105657, 0.06192782],\n",
       "       [0.0759316 , 0.04584547],\n",
       "       [0.14906228, 0.11993354],\n",
       "       [0.14906228, 0.11993354],\n",
       "       [0.07905213, 0.04896599],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.07935299, 0.04926685],\n",
       "       [0.06845118, 0.03836504],\n",
       "       [0.04782297, 0.01773684],\n",
       "       [0.03132398, 0.00123785],\n",
       "       [0.0637384 , 0.03365227],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.0759316 , 0.04584547],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.04782297, 0.01773684],\n",
       "       [0.05114537, 0.02105924],\n",
       "       [0.05394114, 0.023855  ],\n",
       "       [0.05114537, 0.02105924],\n",
       "       [0.08461364, 0.0554849 ],\n",
       "       [0.08339798, 0.05331184],\n",
       "       [0.14906228, 0.11993354],\n",
       "       [0.0759316 , 0.04584547],\n",
       "       [0.        , 0.        ]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fqi_agent.Qtable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "unnecessary-change",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
