{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import variable\n",
    "import time\n",
    "import os\n",
    "import numpy as np\n",
    "import gym\n",
    "import shutil\n",
    "\n",
    "from tqdm import tqdm\n",
    "from random import uniform, randint, sample, random, choices\n",
    "from collections import deque\n",
    "\n",
    "import io\n",
    "import base64\n",
    "from IPython.display import HTML\n",
    "\n",
    "from models.feature_q_model import feature_q_model\n",
    "from memory.memory import ReplayBuffer, ReplayBuffer_decom\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "from tensorboardX import SummaryWriter\n",
    "import math\n",
    "import matplotlib as mpl\n",
    "from datetime import datetime\n",
    "%matplotlib inline\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor\n",
    "Writer = SummaryWriter(log_dir=\"CartPole_summary\")\n",
    "mpl.style.use('bmh')\n",
    "plt.rcParams[\"font.family\"] = \"Arial\"#\"Helvetica\"\n",
    "%config InlineBackend.figure_format = 'retina'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ENV_NAME = 'CartPole_ESP'\n",
    "env = gym.make('CartPole-v1')\n",
    "\n",
    "# Move left, Move right\n",
    "ACTION_DICT = {\n",
    "    \"LEFT\": 0,\n",
    "    \"RIGHT\": 1\n",
    "}\n",
    "action_name = {\n",
    "    0: 'push left',\n",
    "    1: 'push right',\n",
    "}\n",
    "FEATRUESNAME = [\"F1 Cart Right Boundary\", \"F2 Cart Left Boundary\", \"F3 Cart Right Velocity\", \"F4 Cart Left Velocity\", \n",
    "                \"F5 Pole Angle Right\", \"F6 Pole Angle Left\", \"F7 Pole Right Velocity\", \"F8 Pole Left Velocity\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_floder = ENV_NAME\n",
    "result_floder_exp = \"{}/{}\".format(result_floder, ENV_NAME + \"_exp\")\n",
    "result_file = ENV_NAME + \"/results.txt\"\n",
    "result_file_GVFs_loss = ENV_NAME + \"/results_GVFs_loss.txt\"\n",
    "if not os.path.isdir(result_floder):\n",
    "    os.mkdir(result_floder)\n",
    "if not os.path.isdir(result_floder_exp):\n",
    "    os.mkdir(result_floder_exp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MSX(vector):\n",
    "    vector = np.array(vector)\n",
    "    indeces = np.argsort(vector)[::-1]\n",
    "    negative_sum = sum(vector[vector < 0])\n",
    "    pos_sum = 0\n",
    "    MSX_idx = []\n",
    "    for idx in indeces:\n",
    "        pos_sum += vector[idx]\n",
    "        MSX_idx.append(idx)\n",
    "        if pos_sum > abs(negative_sum):\n",
    "            break\n",
    "    return MSX_idx, vector[MSX_idx]\n",
    "\n",
    "def print_Cartpole(state):\n",
    "    print(\"Cart Position: \", state[0], end = \"  \")\n",
    "    print(\"Cart Velocity: \", state[1])\n",
    "    print(\"Pole Angle: \", state[2], end = \"  \")\n",
    "    print(\"Pole Velocity At Tip: \", state[3])\n",
    "def plot_action_group(values, group, elements = [], title = 'decomposition values', y_label = \"\", q_values = None, IGX_action = None, exp_count = -1):\n",
    "    plt.clf()\n",
    "    x = np.arange(len(group))  # the label locations\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    # Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "    plt.ylabel(y_label)\n",
    "    plt.title(title)\n",
    "\n",
    "    # set width of bar\n",
    "    length = len(values[0])\n",
    "    barWidth = 1 / (len(values) + 1)\n",
    "    x_labels_feature = [\"F{}\".format(x + 1) for x in range(len(values))]\n",
    "    for i in range(len(values)):\n",
    "        r = [j + barWidth * i for j in range(length)]\n",
    "        if len(elements) > 0:\n",
    "            plt.bar(r, values[i] , width=barWidth,  label=elements[i])\n",
    "#             for j, rr in enumerate(r):\n",
    "#                 plt.text(rr - barWidth / 4, values[i][j] if values[i][j] > 0 else 0, x_labels_feature[i], fontsize=5)\n",
    "        else:\n",
    "            plt.bar(r, values[i], width=barWidth)\n",
    "    \n",
    "    print(barWidth)\n",
    "    y_lim = plt.gca().get_ylim()\n",
    "    gap = (y_lim[1] - 0) / 30\n",
    "    for i in range(len(values)):\n",
    "        r = [j + barWidth * i for j in range(length)]\n",
    "        for j, rr in enumerate(r):\n",
    "            if values[i][j] != 0:\n",
    "                plt.text(rr - barWidth / 7, -gap * 1.5 if values[i][j] > 0 else gap, x_labels_feature[i], fontsize=8)\n",
    "    \n",
    "    r = [j + barWidth * (i + 1) for j in range(length)]\n",
    "    for rr in r[:-1]:\n",
    "        plt.axvline(x = rr, alpha = 0.5, linestyle='--')\n",
    "\n",
    "    # Add xticks on the middle of the group bars\n",
    "    plt.xlabel('action', fontweight='bold')\n",
    "    center_pos = (1 - barWidth * 2) / 2\n",
    "    \n",
    "    if IGX_action is not None:\n",
    "        group = [\"\\\"{}\\\"\\n greater than\".format(IGX_action)] + group\n",
    "        x_lim_left = plt.gca().get_xlim()[0]\n",
    "        pos = [x_lim_left] + [r + center_pos for r in range(length)]\n",
    "        plt.xticks(pos, group, ha='center')\n",
    "    else:\n",
    "        plt.xticks([r + center_pos for r in range(length)], group, ha='center')\n",
    "    \n",
    "    plt.gca().xaxis.grid(False)\n",
    "    plt.gca().yaxis.grid(True)\n",
    "    plt.gca().set_axisbelow(True)\n",
    "    if np.min(values) >= -gap * 3:\n",
    "        plt.gca().set_ylim(bottom = -gap * 3)\n",
    "    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))\n",
    "    \n",
    "    y_lim = plt.gca().get_ylim()\n",
    "    np_values = np.array(values).T\n",
    "    txt_pos = y_lim[1] - (y_lim[1] - y_lim[0]) / 25\n",
    "    max_values = np_values.max(axis = 1)\n",
    "    sum_values = np_values.sum(axis = 1)\n",
    "    \n",
    "    for i, v in enumerate(max_values):\n",
    "        if q_values is None:\n",
    "            plt.gca().text(i + center_pos, txt_pos, \"sum ≈ {}\".format(np.round(sum_values[i], 4)), color='black', fontweight='ultralight', ha='center')\n",
    "        else:\n",
    "            plt.gca().text(i + center_pos, txt_pos, \"Q_v ≈ {}\".format(np.round(q_values[i], 4)), color='black', fontweight='ultralight', ha='center')\n",
    "\n",
    "    plt.savefig(\"{}/{}-{}.png\".format(result_floder_exp, title, exp_count))\n",
    "    \n",
    "def display_frames_as_gif(frames, video_name):\n",
    "    \"\"\"\n",
    "    Displays a list of frames as a gif, with controls\n",
    "    \"\"\"\n",
    "    Writer = animation.writers['ffmpeg']\n",
    "    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)\n",
    "    #plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)\n",
    "    patch = plt.imshow(frames[0])\n",
    "    plt.axis('off')\n",
    "\n",
    "    def animate(i):\n",
    "        patch.set_data(frames[i])\n",
    "\n",
    "    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)\n",
    "#     display(display_animation(anim, default_mode='loop'))\n",
    "    anim.save(result_floder + '/' + video_name, writer=writer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_IG_MSX(values, values_msx, group, elements = [], title = 'decomposition values', y_label = \"\", q_values = None, IGX_action = None, exp_count = -1):\n",
    "    plt.clf()\n",
    "    x = np.arange(len(group))  # the label locations\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    # Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "    plt.ylabel(y_label)\n",
    "    plt.title(title)\n",
    "\n",
    "    # set width of bar\n",
    "    length = len(values[0])\n",
    "    barWidth = 1 / (len(values) + 1)\n",
    "    x_labels_feature = [\"F{}\".format(x + 1) for x in range(len(values))]\n",
    "    for i in range(len(values)):\n",
    "        r = [j + barWidth * i for j in range(length)]\n",
    "        \n",
    "        for j in range(len(values[i])):\n",
    "            if values_msx[i][j] != 0:\n",
    "                plt.bar(r[j], values[i][j], width=barWidth, hatch=\"////\",\n",
    "                        color = plt.rcParams['axes.prop_cycle'].by_key()['color'][i])\n",
    "            else:\n",
    "                plt.bar(r[j], values[i][j], width=barWidth,\n",
    "                        color = plt.rcParams['axes.prop_cycle'].by_key()['color'][i])\n",
    "    \n",
    "    y_lim = plt.gca().get_ylim()\n",
    "    gap = (y_lim[1] - 0) / 30\n",
    "    for i in range(len(values)):\n",
    "        r = [j + barWidth * i for j in range(length)]\n",
    "        for j, rr in enumerate(r):\n",
    "            if values[i][j] != 0:\n",
    "                plt.text(rr - barWidth / 4, -gap * 1.5 if values[i][j] > 0 else gap, x_labels_feature[i], fontsize=5)\n",
    "    \n",
    "    r = [j + barWidth * (i + 1) for j in range(length)]\n",
    "    for rr in r[:-1]:\n",
    "        plt.axvline(x = rr, alpha = 0.5, linestyle='--')\n",
    "\n",
    "    # Add xticks on the middle of the group bars\n",
    "    plt.xlabel('action', fontweight='bold')\n",
    "    center_pos = (1 - barWidth * 2) / 2\n",
    "    \n",
    "    plt.xticks([r + center_pos for r in range(length)], group, ha='center')\n",
    "    \n",
    "    plt.gca().xaxis.grid(False)\n",
    "    plt.gca().yaxis.grid(True)\n",
    "    plt.gca().set_axisbelow(True)\n",
    "    if np.min(values) >= -gap * 3:\n",
    "        plt.gca().set_ylim(bottom = -gap * 3)\n",
    "    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))\n",
    "    \n",
    "    y_lim = plt.gca().get_ylim()\n",
    "    np_values = np.array(values).T\n",
    "    txt_pos = y_lim[1] - (y_lim[1] - y_lim[0]) / 25\n",
    "    max_values = np_values.max(axis = 1)\n",
    "    sum_values = np_values.sum(axis = 1)\n",
    "    \n",
    "    for i, v in enumerate(max_values):\n",
    "        if q_values is None:\n",
    "            plt.gca().text(i + center_pos, txt_pos, \"{} > {}\".format(IGX_action, group[i]), color='black', fontweight='ultralight', ha='center')\n",
    "        else:\n",
    "            plt.gca().text(i + center_pos, txt_pos, \"Q_v ≈ {}\".format(np.round(q_values[i], 4)), color='black', fontweight='ultralight', ha='center')\n",
    "\n",
    "    plt.savefig(\"{}/{}-{}.png\".format(result_floder_exp, title, exp_count))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams_CarPole = {\n",
    "    'epsilon_decay_steps' : 200000, \n",
    "    'final_epsilon' : 0.05,\n",
    "    'batch_size' : 128, \n",
    "    'update_steps' : 5, \n",
    "    'memory_size' : 200000, \n",
    "    'beta' : 0.99, \n",
    "    'model_replace_freq' : 1,\n",
    "    'learning_rate' : 0.00001,\n",
    "    'privous_state' : 1,\n",
    "    'decom_reward_len': 8,\n",
    "    'soft_tau': 5e-4\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ESP_agent(object):\n",
    "    def __init__(self, env, hyper_params, action_space = len(ACTION_DICT)):\n",
    "        \n",
    "        self.env = env\n",
    "        self.max_episode_steps = env._max_episode_steps\n",
    "        \n",
    "        \"\"\"\n",
    "            beta: The discounted factor of Q-value function\n",
    "            (epsilon): The explore or exploit policy epsilon. \n",
    "            initial_epsilon: When the 'steps' is 0, the epsilon is initial_epsilon, 1\n",
    "            final_epsilon: After the number of 'steps' reach 'epsilon_decay_steps', \n",
    "                The epsilon set to the 'final_epsilon' determinately.\n",
    "            epsilon_decay_steps: The epsilon will decrease linearly along with the steps from 0 to 'epsilon_decay_steps'.\n",
    "        \"\"\"\n",
    "        self.beta = hyper_params['beta']\n",
    "        self.initial_epsilon = 1\n",
    "        self.final_epsilon = hyper_params['final_epsilon']\n",
    "        self.epsilon_decay_steps = hyper_params['epsilon_decay_steps']\n",
    "        self.soft_tau = hyper_params['soft_tau']\n",
    "\n",
    "        \"\"\"\n",
    "            episode: Record training episode\n",
    "            steps: Add 1 when predicting an action\n",
    "            learning: The trigger of agent learning. It is on while training agent. It is off while testing agent.\n",
    "            action_space: The action space of the current environment, e.g 2.\n",
    "        \"\"\"\n",
    "        self.episode = 0\n",
    "        self.steps = 0\n",
    "        self.best_reward = -1000\n",
    "        self.learning = True\n",
    "        self.action_space = action_space\n",
    "        self.privous_state = hyper_params['privous_state']\n",
    "\n",
    "        \"\"\"\n",
    "            input_len The input length of the neural network. It equals to the length of the state vector.\n",
    "            output_len: The output length of the neural network. It is equal to the action space.\n",
    "            eval_model: The model for predicting action for the agent.\n",
    "            target_model: The model for calculating Q-value of next_state to update 'eval_model'.\n",
    "        \"\"\"\n",
    "        self.decom_reward_len = hyper_params['decom_reward_len']\n",
    "        state = env.reset()\n",
    "        self.state_len = len(state)\n",
    "        input_len = self.state_len * self.privous_state + action_space\n",
    "        output_len = 1\n",
    "        \n",
    "        self.action_vector = self.get_action_vector()\n",
    "        self.eval_model = feature_q_model(input_len, self.decom_reward_len, output_len, learning_rate = hyper_params['learning_rate'])\n",
    "        self.target_model = feature_q_model(input_len, self.decom_reward_len, output_len, learning_rate = hyper_params['learning_rate'])\n",
    "#         memory: Store and sample experience replay.\n",
    "        self.memory = ReplayBuffer_decom(hyper_params['memory_size'])\n",
    "        \n",
    "        \"\"\"\n",
    "            batch_size: Mini batch size for training model.\n",
    "            update_steps: The frequence of traning model\n",
    "            model_replace_freq: The frequence of replacing 'target_model' by 'eval_model'\n",
    "        \"\"\"\n",
    "        self.exp_count = 0\n",
    "        self.batch_size = hyper_params['batch_size']\n",
    "        self.update_steps = hyper_params['update_steps']\n",
    "        self.model_replace_freq = hyper_params['model_replace_freq']\n",
    "        self.q_value_gd = []\n",
    "        self.v_feature_input= []\n",
    "        self.loss_accumulate = deque(maxlen=1000)\n",
    "        \n",
    "        self.optimizer_com = self.eval_model.optimizer_com\n",
    "        self.loss_fn = self.eval_model.loss_fn\n",
    "        \n",
    "    # Linear decrease function for epsilon\n",
    "    def linear_decrease(self, initial_value, final_value, curr_steps, final_decay_steps):\n",
    "        decay_rate = curr_steps / final_decay_steps\n",
    "        if decay_rate > 1:\n",
    "            decay_rate = 1\n",
    "        return initial_value - (initial_value - final_value) * decay_rate\n",
    "    \n",
    "    def get_action_vector(self):\n",
    "        action_vector = np.zeros((self.action_space, self.action_space))\n",
    "        for i in range(len(action_vector)):\n",
    "            action_vector[i, i] = 1\n",
    "        \n",
    "        return FloatTensor(action_vector)\n",
    "    \n",
    "    def concat_state_action(self, states, actions = None, is_full_action = False):\n",
    "        if is_full_action:\n",
    "            com_state = FloatTensor(states).repeat((1, self.action_space)).view((-1, self.state_len))\n",
    "            actions = self.action_vector.repeat((len(states), 1))\n",
    "        else:\n",
    "            com_state = states.clone()\n",
    "            actions = actions.clone()\n",
    "        state_action = torch.cat((com_state, actions), 1)\n",
    "        return state_action\n",
    "        \n",
    "    def explore_or_exploit_policy(self, state):\n",
    "        p = uniform(0, 1)\n",
    "        # Get decreased epsilon\n",
    "        epsilon = self.linear_decrease(self.initial_epsilon, \n",
    "                               self.final_epsilon,\n",
    "                               self.steps,\n",
    "                               self.epsilon_decay_steps)\n",
    "        self.epsilon = epsilon\n",
    "        \n",
    "        if p < epsilon:\n",
    "            #return action, None\n",
    "            return randint(0, self.action_space - 1)\n",
    "        else:\n",
    "            #return action, Q-value\n",
    "            return self.greedy_policy(state)[0]\n",
    "        \n",
    "    def greedy_policy(self, state):\n",
    "        state_ft = FloatTensor(state).view(-1, self.state_len)\n",
    "        state_action = self.concat_state_action(state_ft, is_full_action = True)\n",
    "        feature_vectors, q_values = self.eval_model.predict_batch(state_action)\n",
    "        q_v, best_action = q_values.max(0)\n",
    "        \n",
    "        return best_action.item(), q_v, feature_vectors[best_action.item()]\n",
    "    \n",
    "    def update_batch(self):\n",
    "        if len(self.memory) < self.batch_size or self.steps % self.update_steps != 0:\n",
    "            return\n",
    "\n",
    "        batch = self.memory.sample(self.batch_size)\n",
    "\n",
    "        (states_actions, _, reward, next_states,\n",
    "         is_terminal, rewards_decom) = batch\n",
    "        \n",
    "        next_states = FloatTensor(next_states)\n",
    "        terminal = FloatTensor([1 if t else 0 for t in is_terminal])\n",
    "        reward = FloatTensor(reward)\n",
    "        rewards_decom = FloatTensor(rewards_decom)\n",
    "        batch_index = torch.arange(self.batch_size,\n",
    "                                   dtype=torch.long)\n",
    "        \n",
    "        # Current Q Values\n",
    "        feature_vector, q_values = self.eval_model.predict_batch(states_actions)\n",
    "        next_state_actions = self.concat_state_action(next_states, is_full_action = True)\n",
    "        feature_vector_next, q_next = self.target_model.predict_batch(next_state_actions)\n",
    "        q_next = q_next.view((-1, self.action_space))\n",
    "        feature_vector_next = feature_vector_next.view((-1, self.action_space, self.decom_reward_len))\n",
    "        q_max, idx = q_next.max(1)\n",
    "\n",
    "        q_max = (1 - terminal) * q_max\n",
    "        q_target = reward + self.beta * q_max\n",
    "        q_target = q_target.unsqueeze(1)\n",
    "        \n",
    "        feature_vector_max = feature_vector_next[batch_index, idx, :]\n",
    "        \n",
    "        feature_vector_max = (1 - terminal.view(-1, 1)) * feature_vector_max\n",
    "        feature_vector_target = rewards_decom + feature_vector_max * self.beta\n",
    "        \n",
    "        self.eval_model.fit(q_values, q_target, feature_vector, feature_vector_target)\n",
    "        \n",
    "    def learn_and_evaluate(self, training_episodes, test_interval):\n",
    "        test_number = training_episodes // test_interval\n",
    "        all_results = []\n",
    "        \n",
    "        for i in range(test_number):\n",
    "            # learn\n",
    "            self.learn(test_interval)\n",
    "            f = open(result_file, \"a+\")\n",
    "            f.write(str(\"\\n***{}\".format((i + 1) * test_interval) + \"\\n\"))\n",
    "            f.close()\n",
    "            f = open(result_file_GVFs_loss, \"a+\")\n",
    "            f.write(str(\"\\n***{}\".format((i + 1) * test_interval) + \"\\n\"))\n",
    "            f.close()\n",
    "            # evaluate\n",
    "            avg_reward = self.evaluate((i + 1) * test_interval)\n",
    "            all_results.append(avg_reward)\n",
    "            \n",
    "        return all_results\n",
    "    \n",
    "    def get_features_decom(self, state, next_state, done):\n",
    "        \n",
    "        threshold_x = 1\n",
    "        threshold_c_v = 1\n",
    "        threshold_angle = 0.07\n",
    "        threshold_p_v = 0.7\n",
    "        \n",
    "        features_decom = np.ones(self.decom_reward_len)\n",
    "#         cart_position, cart_velocity, pole_angle, pole_velocity = state\n",
    "        next_cart_position, next_cart_velocity, next_pole_angle, next_pole_velocity = next_state\n",
    "            \n",
    "        if threshold_x < next_cart_position:\n",
    "            features_decom[0] = -1\n",
    "        if -threshold_x > next_cart_position:\n",
    "            features_decom[1] = -1        \n",
    "    \n",
    "        if threshold_c_v < next_cart_velocity:\n",
    "            features_decom[2] = -1\n",
    "        if -threshold_c_v > next_cart_velocity:\n",
    "            features_decom[3] = -1   \n",
    "            \n",
    "        if threshold_angle < next_pole_angle:\n",
    "            features_decom[4] = -1\n",
    "        if -threshold_angle > next_pole_angle:\n",
    "            features_decom[5] = -1   \n",
    "        \n",
    "        if threshold_p_v < next_pole_velocity:\n",
    "            features_decom[6] = -1\n",
    "        if -threshold_p_v > next_pole_velocity:\n",
    "            features_decom[7] = -1   \n",
    "        return features_decom\n",
    "        \n",
    "    def learn(self, test_interval):\n",
    "        \n",
    "        for episode in tqdm(range(test_interval), desc=\"Training\"):\n",
    "            state = self.env.reset()\n",
    "            done = False\n",
    "            steps = 0\n",
    "            \n",
    "            while steps < self.max_episode_steps and not done:\n",
    "                steps += 1\n",
    "                self.steps += 1\n",
    "                \n",
    "                action = self.explore_or_exploit_policy(state)\n",
    "                next_state, reward, done, _ = self.env.step(action)\n",
    "                \n",
    "                features_decom = self.get_features_decom(state, next_state, steps < self.max_episode_steps and done)\n",
    "                action_vector = np.zeros(self.action_space)\n",
    "                action_vector[action] = 1\n",
    "                \n",
    "                self.memory.add(np.concatenate((state.copy(), action_vector.copy()), axis=0), -1, reward, next_state, steps < self.max_episode_steps and done, features_decom)\n",
    "                self.update_batch()\n",
    "                \n",
    "                if self.steps % self.model_replace_freq == 0:\n",
    "                    if self.model_replace_freq == 1:\n",
    "                        self.target_model.replace_soft(self.eval_model, tau = self.soft_tau)\n",
    "                    else:\n",
    "                        self.target_model.replace(self.eval_model)\n",
    "                state = next_state\n",
    "                \n",
    "    def evaluate(self, episode_num, trials = 100, loss_max_steps = 100):\n",
    "        total_reward = 0\n",
    "        total_steps = 0\n",
    "        all_GVFs_loss = 0\n",
    "        for _ in tqdm(range(trials), desc=\"Evaluating\"):\n",
    "            state = self.env.reset()\n",
    "            done = False\n",
    "            steps = 0\n",
    "            all_features_gt = FloatTensor(np.zeros((loss_max_steps, self.decom_reward_len)))\n",
    "            all_features_predict = FloatTensor(np.zeros((loss_max_steps, self.decom_reward_len)))\n",
    "            discounted_para = FloatTensor(np.zeros((loss_max_steps, 1)))\n",
    "            while steps < self.max_episode_steps and not done:\n",
    "                steps += 1\n",
    "                action, _, fv = self.greedy_policy(state)\n",
    "                next_state, reward, done, _ = self.env.step(action)\n",
    "                total_reward += reward\n",
    "                \n",
    "                if steps < loss_max_steps:\n",
    "                    measure_steps = steps\n",
    "                    all_features_predict[measure_steps - 1] = fv\n",
    "                    \n",
    "                features_decom = self.get_features_decom(state, next_state, steps < self.max_episode_steps and done)\n",
    "                \n",
    "                all_features_gt[:measure_steps] += (FloatTensor(features_decom) * (self.beta ** discounted_para))[:measure_steps]\n",
    "                    \n",
    "                discounted_para[:measure_steps] += 1\n",
    "                state = next_state\n",
    "            total_steps += steps\n",
    "            with torch.no_grad():\n",
    "                all_GVFs_loss += self.eval_model.loss_fn(all_features_gt[:measure_steps], all_features_predict[:measure_steps]).item()\n",
    "        avg_reward = total_reward / trials\n",
    "        avg_GVFs_loss = all_GVFs_loss / trials\n",
    "        print(\"avg score: {}\".format(avg_reward))\n",
    "        print(\"avg GVF loss: {}\".format(avg_GVFs_loss))\n",
    "        f = open(result_file, \"a+\")\n",
    "        f.write(str(avg_reward) + \"\\n\")\n",
    "        f.close()\n",
    "        f = open(result_file_GVFs_loss, \"a+\")\n",
    "        f.write(str(avg_GVFs_loss) + \"\\n\")\n",
    "        f.close()\n",
    "        if avg_reward >= self.best_reward:\n",
    "            self.best_reward = avg_reward\n",
    "            self.save_model()\n",
    "            print(\"save\")\n",
    "        Writer.add_scalars(main_tag='CartPole/GQF discrete',\n",
    "                                tag_scalar_dict = {'GQF discrete 1':avg_reward}, \n",
    "#                                 scalar_value=,\n",
    "                                global_step=episode_num)\n",
    "        return avg_reward\n",
    "#################################################################################################\n",
    "#################################################################################################  \n",
    "############################################### MSX #############################################  \n",
    "#################################################################################################  \n",
    "#################################################################################################  \n",
    "    def explore_or_exploit_policy_eval(self, state, epsilon = 0):\n",
    "        p = uniform(0, 1)\n",
    "        if p < epsilon:\n",
    "            #return action, None\n",
    "            return randint(0, self.action_space - 1)\n",
    "        else:\n",
    "            #return action, Q-value\n",
    "            return self.greedy_policy(state)[0]\n",
    "\n",
    "    def evaluate_combination_action_large_diff(self, example_num = 10, epsilon = 0, states = []):\n",
    "        self.eval_model.eval_mode()\n",
    "        loss_fn = nn.MSELoss()\n",
    "        \n",
    "        all_states = []\n",
    "        all_frames = []\n",
    "        state = self.env.reset()\n",
    "        done = False\n",
    "        steps = 0\n",
    "        \n",
    "        state_len = len(state)\n",
    "        \n",
    "#         q_vs = []\n",
    "        while steps < self.max_episode_steps and not done:\n",
    "            steps += 1\n",
    "            self.steps += 1\n",
    "\n",
    "            action = self.explore_or_exploit_policy_eval(state, epsilon = epsilon)\n",
    "#             q_vs.append(q_v.item())\n",
    "            state, reward, done, _ = self.env.step(action)\n",
    "            \n",
    "            all_frames.append(env.render(mode='rgb_array'))\n",
    "            \n",
    "            all_states.append(state.copy())\n",
    "        for s in states:\n",
    "            all_states.append(s)\n",
    "            all_frames.append(np.zeros_like(all_frames[0]))\n",
    "        examples = FloatTensor(all_states)\n",
    "        eval_frames = np.array(all_frames)\n",
    "\n",
    "        com_examples = self.concat_state_action(examples, is_full_action = True)\n",
    "        v_features, q_value = self.eval_model.predict_batch(com_examples)\n",
    "        q_value = q_value.view((-1, self.action_space))\n",
    "        v_features = v_features.view((-1, self.action_space, self.decom_reward_len))\n",
    "        \n",
    "        q_best_values, predict_best_action = q_value.max(1)\n",
    "        q_worst_values, predict_worst_action = q_value.min(1)\n",
    "        \n",
    "        q_diff = q_best_values - q_worst_values\n",
    "        \n",
    "        q_diff_idx = q_diff.argsort(descending = True)[:example_num]\n",
    "        if len(states) > 0:\n",
    "            q_diff_idx = torch.cat((q_diff_idx, LongTensor(list(range(len(all_frames) - len(states), len(all_frames))))))\n",
    "        q_diff_idx_np = np.array(q_diff_idx.tolist())\n",
    "        q_diff = q_diff[q_diff_idx]\n",
    "        print(\"diff example length: {}\".format(len(q_diff)))\n",
    "        q_value = q_value[q_diff_idx]\n",
    "        v_features = v_features[q_diff_idx]\n",
    "        q_best_values, predict_best_action = q_value.max(1)\n",
    "        q_worst_values, predict_worst_action = q_value.min(1)\n",
    "        \n",
    "        eval_frames = eval_frames[q_diff_idx_np]\n",
    "        examples = examples[q_diff_idx_np]\n",
    "        \n",
    "        total_acc_1 = 0\n",
    "        total_acc_2 = 0\n",
    "    \n",
    "        print(\"diff action value: {}\".format(q_diff))\n",
    "        count = 0\n",
    "        for i, vf in enumerate(v_features):\n",
    "            p_a = predict_best_action[i].item()\n",
    "            sub_actions = q_value[i].argsort()[:-1]\n",
    "    \n",
    "            print(\"=========================================================================\")\n",
    "            \n",
    "#             print(\"true target action: {}\\n true baseline action: {}\".format(action_name[t_a], action_name[true_sub_action]))\n",
    "            show_image = True\n",
    "            print_Cartpole(examples[i].tolist())\n",
    "        \n",
    "            IGs = []\n",
    "            ans = []\n",
    "            MSX_values = []\n",
    "            baseline_values = []\n",
    "            GVFs = []\n",
    "            q_print_values = []\n",
    "            GVFs.append(v_features[i][p_a].tolist())\n",
    "            ans.append(action_name[p_a])\n",
    "            q_print_values.append(q_value[i][p_a].item())\n",
    "            print(\"target action {}, value: {}\".format(action_name[p_a], q_best_values[i]))\n",
    "            for sb in sub_actions.tolist()[::-1]:\n",
    "#                 print(\"=========================================================================\")\n",
    "                sub_action = sb\n",
    "#                 print(\"target action: {}\\n baseline action: {}\".format(action_name[p_a], action_name[sub_action]))\n",
    "                ans.append(action_name[sub_action])\n",
    "                GVFs.append(v_features[i][sub_action].tolist())\n",
    "                msx_idx, msx_value, intergated_grad = self.differenc_vector_2(self.eval_model.q_model, \n",
    "                                      (v_features[i][p_a], eval_frames[i], q_value[i][p_a])\n",
    "                                      ,(v_features[i][sub_action], eval_frames[i], q_value[i][sub_action]), False, iteration = 30, show_image = show_image)\n",
    "                ig = np.array(intergated_grad[0].tolist())\n",
    "                IGs.append(ig)\n",
    "                baseline_values.append(q_value[i][sub_action].item())\n",
    "                q_print_values.append(q_value[i][sub_action].item())\n",
    "                orginal_pos_MSX = np.zeros(len(ig))\n",
    "                orginal_pos_MSX[msx_idx] = ig[msx_idx]\n",
    "                MSX_values.append(orginal_pos_MSX)\n",
    "                if show_image:\n",
    "                    show_image = False\n",
    "                    \n",
    "            for v, k in zip(baseline_values, ans):\n",
    "                print(\"baseline action {}, value: {}\".format(k, v))\n",
    "            \n",
    "            plot_action_group(np.array(GVFs).T, ans, FEATRUESNAME, title = \"GVFs\", y_label = \"GVFs Values\", q_values = q_print_values, exp_count = self.exp_count)\n",
    "            plot_IG_MSX(np.array(IGs).T, np.array(MSX_values).T, ans[1:], FEATRUESNAME, title = \"IG & MSX\", y_label = \"IGX\", IGX_action = ans[0], exp_count = self.exp_count)\n",
    "        \n",
    "    def differenc_vector_2(self, model, target, baseline, verbose = True, iteration = 100, show_image = True):\n",
    "        t_feature, t_frame, t_value = target\n",
    "        b_feature, b_frame, b_value = baseline\n",
    "        \n",
    "        if verbose or show_image:\n",
    "            plt.clf()\n",
    "            print(\"target value: {}\".format(t_value.item()))\n",
    "            print(\"baseline value: {}\".format(b_value.item()))\n",
    "            self.exp_count += 1\n",
    "            plt.imshow(t_frame)\n",
    "            plt.axis('off')\n",
    "            plt.savefig(\"{}/game_state-{}.png\".format(result_floder_exp, self.exp_count))\n",
    "        (msx_idx, msx_value), intergated_grad = self.intergated_gradients(model, t_feature, baseline = b_feature, verbose = verbose, iteration = iteration)\n",
    "        return msx_idx, msx_value, intergated_grad\n",
    "    def intergated_gradients(self, model, x, iteration = 100, baseline = None, verbose = True):\n",
    "        y_baseline = model(baseline).item()\n",
    "        x = x.view(1, -1)\n",
    "        x.size()[1]\n",
    "        if baseline is None:\n",
    "            baseline = torch.zeros_like(x)\n",
    "        elif verbose:\n",
    "            print(\"baseline: {}\".format(baseline))\n",
    "        \n",
    "        intergated_grad = torch.zeros_like(x)\n",
    "    \n",
    "        for i in range(iteration):\n",
    "            new_input = baseline + ((i + 1) / iteration * (x - baseline))\n",
    "            new_input = new_input.clone().detach().requires_grad_(True)\n",
    "    #         print(new_input)\n",
    "\n",
    "            y = model(new_input)\n",
    "            loss = abs(y_baseline - y)\n",
    "            self.optimizer_com.zero_grad()\n",
    "            loss.backward()\n",
    "            intergated_grad += (new_input.grad) / iteration\n",
    "            \n",
    "        intergated_grad *= x - baseline\n",
    "        \n",
    "        print(\"target GVFs: {}\".format(x.tolist()))\n",
    "        print(\"baseline GVFs: {}\".format(baseline.tolist()))\n",
    "        MSX_idx, MSX_values = MSX(intergated_grad.tolist()[0])\n",
    "        if verbose:\n",
    "            msx_vector = np.zeros(self.decom_reward_len)\n",
    "            msx_vector[MSX_idx] = MSX_values\n",
    "            plot_msx_plus(msx_vector, tittle = 'MSX+')\n",
    "    \n",
    "        return (MSX_idx, MSX_values), intergated_grad\n",
    "    \n",
    "    def generate_video(self, video_name):\n",
    "        self.eval_model.eval_mode()\n",
    "        loss_fn = nn.MSELoss()\n",
    "        total_reward = 0\n",
    "        all_frames = []\n",
    "        state = self.env.reset()\n",
    "        done = False\n",
    "        steps = 0\n",
    "        \n",
    "        while steps < self.max_episode_steps and not done:\n",
    "            steps += 1\n",
    "\n",
    "            action, q_v, feature_vector = self.greedy_policy(state)\n",
    "            state, reward, done, _ = self.env.step(action)\n",
    "            total_reward += reward\n",
    "            all_frames.append(env.render(mode='rgb_array'))\n",
    "        print(total_reward)\n",
    "        display_frames_as_gif(all_frames, video_name)\n",
    "        print(\"video generted\")\n",
    "\n",
    "    def save_model(self, path = None, name = \"best_model.pt\"):\n",
    "        if path is None:\n",
    "            path = result_floder\n",
    "        self.eval_model.save(\"{}/{}\".format(path, name))\n",
    "        \n",
    "    # load model\n",
    "    def load_model(self, path = \"\", name = \"best_model.pt\"):\n",
    "        if path is None:\n",
    "            path = result_floder\n",
    "        self.eval_model.load(\"{}/{}\".format(path, name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_episodes, test_interval = 10000, 100\n",
    "for i in range(10):\n",
    "    print(\"*************{}*************\".format(i))\n",
    "    agent = ESP_agent(env, hyperparams_CarPole)\n",
    "    result = agent.learn_and_evaluate(training_episodes, test_interval)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_ckpt_path = \"{}/best_check_point\".format(result_floder)\n",
    "agent = ESP_agent(env, hyperparams_CarPole)\n",
    "agent.load_model(path = best_ckpt_path)\n",
    "agent.evaluate_combination_action_large_diff(example_num = 5, epsilon = 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_ckpt_path = \"{}/best_check_point\".format(result_floder)\n",
    "agent = ESP_agent(env, hyperparams_CarPole)\n",
    "agent.load_model(path = best_ckpt_path)\n",
    "agent.evaluate(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!jupyter nbconvert --to script CP_ESP.ipynb"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
