{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import numpy as np\n",
    "\n",
    "# custom libraries \n",
    "from envs import BasicGrid\n",
    "from functools import partial\n",
    "from utils import * \n",
    "from agents import Pi1Agent, LambdaR\n",
    "from runners import * \n",
    "from tqdm import tqdm\n",
    "\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "# Create a ListedColormap from the extracted colors\n",
    "# Define the color segments for the colormap\n",
    "segments = [(i/(len(colors)-1), colors[i]) for i in range(len(colors))]\n",
    "\n",
    "# Create a LinearSegmentedColormap from the color segments\n",
    "cmap = LinearSegmentedColormap.from_list(name='my_colormap', colors=segments)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dynamic programming\n",
    "def compute_LF(LR, Phi, P, pi, lambda_, gamma, max_iters, q_true=None, r=None, tol=1e-2):\n",
    "    S, A, _ = P.shape\n",
    "    _, D = Phi.shape\n",
    "    deltas, mses = [], []\n",
    "    for _ in tqdm(range(max_iters)):\n",
    "        delta = 0.0\n",
    "        for s in range(S):\n",
    "            feat = Phi[s]\n",
    "            for a in range(A):\n",
    "                lr = np.zeros(D)\n",
    "                for stp1 in range(S):\n",
    "                    for atp1 in range(A):\n",
    "                        lr += pi[stp1, atp1] * P[s, a, stp1] * feat * (\n",
    "                            1 + lambda_ * gamma * LR[stp1, atp1, :])\n",
    "                        lr += pi[stp1, atp1] * P[s, a, stp1] * gamma * (\n",
    "                            1 - feat) * LR[stp1, atp1, :]\n",
    "                delta = max(np.max(delta), np.max(np.abs(LR[s, a, :] - lr)))\n",
    "                LR[s, a, :] = lr\n",
    "        deltas.append(delta)\n",
    "        if q_true is not None and r is not None:\n",
    "            q_pred = LR @ r\n",
    "            mses.append(np.mean((q_true - q_pred.flatten())**2))\n",
    "        if delta < tol:\n",
    "            break\n",
    "    return LR, deltas, mses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = BasicGrid(lambda_=0.5)\n",
    "policy = Pi1Agent(env._number_of_states, env.get_obs(s=env._start_state)).q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get optimal Q\n",
    "LR = np.zeros((env._number_of_states, env._number_of_actions, env._number_of_states))\n",
    "LR, deltas, _ = compute_LF(LR, np.eye(env._number_of_states), env.P, policy, env.lambda_, env.discount, 100, tol=5e-2)\n",
    "q_opt = LR @ env.r\n",
    "v_opt = np.max(q_opt, axis=1)\n",
    "q_opt = q_opt.flatten()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lambdas = [0.0, 0.5, 1.0]\n",
    "dp_results = dict()\n",
    "for i, lambda_ in enumerate(lambdas):\n",
    "    dp_results[lambda_] = dict()\n",
    "    LR = np.zeros((env._number_of_states, env._number_of_actions, env._number_of_states))\n",
    "    LR, deltas, mses = compute_LF(LR, np.eye(env._number_of_states), env.P, policy, lambda_, env.discount, 10, q_true=q_opt, r=env.r, tol=-5)\n",
    "    dp_results[lambda_]['LR'] = LR\n",
    "    dp_results[lambda_]['deltas'] = deltas\n",
    "    dp_results[lambda_]['mses'] = mses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tabular TD learning\n",
    "env = BasicGrid(lambda_=0.5)\n",
    "policy = Pi1Agent(env._number_of_states, env.get_obs(s=env._start_state))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pi = [policy]\n",
    "td_LRs_all = {}\n",
    "lambdas = [0.5, 1.0]\n",
    "# multiple start states to get DR\n",
    "start_states = [7, 8, 9, 10, 13, 14, 16, 19, 20, 21, 22, 25, 26, 27, 28]\n",
    "n_repeats = 3\n",
    "for lambda_ in lambdas:\n",
    "  td_LRs_all[lambda_] = []\n",
    "  for i in range(n_repeats):\n",
    "    print(f\"\\nlambda {lambda_}, expt \", i)\n",
    "    np.random.seed(i)\n",
    "    LRs = []\n",
    "    for i, pi in enumerate(Pi):\n",
    "        lr_agent = LambdaR(env._layout.size, 4, env.get_obs(),\n",
    "                  policy=partial(epsilon_greedy, epsilon=0.2), q=pi.q, \n",
    "                  step_size=0.1, sa=True, lambda_=lambda_)\n",
    "        for _ in range(100):\n",
    "          for start in start_states:\n",
    "              lr_agent._state = start\n",
    "              grid = BasicGrid(start_state=start, discount=env.discount, lambda_=0.5)\n",
    "              results = run_experiment_episodic(grid, lr_agent, 1, display_eps=1, respect_done=False)\n",
    "          \n",
    "              LRs += results['lambdaR_hist']\n",
    "\n",
    "    td_LRs_all[lambda_].append(LRs)\n",
    "\n",
    "  td_LRs_all[lambda_] = np.stack(td_LRs_all[lambda_])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
