{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TD3-FORK-BipedalWalkerHardcore",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "J-123tqbrE6u",
        "outputId": "0323380b-e53c-4945-95ea-786e5a546077",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import gym\n",
        "from collections import namedtuple, deque\n",
        "import torch.optim as optim\n",
        "import random\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline \n",
        "\"\"\"Uncomment these two lines to get access to Google Drive\"\"\"\n",
        "#from google.colab import drive\n",
        "#drive.mount('/content/drive')\n",
        "import os\n",
        "import time"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qBxyUyM5rL0r"
      },
      "source": [
        "# Actor Neural Network\n",
        "class Actor(nn.Module):\n",
        "    def __init__(self, state_size, action_size, seed, fc_units=256, fc1_units=256):\n",
        "        super(Actor, self).__init__()\n",
        "        self.seed = torch.manual_seed(seed)\n",
        "        self.fc1 = nn.Linear(state_size, fc_units)\n",
        "        self.fc2 = nn.Linear(fc_units, fc1_units)\n",
        "        self.fc3 = nn.Linear(fc1_units, action_size)\n",
        "\n",
        "    def forward(self, state):\n",
        "        \"\"\"Build an actor (policy) network that maps states -> actions.\"\"\"\n",
        "        x = F.relu(self.fc1(state))\n",
        "        x = F.relu(self.fc2(x))\n",
        "        return F.torch.tanh(self.fc3(x))\n",
        "\n",
        "# Q1-Q2-Critic Neural Network  \n",
        "  \n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, state_size, action_size, seed, fc1_units=256, fc2_units=256):\n",
        "        super(Critic, self).__init__()\n",
        "        self.seed = torch.manual_seed(seed)\n",
        "        # Q1 architecture\n",
        "        self.l1 = nn.Linear(state_size + action_size, fc1_units)\n",
        "        self.l2 = nn.Linear(fc1_units, fc2_units)\n",
        "        self.l3 = nn.Linear(fc2_units, 1)\n",
        "\n",
        "        # Q2 architecture\n",
        "        self.l4 = nn.Linear(state_size + action_size, fc1_units)\n",
        "        self.l5 = nn.Linear(fc1_units, fc2_units)\n",
        "        self.l6 = nn.Linear(fc2_units, 1)\n",
        "\n",
        "    def forward(self, state, action):\n",
        "        \"\"\"Build a critic (value) network that maps (state, action) pairs -> Q-values.\"\"\"\n",
        "        xa = torch.cat([state, action], 1)\n",
        "\n",
        "        x1 = F.relu(self.l1(xa))\n",
        "        x1 = F.relu(self.l2(x1))\n",
        "        x1 = self.l3(x1)\n",
        "\n",
        "        x2 = F.relu(self.l4(xa))\n",
        "        x2 = F.relu(self.l5(x2))\n",
        "        x2 = self.l6(x2)\n",
        "\n",
        "        return x1, x2\n",
        "\n",
        "\n",
        "class SysModel(nn.Module):\n",
        "    def __init__(self, state_size, action_size, seed, fc1_units=400, fc2_units=300):\n",
        "        super(SysModel, self).__init__()\n",
        "        self.seed = torch.manual_seed(seed)\n",
        "        self.l1 = nn.Linear(state_size + action_size, fc1_units)\n",
        "        self.l2 = nn.Linear(fc1_units, fc2_units)\n",
        "        self.l3 = nn.Linear(fc2_units, state_size)\n",
        "\n",
        "\n",
        "    def forward(self, state, action):\n",
        "        \"\"\"Build a system model to predict the next state at a given state.\"\"\"\n",
        "        xa = torch.cat([state, action], 1)\n",
        "\n",
        "        x1 = F.relu(self.l1(xa))\n",
        "        x1 = F.relu(self.l2(x1))\n",
        "        x1 = self.l3(x1)\n",
        "\n",
        "        return x1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "85Nxvts3rR2w"
      },
      "source": [
        "class TD3_FORK:\n",
        "    def __init__(\n",
        "        self,name,env,\n",
        "        load = False,\n",
        "        gamma = 0.99, #discount factor\n",
        "        lr_actor = 3e-4,\n",
        "        lr_critic = 3e-4,\n",
        "        lr_sysmodel = 3e-4,\n",
        "        batch_size = 100,\n",
        "        buffer_capacity = 1000000,\n",
        "        tau = 0.02,  #target network update factor\n",
        "        cuda = True,\n",
        "        policy_noise=0.2, \n",
        "        std_noise = 0.1,\n",
        "        noise_clip=0.5,\n",
        "        policy_freq=2, #target network update period\n",
        "    ):\n",
        "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self.env = env\n",
        "        self.create_actor()\n",
        "        self.create_critic()\n",
        "        self.create_sysmodel()\n",
        "        self.act_opt = optim.Adam(self.actor.parameters(), lr=lr_actor)\n",
        "        self.crt_opt = optim.Adam(self.critic.parameters(), lr=lr_critic)\n",
        "        self.sys_opt = optim.Adam(self.sysmodel.parameters(), lr=lr_sysmodel)\n",
        "        self.set_weights()\n",
        "        self.replay_memory_buffer = deque(maxlen = buffer_capacity)\n",
        "        self.batch_size = batch_size\n",
        "        self.tau = tau\n",
        "        self.policy_freq = policy_freq\n",
        "        self.gamma = gamma\n",
        "        self.name = name\n",
        "        self.upper_bound = self.env.action_space.high[0] #action space upper bound\n",
        "        self.lower_bound = self.env.action_space.low[0]  #action space lower bound\n",
        "        self.obs_upper_bound = self.env.observation_space.high[0] #state space upper bound\n",
        "        self.obs_lower_bound = self.env.observation_space.low[0]  #state space lower bound\n",
        "        self.policy_noise = policy_noise\n",
        "        self.noise_clip = noise_clip\n",
        "        self.std_noise = std_noise   \n",
        "        self.sys_updates = 0\n",
        "\n",
        "    \n",
        "    def create_actor(self):\n",
        "        params = {\n",
        "            'state_size':      self.env.observation_space.shape[0],\n",
        "            'action_size':     self.env.action_space.shape[0],\n",
        "            'seed':            88\n",
        "        }\n",
        "        self.actor = Actor(**params).to(self.device)\n",
        "        self.actor_target = Actor(**params).to(self.device)\n",
        "\n",
        "    def create_critic(self):\n",
        "        params = {\n",
        "            'state_size':      self.env.observation_space.shape[0],\n",
        "            'action_size':     self.env.action_space.shape[0],\n",
        "            'seed':            88\n",
        "        }\n",
        "        self.critic = Critic(**params).to(self.device)\n",
        "        self.critic_target = Critic(**params).to(self.device)\n",
        "\n",
        "    def create_sysmodel(self):\n",
        "        params = {\n",
        "            'state_size':      self.env.observation_space.shape[0],\n",
        "            'action_size':     self.env.action_space.shape[0],\n",
        "            'seed':            88\n",
        "        }\n",
        "        self.sysmodel = SysModel(**params).to(self.device)\n",
        "\n",
        "    def set_weights(self):\n",
        "        self.actor_target.load_state_dict(self.actor.state_dict())\n",
        "        self.critic_target.load_state_dict(self.critic.state_dict())\n",
        "      \n",
        "    def load_weight(self):\n",
        "        self.actor.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/actor.pth', map_location=self.device))\n",
        "        self.critic.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/critic.pth', map_location=self.device))\n",
        "        self.actor_target.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/actor_t.pth', map_location=self.device))\n",
        "        self.critic_target.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/critic_t.pth', map_location=self.device))\n",
        "        self.sysmodel.load_state_dict(torch.load('/content/drive/My Drive/bipedal/weights/hardcore/sysmodel.pth', map_location=self.device))\n",
        "        \n",
        "\n",
        "    def add_to_replay_memory(self, transition, buffername):\n",
        "        #add samples to replay memory\n",
        "        buffername.append(transition)\n",
        "\n",
        "    def get_random_sample_from_replay_mem(self, buffername):\n",
        "        #random samples from replay memory\n",
        "        random_sample = random.sample(buffername, self.batch_size)\n",
        "        return random_sample\n",
        "\n",
        "\n",
        "    def learn_and_update_weights_by_replay(self,training_iterations, weight, totrain):\n",
        "        \"\"\"Update policy and value parameters using given batch of experience tuples.\n",
        "        where:\n",
        "            actor_target(state) -> action\n",
        "            critic_target(state, action) -> Q-value\n",
        "        \"\"\"\n",
        "        if len(self.replay_memory_buffer) < 1e4:\n",
        "            return \n",
        "        for it in range(training_iterations):\n",
        "            mini_batch = self.get_random_sample_from_replay_mem(self.replay_memory_buffer)\n",
        "            state_batch = torch.from_numpy(np.vstack([i[0] for i in mini_batch])).float().to(self.device)\n",
        "            action_batch = torch.from_numpy(np.vstack([i[1] for i in mini_batch])).float().to(self.device)\n",
        "            reward_batch = torch.from_numpy(np.vstack([i[2] for i in mini_batch])).float().to(self.device)\n",
        "            next_state_batch = torch.from_numpy(np.vstack([i[3] for i in mini_batch])).float().to(self.device)\n",
        "            done_list = torch.from_numpy(np.vstack([i[4] for i in mini_batch]).astype(np.uint8)).float().to(self.device)\n",
        "\n",
        "            # Training and updating Actor & Critic networks.\n",
        "            \n",
        "            #Train Critic\n",
        "            target_actions = self.actor_target(next_state_batch)\n",
        "            offset_noises = torch.FloatTensor(action_batch.shape).data.normal_(0, self.policy_noise).to(self.device)\n",
        "\n",
        "            #clip noise\n",
        "            offset_noises = offset_noises.clamp(-self.noise_clip, self.noise_clip)\n",
        "            target_actions = (target_actions + offset_noises).clamp(self.lower_bound, self.upper_bound)\n",
        "\n",
        "            #Compute the target Q value\n",
        "            Q_targets1, Q_targets2 = self.critic_target(next_state_batch, target_actions)\n",
        "            Q_targets = torch.min(Q_targets1, Q_targets2)\n",
        "            Q_targets = reward_batch + self.gamma * Q_targets * (1 - done_list)\n",
        "\n",
        "            #Compute current Q estimates\n",
        "            current_Q1, current_Q2 = self.critic(state_batch, action_batch)\n",
        "            # Compute critic loss\n",
        "            critic_loss = F.mse_loss(current_Q1, Q_targets.detach()) + F.mse_loss(current_Q2, Q_targets.detach())\n",
        "            # Optimize the critic\n",
        "            self.crt_opt.zero_grad()\n",
        "            critic_loss.backward()\n",
        "            self.crt_opt.step()\n",
        "\n",
        "            self.soft_update_target(self.critic, self.critic_target)\n",
        "\n",
        "            #Train sysmodel and reward model\n",
        "            predict_next_state = self.sysmodel(state_batch, action_batch) * (1-done_list)\n",
        "            next_state_batch = next_state_batch * (1 - done_list)\n",
        "            sysmodel_loss = F.mse_loss(predict_next_state, next_state_batch.detach())\n",
        "            self.sys_opt.zero_grad()\n",
        "            sysmodel_loss.backward()\n",
        "            self.sys_opt.step()\n",
        "            self.sysmodel_loss = sysmodel_loss.item()\n",
        "\n",
        "            s_flag = 1 if sysmodel_loss.item() < 0.020  else 0\n",
        "\n",
        "            #Train Actor\n",
        "            # Delayed policy updates\n",
        "            if it % self.policy_freq == 0 and totrain == 1:\n",
        "                actions = self.actor(state_batch) *self.upper_bound\n",
        "                actor_loss1,_ = self.critic_target(state_batch, actions)\n",
        "                actor_loss1 =  - actor_loss1.mean()\n",
        "                \n",
        "                if s_flag == 1:\n",
        "                    p_actions = self.actor(state_batch)\n",
        "                    p_next_state = self.sysmodel(state_batch, p_actions).clamp(self.obs_lower_bound,self.obs_upper_bound)\n",
        "                    \n",
        "                    p_actions2 = self.actor(p_next_state.detach()) * self.upper_bound\n",
        "                    actor_loss2,_ = self.critic_target(p_next_state.detach(), p_actions2)\n",
        "                    actor_loss2 = actor_loss2.mean()\n",
        "\n",
        "                    p_next_state2 = self.sysmodel(p_next_state.detach(), p_actions2).clamp(self.obs_lower_bound,self.obs_upper_bound)\n",
        "                    p_actions3 = self.actor(p_next_state2.detach()) * self.upper_bound\n",
        "                    actor_loss3,_ = self.critic_target(p_next_state2.detach(), p_actions3)\n",
        "                    actor_loss3 = actor_loss3.mean() \n",
        "                    actor_loss =   (actor_loss1  - weight * actor_loss2 - 0.5 * weight * actor_loss3)                  \n",
        "                    self.sys_updates += 1\n",
        "                else:\n",
        "                    actor_loss =  actor_loss1 \n",
        "\n",
        "\n",
        "                self.act_opt.zero_grad()\n",
        "                actor_loss.backward()\n",
        "                self.act_opt.step()\n",
        "\n",
        "                #Soft update target models\n",
        "               \n",
        "                self.soft_update_target(self.actor, self.actor_target)\n",
        "                \n",
        "                \n",
        "\n",
        "    def soft_update_target(self,local_model,target_model):\n",
        "        \"\"\"Soft update model parameters.\n",
        "        θ_target = τ*θ_local + (1 - τ)*θ_target\n",
        "        Params\n",
        "        ======\n",
        "            local_model: PyTorch model (weights will be copied from)\n",
        "            target_model: PyTorch model (weights will be copied to)\n",
        "            tau (float): interpolation parameter\n",
        "        \"\"\"\n",
        "        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):\n",
        "            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)\n",
        "\n",
        "    def policy(self,state):\n",
        "        \"\"\"select action with noise based on ACTOR\"\"\"\n",
        "        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)\n",
        "        self.actor.eval()\n",
        "        with torch.no_grad():\n",
        "            actions = self.actor(state).cpu().data.numpy()\n",
        "        self.actor.train()\n",
        "        # Adding noise to action\n",
        "        shift_action = np.random.normal(0, self.std_noise, size=self.env.action_space.shape[0])\n",
        "        sampled_actions = (actions + shift_action)\n",
        "        # We make sure action is within bounds\n",
        "        legal_action = np.clip(sampled_actions,self.lower_bound,self.upper_bound)\n",
        "        return np.squeeze(legal_action)\n",
        "\n",
        "\n",
        "    def select_action(self,state):\n",
        "        \"\"\"select action based on ACTOR\"\"\"\n",
        "        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)\n",
        "        self.actor.eval()\n",
        "        with torch.no_grad():\n",
        "            actions = self.actor(state).cpu().data.numpy()\n",
        "        self.actor.train()\n",
        "        return np.squeeze(actions)\n",
        "\n",
        "\n",
        "    def eval_policy(self, env_name, seed, eval_episodes):\n",
        "        eval_env = env_name\n",
        "        eval_env.seed(seed + 100)\n",
        "        \n",
        "        avg_reward = 0.\n",
        "        for _ in range(eval_episodes):\n",
        "            state, done = eval_env.reset(), False\n",
        "            while not done:\n",
        "                action = self.select_action(np.array(state))\n",
        "                state, reward, done, _ = eval_env.step(action)\n",
        "                avg_reward += reward\n",
        "        avg_reward /= eval_episodes\n",
        "\n",
        "        print(\"---------------------------------------\")\n",
        "        print(f\"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}\")\n",
        "        print(\"---------------------------------------\")\n",
        "        return avg_reward"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_pwToClSrd0P",
        "outputId": "e56c42c0-9b3f-4f3f-979a-7e9bb13b6c2c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        }
      },
      "source": [
        "\"\"\"Install Environment\"\"\"\n",
        "\"\"\"You may need to restart the kernel to use BipedalWalker.\"\"\"\n",
        "%pip install Box2D \n"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting Box2D\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a9/0b/d48d42dd9e19ce83a3fb4eee074e785b6c6ea612a2244dc2ef69427d338b/Box2D-2.3.10-cp36-cp36m-manylinux1_x86_64.whl (1.3MB)\n",
            "\r\u001b[K     |▎                               | 10kB 19.2MB/s eta 0:00:01\r\u001b[K     |▌                               | 20kB 6.7MB/s eta 0:00:01\r\u001b[K     |▊                               | 30kB 7.5MB/s eta 0:00:01\r\u001b[K     |█                               | 40kB 7.2MB/s eta 0:00:01\r\u001b[K     |█▎                              | 51kB 7.5MB/s eta 0:00:01\r\u001b[K     |█▌                              | 61kB 8.4MB/s eta 0:00:01\r\u001b[K     |█▊                              | 71kB 8.2MB/s eta 0:00:01\r\u001b[K     |██                              | 81kB 9.1MB/s eta 0:00:01\r\u001b[K     |██▎                             | 92kB 8.7MB/s eta 0:00:01\r\u001b[K     |██▌                             | 102kB 8.8MB/s eta 0:00:01\r\u001b[K     |██▊                             | 112kB 8.8MB/s eta 0:00:01\r\u001b[K     |███                             | 122kB 8.8MB/s eta 0:00:01\r\u001b[K     |███▏                            | 133kB 8.8MB/s eta 0:00:01\r\u001b[K     |███▌                            | 143kB 8.8MB/s eta 0:00:01\r\u001b[K     |███▊                            | 153kB 8.8MB/s eta 0:00:01\r\u001b[K     |████                            | 163kB 8.8MB/s eta 0:00:01\r\u001b[K     |████▏                           | 174kB 8.8MB/s eta 0:00:01\r\u001b[K     |████▌                           | 184kB 8.8MB/s eta 0:00:01\r\u001b[K     |████▊                           | 194kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████                           | 204kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████▏                          | 215kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████▌                          | 225kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████▊                          | 235kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████                          | 245kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████▏                         | 256kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████▍                         | 266kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████▊                         | 276kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████                         | 286kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████▏                        | 296kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████▍                        | 307kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████▊                        | 317kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████                        | 327kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████▏                       | 337kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████▍                       | 348kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████▊                       | 358kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████                       | 368kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████▏                      | 378kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████▍                      | 389kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████▋                      | 399kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████                      | 409kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████▏                     | 419kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████▍                     | 430kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████▋                     | 440kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████                     | 450kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████▏                    | 460kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████▍                    | 471kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████▋                    | 481kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████                    | 491kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████▏                   | 501kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████▍                   | 512kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 522kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████▉                   | 532kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████▏                  | 542kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████▍                  | 552kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████▋                  | 563kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████▉                  | 573kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████▏                 | 583kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████▍                 | 593kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████▋                 | 604kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████▉                 | 614kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████▏                | 624kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████▍                | 634kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████▋                | 645kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████▉                | 655kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████                | 665kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████▍               | 675kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████▋               | 686kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████▉               | 696kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████               | 706kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████▍              | 716kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████▋              | 727kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████▉              | 737kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████              | 747kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████▎             | 757kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████▋             | 768kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████▉             | 778kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 788kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████▎            | 798kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████▋            | 808kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████▉            | 819kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████            | 829kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████▎           | 839kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████▋           | 849kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████▉           | 860kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 870kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████▎          | 880kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████▌          | 890kB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████▉          | 901kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 911kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████▎         | 921kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████▌         | 931kB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████▉         | 942kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████         | 952kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████▎        | 962kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████▌        | 972kB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████▉        | 983kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████        | 993kB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 1.0MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████▌       | 1.0MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████▊       | 1.0MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████       | 1.0MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▎      | 1.0MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▌      | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▊      | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████      | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▎     | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▌     | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▊     | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████     | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▎    | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▌    | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▊    | 1.1MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████    | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▎   | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▌   | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▊   | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████   | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▎  | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▌  | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▊  | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████  | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▎ | 1.2MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▊ | 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████ | 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▏| 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▊| 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 1.3MB 8.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 1.3MB 8.8MB/s \n",
            "\u001b[?25hInstalling collected packages: Box2D\n",
            "Successfully installed Box2D-2.3.10\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Fa3h5ntQBgBR"
      },
      "source": [
        "\"\"\"Training the agent\"\"\"\n",
        "gym.logger.set_level(40)\n",
        "max_steps = 3000\n",
        "falling_down = 0\n",
        "import random\n",
        "SEED = 12\n",
        "import random\n",
        "\n",
        "\n",
        "#Adaptive weight by using weight =  1 - np.clip(sysmodel_loss.item()/0.020, 0, 1)\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    env = gym.make('BipedalWalkerHardcore-v3')\n",
        "    agent = TD3_FORK('ddpg_Bipedal', env, batch_size = 100)\n",
        "    total_episodes = 100000\n",
        "    start_timestep=0                #time_step to select action based on Actor\n",
        "    time_start = time.time()        # Init start time\n",
        "    ep_reward_list = []\n",
        "    avg_reward_list = []\n",
        "    total_timesteps = 0\n",
        "    sys_loss = 0\n",
        "    numtrainedexp = 0\n",
        "\n",
        "\n",
        "    #Set random seed\n",
        "    random.seed(SEED)\n",
        "    env.seed(SEED)\n",
        "    torch.manual_seed(SEED)\n",
        "    np.random.seed(SEED)\n",
        "\n",
        "    save_time = 0\n",
        "    expcount = 0\n",
        "    totrain = 0\n",
        "\n",
        "    for ep in range(total_episodes):\n",
        "        state = env.reset()\n",
        "        episodic_reward = 0\n",
        "        timestep = 0\n",
        "        agent.sys_updates = 0\n",
        "        temp_replay_buffer = []\n",
        "\n",
        "        for st in range(max_steps):\n",
        "        # Uncomment this to see the Actor in action But not in a python notebook.\n",
        "            #env.render()\n",
        "            timestep += 1 \n",
        "            # Select action randomly or according to policy\n",
        "            if total_timesteps < start_timestep:\n",
        "                action = env.action_space.sample()\n",
        "            else:\n",
        "                action = agent.policy(state)\n",
        "\n",
        "            # Recieve state and reward from environment.\n",
        "            next_state, reward, done, info = env.step(action)\n",
        "            #change original reward from -100 to -5 and 5*reward for other values\n",
        "            episodic_reward += reward\n",
        "            if reward == -100:\n",
        "                reward = -5\n",
        "                fall = -1\n",
        "                falling_down += 1\n",
        "                expcount += 1\n",
        "            else:\n",
        "                reward = 5 * reward\n",
        "                fall = 0\n",
        "            temp_replay_buffer.append((state, action, reward, next_state, done))        \n",
        "            \n",
        "            # End this episode when `done` is True\n",
        "            if done:\n",
        "                if fall == -1 or episodic_reward < 250:            \n",
        "                    totrain = 1\n",
        "                    for temp in temp_replay_buffer: \n",
        "                        agent.add_to_replay_memory(temp, agent.replay_memory_buffer)\n",
        "                elif expcount > 0 and np.random.rand() > 0.5:\n",
        "                    totrain = 1\n",
        "                    expcount -= 10\n",
        "                    for temp in temp_replay_buffer: \n",
        "                        agent.add_to_replay_memory(temp, agent.replay_memory_buffer)\n",
        "                break\n",
        "            state = next_state    \n",
        "            total_timesteps += 1\n",
        "\n",
        "        ep_reward_list.append(episodic_reward)\n",
        "        # Mean of last 100 episodes\n",
        "        avg_reward = np.mean(ep_reward_list[-100:])\n",
        "        avg_reward_list.append(avg_reward)\n",
        "\n",
        "        if avg_reward > 294:\n",
        "            test_reward = agent.eval_policy(env, seed=88, eval_episodes=10)\n",
        "            if test_reward > 300:\n",
        "                final_test_reward = agent.eval_policy(env, seed=88, eval_episodes=100)\n",
        "                if final_test_reward > 300:\n",
        "                    torch.save(agent.actor.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/actor.pth')\n",
        "                    torch.save(agent.critic.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/critic.pth')\n",
        "                    torch.save(agent.actor_target.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/actor_t.pth')\n",
        "                    torch.save(agent.critic_target.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/critic_t.pth')\n",
        "                    torch.save(agent.sysmodel.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/sysmodel.pth')\n",
        "                    \n",
        "                    print(\"===========================\")\n",
        "                    print('Task Solved')\n",
        "                    print(\"===========================\")\n",
        "                    break\n",
        "                    \n",
        "        s = (int)(time.time() - time_start)\n",
        "        #Training agent only when new experiences are added to the replay buffer\n",
        "        weight =  (1 - np.clip(np.mean(ep_reward_list[-100:])/300, 0, 1)) \n",
        "        if totrain == 1:\n",
        "            agent.learn_and_update_weights_by_replay(timestep, weight, totrain)\n",
        "        else: \n",
        "            agent.learn_and_update_weights_by_replay(100, weight, totrain)\n",
        "        totrain = 0\n",
        "\n",
        "        print('Ep. {}, Timestep {},  Ep.Timesteps {}, Episode Reward: {:.2f}, Moving Avg.Reward: {:.2f}, Time: {:02}:{:02}:{:02} , Falling down: {}, Weight: {}, Update_time: {}'\n",
        "                .format(ep, total_timesteps, timestep,\n",
        "                      episodic_reward, avg_reward, s//3600, s%3600//60, s%60, falling_down, weight, agent.sys_updates)) \n",
        "        \n",
        "        if s // 1800 == save_time:           \n",
        "            torch.save(agent.actor.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/actor-time{}.pth'.format(save_time))\n",
        "            torch.save(agent.critic.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/critic-time{}.pth'.format(save_time))\n",
        "            torch.save(agent.actor_target.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/actor_t-time{}.pth'.format(save_time))\n",
        "            torch.save(agent.critic_target.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/critic_t-time{}.pth'.format(save_time))\n",
        "            torch.save(agent.sysmodel.state_dict(), '/content/drive/My Drive/bipedal/weights/hardcore/sysmodel-time{}.pth'.format(save_time))        \n",
        "            print(\"===========================\")\n",
        "            print('Saving Successfully!')\n",
        "            print(\"===========================\")\n",
        "            save_time += 1\n",
        "        \n",
        "# Plotting graph\n",
        "# Episodes versus Avg. Rewards\n",
        "plt.plot(avg_reward_list)\n",
        "plt.xlabel(\"Episode\")\n",
        "plt.ylabel(\"Avg. Epsiodic Reward\")\n",
        "plt.show()\n",
        "env.close()"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}