{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimal on LQR\n",
    "\n",
    "## Define paramters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T14:07:49.474494Z",
     "iopub.status.busy": "2022-09-19T14:07:49.474265Z",
     "iopub.status.idle": "2022-09-19T14:07:50.101602Z",
     "shell.execute_reply": "2022-09-19T14:07:50.101089Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "import jax\n",
    "import os\n",
    "import json\n",
    "\n",
    "parameters = json.load(open(\"parameters.json\"))\n",
    "env_seed = parameters[\"env_seed\"]\n",
    "max_discrete_state = parameters[\"max_discrete_state\"]\n",
    "\n",
    "# keys\n",
    "env_key = jax.random.PRNGKey(env_seed)\n",
    "dummy_q_network_key = env_key.copy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T14:07:50.105869Z",
     "iopub.status.busy": "2022-09-19T14:07:50.105607Z",
     "iopub.status.idle": "2022-09-19T14:07:50.373661Z",
     "shell.execute_reply": "2022-09-19T14:07:50.373125Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transition: s' = As + Ba\n",
      "Transition: s' = 0.7204725742340088s + -0.5264108180999756a\n",
      "Reward: Qs² + Ra² + 2 Ssa\n",
      "Reward: -0.13297832012176514s² + -0.8039400577545166a² + 0.2581009864807129sa\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "from pbo.environments.linear_quadratic import LinearQuadraticEnv\n",
    "\n",
    "\n",
    "env = LinearQuadraticEnv(env_key, max_init_state=max_discrete_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimal weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T14:07:50.376088Z",
     "iopub.status.busy": "2022-09-19T14:07:50.375926Z",
     "iopub.status.idle": "2022-09-19T14:07:50.390328Z",
     "shell.execute_reply": "2022-09-19T14:07:50.389799Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([-0.22707027,  0.19779846, -0.8541705 ], dtype=float32)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.optimal_weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T14:07:50.393304Z",
     "iopub.status.busy": "2022-09-19T14:07:50.392799Z",
     "iopub.status.idle": "2022-09-19T14:07:50.475005Z",
     "shell.execute_reply": "2022-09-19T14:07:50.474490Z"
    }
   },
   "outputs": [],
   "source": [
    "if not os.path.exists(\"figures/data/optimal/\"):\n",
    "    os.makedirs(\"figures/data/optimal/\")\n",
    "np.save(f\"figures/data/optimal/W.npy\", env.optimal_weights)\n",
    "np.save(f\"figures/data/optimal/Pi.npy\", env.greedy_V(env.optimal_weights))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 ('env_cpu': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "af5525a3273d35d601ae265c5d3634806dd61a1c4d085ae098611a6832982bdb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
