{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "sys.path.append('../simulated_fqi/')\n",
    "import seaborn as sns\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt \n",
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "import shap\n",
    "import configargparse\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "\n",
    "from environments import CartPoleRegulatorEnv\n",
    "from environments import CartEnv\n",
    "from environments import AcrobotEnv\n",
    "from models.agents import NFQAgent\n",
    "from models.networks import NFQNetwork, ContrastiveNFQNetwork\n",
    "from util import get_logger, close_logger, load_models, make_reproducible, save_models\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import itertools\n",
    "from train import fqi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runFQI(\n",
    "verbose=True, \n",
    "is_contrastive=False, \n",
    "epoch=1000, \n",
    "init_experience=200, \n",
    "evaluations=5, \n",
    "force_left=5, \n",
    "random_seed=None, \n",
    "reward_weights=np.asarray([0.1] * 5)\n",
    "):\n",
    "    # Setup environment\n",
    "    bg_cart_mass = 1.0\n",
    "    fg_cart_mass = 1.0\n",
    "    train_env_bg = CartPoleRegulatorEnv(\n",
    "        group=0,\n",
    "        masscart=bg_cart_mass,\n",
    "        mode=\"train\",\n",
    "        force_left=force_left,\n",
    "        is_contrastive=is_contrastive,\n",
    "    )\n",
    "    train_env_fg = CartPoleRegulatorEnv(\n",
    "        group=1,\n",
    "        masscart=fg_cart_mass,\n",
    "        mode=\"train\",\n",
    "        force_left=force_left,\n",
    "        is_contrastive=is_contrastive,\n",
    "    )\n",
    "    eval_env_bg = CartPoleRegulatorEnv(\n",
    "        group=0,\n",
    "        masscart=bg_cart_mass,\n",
    "        mode=\"eval\",\n",
    "        force_left=force_left,\n",
    "        is_contrastive=is_contrastive,\n",
    "    )\n",
    "    eval_env_fg = CartPoleRegulatorEnv(\n",
    "        group=1,\n",
    "        masscart=fg_cart_mass,\n",
    "        mode=\"eval\",\n",
    "        force_left=force_left,\n",
    "        is_contrastive=is_contrastive,\n",
    "    )\n",
    "\n",
    "    # Log to File, Console, TensorBoard, W&B\n",
    "    logger = get_logger()\n",
    "    \n",
    "    # NFQ Main loop\n",
    "    bg_rollouts = []\n",
    "    fg_rollouts = []\n",
    "    total_cost = 0\n",
    "    if init_experience > 0:\n",
    "        for _ in range(init_experience):\n",
    "            rollout_bg, episode_cost = train_env_bg.generate_rollout(\n",
    "                None, render=False, group=0\n",
    "            )\n",
    "            rollout_fg, episode_cost = train_env_fg.generate_rollout(\n",
    "                None, render=False, group=1\n",
    "            )\n",
    "            bg_rollouts.extend(rollout_bg)\n",
    "            fg_rollouts.extend(rollout_fg)\n",
    "            total_cost += episode_cost\n",
    "    bg_rollouts.extend(fg_rollouts)\n",
    "    all_rollouts = bg_rollouts.copy()\n",
    "\n",
    "    bg_rollouts_test = []\n",
    "    fg_rollouts_test = []\n",
    "    if init_experience > 0:\n",
    "        for _ in range(init_experience):\n",
    "            rollout_bg, episode_cost = eval_env_bg.generate_rollout(\n",
    "                None, render=False, group=0\n",
    "            )\n",
    "            rollout_fg, episode_cost = eval_env_fg.generate_rollout(\n",
    "                None, render=False, group=1\n",
    "            )\n",
    "            bg_rollouts_test.extend(rollout_bg)\n",
    "            fg_rollouts_test.extend(rollout_fg)\n",
    "    bg_rollouts_test.extend(fg_rollouts)\n",
    "    all_rollouts_test = bg_rollouts_test.copy()\n",
    "    # Setup agent\n",
    "    nfq_net = ContrastiveNFQNetwork(\n",
    "        state_dim=train_env_bg.state_dim, is_contrastive=is_contrastive\n",
    "    )\n",
    "\n",
    "    if is_contrastive:\n",
    "        optimizer = optim.Adam(\n",
    "            itertools.chain(\n",
    "                nfq_net.layers_shared.parameters(),\n",
    "                nfq_net.layers_last_shared.parameters(),\n",
    "            ),\n",
    "            lr=1e-1,\n",
    "        )\n",
    "    else:\n",
    "            optimizer = optim.Adam(nfq_net.parameters(), lr=1e-1)\n",
    "    nfq_agent = NFQAgent(nfq_net, optimizer)\n",
    "\n",
    "    bg_success_queue = [0] * 3\n",
    "    fg_success_queue = [0] * 3\n",
    "    epochs_fg = 0\n",
    "    eval_fg = 0\n",
    "    for epoch in range(epoch + 1):\n",
    "\n",
    "        state_action_b, target_q_values, groups = nfq_agent.generate_pattern_set(\n",
    "            all_rollouts, reward_weights=reward_weights\n",
    "        )\n",
    "        X = state_action_b\n",
    "        train_groups = groups\n",
    "\n",
    "        if not nfq_net.freeze_shared:\n",
    "            loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "\n",
    "        eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg = 0, 0, 0\n",
    "        if nfq_net.freeze_shared:\n",
    "            eval_fg += 1\n",
    "\n",
    "            if eval_fg > 50:\n",
    "                loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "        (\n",
    "            eval_episode_length_bg,\n",
    "            eval_success_bg,\n",
    "            eval_episode_cost_bg,\n",
    "        ) = nfq_agent.evaluate(eval_env_bg, render=False)\n",
    "        (\n",
    "            eval_episode_length_fg,\n",
    "            eval_success_fg,\n",
    "            eval_episode_cost_fg,\n",
    "        ) = nfq_agent.evaluate(eval_env_fg, render=False)\n",
    "\n",
    "        bg_success_queue = bg_success_queue[1:]\n",
    "        bg_success_queue.append(1 if eval_success_bg else 0)\n",
    "\n",
    "        fg_success_queue = fg_success_queue[1:]\n",
    "        fg_success_queue.append(1 if eval_success_fg else 0)\n",
    "\n",
    "        printed_bg = False\n",
    "        printed_fg = False\n",
    "\n",
    "        if sum(bg_success_queue) == 3 and not nfq_net.freeze_shared == True:\n",
    "            if epochs_fg == 0:\n",
    "                epochs_fg = epoch\n",
    "            printed_bg = True\n",
    "            nfq_net.freeze_shared = True\n",
    "            if verbose:\n",
    "                print(\"FREEZING SHARED\")\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                param.requires_grad = False\n",
    "\n",
    "            optimizer = optim.Adam(\n",
    "                itertools.chain(\n",
    "                    nfq_net.layers_fg.parameters(),\n",
    "                    nfq_net.layers_last_fg.parameters(),\n",
    "                ),\n",
    "                lr=1e-1,\n",
    "            )\n",
    "            nfq_agent._optimizer = optimizer\n",
    "\n",
    "        # Print current status\n",
    "        if verbose:\n",
    "            logger.info(\n",
    "                \"Epoch {:4d} | Eval BG {:4d} / {:4f} | Eval FG {:4d} / {:4f} | Train Loss {:.4f}\".format(\n",
    "                    epoch,\n",
    "                    eval_episode_length_bg,\n",
    "                    eval_episode_cost_bg,\n",
    "                    eval_episode_length_fg,\n",
    "                    eval_episode_cost_fg,\n",
    "                    loss,\n",
    "                )\n",
    "            )\n",
    "        if sum(fg_success_queue) == 3:\n",
    "            printed_fg = True\n",
    "            break\n",
    "\n",
    "    eval_env_bg.step_number = 0\n",
    "    eval_env_fg.step_number = 0\n",
    "\n",
    "    eval_env_bg.max_steps = 1000\n",
    "    eval_env_fg.max_steps = 1000\n",
    "\n",
    "    performance_fg = []\n",
    "    performance_bg = []\n",
    "    num_steps_bg = []\n",
    "    num_steps_fg = []\n",
    "    total = 0\n",
    "    for it in range(evaluations):\n",
    "        (\n",
    "            eval_episode_length_bg,\n",
    "            eval_success_bg,\n",
    "            eval_episode_cost_bg,\n",
    "        ) = nfq_agent.evaluate(eval_env_bg, False)\n",
    "        if verbose:\n",
    "            print(eval_episode_length_bg, eval_success_bg)\n",
    "        num_steps_bg.append(eval_episode_length_bg)\n",
    "        performance_bg.append(eval_episode_length_bg)\n",
    "        total += 1\n",
    "        train_env_bg.close()\n",
    "        eval_env_bg.close()\n",
    "\n",
    "        (\n",
    "            eval_episode_length_fg,\n",
    "            eval_success_fg,\n",
    "            eval_episode_cost_fg,\n",
    "        ) = nfq_agent.evaluate(eval_env_fg, False)\n",
    "        if verbose:\n",
    "            print(eval_episode_length_fg, eval_success_fg)\n",
    "        num_steps_fg.append(eval_episode_length_fg)\n",
    "        performance_fg.append(eval_episode_length_fg)\n",
    "        total += 1\n",
    "        train_env_fg.close()\n",
    "        eval_env_fg.close()\n",
    "    print(\"Fg trained after \" + str(epochs_fg) + \" epochs\")\n",
    "    print(\"BG stayed up for steps: \", num_steps_bg)\n",
    "    print(\"FG stayed up for steps: \", num_steps_fg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getRhos(W=False, PD=True, vset='test'):\n",
    "\n",
    "        rhos = np.ones([self.NV[vset], self.maxT])\n",
    "        epRhos = []\n",
    "\n",
    "        for i, v in enumerate(tqdm(self.visits[vset])):\n",
    "            \n",
    "            # load episode\n",
    "            states, actions, phis, rewards = self.getEpisode(v, vset)\n",
    "            epT = len(states)\n",
    "            T = min(epT, self.maxT)            \n",
    "            \n",
    "            # load action probabilities\n",
    "            prob_b = self.piB['cat'].predict_proba(states)[np.arange(epT), actions][:T]\n",
    "            if self.piE['cat'].predict_proba(states).shape[1] < 4:\n",
    "                #print('Warning: < 4 classes')\n",
    "                prob_e = np.zeros([T, 4])\n",
    "                cols = np.unique(self.piE['cat'].predict(states))\n",
    "                probs = self.piE['cat'].predict_proba(states)\n",
    "                for i, a in enumerate(cols):\n",
    "                    prob_e[:, a] = probs[:, i]\n",
    "                prob_e = prob_e[np.arange(epT), actions][:T]\n",
    "            else:\n",
    "                prob_e = self.piE['cat'].predict_proba(states)[np.arange(epT), actions][:T]\n",
    "            \n",
    "            # calculate importance weights\n",
    "            if PD:\n",
    "                # per-step cumulative weights\n",
    "                invprop = np.cumprod(prob_e/prob_b, axis=0)\n",
    "                # clip importance weights\n",
    "                invprop[invprop<1e-3] = 1e-3\n",
    "                invprop[invprop>1e3] = 1e3\n",
    "                \n",
    "                rhos[i, :len(invprop)] = list(invprop)\n",
    "                rhos[i, len(invprop):] = np.ones(self.maxT-len(invprop)) * rhos[i,len(invprop)-1]\n",
    "                epRhos.append({'s': states, 'a': actions, 'phi': phis,'r': rewards})  \n",
    "        norm = self.NV[vset]\n",
    "        if W: norm = np.sum(rhos, axis=0)    \n",
    "        for i in range(self.NV[vset]): \n",
    "            epRhos[i]['rho'] = rhos[i,:] / norm\n",
    "            \n",
    "        return epRhos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_feature_expectations(rhos, behav=False, vset='train'):\n",
    "\n",
    "        gamma_vec = [self.gamma**(i+1) for i in range(self.maxT)]\n",
    "        T = [len(rhos[i]['phi']) for i in range(len(rhos))]  \n",
    "\n",
    "        if behav:\n",
    "            print('Simple averaging')\n",
    "            estimated_mu = np.mean(np.vstack([np.sum(rhos[i]['phi'] * np.array(gamma_vec[:T[i]])[:,np.newaxis], axis=0) \n",
    "                                      for i in range(self.NV[vset])]), axis=0)\n",
    "        else:\n",
    "            print('PDWIS estimate')\n",
    "            estimated_mu = np.mean(np.vstack([np.sum(rhos[i]['phi'] * (gamma_vec[:T[i]] * rhos[i]['rho'][:T[i]])[:,np.newaxis],\n",
    "                                             axis=0) for i in range(self.NV[vset])]), axis=0)\n",
    "\n",
    "        return estimated_mu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_feature_expectations():\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs=10; learning_rate=0.5; init_w = [0.1]\n",
    "# Initialize reward weights:\n",
    "w_vecs = []\n",
    "if init_w is None:\n",
    "    w = np.ones(len(reward_weights))/float(len(reward_weights))\n",
    "else:\n",
    "    w = init_w\n",
    "\n",
    "muB = None\n",
    "\n",
    "# Find the difference between just training on the train set and \n",
    "\n",
    "for i in range(epochs):\n",
    "    print('Epoch', i, '- Train pi with current w=', w)\n",
    "    reward_weights = w\n",
    "    w_vecs.append(w)\n",
    "    try:\n",
    "        runFQI(reward_weights=reward_weights)\n",
    "\n",
    "        print('Evaluate feature expectations for pi')\n",
    "        # This gives us importance sampling. We should do it for all samples instead. \n",
    "        # epRhos = self.getRhos(vset='train')\n",
    "        mu = self.find_feature_expectations(epRhos, behav=False, vset='train')\n",
    "        print(mu)\n",
    "\n",
    "        print('Initialize behaviour mu:')\n",
    "        if muB is None:\n",
    "            muB = self.find_feature_expectations(epRhos, behav=True, vset='train')\n",
    "        print(muB)\n",
    "\n",
    "        print('Gradient update for new w')\n",
    "        grad = norm(muB) - norm(mu)\n",
    "    except:\n",
    "        print('Error - skip update')\n",
    "        grad = 0\n",
    "    w += learning_rate*(0.95**i) * grad\n",
    "    w = w/np.sum(np.abs(w))\n",
    "\n",
    "\n",
    "return w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research [~/.conda/envs/research/]",
   "language": "python",
   "name": "conda_research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
