{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mountaincar Environment\n",
    "* Eval can start anywhere from left to goal state, vel 0 (also training). They need 71 episodes\n",
    "* Modify cartpole to only have two actions-> left and right. The magnitude of the actions are much larger in nfq paper\n",
    "* Hint to goal, which sometimes makes the agent perform worse\n",
    "* Group: the magnitude of the action\n",
    "* Made the forces symmetric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import configargparse\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from environments import MountainCarEnv, Continuous_MountainCarEnv\n",
    "from models.agents import NFQAgent\n",
    "from models.networks import NFQNetwork, ContrastiveNFQNetwork\n",
    "from util import get_logger, close_logger, load_models, make_reproducible, save_models\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import itertools\n",
    "import seaborn as sns\n",
    "import tqdm\n",
    "from train_mountaincar import generate_data\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Running experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## \"Structureless Test\"\n",
    "* The dynamics of the systems are actually the same. Do any of the algorithms learn a difference?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from train_mountaincar import fqi, warm_start, transfer_learning\n",
    "num_iter=15\n",
    "perf_foreground = []\n",
    "perf_background = []\n",
    "for i in range(num_iter):\n",
    "    print(str(i))\n",
    "    perf_bg, perf_fg = fqi(epoch=1500, gravity=0.0025, verbose=True, is_contrastive=True, structureless=True, hint_to_goal=False)\n",
    "    perf_foreground.append(perf_fg)\n",
    "    perf_background.append(perf_bg)\n",
    "sns.distplot(perf_foreground, label='Foreground Performance')\n",
    "sns.distplot(perf_background, label='Background Performance')\n",
    "plt.legend()\n",
    "plt.xlabel(\"Average Reward Earned\")\n",
    "plt.title(\"Dynamics are the same in fg and bg environments\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## \"Performance when force left is different\"\n",
    "* We change the gravity on the foreground environments. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Group imbalance test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from train_mountaincar import fqi, warm_start, transfer_learning\n",
    "num_iter = 2\n",
    "results = {}\n",
    "\n",
    "GRAVITY = 0.004\n",
    "\n",
    "total_samples = 400\n",
    "fg_sample_fractions = [0.1 * x for x in np.arange(1, 6)]\n",
    "\n",
    "for i in fg_sample_fractions:\n",
    "    results[i] = {}\n",
    "    results[i][\"fg_only\"] = {}\n",
    "    results[i][\"cfqi\"] = {}\n",
    "    results[i][\"fqi_joint\"] = {}\n",
    "    \n",
    "for i in range(num_iter):\n",
    "\n",
    "    for fg_sample_fraction in fg_sample_fractions:\n",
    "\n",
    "        n_fg = int(total_samples * fg_sample_fraction)\n",
    "        n_bg = int(total_samples - n_fg)\n",
    "        \n",
    "        # Only train/test on small set of foreground samples\n",
    "        perf_bg, perf_fg = fqi(epoch=1500, verbose=False, is_contrastive=True, structureless=False, gravity=GRAVITY, fg_only=True, init_experience_bg=n_fg // 2,\n",
    "            init_experience_fg=n_fg // 2)\n",
    "        results[fg_sample_fraction][\"fg_only\"][i] = (perf_bg, perf_fg)\n",
    "\n",
    "        # Use contrastive model with larger pool of background samples\n",
    "        perf_bg, perf_fg = fqi(epoch=1500, is_contrastive=True,init_experience_bg=n_bg,init_experience_fg=n_fg,fg_only=False,verbose=False,gravity=GRAVITY)\n",
    "        results[fg_sample_fraction][\"cfqi\"][i] = (perf_bg, perf_fg)\n",
    "\n",
    "        # Use non-contrastive model with larger pool of background samples\n",
    "        perf_bg, perf_fg = fqi(is_contrastive=False,init_experience_bg=n_bg,init_experience_fg=n_fg,fg_only=False,gravity=GRAVITY,epoch=1500,verbose=False,)\n",
    "        results[fg_sample_fraction][\"fqi_joint\"][i] = (perf_bg, perf_fg)\n",
    "\n",
    "        with open(\"class_imbalance_cfqi.json\", \"w\") as f:\n",
    "            json.dump(results, f)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing out some of the other methods\n",
    "* Allowing it to succeed, maybe after some training\n",
    "* Adding the successful episodes to the train_rollouts\n",
    "* TODO: modifying the environment to include stronger actions\n",
    "* TODO: modifying training regime\n",
    "* TODO: evaluating appropriate times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_car(agent, env, eps=0.9):\n",
    "    episode_length = 0\n",
    "    obs = env.reset()\n",
    "    done = False\n",
    "    render = False\n",
    "    info = {\"time_limit\": False}\n",
    "    episode_cost = 0\n",
    "    rollouts = []\n",
    "    while not done:\n",
    "        if random.random() < eps:\n",
    "            action = np.random.choice(env.unique_actions)\n",
    "            action = env.a_to_oh(action)\n",
    "        else:\n",
    "            action = agent.get_best_action(obs, env.unique_oh_actions, env.group)\n",
    "\n",
    "        next_obs, cost, done, info = env.step(action)\n",
    "        rollouts.append((obs.squeeze(), action, cost, next_obs.squeeze(), done, env.group))\n",
    "        episode_cost += cost\n",
    "        obs = next_obs\n",
    "        episode_length += 1\n",
    "\n",
    "    success = (done)\n",
    "    return episode_length, success, episode_cost, rollouts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 1206/10000 [00:15<03:15, 44.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.95 Episode Length:  1735\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 1505/10000 [00:24<05:02, 28.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.9025 Episode Length:  447\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 1805/10000 [00:35<10:39, 12.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.8573749999999999 Episode Length:  2112\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 21%|██        | 2104/10000 [00:54<27:13,  4.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.8145062499999999 Episode Length:  4590\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 2404/10000 [01:12<09:45, 12.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.7737809374999999 Episode Length:  372\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 27%|██▋       | 2704/10000 [01:30<07:01, 17.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.7350918906249998 Episode Length:  19\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3004/10000 [01:48<09:19, 12.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.6983372960937497 Episode Length:  593\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 3304/10000 [02:07<13:55,  8.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.6634204312890623 Episode Length:  1182\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 3603/10000 [05:56<79:05:15, 44.51s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.6302494097246091 Episode Length:  404757\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 39%|███▉      | 3903/10000 [06:14<10:48,  9.40it/s]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.5987369392383786 Episode Length:  477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 4202/10000 [08:00<15:00:28,  9.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.5688000922764596 Episode Length:  147409\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 4503/10000 [12:26<79:16:14, 51.91s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.5403600876626365 Episode Length:  390441\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 4803/10000 [22:54<92:26:57, 64.04s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 0.5133420832795047 Episode Length:  919353\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 51%|█████     | 5099/10000 [23:33<04:38, 17.62it/s]   "
     ]
    }
   ],
   "source": [
    "train_rollouts, train_env_bg, train_env_fg = generate_data(init_experience_fg=10, init_experience_bg=10, bg_only=False, structureless=True,initialize_model=False)\n",
    "nfq_net = ContrastiveNFQNetwork(state_dim=train_env_bg.state_dim, is_contrastive=False, deep=False)\n",
    "optimizer = optim.Adam(nfq_net.parameters(), lr=1e-2)\n",
    "nfq_agent = NFQAgent(nfq_net, optimizer)\n",
    "episodes = 10000\n",
    "losses = []\n",
    "eps = 1\n",
    "for _, ep in enumerate(tqdm.tqdm(range(episodes))):\n",
    "    state_action_b, target_q_values, groups = nfq_agent.generate_pattern_set(train_rollouts)\n",
    "    loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "    losses.append(loss)\n",
    "    if ep > 1000 and ep % 300 == 0:\n",
    "        eps *= 0.95\n",
    "        episode_length, success, episode_cost, rollouts = evaluate_car(nfq_agent, train_env_bg, eps=eps)\n",
    "        print(\"Eps: \" + str(eps) + \" Episode Length: \", episode_length)\n",
    "        train_rollouts.extend(rollouts)\n",
    "        if len(train_rollouts) > 5000:\n",
    "            train_rollouts = train_rollouts[-5000:]\n",
    "sns.displot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "    if hint_to_goal:\n",
    "        (\n",
    "            goal_state_action_b_bg,\n",
    "            goal_target_q_values_bg,\n",
    "            group_bg,\n",
    "        ) = train_env_bg.get_goal_pattern_set(group=0)\n",
    "        (\n",
    "            goal_state_action_b_fg,\n",
    "            goal_target_q_values_fg,\n",
    "            group_fg,\n",
    "        ) = train_env_fg.get_goal_pattern_set(group=1)\n",
    "\n",
    "        goal_state_action_b_bg = torch.FloatTensor(goal_state_action_b_bg)\n",
    "        goal_target_q_values_bg = torch.FloatTensor(goal_target_q_values_bg)\n",
    "        goal_state_action_b_fg = torch.FloatTensor(goal_state_action_b_fg)\n",
    "        goal_target_q_values_fg = torch.FloatTensor(goal_target_q_values_fg)\n",
    "\n",
    "    \n",
    "\n",
    "    bg_success_queue = [0] * 3\n",
    "    fg_success_queue = [0] * 3\n",
    "    evaluations = 5\n",
    "    losses = []\n",
    "    for k, ep in enumerate(tqdm.tqdm(range(epoch + 1))):\n",
    "        \n",
    "        if hint_to_goal:\n",
    "            goal_state_action_b = torch.cat(\n",
    "                [goal_state_action_b_bg, goal_state_action_b_fg], dim=0\n",
    "            )\n",
    "            goal_target_q_values = torch.cat(\n",
    "                [goal_target_q_values_bg, goal_target_q_values_fg], dim=0\n",
    "            )\n",
    "            state_action_b = torch.cat([state_action_b, goal_state_action_b], dim=0)\n",
    "            target_q_values = torch.cat([target_q_values, goal_target_q_values], dim=0)\n",
    "            goal_groups = torch.cat([group_bg, group_fg], dim=0)\n",
    "            groups = torch.cat([groups, goal_groups], dim=0)\n",
    "\n",
    "        loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "        losses.append(loss)\n",
    "\n",
    "\n",
    "        (\n",
    "            eval_episode_length_bg,\n",
    "            eval_success_bg,\n",
    "            eval_episode_cost_bg,\n",
    "        ) = nfq_agent.evaluate_car(eval_env_bg, render=render)\n",
    "        bg_success_queue = bg_success_queue[1:]\n",
    "        bg_success_queue.append(1 if eval_success_bg else 0)\n",
    "\n",
    "        (\n",
    "            eval_episode_length_fg,\n",
    "            eval_success_fg,\n",
    "            eval_episode_cost_fg,\n",
    "        ) = nfq_agent.evaluate_car(eval_env_fg, render=render)\n",
    "        fg_success_queue = fg_success_queue[1:]\n",
    "        fg_success_queue.append(1 if eval_success_fg else 0)\n",
    "\n",
    "        if (sum(bg_success_queue) == 3 and not nfq_net.freeze_shared == True) or ep == int(epoch*0.75):\n",
    "            nfq_net.freeze_shared = True\n",
    "            if verbose:\n",
    "                print(\"FREEZING SHARED\")\n",
    "            if is_contrastive:\n",
    "                for param in nfq_net.layers_shared.parameters():\n",
    "                    param.requires_grad = False\n",
    "                for param in nfq_net.layers_last_shared.parameters():\n",
    "                    param.requires_grad = False\n",
    "                for param in nfq_net.layers_fg.parameters():\n",
    "                    param.requires_grad = True\n",
    "                for param in nfq_net.layers_last_fg.parameters():\n",
    "                    param.requires_grad = True\n",
    "            else:\n",
    "                for param in nfq_net.layers_fg.parameters():\n",
    "                    param.requires_grad = False\n",
    "                for param in nfq_net.layers_last_fg.parameters():\n",
    "                    param.requires_grad = False\n",
    "\n",
    "            optimizer = optim.Adam(\n",
    "                itertools.chain(\n",
    "                    nfq_net.layers_fg.parameters(),\n",
    "                    nfq_net.layers_last_fg.parameters(),\n",
    "                ),\n",
    "                lr=1e-1,\n",
    "            )\n",
    "            nfq_agent._optimizer = optimizer\n",
    "        if sum(fg_success_queue) == 3:\n",
    "            if verbose:\n",
    "                print(\"FG Trained\")\n",
    "            break\n",
    "\n",
    "        if ep % 600 == 0:\n",
    "            perf_bg = []\n",
    "            perf_fg = []\n",
    "            for it in range(evaluations):\n",
    "                (\n",
    "                    eval_episode_length_bg,\n",
    "                    eval_success_bg,\n",
    "                    eval_episode_cost_bg,\n",
    "                ) = nfq_agent.evaluate_car(eval_env_bg, render=render)\n",
    "                (\n",
    "                    eval_episode_length_fg,\n",
    "                    eval_success_fg,\n",
    "                    eval_episode_cost_fg,\n",
    "                ) = nfq_agent.evaluate_car(eval_env_fg, render=render)\n",
    "                perf_bg.append(eval_episode_cost_bg)\n",
    "                perf_fg.append(eval_episode_cost_fg)\n",
    "                train_env_bg.close()\n",
    "                train_env_fg.close()\n",
    "                eval_env_bg.close()\n",
    "                eval_env_fg.close()\n",
    "            if verbose:\n",
    "                print(\n",
    "                    \"Evaluation bg: \" + str(perf_bg) + \" Evaluation fg: \" + str(perf_fg)\n",
    "                )\n",
    "    perf_bg = []\n",
    "    perf_fg = []\n",
    "    for it in range(evaluations * 10):\n",
    "        (\n",
    "            eval_episode_length_bg,\n",
    "            eval_success_bg,\n",
    "            eval_episode_cost_bg,\n",
    "        ) = nfq_agent.evaluate_car(eval_env_bg, render=render)\n",
    "        (\n",
    "            eval_episode_length_fg,\n",
    "            eval_success_fg,\n",
    "            eval_episode_cost_fg,\n",
    "        ) = nfq_agent.evaluate_car(eval_env_fg, render=render)\n",
    "        perf_bg.append(eval_episode_cost_bg)\n",
    "        perf_fg.append(eval_episode_cost_fg)\n",
    "        eval_env_bg.close()\n",
    "        eval_env_fg.close()\n",
    "    if verbose:\n",
    "        print(\n",
    "            \"Evaluation bg: \"\n",
    "            + str(sum(perf_bg) / len(perf_bg))\n",
    "            + \" Evaluation fg: \"\n",
    "            + str(sum(perf_fg) / len(perf_fg))\n",
    "        )\n",
    "    sns.distplot(losses)\n",
    "    return sum(perf_bg) / len(perf_bg), sum(perf_fg) / len(perf_fg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research [~/.conda/envs/research/]",
   "language": "python",
   "name": "conda_research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
