{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import pickle\n",
        "from KL_uncertainity_evaluator import Robust_pol_Kl_uncertainity\n",
        "import time\n",
        "\n",
        "with open('rew_Cartpole.pkl', 'rb') as f:\n",
        "    R = pickle.load(f)\n",
        "f.close()\n",
        "\n",
        "with open('constraint_cost_Cartpole.pkl', 'rb') as f:\n",
        "    C = pickle.load(f)\n",
        "f.close()\n",
        "\n",
        "R = R +np.ones_like(R.shape)*0.1"
      ],
      "metadata": {
        "id": "TDN-5petHlUD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def softmax(z):\n",
        "    exp_z = np.exp(z)\n",
        "    return exp_z / np.sum(exp_z)\n",
        "\n",
        "def get_policy_from_theta(theta):\n",
        "    nS, nA = theta.shape\n",
        "    return np.array([softmax(theta[s]) for s in range(nS)])\n",
        "\n",
        "def kl_divergence(pi_new, pi_old):\n",
        "    kl = 0.0\n",
        "    for s in range(len(pi_new)):\n",
        "        kl += np.sum(pi_new[s] * (np.log(pi_new[s] + 1e-8) - np.log(pi_old[s] + 1e-8)))\n",
        "    return kl\n",
        "\n",
        "def flatten_grad(grad):\n",
        "    return grad.flatten()\n",
        "\n",
        "def reshape_grad(vec, shape):\n",
        "    return vec.reshape(shape)\n",
        "\n",
        "def natural_gradient_update(theta, grad, kl_lambda, alpha,ch_dep,ch,booster):\n",
        "    # Perform a natural gradient-like update based on the objective:\n",
        "    # max_\\theta_new grad^T (\\theta_new - \\theta) - \\lambda KL(\\theta_new || \\theta)\n",
        "    # This is equivalent to solving: \\theta_new = \\theta + 1/\\lambda * grad\n",
        "    alpha = alpha*booster\n",
        "    if(ch_dep==0):\n",
        "        return theta - alpha * grad / kl_lambda\n",
        "    elif(ch==1):\n",
        "        if(ch==0):\n",
        "           return theta + alpha * grad / kl_lambda\n",
        "        else:\n",
        "            return theta - alpha * grad / kl_lambda\n",
        "    else:\n",
        "        return theta + alpha * grad / kl_lambda"
      ],
      "metadata": {
        "id": "V9m79mbcJYWC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "env_dep = 1 # similar to river-swim type\n",
        "nS,nA = R.shape\n",
        "cost_list = [R, C]\n",
        "init_dist = np.exp(np.random.normal(0,1,nS))\n",
        "init_dist = init_dist/np.sum(init_dist)\n",
        "init_dist = init_dist.tolist()\n",
        "rpe = Robust_pol_Kl_uncertainity(nS, nA, cost_list, init_dist, alpha=0.000001)\n",
        "P = np.zeros((nA,nS,nS))\n",
        "for s in range(nS):\n",
        "  for a in range(nA):\n",
        "    mu,sigma = np.random.uniform(0,100),np.random.uniform(0,100)\n",
        "    P[a,s,:] = np.random.normal(mu,sigma,nS)\n",
        "    P[a,s,:] = np.exp(P[a,s,:])\n",
        "    P[a,s,:] = P[a,s,:]/np.sum(P[a,s,:])"
      ],
      "metadata": {
        "id": "tYA9GoQtKSK7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "C_KL = 0.5\n",
        "kl_lambda = 5\n",
        "alpha = 0.5\n",
        "T = 10\n",
        "b = 90\n",
        "booster = 1\n",
        "theta_old = np.random.randn(nS, nA)\n",
        "theta = np.copy(theta_old)"
      ],
      "metadata": {
        "id": "PbqxHNTbL1oq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "vf = []\n",
        "cf = []\n",
        "start = time.time()\n",
        "for t in range(T):\n",
        "    policy = get_policy_from_theta(theta)\n",
        "\n",
        "    # Get both objectives and gradients\n",
        "    J_v, grad_v = rpe.evaluate_policy(policy, P, C_KL, n=0, t=t)\n",
        "    J_c, grad_c = rpe.evaluate_policy(policy, P, C_KL, n=1, t=t)\n",
        "    vf.append(J_v)\n",
        "    cf.append(J_c)\n",
        "\n",
        "    # Choose which gradient to follow\n",
        "    if env_dep!=2:\n",
        "        ch = np.argmax([J_v, kl_lambda*(np.max(J_c-b,0))])\n",
        "    else:\n",
        "        ch = np.argmax([J_v, kl_lambda*(np.max(b-J_c,0))])\n",
        "    grad = grad_v if ch == 0 else grad_c\n",
        "\n",
        "    # Flatten gradient and apply natural-like update\n",
        "    grad_vec = flatten_grad(grad)\n",
        "    theta_vec = flatten_grad(theta)\n",
        "    theta_new_vec = natural_gradient_update(theta_vec, grad_vec, kl_lambda, alpha,env_dep,ch,booster)\n",
        "    theta_new = reshape_grad(theta_new_vec, (nS, nA))\n",
        "\n",
        "    # Check KL divergence\n",
        "    pi_old = get_policy_from_theta(theta)\n",
        "    pi_new = get_policy_from_theta(theta_new)\n",
        "    kl = kl_divergence(pi_new, pi_old)\n",
        "\n",
        "    # Accept the update\n",
        "    theta = theta_new\n",
        "    print(grad)\n",
        "    print(f\"[Iter {t}] J_v={J_v:.4f}, J_c={J_c:.4f}, KL={kl:.6f}, ch={ch}\")\n",
        "\n",
        "print(\"Time taken:\",time.time()-start)\n",
        "\n",
        "with open(\"Value_function_kl_lambda_CP\"+str(kl_lambda)+\".pkl\",\"wb\") as f:\n",
        "    pickle.dump(vf,f)\n",
        "f.close()\n",
        "\n",
        "with open(\"Cost_function_kl_lambda_CP\"+str(kl_lambda)+\".pkl\",\"wb\") as f:\n",
        "    pickle.dump(cf,f)\n",
        "f.close()\n",
        "# Final policy\n",
        "#final_policy = get_policy_from_theta(theta)\n",
        "#print(\"Final policy:\")\n",
        "#print(final_policy)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 634
        },
        "id": "naS8JPftMZPG",
        "outputId": "bb9a867c-a10b-4bb1-b11f-a45788ff7a1c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[0.09201173 0.09201173]\n",
            " [0.09050018 0.09050018]\n",
            " [0.09161896 0.09161896]\n",
            " ...\n",
            " [0.09074648 0.09074648]\n",
            " [0.0902476  0.0902476 ]\n",
            " [0.0896808  0.0896808 ]]\n",
            "[Iter 0] J_v=20.9859, J_c=116.9688, KL=0.000000, ch=1\n",
            "[[0.09201173 0.09201173]\n",
            " [0.09050018 0.09050018]\n",
            " [0.09161896 0.09161896]\n",
            " ...\n",
            " [0.09074648 0.09074648]\n",
            " [0.0902476  0.0902476 ]\n",
            " [0.0896808  0.0896808 ]]\n",
            "[Iter 1] J_v=20.9859, J_c=116.9688, KL=0.000000, ch=1\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-57-278ca4f4b6f3>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0;31m# Get both objectives and gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m     \u001b[0mJ_v\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate_policy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mC_KL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m     \u001b[0mJ_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_c\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate_policy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mC_KL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0mvf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mJ_v\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/content/KL_uncertainity_evaluator.py\u001b[0m in \u001b[0;36mevaluate_policy\u001b[0;34m(self, policy, P, C_KL, n, t)\u001b[0m\n\u001b[1;32m     51\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0ms_next\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m                 \u001b[0mT\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ms_next\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mP\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ms_next\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnA\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     54\u001b[0m         \u001b[0mI\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0mQ_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mI\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgamma\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcost_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/numpy/_core/fromnumeric.py\u001b[0m in \u001b[0;36msum\u001b[0;34m(a, axis, dtype, out, keepdims, initial, where)\u001b[0m\n\u001b[1;32m   2387\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2388\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2389\u001b[0;31m     return _wrapreduction(\n\u001b[0m\u001b[1;32m   2390\u001b[0m         \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sum'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2391\u001b[0m         \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkeepdims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwhere\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/numpy/_core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapreduction\u001b[0;34m(obj, ufunc, method, axis, dtype, out, **kwargs)\u001b[0m\n\u001b[1;32m     84\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpasskwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mufunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpasskwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(get_policy_from_theta(theta))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "d7fCJirCWLgR",
        "outputId": "fa4daeff-4978-4ef9-d813-acc901c0ee05"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[0.91483864 0.08516136]\n",
            " [0.84496566 0.15503434]\n",
            " [0.34728545 0.65271455]\n",
            " ...\n",
            " [0.81128932 0.18871068]\n",
            " [0.7047552  0.2952448 ]\n",
            " [0.18485169 0.81514831]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "grad"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uluSjBdQXhKv",
        "outputId": "ed785883-d3c0-4d97-b446-595e34153de7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[0.09164769, 0.09164769],\n",
              "       [0.09105194, 0.09105194],\n",
              "       [0.09180509, 0.09180509],\n",
              "       ...,\n",
              "       [0.09229802, 0.09229802],\n",
              "       [0.09166756, 0.09166756],\n",
              "       [0.09095145, 0.09095145]])"
            ]
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ofQBwlH_YfxB"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}