{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "from train_cnfqi import run\n",
    "import seaborn as sns\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt \n",
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "import shap\n",
    "from environments import AcrobotEnv\n",
    "from models.agents import NFQAgent\n",
    "from models.networks import NFQNetwork, ContrastiveNFQNetwork\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import itertools\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_contrastive = True\n",
    "init_experience = 200\n",
    "epoch = 1000\n",
    "evaluations = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_env_bg = AcrobotEnv(group=0, mode=\"train\")\n",
    "train_env_fg = AcrobotEnv(group=1, mode=\"train\")\n",
    "eval_env_bg = AcrobotEnv(group=0, mode=\"eval\")\n",
    "eval_env_fg = AcrobotEnv(group=1, mode=\"eval\")\n",
    "\n",
    "# Setup agent\n",
    "nfq_net = ContrastiveNFQNetwork(state_dim=train_env_bg.state_dim, is_contrastive=is_contrastive)\n",
    "\n",
    "if is_contrastive:\n",
    "    optimizer = optim.Adam(itertools.chain(nfq_net.layers_shared.parameters(), nfq_net.layers_last_shared.parameters()), lr=1e-1)\n",
    "else:\n",
    "    optimizer = optim.Adam(nfq_net.parameters(), lr=1e-1)\n",
    "\n",
    "nfq_agent = NFQAgent(nfq_net, optimizer)\n",
    "\n",
    "# NFQ Main loop\n",
    "bg_rollouts = []\n",
    "fg_rollouts = []\n",
    "total_cost = 0\n",
    "if init_experience > 0:\n",
    "    for _ in range(init_experience):\n",
    "        rollout_bg, episode_cost = train_env_bg.generate_rollout(\n",
    "            None, render=False, group=0\n",
    "        )\n",
    "        rollout_fg, episode_cost = train_env_fg.generate_rollout(\n",
    "            None, render=False, group=1\n",
    "        )\n",
    "        bg_rollouts.extend(rollout_bg)\n",
    "        fg_rollouts.extend(rollout_fg)\n",
    "        total_cost += episode_cost\n",
    "bg_rollouts.extend(fg_rollouts)\n",
    "all_rollouts = bg_rollouts.copy()\n",
    "\n",
    "bg_rollouts_test = []\n",
    "fg_rollouts_test = []\n",
    "if init_experience > 0:\n",
    "    for _ in range(init_experience):\n",
    "        rollout_bg, episode_cost = eval_env_bg.generate_rollout(\n",
    "            None, render=False, group=0\n",
    "        )\n",
    "        rollout_fg, episode_cost = eval_env_fg.generate_rollout(\n",
    "            None, render=False, group=1\n",
    "        )\n",
    "        bg_rollouts_test.extend(rollout_bg)\n",
    "        fg_rollouts_test.extend(rollout_fg)\n",
    "bg_rollouts_test.extend(fg_rollouts)\n",
    "all_rollouts_test = bg_rollouts_test.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1001/1001 [12:24<00:00,  1.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fg trained after 0 epochs\n",
      "BG stayed up for steps:  [False, False, False, False, False, False, False, False, False, False]\n",
      "FG stayed up for steps:  [False, False, False, False, False, False, False, False, False, False]\n"
     ]
    }
   ],
   "source": [
    "state_action_b, target_q_values, groups = nfq_agent.generate_pattern_set(all_rollouts_test)\n",
    "\n",
    "bg_success_queue = [0] * 3\n",
    "fg_success_queue = [0] * 3\n",
    "epochs_fg = 0\n",
    "eval_fg = 0\n",
    "for ep in enumerate(tqdm.tqdm(range(epoch + 1))):\n",
    "    state_action_b, target_q_values, groups = nfq_agent.generate_pattern_set(all_rollouts)\n",
    "    X = state_action_b\n",
    "    train_groups = groups\n",
    "\n",
    "    if not nfq_net.freeze_shared:\n",
    "        loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "\n",
    "    eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg = 0, 0, 0\n",
    "    if nfq_net.freeze_shared:\n",
    "        eval_fg += 1\n",
    "\n",
    "        if eval_fg > 50:\n",
    "            loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "\n",
    "    if is_contrastive:\n",
    "        if nfq_net.freeze_shared:\n",
    "            eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg = nfq_agent.evaluate(\n",
    "                eval_env_fg, render=False\n",
    "            )\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                assert param.requires_grad == True\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                assert param.requires_grad == True\n",
    "            for param in nfq_net.layers_shared.parameters():\n",
    "                assert param.requires_grad == False\n",
    "            for param in nfq_net.layers_last_shared.parameters():\n",
    "                assert param.requires_grad == False\n",
    "        else:\n",
    "\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                assert param.requires_grad == False\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                assert param.requires_grad == False\n",
    "            for param in nfq_net.layers_shared.parameters():\n",
    "                assert param.requires_grad == True\n",
    "            for param in nfq_net.layers_last_shared.parameters():\n",
    "                assert param.requires_grad == True\n",
    "            eval_episode_length_bg, eval_success_bg, eval_episode_cost_bg = nfq_agent.evaluate(\n",
    "                eval_env_bg, render=False\n",
    "            )\n",
    "\n",
    "\n",
    "    else:\n",
    "        eval_episode_length_bg, eval_success_bg, eval_episode_cost_bg = nfq_agent.evaluate(\n",
    "            eval_env_bg, render=False\n",
    "        )\n",
    "        eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg = nfq_agent.evaluate(\n",
    "            eval_env_fg, render=False\n",
    "        )\n",
    "\n",
    "    bg_success_queue = bg_success_queue[1:]\n",
    "    bg_success_queue.append(1 if eval_success_bg else 0)\n",
    "\n",
    "    fg_success_queue = fg_success_queue[1:]\n",
    "    fg_success_queue.append(1 if eval_success_fg else 0)\n",
    "\n",
    "    printed_bg = False\n",
    "    printed_fg = False\n",
    "\n",
    "    if sum(bg_success_queue) == 3 and not nfq_net.freeze_shared == True:\n",
    "        if epochs_fg == 0:\n",
    "            epochs_fg = epoch\n",
    "        printed_bg = True\n",
    "        nfq_net.freeze_shared = True\n",
    "        if is_contrastive:\n",
    "            for param in nfq_net.layers_shared.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_last_shared.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                param.requires_grad = True\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                param.requires_grad = True\n",
    "        else:\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                param.requires_grad = False\n",
    "\n",
    "            optimizer = optim.Adam(itertools.chain(nfq_net.layers_fg.parameters(), nfq_net.layers_last_fg.parameters()), lr=1e-1)\n",
    "            nfq_agent._optimizer = optimizer\n",
    "\n",
    "    if sum(fg_success_queue) == 3:\n",
    "        printed_fg = True\n",
    "        break\n",
    "\n",
    "eval_env_bg.step_number = 0\n",
    "eval_env_fg.step_number = 0\n",
    "\n",
    "eval_env_bg.max_steps = 1000\n",
    "eval_env_fg.max_steps = 1000\n",
    "\n",
    "performance_fg = []\n",
    "performance_bg = []\n",
    "total = 0\n",
    "for it in range(evaluations):\n",
    "    eval_episode_length_bg, eval_success_bg, eval_episode_cost_bg = nfq_agent.evaluate(eval_env_bg, False)\n",
    "    performance_bg.append(eval_success_bg)\n",
    "    total += 1\n",
    "    train_env_bg.close()\n",
    "    eval_env_bg.close()\n",
    "    eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg = nfq_agent.evaluate(eval_env_fg, False)\n",
    "    performance_fg.append(eval_success_fg)\n",
    "    total += 1\n",
    "    train_env_fg.close()\n",
    "    eval_env_fg.close()\n",
    "print(\"Fg trained after \" + str(epochs_fg) + \" epochs\")\n",
    "print(\"BG stayed up for steps: \", performance_bg)\n",
    "print(\"FG stayed up for steps: \", performance_fg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research [~/.conda/envs/research/]",
   "language": "python",
   "name": "conda_research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
