{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7b9OzxoB6TUa"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import torch.nn as nn\n",
        "from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler\n",
        "import numpy as np\n",
        "\n",
        "# ----------------------------------------------\n",
        "# Helper Functions and Base Networks\n",
        "# ----------------------------------------------\n",
        "\n",
        "def orthogonal_init(layer, gain=1.0):\n",
        "    nn.init.orthogonal_(layer.weight, gain=gain)\n",
        "    nn.init.constant_(layer.bias, 0)\n",
        "\n",
        "class Actor_Beta(nn.Module):\n",
        "    def __init__(self, args):\n",
        "        super().__init__()\n",
        "        self.fc1 = nn.Linear(args.state_dim, args.hidden_width)\n",
        "        self.fc2 = nn.Linear(args.hidden_width, args.hidden_width)\n",
        "        self.alpha_layer = nn.Linear(args.hidden_width, args.action_dim)\n",
        "        self.beta_layer = nn.Linear(args.hidden_width, args.action_dim)\n",
        "        self.activate_func = [nn.ReLU(), nn.Tanh()][args.use_tanh]\n",
        "        if args.use_orthogonal_init:\n",
        "            orthogonal_init(self.fc1)\n",
        "            orthogonal_init(self.fc2)\n",
        "            orthogonal_init(self.alpha_layer, gain=0.01)\n",
        "            orthogonal_init(self.beta_layer, gain=0.01)\n",
        "\n",
        "    def forward(self, s):\n",
        "        s = self.activate_func(self.fc1(s))\n",
        "        s = self.activate_func(self.fc2(s))\n",
        "        alpha = F.softplus(self.alpha_layer(s)) + 1.0\n",
        "        beta = F.softplus(self.beta_layer(s)) + 1.0\n",
        "        return alpha, beta\n",
        "\n",
        "    def get_dist(self, s):\n",
        "        from torch.distributions import Beta\n",
        "        alpha, beta = self.forward(s)\n",
        "        return Beta(alpha, beta)\n",
        "\n",
        "    def mean(self, s):\n",
        "        alpha, beta = self.forward(s)\n",
        "        return alpha / (alpha + beta)\n",
        "\n",
        "class Actor_Gaussian(nn.Module):\n",
        "    def __init__(self, args):\n",
        "        super().__init__()\n",
        "        self.max_action = args.max_action\n",
        "        self.fc1 = nn.Linear(args.state_dim, args.hidden_width)\n",
        "        self.fc2 = nn.Linear(args.hidden_width, args.hidden_width)\n",
        "        self.mean_layer = nn.Linear(args.hidden_width, args.action_dim)\n",
        "        self.log_std = nn.Parameter(torch.zeros(1, args.action_dim))\n",
        "        self.activate_func = [nn.ReLU(), nn.Tanh()][args.use_tanh]\n",
        "        if args.use_orthogonal_init:\n",
        "            orthogonal_init(self.fc1)\n",
        "            orthogonal_init(self.fc2)\n",
        "            orthogonal_init(self.mean_layer, gain=0.01)\n",
        "\n",
        "    def forward(self, s):\n",
        "        s = self.activate_func(self.fc1(s))\n",
        "        s = self.activate_func(self.fc2(s))\n",
        "        mean = self.max_action * torch.tanh(self.mean_layer(s))\n",
        "        return mean\n",
        "\n",
        "    def get_dist(self, s):\n",
        "        from torch.distributions import Normal\n",
        "        mean = self.forward(s)\n",
        "        log_std = self.log_std.expand_as(mean)\n",
        "        std = torch.exp(log_std)\n",
        "        return Normal(mean, std)\n",
        "\n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, args):\n",
        "        super().__init__()\n",
        "        self.fc1 = nn.Linear(args.state_dim, args.hidden_width)\n",
        "        self.fc2 = nn.Linear(args.hidden_width, args.hidden_width)\n",
        "        self.fc3 = nn.Linear(args.hidden_width, 1)\n",
        "        self.activate_func = [nn.ReLU(), nn.Tanh()][args.use_tanh]\n",
        "        if args.use_orthogonal_init:\n",
        "            orthogonal_init(self.fc1)\n",
        "            orthogonal_init(self.fc2)\n",
        "            orthogonal_init(self.fc3)\n",
        "\n",
        "    def forward(self, s):\n",
        "        s = self.activate_func(self.fc1(s))\n",
        "        s = self.activate_func(self.fc2(s))\n",
        "        return self.fc3(s)\n",
        "\n",
        "# ----------------------------------------------\n",
        "# PPO Continuous with Cost Constraint\n",
        "# ----------------------------------------------\n",
        "\n",
        "class PPO_continuous():\n",
        "    def __init__(self, args):\n",
        "        self.policy_dist = args.policy_dist\n",
        "        self.max_action = args.max_action\n",
        "        self.batch_size = args.batch_size\n",
        "        self.mini_batch_size = args.mini_batch_size\n",
        "        self.max_train_steps = args.max_train_steps\n",
        "        self.lr_a = args.lr_a\n",
        "        self.lr_c = args.lr_c\n",
        "        self.gamma = args.gamma\n",
        "        self.lamda = args.lamda\n",
        "        self.epsilon = args.epsilon\n",
        "        self.K_epochs = args.K_epochs\n",
        "        self.entropy_coef = args.entropy_coef\n",
        "        self.set_adam_eps = args.set_adam_eps\n",
        "        self.use_grad_clip = args.use_grad_clip\n",
        "        self.use_lr_decay = args.use_lr_decay\n",
        "        self.use_adv_norm = args.use_adv_norm\n",
        "        # Constraint-specific\n",
        "        self.lam_constraint = args.lam_constraint     # λ (lambda) for constraint penalty\n",
        "        self.cost_limit = args.cost_limit             # b (budget/cost bound)\n",
        "\n",
        "        if self.policy_dist == \"Beta\":\n",
        "            self.actor = Actor_Beta(args)\n",
        "        else:\n",
        "            self.actor = Actor_Gaussian(args)\n",
        "        self.critic_objective = Critic(args)\n",
        "        self.critic_cost = Critic(args)\n",
        "\n",
        "        if self.set_adam_eps:\n",
        "            self.optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.lr_a, eps=1e-5)\n",
        "            self.optimizer_critic_objective = torch.optim.Adam(self.critic_objective.parameters(), lr=self.lr_c, eps=1e-5)\n",
        "            self.optimizer_critic_cost = torch.optim.Adam(self.critic_cost.parameters(), lr=self.lr_c, eps=1e-5)\n",
        "        else:\n",
        "            self.optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.lr_a)\n",
        "            self.optimizer_critic_objective = torch.optim.Adam(self.critic_objective.parameters(), lr=self.lr_c)\n",
        "            self.optimizer_critic_cost = torch.optim.Adam(self.critic_cost.parameters(), lr=self.lr_c)\n",
        "\n",
        "    def evaluate(self, s):\n",
        "        s = torch.unsqueeze(torch.tensor(s, dtype=torch.float), 0)\n",
        "        with torch.no_grad():\n",
        "            if self.policy_dist == \"Beta\":\n",
        "                a = self.actor.mean(s).detach().numpy().flatten()\n",
        "            else:\n",
        "                a = self.actor(s).detach().numpy().flatten()\n",
        "        return a\n",
        "\n",
        "    def choose_action(self, s):\n",
        "        s = torch.unsqueeze(torch.tensor(s, dtype=torch.float), 0)\n",
        "        with torch.no_grad():\n",
        "            dist = self.actor.get_dist(s)\n",
        "            a = dist.sample()\n",
        "            if self.policy_dist == \"Beta\":\n",
        "                a_logprob = dist.log_prob(a)\n",
        "            else:\n",
        "                a = torch.clamp(a, -self.max_action, self.max_action)\n",
        "                a_logprob = dist.log_prob(a)\n",
        "        return a.numpy().flatten(), a_logprob.numpy().flatten()\n",
        "\n",
        "    def update(self, replay_buffer, total_steps):\n",
        "        # Unpack: make sure your replay_buffer provides costs!\n",
        "        s, a, a_logprob, r, cost, s_, dw, done = replay_buffer.numpy_to_tensor()\n",
        "\n",
        "        adv_obj, adv_cost = [], []\n",
        "        gae_obj, gae_cost = 0, 0\n",
        "\n",
        "        with torch.no_grad():\n",
        "            v_obj = self.critic_objective(s)\n",
        "            v_obj_ = self.critic_objective(s_)\n",
        "            v_cost = self.critic_cost(s)\n",
        "            v_cost_ = self.critic_cost(s_)\n",
        "\n",
        "            delta_obj = r + self.gamma * (1.0 - dw) * v_obj_ - v_obj\n",
        "            delta_cost = cost + self.gamma * (1.0 - dw) * v_cost_ - v_cost\n",
        "\n",
        "            for dlt_o, dlt_c, d in zip(reversed(delta_obj.flatten().numpy()),\n",
        "                                      reversed(delta_cost.flatten().numpy()),\n",
        "                                      reversed(done.flatten().numpy())):\n",
        "                gae_obj = dlt_o + self.gamma * self.lamda * gae_obj * (1.0 - d)\n",
        "                gae_cost = dlt_c + self.gamma * self.lamda * gae_cost * (1.0 - d)\n",
        "                adv_obj.insert(0, gae_obj)\n",
        "                adv_cost.insert(0, gae_cost)\n",
        "\n",
        "            adv_obj = torch.tensor(adv_obj, dtype=torch.float).view(-1, 1)\n",
        "            adv_cost = torch.tensor(adv_cost, dtype=torch.float).view(-1, 1)\n",
        "            v_target_obj = adv_obj + v_obj\n",
        "            v_target_cost = adv_cost + v_cost\n",
        "\n",
        "            if self.use_adv_norm:\n",
        "                adv_obj = (adv_obj - adv_obj.mean()) / (adv_obj.std() + 1e-5)\n",
        "                adv_cost = (adv_cost - adv_cost.mean()) / (adv_cost.std() + 1e-5)\n",
        "\n",
        "        for _ in range(self.K_epochs):\n",
        "            for index in BatchSampler(SubsetRandomSampler(range(self.batch_size)), self.mini_batch_size, False):\n",
        "                dist_now = self.actor.get_dist(s[index])\n",
        "                dist_entropy = dist_now.entropy().sum(1, keepdim=True)\n",
        "                a_logprob_now = dist_now.log_prob(a[index])\n",
        "                ratios = torch.exp(a_logprob_now.sum(1, keepdim=True) - a_logprob[index].sum(1, keepdim=True))\n",
        "\n",
        "                # Compute channel (ch): 0=objective, 1=constraint\n",
        "                v_o = self.critic_objective(s[index]).detach()\n",
        "                v_c = self.critic_cost(s[index]).detach()\n",
        "                v_concat = torch.cat([\n",
        "                    v_o,\n",
        "                    self.lam_constraint * (v_c - self.cost_limit)\n",
        "                ], dim=1)\n",
        "                ch = torch.argmax(v_concat, dim=1, keepdim=True)\n",
        "\n",
        "                # Select which advantage signal to use for each sample\n",
        "                adv_signal = torch.where(\n",
        "                    ch == 0,                 # If argmax selects objective\n",
        "                    adv_obj[index],\n",
        "                    self.lam_constraint * adv_cost[index]\n",
        "                )\n",
        "\n",
        "                surr1 = ratios * adv_signal\n",
        "                surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * adv_signal\n",
        "\n",
        "                actor_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy\n",
        "\n",
        "                self.optimizer_actor.zero_grad()\n",
        "                actor_loss.mean().backward()\n",
        "                if self.use_grad_clip:\n",
        "                    torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)\n",
        "                self.optimizer_actor.step()\n",
        "\n",
        "                # Update critics\n",
        "                v_s_obj = self.critic_objective(s[index])\n",
        "                critic_obj_loss = F.mse_loss(v_target_obj[index], v_s_obj)\n",
        "                self.optimizer_critic_objective.zero_grad()\n",
        "                critic_obj_loss.backward()\n",
        "                if self.use_grad_clip:\n",
        "                    torch.nn.utils.clip_grad_norm_(self.critic_objective.parameters(), 0.5)\n",
        "                self.optimizer_critic_objective.step()\n",
        "\n",
        "                v_s_cost = self.critic_cost(s[index])\n",
        "                critic_cost_loss = F.mse_loss(v_target_cost[index], v_s_cost)\n",
        "                self.optimizer_critic_cost.zero_grad()\n",
        "                critic_cost_loss.backward()\n",
        "                if self.use_grad_clip:\n",
        "                    torch.nn.utils.clip_grad_norm_(self.critic_cost.parameters(), 0.5)\n",
        "                self.optimizer_critic_cost.step()\n",
        "\n",
        "        if self.use_lr_decay:\n",
        "            self.lr_decay(total_steps)\n",
        "\n",
        "    def lr_decay(self, total_steps):\n",
        "        lr_a_now = self.lr_a * (1 - total_steps / self.max_train_steps)\n",
        "        lr_c_now = self.lr_c * (1 - total_steps / self.max_train_steps)\n",
        "        for p in self.optimizer_actor.param_groups:\n",
        "            p['lr'] = lr_a_now\n",
        "        for p in self.optimizer_critic_objective.param_groups:\n",
        "            p['lr'] = lr_c_now\n",
        "        for p in self.optimizer_critic_cost.param_groups:\n",
        "            p['lr'] = lr_c_now"
      ]
    }
  ]
}