{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BoogiPfyajn_",
        "outputId": "3b3dc309-dd9e-4e5d-ad44-70bea65ba57e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wall clock time taken: 1429.147206068039\n"
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "from torch import nn,optim\n",
        "from torch.distributions import Categorical\n",
        "import gymnasium as gym\n",
        "import pandas as pd\n",
        "import time\n",
        "# Actor Network\n",
        "class Actor(nn.Module):\n",
        "    def __init__(self, state_dim, action_dim):\n",
        "        super().__init__()\n",
        "        self.model = nn.Sequential(\n",
        "            nn.Linear(state_dim, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, action_dim),\n",
        "            nn.Softmax(dim=-1)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x)\n",
        "\n",
        "# Critic Network\n",
        "class Critic(nn.Module):\n",
        "    def __init__(self, state_dim):\n",
        "        super().__init__()\n",
        "        self.model = nn.Sequential(\n",
        "            nn.Linear(state_dim, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, 1)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x).squeeze(-1)\n",
        "\n",
        "def compute_returns(values, gamma=0.99):\n",
        "    returns = []\n",
        "    G = 0\n",
        "    for v in reversed(values):\n",
        "        G = v + gamma * G\n",
        "        returns.insert(0, G)\n",
        "    return torch.tensor(returns)\n",
        "\n",
        "def collect_trajectory(env, actor, max_steps=500):\n",
        "    state, _ = env.reset()\n",
        "    states, actions, rewards, costs, log_probs = [], [], [], [], []\n",
        "\n",
        "    for _ in range(max_steps):\n",
        "        state_tensor = torch.FloatTensor(state)\n",
        "        probs = actor(state_tensor)\n",
        "        dist = Categorical(probs)\n",
        "        action = dist.sample()\n",
        "        log_prob = dist.log_prob(action)\n",
        "\n",
        "        next_state, reward, done, truncated, _ = env.step(action.item())\n",
        "\n",
        "        cost = abs(state[0])  # custom cost: cart's distance from center\n",
        "\n",
        "        states.append(state_tensor)\n",
        "        actions.append(action)\n",
        "        rewards.append(reward)\n",
        "        costs.append(cost)\n",
        "        log_probs.append(log_prob)\n",
        "\n",
        "        if done or truncated:\n",
        "            break\n",
        "        state = next_state\n",
        "\n",
        "    return states, actions, rewards, costs, log_probs\n",
        "\n",
        "\n",
        "class EPIRC_PGS_CS:\n",
        "    def __init__(self,K,alpha,env,T,b_1):\n",
        "        self.dict_vf = {'vf':[],'cf':[]}\n",
        "        self.actor = Actor(4,2)\n",
        "        self.reward_critic = Critic(4)\n",
        "        self.constraint_critic = Critic(4)\n",
        "        self.ppg_ob = PPG(alpha,T,env,b_1,self.dict_vf,self.reward_critic,self.constraint_critic,self.actor)\n",
        "        self.K = K\n",
        "        self.i = 0\n",
        "        self.j = 500\n",
        "        self.b_1 = b_1\n",
        "        self.env = env\n",
        "    def run_algo(self):\n",
        "        for k in range(self.K):\n",
        "            b = (self.i+self.j)//2\n",
        "            V_o,V_c = self.ppg_ob.get_policy(b)\n",
        "            #V_o,V_c = self.cost_objects[0].get_vf(self.pi),self.cost_objects[1].get_vf(self.pi)\n",
        "            del_k = np.max([V_o - b,V_c-self.b_1])\n",
        "            if del_k>0:\n",
        "                self.i = b\n",
        "            else:\n",
        "                self.j = b\n",
        "        dF = pd.DataFrame(self.dict_vf)\n",
        "        dF.to_excel('EPIRC_PGS_CS_value_functions.xlsx')\n",
        "        torch.save(self.actor,'EPIRC_actor.pth')\n",
        "        return self.ppg_ob.get_policy(self.j)\n",
        "\n",
        "\n",
        "class PPG:\n",
        "    def __init__(self,alpha,T,env,b_1,dict_,rc,cc,ac):\n",
        "        self.alpha = alpha\n",
        "        self.T = T\n",
        "        #self.critic_objs = critic_objs\n",
        "        self.dict_vf = dict_\n",
        "        self.b_1 = b_1\n",
        "        self.env = env\n",
        "        self.reward_critic = rc\n",
        "        self.constraint_critic = cc\n",
        "        self.actor = ac\n",
        "        self.optim_r = optim.Adam(self.reward_critic.parameters(),lr=1e-3)\n",
        "        self.optim_c = optim.Adam(self.constraint_critic.parameters(),lr=1e-3)\n",
        "    def get_policy(self,b):\n",
        "        self.b = b\n",
        "        V_0,V_c =0,0\n",
        "        for t in range(self.T):\n",
        "            states, actions, rewards, costs, log_probs = collect_trajectory(self.env, self.actor)\n",
        "\n",
        "            states_tensor = torch.stack(states)\n",
        "            log_probs_tensor = torch.stack(log_probs)\n",
        "            reward_returns = compute_returns(rewards)\n",
        "            cost_returns = compute_returns(costs)\n",
        "            V_o = reward_returns[0].item()\n",
        "            V_c = cost_returns[0].item()\n",
        "            self.dict_vf['vf'].append(V_o)\n",
        "            self.dict_vf['cf'].append(V_c)\n",
        "\n",
        "            # === Train Critics ===\n",
        "            self.optim_r.zero_grad()\n",
        "            loss_r = nn.functional.mse_loss(self.reward_critic(states_tensor), reward_returns)\n",
        "            loss_r.backward()\n",
        "            self.optim_r.step()\n",
        "\n",
        "            self.optim_c.zero_grad()\n",
        "            loss_c = nn.functional.mse_loss(self.constraint_critic(states_tensor), cost_returns)\n",
        "            loss_c.backward()\n",
        "            self.optim_c.step()\n",
        "\n",
        "            # === Evaluate Constraint ===\n",
        "            total_cost = sum(costs)\n",
        "            violation = total_cost - self.b_1\n",
        "\n",
        "            if violation > 0:\n",
        "                critic = self.constraint_critic\n",
        "                sign = -1\n",
        "            else:\n",
        "                critic = self.reward_critic\n",
        "                sign = +1\n",
        "\n",
        "            # === Policy Gradient Step ===\n",
        "            advantage = compute_returns(critic(states_tensor).detach().tolist())\n",
        "            loss_pi = -sign * (log_probs_tensor * advantage).mean()\n",
        "\n",
        "            self.actor.zero_grad()\n",
        "            loss_pi.backward()\n",
        "            for p in self.actor.parameters():\n",
        "                p.data -= self.alpha * p.grad  # projected step\n",
        "        return V_0,V_c\n",
        "\n",
        "if __name__=='__main__':\n",
        "    env = gym.make('CartPole-v1')\n",
        "    b_1 = 200\n",
        "    K=20\n",
        "    T=500\n",
        "    alpha = 1e-4\n",
        "    start = time.time()\n",
        "    algo_obj = EPIRC_PGS_CS(K,alpha,env,T,b_1)\n",
        "    algo_obj.run_algo()\n",
        "    print(\"Wall clock time taken:\",time.time()-start)"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "NP6MCYNXaqvZ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}