{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sensitivity of regularized NPG primal-dual for finite constrained MDPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nSet-up:\\n1) Softmax policy \\n2) Bounded rewards\\n3) Many states and actions\\n4) Regularized (Natural) policy gradient + primal-dual method\\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "Set-up:\n",
    "1) Softmax policy \n",
    "2) Bounded rewards\n",
    "3) Many states and actions\n",
    "4) Regularized (Natural) policy gradient + primal-dual method\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.set_printoptions(formatter={'float': lambda x: \"{0:0.6f}\".format(x)})\n",
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linprog"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Random Seed\n",
    "np.random.seed(10) \n",
    "## Problem Setup\n",
    "gamma = 0.9\n",
    "n, m = 20, 5 # s, a\n",
    "'''\n",
    "Randomly generated probability transition matrix P((s,a) -> s') in R^{|S||A| x |S|}\n",
    "Each row sums up to one\n",
    "'''\n",
    "raw_transition = np.random.uniform(0,1,size=(n*m,n))\n",
    "prob_transition = raw_transition/raw_transition.sum(axis=1,keepdims=1)\n",
    "'''\n",
    "Random positive rewards\n",
    "'''\n",
    "reward = np.random.uniform(0,1,size=(n*m))\n",
    "\n",
    "'''\n",
    "Random utilities between -1 and +1\n",
    "'''\n",
    "utility = np.random.uniform(-1,1,size=(n*m))\n",
    "\n",
    "'''\n",
    "Start state distribution\n",
    "'''\n",
    "# uniform\n",
    "rho = np.ones(n)/n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Input: theta as an array and \n",
    "Ouput: array of probabilites corresponding to each state: [\\pi_{s_1}(.), ...., \\pi_{s_n}(.)]\n",
    "'''\n",
    "def theta_to_policy(theta,n,m):\n",
    "    prob = []\n",
    "    for i in range(n):\n",
    "        norm = np.sum(np.exp(theta[m*i:m*(i+1)]))\n",
    "        for j in range(m*i,m*(i+1)):\n",
    "            prob.append(np.exp(theta[j])/norm)\n",
    "            \n",
    "    return np.asarray(prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Input: theta as an array and \n",
    "Ouput: array of probabilites corresponding to each state: [\\pi_{s_1}(.), ...., \\pi_{s_n}(.)]\n",
    "'''\n",
    "def theta_to_policy_naive(theta,n,m):\n",
    "    prob = []\n",
    "    for i in range(n):\n",
    "        norm = np.sum(theta[m*i:m*(i+1)])\n",
    "        for j in range(m*i,m*(i+1)):\n",
    "            prob.append(theta[j]/norm)\n",
    "            \n",
    "    return np.asarray(prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "License: BSD\n",
    "Author: Mathieu Blondel\n",
    "Implements three algorithms for projecting a vector onto the simplex: sort, pivot and bisection.\n",
    "For details and references, see the following paper:\n",
    "Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex\n",
    "Mathieu Blondel, Akinori Fujino, and Naonori Ueda.\n",
    "ICPR 2014.\n",
    "http://www.mblondel.org/publications/mblondel-icpr2014.pdf\n",
    "\"\"\"\n",
    "\n",
    "def projection_simplex_sort(v, z=1):\n",
    "    n_features = v.shape[0]\n",
    "    u = np.sort(v)[::-1]\n",
    "    cssv = np.cumsum(u) - z\n",
    "    ind = np.arange(n_features) + 1\n",
    "    cond = u - cssv / ind > 0\n",
    "    rho = ind[cond][-1]\n",
    "    theta = cssv[cond][-1] / float(rho)\n",
    "    w = np.maximum(v - theta, 0)\n",
    "    return w\n",
    "\n",
    "\n",
    "def projection_simplex_pivot(v, z=1, random_state=None):\n",
    "    rs = np.random.RandomState(random_state)\n",
    "    n_features = len(v)\n",
    "    U = np.arange(n_features)\n",
    "    s = 0\n",
    "    rho = 0\n",
    "    while len(U) > 0:\n",
    "        G = []\n",
    "        L = []\n",
    "        k = U[rs.randint(0, len(U))]\n",
    "        ds = v[k]\n",
    "        for j in U:\n",
    "            if v[j] >= v[k]:\n",
    "                if j != k:\n",
    "                    ds += v[j]\n",
    "                    G.append(j)\n",
    "            elif v[j] < v[k]:\n",
    "                L.append(j)\n",
    "        drho = len(G) + 1\n",
    "        if s + ds - (rho + drho) * v[k] < z:\n",
    "            s += ds\n",
    "            rho += drho\n",
    "            U = L\n",
    "        else:\n",
    "            U = G\n",
    "    theta = (s - z) / float(rho)\n",
    "    return np.maximum(v - theta, 0)\n",
    "\n",
    "\n",
    "def projection_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000):\n",
    "    func = lambda x: np.sum(np.maximum(v - x, 0)) - z\n",
    "    lower = np.min(v) - z / len(v)\n",
    "    upper = np.max(v)\n",
    "\n",
    "    for it in range(max_iter):\n",
    "        midpoint = (upper + lower) / 2.0\n",
    "        value = func(midpoint)\n",
    "\n",
    "        if abs(value) <= tau:\n",
    "            break\n",
    "\n",
    "        if value <= 0:\n",
    "            upper = midpoint\n",
    "        else:\n",
    "            lower = midpoint\n",
    "\n",
    "    return np.maximum(v - midpoint, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Input: theta as an array and \n",
    "Ouput: array of probabilites corresponding to each state: [\\pi_{s_1}(.), ...., \\pi_{s_n}(.)]\n",
    "'''\n",
    "def project_to_policy(theta,n,m):\n",
    "    prob = []\n",
    "    prob_pers = []\n",
    "    for i in range(n):\n",
    "#         norm = np.sum(np.exp(theta[m*i:m*(i+1)]))\n",
    "        prob_pers = projection_simplex_sort(theta[m*i:m*(i+1)], z=1)\n",
    "        for j in range(m):\n",
    "            prob.append(prob_pers[j])\n",
    "            \n",
    "    return np.asarray(prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Get \\Pi_{\\pi}((s) -> (s,a)) in R^{|S| x |S||A|} matrix corresponding to the policy \\pi using the prob vector\n",
    "'''\n",
    "def get_Pi(prob,n,m):\n",
    "    Pi = np.zeros((n,n*m))\n",
    "    for i in range(n):\n",
    "        Pi[i,i*m:(i+1)*m] = prob[i*m:(i+1)*m]\n",
    "    \n",
    "    return Pi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Input: probability vector, state, action\n",
    "Output: \\nabla_{\\theta} \\pi_{\\theta}(s,a)\n",
    "\n",
    "States go from 0 to n-1 and actons from 0 to m-1\n",
    "'''\n",
    "def grad_state_action(prob,state,action):\n",
    "    grad = np.zeros(n*m)\n",
    "    for j in range(0,m):\n",
    "        if j == action:\n",
    "            grad[m*state + j] = prob[m*state + j]*(1-prob[m*state + j])\n",
    "        else:\n",
    "            grad[m*state + j] = -prob[m*state + action]*prob[m*state + j]\n",
    "            \n",
    "    return grad\n",
    "\n",
    "def grad_state(qvals,prob,state):\n",
    "    grad = np.sum([qvals[state*m + i]*grad_state_action(prob,state,i) for i in range(0,m)],axis=0)\n",
    "    return grad\n",
    "\n",
    "def grad(qvals,prob,d_pi):\n",
    "    grad = np.sum([d_pi[i]*grad_state(qvals,prob,i) for i in range(0,n)],axis=0)\n",
    "    return grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Input: probability vector\n",
    "Output: Fisher information matrix\n",
    "        \\nabla_{\\theta} \\pi_{\\theta}(s,a) x {\\nabla_{\\theta} \\pi_{\\theta}(s,a)}^T\n",
    "'''\n",
    "def Fisher_info(prob,d_pi):\n",
    "    qvals_one = np.ones(n*m)\n",
    "    grad = np.sum([d_pi[i]*grad_state(qvals_one,prob,i) for i in range(0,n)],axis=0)\n",
    "    fisher = np.outer(grad,grad)+1e-3*np.identity(n*m)\n",
    "    return fisher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "The overall reward function \\ell(\\theta)\n",
    "'''\n",
    "def ell(qvals,prob,rho):\n",
    "    V = np.zeros(n)\n",
    "    for i in range(n):\n",
    "        V[i] = np.sum([qvals[i*m + j]*prob[i*m + j] for j in range(m)])\n",
    "    \n",
    "    ell = np.dot(V,rho)\n",
    "    return ell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "The overall reward advantage function \\ell(\\theta)\n",
    "'''\n",
    "def avals(qvals,prob):\n",
    "    V = np.zeros(n)\n",
    "    for i in range(n):\n",
    "        V[i] = np.sum([qvals[i*m + j]*prob[i*m + j] for j in range(m)])\n",
    "    \n",
    "    A = qvals\n",
    "    for i in range(n):\n",
    "        for j in range(m):\n",
    "            A[i*m + j] = qvals[i*m + j] - V[i]\n",
    "        \n",
    "    return A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "The projection function\n",
    "Input: a scalar \n",
    "Output: a scalar in the interval [0 C]\n",
    "'''\n",
    "def proj(scalar,gamma):\n",
    "    offset = 1000/(1-gamma)\n",
    "    if scalar < 0:\n",
    "        scalar = 0\n",
    "\n",
    "    if scalar > offset:\n",
    "        scalar = offset\n",
    "\n",
    "    return scalar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Policy Iteration to check feasibility "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_vec = np.random.uniform(0,1,size=(n,m))\n",
    "prob_vec = raw_vec/raw_vec.sum(axis=1,keepdims=1)\n",
    "init_policy = prob_vec.flatten()## Policy Iteration to get the optimal policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Policy iteration function\n",
    "'''\n",
    "def policy_iter(q_vals,n,m):\n",
    "    new_policy = np.zeros(n*m)\n",
    "    for i in range(n):\n",
    "        idx = np.argmax(q_vals[i*m:(i+1)*m])\n",
    "        new_policy[i*m + idx] = 1\n",
    "    \n",
    "    return new_policy       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting policy [0.022861 0.022257 0.496791 0.417151 0.040940 0.260055 0.020741 0.239421\n",
      " 0.113255 0.366528 0.077009 0.252897 0.293405 0.147306 0.229383 0.234235\n",
      " 0.114754 0.214643 0.266094 0.170273 0.019934 0.289380 0.444991 0.163370\n",
      " 0.082325 0.184145 0.550146 0.115828 0.008301 0.141579 0.193071 0.128147\n",
      " 0.056816 0.408135 0.213831 0.164747 0.271480 0.219975 0.227132 0.116667\n",
      " 0.170695 0.097732 0.193326 0.297322 0.240925 0.204368 0.076399 0.025505\n",
      " 0.397828 0.295900 0.232097 0.206203 0.026091 0.237565 0.298044 0.098219\n",
      " 0.185154 0.083181 0.392678 0.240769 0.392924 0.090288 0.311557 0.079063\n",
      " 0.126168 0.106657 0.096738 0.356238 0.324030 0.116337 0.184156 0.268226\n",
      " 0.030368 0.122968 0.394283 0.506945 0.038054 0.101757 0.114022 0.239222\n",
      " 0.075823 0.464715 0.061802 0.241150 0.156510 0.268392 0.088919 0.131564\n",
      " 0.277810 0.233316 0.377841 0.216636 0.191506 0.064988 0.149028 0.467008\n",
      " 0.035160 0.104154 0.047115 0.346563]\n",
      "Final policy [0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 1.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000\n",
      " 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000\n",
      " 1.000000 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000 0.000000\n",
      " 0.000000 1.000000 0.000000 0.000000 1.000000 0.000000 0.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000 0.000000 0.000000\n",
      " 0.000000 1.000000 0.000000 1.000000 0.000000 0.000000 0.000000 0.000000\n",
      " 0.000000 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000\n",
      " 1.000000 0.000000 0.000000 0.000000 0.000000 1.000000 1.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000 0.000000\n",
      " 0.000000 0.000000 1.000000 0.000000 0.000000 1.000000 0.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000\n",
      " 1.000000 0.000000 0.000000 0.000000]\n"
     ]
    }
   ],
   "source": [
    "curr_policy = np.random.uniform(0,1,size=(n*m))\n",
    "new_policy = init_policy\n",
    "print('Starting policy',init_policy)\n",
    "\n",
    "while np.count_nonzero(curr_policy - new_policy) > 0:\n",
    "    curr_policy = new_policy\n",
    "    Pi = get_Pi(curr_policy,n,m)\n",
    "    mat = np.identity(n*m) - gamma*np.matmul(prob_transition,Pi)\n",
    "    q_vals_utility = np.dot(np.linalg.inv(mat),utility)\n",
    "    new_policy = policy_iter(q_vals_utility,n,m)\n",
    "    \n",
    "print('Final policy',new_policy)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.556458336351498\n"
     ]
    }
   ],
   "source": [
    "ell_utility_star = ell(q_vals_utility,new_policy,rho)\n",
    "print(ell_utility_star)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute the optimal reward value from LP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal reward value: 8.163862517858446\n",
      "Optimal policy: [0.000000 1.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000 0.000000\n",
      " 0.999999 0.000001 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000\n",
      " 0.000000 1.000000 0.000000 1.000000 0.000000 0.000000 0.000000 0.000000\n",
      " 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000001\n",
      " 0.000000 0.999999 0.000000 1.000000 0.000000 0.000000 0.000000 0.000000\n",
      " 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.350830 0.000000\n",
      " 0.649170 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000\n",
      " 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000 1.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 0.000000\n",
      " 0.000000 0.000000 0.000000 1.000000]\n"
     ]
    }
   ],
   "source": [
    "## linear programming solver\n",
    "\n",
    "# minimize c @ x\n",
    "# \n",
    "# such that \n",
    "#          A_ub @ x <= b_ub\n",
    "#          A_eq @ x == b_eq\n",
    "#          lb <= x <= ub\n",
    "\n",
    "c = -reward\n",
    "A_ub = -utility.reshape(1, n*m)\n",
    "b_ub = np.zeros(1)\n",
    "\n",
    "prob_transition_lp = np.transpose(prob_transition)\n",
    "\n",
    "E_sum = np.full_like(prob_transition_lp, 0)\n",
    "for i in range(n):\n",
    "    E_sum[i,m*i:m*(i+1)] = np.ones(m)\n",
    "    \n",
    "A_eq = E_sum - gamma*prob_transition_lp\n",
    "b_eq = rho\n",
    "\n",
    "lb = np.zeros(n*m)\n",
    "ub = np.ones(n*m)\n",
    "ub = ub/(1-gamma)\n",
    "bounds = np.transpose([lb, ub])\n",
    "\n",
    "res = linprog(c, A_ub, b_ub, A_eq, b_eq, bounds)\n",
    "\n",
    "optimal_reward_value_LP = -res.fun\n",
    "optimal_policy_LP = theta_to_policy_naive(res.x,n,m)\n",
    "\n",
    "print('Optimal reward value:',optimal_reward_value_LP)\n",
    "\n",
    "print('Optimal policy:',optimal_policy_LP)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regularized NPG primal dual method "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Policy search via regularized NPG primal dual method \n",
    "Input: n, m, prob_transition, gamma, reward, utility, stepsize, tau, total_iterates   \n",
    "Output: function values of reward and utility  \n",
    "'''\n",
    "\n",
    "def Reg_NPG_primal_dual(n, m, prob_transition, gamma, reward, utility, stepsize, tau, total_iterates,optimal_policy_LP):\n",
    "\n",
    "    theta = np.random.uniform(0,1,size=n*m)\n",
    "    dual = 0\n",
    "    reward_value = []\n",
    "    utility_value = []\n",
    "    policy_distance = []\n",
    "\n",
    "    for k in range(total_iterates):\n",
    "        prob = theta_to_policy(theta,n,m)\n",
    "\n",
    "        Pi = get_Pi(prob,n,m)\n",
    "        mat = np.identity(n*m) - gamma*np.matmul(prob_transition,Pi)\n",
    "        qvals_reward = np.dot(np.linalg.inv(mat),reward)\n",
    "        qvals_utility = np.dot(np.linalg.inv(mat),utility)\n",
    "\n",
    "        P_theta = np.matmul(Pi,prob_transition)\n",
    "        d_pi = (1-gamma)*np.dot(np.transpose((np.linalg.inv(np.identity(n) - gamma*P_theta))),rho)\n",
    "    \n",
    "        # entropy \n",
    "        qvals_entropy = np.dot(np.linalg.inv(mat),-np.log(prob))\n",
    "    \n",
    "        # natural gradient \n",
    "        qvals = qvals_reward + dual*qvals_utility + tau*qvals_entropy\n",
    "    \n",
    "        # natural gradient ascent\n",
    "        theta += stepsize*avals(qvals,prob)\n",
    "    \n",
    "        # dual desceent \n",
    "        violation_gradient = ell(qvals_utility,prob,rho) \n",
    "        dual = (1-stepsize*tau)*dual - stepsize*violation_gradient\n",
    "        dual = proj(dual,gamma)\n",
    "    \n",
    "    \n",
    "        if k % 1 == 0:\n",
    "        \n",
    "        # record iterates\n",
    "        \n",
    "            # record values\n",
    "            avg_reward = ell(qvals_reward,prob,rho)\n",
    "            avg_utility = ell(qvals_utility,prob,rho)\n",
    "            reward_value.append(avg_reward)\n",
    "            utility_value.append(avg_utility)\n",
    "            \n",
    "            policy_norm_squared = np.inner(prob - optimal_policy_LP, prob - optimal_policy_LP) \n",
    "            policy_distance.append(policy_norm_squared)\n",
    "            \n",
    "    return reward_value, utility_value, policy_distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'policy_distance' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-96d1f0978c40>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mtau\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mreward_value_reg_3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutility_value_reg_3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_distance_reg_3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mReg_NPG_primal_dual\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprob_transition\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutility\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstepsize\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtau\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal_iterates\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimal_policy_LP\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0mtau\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.05\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-19-2c9b8f39a8b6>\u001b[0m in \u001b[0;36mReg_NPG_primal_dual\u001b[0;34m(n, m, prob_transition, gamma, reward, utility, stepsize, tau, total_iterates, optimal_policy_LP)\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     50\u001b[0m             \u001b[0mpolicy_norm_squared\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprob\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0moptimal_policy_LP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprob\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0moptimal_policy_LP\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m             \u001b[0mpolicy_distance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpolicy_norm_squared\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     53\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mreward_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutility_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_distance\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'policy_distance' is not defined"
     ]
    }
   ],
   "source": [
    "# call regularized NPG primal dual method\n",
    "\n",
    "total_iterates = 2000\n",
    "\n",
    "stepsize = 0.1\n",
    "# stepsize = 0.01\n",
    "\n",
    "tau = 0.1\n",
    "reward_value_reg_3, utility_value_reg_3, policy_distance_reg_3 = Reg_NPG_primal_dual(n, m, prob_transition, gamma, reward, utility, stepsize, tau, total_iterates, optimal_policy_LP)\n",
    "\n",
    "tau = 0.05\n",
    "reward_value_reg_2, utility_value_reg_2, policy_distance_reg_2 = Reg_NPG_primal_dual(n, m, prob_transition, gamma, reward, utility, stepsize, tau, total_iterates, optimal_policy_LP)\n",
    "\n",
    "tau = 0.01\n",
    "reward_value_reg_1, utility_value_reg_1, policy_distance_reg_1 = Reg_NPG_primal_dual(n, m, prob_transition, gamma, reward, utility, stepsize, tau, total_iterates, optimal_policy_LP)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot reward values\n",
    "\n",
    "f_value = plt.figure\n",
    "\n",
    "num_every = 20\n",
    "num_grads = np.arange(0, total_iterates, num_every)\n",
    "\n",
    "reward_value_reg_3 = np.array(reward_value_reg_3)\n",
    "reward_value_reg_2 = np.array(reward_value_reg_2)\n",
    "reward_value_reg_1 = np.array(reward_value_reg_1)\n",
    "\n",
    "# absolute difference to the optimal reward value from LP\n",
    "error_reward_value_reg_3 = np.absolute(reward_value_reg_3 + res.fun)\n",
    "error_reward_value_reg_2 = np.absolute(reward_value_reg_2 + res.fun)\n",
    "error_reward_value_reg_1 = np.absolute(reward_value_reg_1 + res.fun)\n",
    "\n",
    "# convert y-axis to Logarithmic scale\n",
    "plt.yscale(\"log\")\n",
    "\n",
    "# plt.plot(num_grads,reward_value[::num_every], \"c-.\", linewidth=2)\n",
    "plt.plot(num_grads,error_reward_value_reg_3[::num_every], \"k-\",linewidth=2)\n",
    "plt.plot(num_grads,error_reward_value_reg_2[::num_every], \"b--\",linewidth=2)\n",
    "plt.plot(num_grads,error_reward_value_reg_1[::num_every], \"r:\",linewidth=2)\n",
    "plt.grid(axis='y', color='0.85')\n",
    "plt.draw()\n",
    "get_f_value = plt.gcf()\n",
    "get_f_value.savefig('NPG_primal_dual_sensitivity_reward_tau_rate.png',dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot utility values\n",
    "\n",
    "g_value = plt.figure\n",
    "\n",
    "num_every = 20\n",
    "num_grads = np.arange(0, total_iterates, num_every)\n",
    "\n",
    "utility_value_reg_3 = np.array(utility_value_reg_3)\n",
    "utility_value_reg_2 = np.array(utility_value_reg_2)\n",
    "utility_value_reg_1 = np.array(utility_value_reg_1)\n",
    "\n",
    "# absolute difference to the optimal utility value 0\n",
    "error_utility_value_reg_3 = np.absolute(utility_value_reg_3)\n",
    "error_utility_value_reg_2 = np.absolute(utility_value_reg_2)\n",
    "error_utility_value_reg_1 = np.absolute(utility_value_reg_1)\n",
    "\n",
    "# convert y-axis to Logarithmic scale\n",
    "plt.yscale(\"log\")\n",
    "\n",
    "# plt.plot(num_grads,utility_value[::num_every], \"c-.\", linewidth=2)\n",
    "plt.plot(num_grads,error_utility_value_reg_3[::num_every], \"k-\", linewidth=2)\n",
    "plt.plot(num_grads,error_utility_value_reg_2[::num_every], \"b--\", linewidth=2)\n",
    "plt.plot(num_grads,error_utility_value_reg_1[::num_every], \"r:\", linewidth=2)\n",
    "plt.grid(axis='y', color='0.85')\n",
    "plt.draw()\n",
    "get_g_value = plt.gcf()\n",
    "get_g_value.savefig('NPG_primal_dual_sensitivity_utility_tau_rate.png',dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot sqaured norm of policy gap  \n",
    "\n",
    "p_distance = plt.figure\n",
    "\n",
    "num_every = 20\n",
    "num_grads = np.arange(0, total_iterates, num_every)\n",
    "\n",
    "# case #1\n",
    "policy_distance_reg_1 = np.array(policy_distance_reg_1)\n",
    "\n",
    "# case #2\n",
    "policy_distance_reg_2 = np.array(policy_distance_reg_2)\n",
    "\n",
    "\n",
    "# case #3\n",
    "policy_distance_reg_3 = np.array(policy_distance_reg_3)\n",
    "\n",
    "\n",
    "# convert y-axis to Logarithmic scale\n",
    "plt.yscale(\"log\")\n",
    "\n",
    "plt.plot(num_grads,policy_distance_reg_1[::num_every], \"k:\", linewidth=2)\n",
    "plt.plot(num_grads,policy_distance_reg_2[::num_every], \"b-.\", linewidth=2)\n",
    "plt.plot(num_grads,policy_distance_reg_3[::num_every], \"r-\", linewidth=2)\n",
    "plt.grid(axis='y', color='0.85')\n",
    "plt.draw()\n",
    "p_distance = plt.gcf()\n",
    "p_distance.savefig('NPG_primal_dual_squared_policy_distance_rate.png',dpi=300)\n",
    "# p_distance.savefig('NPG_primal_dual_squared_policy_distance.png',dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
