{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mU9x38mK_p4p"
      },
      "source": [
        "# Defining the LDS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WMgNDGqySVOY"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as np\n",
        "import pandas as pd\n",
        "import numpy as onp\n",
        "import numpy.random as random\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "from scipy.linalg import solve_discrete_are as dare\n",
        "from jax import jit, grad\n",
        "from tqdm import tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wdifxW3CSbwx",
        "outputId": "8afb8e19-1329-4d73-9e23-56a1e6c93162"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        }
      ],
      "source": [
        "# LDS specification\n",
        "#n, m, A, B = 2, 1, np.array([[1., 1.], [0., 1.]]), np.array([[0.], [1.]])\n",
        "n, m , A, B = 5,5, random.normal(size= (5,5)), random.normal(size= (5,5)) \n",
        "Q, R = np.eye(N = n), np.eye(N = m)\n",
        "x0, T = np.zeros((n, 1)), 1000\n",
        "alg_name = ['LQR/H2Control', 'GPC', 'BPC', 'RBPC']\n",
        "color_code = {'LQR/H2Control': 'blue', \n",
        "              'GPC': 'red', 'BPC': 'purple', 'RBPC': 'green'}\n",
        "\n",
        "quad_cost = lambda x, u: np.sum(x.T @ Q @ x + u.T @ R @ u)\n",
        "\n",
        "\n",
        "\n",
        "# Func: Evaluate a given policy\n",
        "def evaluate_scan(controller, W, cost_fn):\n",
        "    def step(carry, w):\n",
        "      x, ctrl = carry\n",
        "      u = ctrl.act(u)\n",
        "      loss = cost_fn(x, u)\n",
        "      x_new = A @ x + B @ u + w\n",
        "      carry_new = x_new, ctrl\n",
        "      return carry_new, loss\n",
        "    init = x0, controller\n",
        "    _, losses = jax.lax.scan(step, init, np.arange(T))\n",
        "    return losses\n",
        "\n",
        "def evaluate(controller, W, cost_fn):\n",
        "    x, loss = x0, [0. for _ in range(T)]\n",
        "    for t in range(T):\n",
        "        u = controller.act(x)\n",
        "        loss[t] = cost_fn(x, u)\n",
        "        x = A @ x + B @ u + W[t]\n",
        "        \n",
        "    return np.array(loss, dtype=np.float32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pAPXdYEg_0MU"
      },
      "source": [
        "# No Control, LQR, H-inf, GPC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YZi3HNM6TjGW"
      },
      "outputs": [],
      "source": [
        "# Run zero control\n",
        "class ZeroControl:\n",
        "    def __init__(self):\n",
        "        pass\n",
        "    def act(self,x):\n",
        "        return np.zeros((m, 1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mmIsMa7aS9wW"
      },
      "outputs": [],
      "source": [
        "# Solve H2 Control\n",
        "class H2Control:\n",
        "    def __init__(self, A, B, Q, R):\n",
        "        P = dare(A, B, Q, R)\n",
        "        self.K = np.linalg.inv(R + B.T @ P @ B) @ (B.T @ P @ A)\n",
        "    def act(self, x):\n",
        "        return -self.K @ x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6JS8gpSfmv5"
      },
      "outputs": [],
      "source": [
        "# Solve the non-stationary/finite-horizon version for H2 Control\n",
        "class H2ControlNonStat:\n",
        "    def __init__(self, A, B, Q, R, T):\n",
        "        n, m = B.shape\n",
        "        P, self.K, self.t = [np.zeros((n,n)) for _ in range(T+1)], [np.zeros((m, n)) for _ in range(T)], 0\n",
        "        P[T] = Q\n",
        "        for t in range(T-1, -1, -1):\n",
        "            P[t] = Q + A.T @ P[t+1] @ A - A.T @ P[t+1] @ B @ np.linalg.inv(R + B.T @ P[t+1] @ B) @ B.T @ P[t+1] @ A\n",
        "            self.K[t] = np.linalg.inv(R + B.T @ P[t] @ B) @ B.T @ P[t] @ A\n",
        "    def act(self, x):\n",
        "        u = -self.K[self.t] @ x\n",
        "        self.t += 1\n",
        "        return u"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "juowNe1DDfQo"
      },
      "outputs": [],
      "source": [
        "# Solve H2 Control for Random Walk\n",
        "class ExtendedH2Control:\n",
        "    def __init__(self, A, B, Q, R, T):\n",
        "        Aprime = onp.block([[A, np.eye(n)], [np.zeros((n,n)), np.eye(n)]])\n",
        "        Bprime = onp.block([[B], [np.zeros((n,m))]])\n",
        "        Qprime = onp.block([[Q, np.zeros((n,n))], [np.zeros((n,n)), np.zeros((n,n))]])\n",
        "        Rprime = R\n",
        "        self.A, self.B = A, B\n",
        "        self.H2 = H2ControlNonStat(Aprime, Bprime, Qprime, Rprime, T)\n",
        "        self.x, self.u = np.zeros((n,1)), np.zeros((m,1))\n",
        "    def act(self, x):\n",
        "        W = x - self.A @ self.x - self.B @ self.u\n",
        "        self.x = x\n",
        "        self.u = self.H2.act(onp.block([[x],[W]]))\n",
        "        return self.u"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JiZ2qQCIv6s5"
      },
      "outputs": [],
      "source": [
        "# Solve Hinf Control\n",
        "class HinfControl:\n",
        "    def __init__(self, A, B, Q, R, T, gamma):\n",
        "        P, self.K, self.W, self.t = [np.zeros((n, n)) for _ in range(T+1)], [np.zeros((m, n)) for _ in range(T)], [np.zeros((n,n)) for _ in range(T)], 0\n",
        "        P[T] = Q\n",
        "        for t in range(T-1, -1, -1):\n",
        "            P[t] = Q + A.T @ np.linalg.inv(np.linalg.inv(P[t+1]) + B @ np.linalg.inv(R) @ B.T - gamma**2 * np.eye(n)) @ A\n",
        "            Lambda = np.eye(n) + (B @ np.linalg.inv(R) @ B.T - gamma**2 * np.eye(n)) @ P[t+1]\n",
        "            self.K[t] = np.linalg.inv(R) @ B.T @ P[t+1] @ np.linalg.inv(Lambda) @ A\n",
        "            self.W[t] = (gamma**2)*P[t+1] @ np.linalg.inv(Lambda) @ A\n",
        "    def act(self, x):\n",
        "        u = self.K[self.t] @ x\n",
        "        self.t += 1\n",
        "        return u"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i4u3Lr8XumOF"
      },
      "outputs": [],
      "source": [
        "# GPC definition\n",
        "class GPC:\n",
        "    def __init__(self, A, B, Q, R, x0, M, H, lr, cost_fn):\n",
        "        n, m = B.shape\n",
        "        self.lr, self.A, self.B, self.M = lr, A, B, M\n",
        "        self.x, self.u, self.off, self.t = x0, np.zeros((m, 1)), np.zeros((m, 1)), 1\n",
        "        self.K, self.E, self.W = H2Control(A, B, Q, R).K, np.zeros((M, m, n)), np.zeros((H+M, n, 1))\n",
        "\n",
        "        def counterfact_loss(E, W):\n",
        "            y = np.zeros((n, 1))\n",
        "            for h in range(H-1):\n",
        "                v = -self.K @ y + np.tensordot(E, W[h : h + M], axes = ([0, 2], [0, 1]))\n",
        "                y = A @ y + B @ v + W[h + M]\n",
        "            v = -self.K @ y + np.tensordot(E, W[h : h + M], axes = ([0, 2], [0, 1]))\n",
        "            cost = cost_fn(y, v)\n",
        "            return cost\n",
        "\n",
        "        self.grad = jit(grad(counterfact_loss))\n",
        "\n",
        "    def act(self, x):\n",
        "        # 1. Get new noise\n",
        "        self.W = self.W.at[0].set( x - self.A @ self.x - self.B @ self.u)\n",
        "        self.W = np.roll(self.W, -1, axis = 0)\n",
        "\n",
        "        # 2. Get gradients\n",
        "        delta_E = self.grad(self.E, self.W)\n",
        "\n",
        "        # 3. Execute updates\n",
        "        self.E -= self.lr * delta_E\n",
        "        #self.off -= self.lr * delta_off\n",
        "\n",
        "        # 4. Update x & t and get action\n",
        "        self.x, self.t = x, self.t + 1\n",
        "        self.u = -self.K @ x + np.tensordot(self.E, self.W[-self.M:], axes = ([0, 2], [0, 1])) #+ self.off\n",
        "        return self.u"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nf7EhpBgTAwy"
      },
      "outputs": [],
      "source": [
        "# BPC definition\n",
        "class BPC:\n",
        "    def __init__(self, A, B, Q, R, x0, M, H, lr, delta, cost_fn):\n",
        "        n, m = B.shape\n",
        "        self.n, self.m = n, m\n",
        "        self.lr, self.A, self.B, self.M = lr, A, B, M\n",
        "        self.x, self.u, self.delta, self.t = x0, np.zeros((m, 1)), delta, 0\n",
        "        self.K, self.E, self.W = H2Control(A, B, Q, R).K, np.zeros((M, m, n)), np.zeros((M, n, 1))\n",
        "        self.cost_fn = cost_fn\n",
        "\n",
        "        def _generate_uniform(shape, norm=1.00):\n",
        "            v = random.normal(size=shape)\n",
        "            v = norm * v / np.linalg.norm(v)\n",
        "            return v\n",
        "        self._generate_uniform = _generate_uniform\n",
        "        self.eps = self._generate_uniform((M, M, m, n))\n",
        "\n",
        "    def act(self, x):\n",
        "        # 1. Get new noise\n",
        "        self.W = self.W.at[0].set(x - self.A @ self.x - self.B @ self.u)\n",
        "        self.W = np.roll(self.W, -1, axis = 0)\n",
        "        \n",
        "        # 2. Get gradient estimates\n",
        "        delta_E = self.cost_fn(self.x, self.u) * np.sum(self.eps, axis = 0)\n",
        "\n",
        "        # 3. Execute updates\n",
        "        self.E -= self.lr * delta_E\n",
        "\n",
        "        # 3. Ensure norm is good\n",
        "        norm = np.linalg.norm(self.E)\n",
        "        if norm > 1:\n",
        "           self.E *= 1 / norm\n",
        "            \n",
        "        # 4. Get new eps (after parameter update (4) or ...?)\n",
        "        self.eps = self.eps.at[0].set(self._generate_uniform(\n",
        "                    shape = (self.M, self.m, self.n), norm = np.sqrt(1 - np.linalg.norm(self.eps[1:])**2)))\n",
        "        self.eps = np.roll(self.eps, -1, axis = 0)\n",
        "\n",
        "        # 5. Update x & t and get action\n",
        "        self.x, self.t = x, self.t + 1\n",
        "        self.u = -self.K @ x + np.tensordot(self.E + self.delta * self.eps[-1], \\\n",
        "                            self.W[-self.M:], axes = ([0, 2], [0, 1])) \n",
        "              \n",
        "        return self.u"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MVlkHTJOvoZP"
      },
      "outputs": [],
      "source": [
        "# RBPC definition\n",
        "class RBPC:\n",
        "    def __init__(self, A, B, Q, R, x0, M, H, lr, delta, cost_fn, noise_sd):\n",
        "        n, m = B.shape\n",
        "        self.n, self.m = n, m\n",
        "        self.lr, self.A, self.B, self.M, self.H, self.noise_sd = lr, A, B, M, H, noise_sd\n",
        "        self.x, self.u, self.delta, self.t = x0, np.zeros((m, 1)), delta, 0\n",
        "        self.K, self.E, self.W = H2Control(A, B, Q, R).K, np.zeros((M, n, m)), np.zeros((M + H -1, n))\n",
        "        self.cost_fn = cost_fn\n",
        "\n",
        "        def _generate_uniform(shape, norm=1.00):\n",
        "            v = random.normal(size=shape)\n",
        "            v = norm * v / (np.linalg.norm(v)+ 1e-6)\n",
        "            return v\n",
        "        self._generate_uniform = _generate_uniform\n",
        "        self.eps = self._generate_uniform((M, m))\n",
        "\n",
        "    def Egrad(self):\n",
        "        gE= np.zeros((self.M, self.n, self.m))\n",
        "        for i in range(self.H):\n",
        "          gE += np.einsum(\"ij, k->ijk\",self.W[i:self.M+i,:], self.eps[i])\n",
        "        return gE * self.cost_fn(self.x, self.u)\n",
        "\n",
        "    def act(self, x):\n",
        "        # 1. Get gradient estimates\n",
        "        delta_E = self.Egrad()\n",
        "\n",
        "        # 2. Execute updates\n",
        "        self.E -= self.lr * delta_E\n",
        "\n",
        "        # 3. Ensure norm is good\n",
        "        norm = np.linalg.norm(self.E)\n",
        "        if norm > 1:\n",
        "           self.E *= 1/ norm\n",
        "\n",
        "        # 4. Get new noise\n",
        "        w = x - self.A @ self.x - self.B @ self.u\n",
        "        w = w.reshape(self.n)\n",
        "        self.W = self.W.at[0].set(w)\n",
        "        self.W = np.roll(self.W, -1, axis = 0)\n",
        "            \n",
        "        # 5. Get new eps (after parameter update (4) or ...?)\n",
        "        self.eps = self.eps.at[0].set(self._generate_uniform(\n",
        "                    shape = (self.m), norm = self.noise_sd))\n",
        "        self.eps = np.roll(self.eps, -1, axis = 0)\n",
        "\n",
        "        # 5. Update x & t and get action\n",
        "        self.x, self.t = x, self.t + 1\n",
        "        u = -self.K @ x + np.tensordot(self.E , self.W[-self.M:], axes = ([0, 1], [0, 1])).reshape((m,1)) + self.eps[-1].reshape((m,1))\n",
        "        self.u = u.reshape((self.m, 1))\n",
        "              \n",
        "        return self.u\n",
        "\n",
        "    "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Q_22GMcAD8A"
      },
      "source": [
        "# Plot & repeat utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YY5Z2uLmzMog"
      },
      "outputs": [],
      "source": [
        "def benchmark(M, W, cost_fn = quad_cost, lr = 0.001, delta = 0.001, no_control = False, gamma = None, grw = False, rbpc_lr=1e-3, rbpc_noise_sd=1.0):\n",
        "    global A, B, Q, R, T\n",
        "    #loss_zero = evaluate(ZeroControl(), W, cost_fn) if no_control else onp.full(T, np.nan, dtype=float)\n",
        "    loss_h2 = evaluate(H2Control(A, B, Q, R), W, cost_fn)\n",
        "    #loss_hinf = evaluate(HinfControl(A, B, Q, R, T, gamma), W, cost_fn) if gamma else onp.full(T, np.nan, dtype=np.float32)\n",
        "    #loss_ogrw = evaluate(ExtendedH2Control(A, B, Q, R, T), W, cost_fn) if grw else onp.full(T, np.nan, dtype=np.float32)\n",
        "\n",
        "    H, M = 5, M\n",
        "    loss_gpc = evaluate(GPC(A, B, Q, R, x0, M, H, lr, cost_fn), W, cost_fn)\n",
        "    loss_bpc = evaluate(BPC(A, B, Q, R, x0, M, H, lr, delta, cost_fn), W, cost_fn)\n",
        "    loss_rbpc = evaluate(RBPC(A, B, Q, R, x0, M, H, rbpc_lr, delta, cost_fn, rbpc_noise_sd), W, cost_fn)\n",
        "    print(f\"loss_h2 ; {np.mean(loss_h2)}, loss_gpc: {np.mean(loss_gpc)}, loss_bpc: {np.mean(loss_bpc)}, loss_rbpc: {np.mean(loss_rbpc)}\")\n",
        "    return loss_h2, loss_gpc, loss_bpc, loss_rbpc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qV3hclQOxKoj"
      },
      "outputs": [],
      "source": [
        "cummean = lambda x: np.cumsum(x)/(np.arange(T)+1)\n",
        "\n",
        "def to_dataframe(alg, loss, avg_loss):\n",
        "    global T\n",
        "    return pd.DataFrame(data = {'Algorithm': alg, 'Time': np.arange(T, dtype=np.float32),\n",
        "                                'Instantaneous Cost': loss, 'Average Cost': avg_loss})\n",
        "\n",
        "def repeat_benchmark(M, Wgen, rep, cost_fn = quad_cost, lr = 0.001, \n",
        "                     delta = 0.001, no_control = False, gamma = None, grw = False, rbpc_lr=1e-3, rbpc_noise_sd=1.0):\n",
        "    all_data = []\n",
        "    for r in range(rep):\n",
        "        loss = benchmark(M, Wgen(), cost_fn, lr, delta, no_control, gamma, grw, rbpc_lr, rbpc_noise_sd)\n",
        "        avg_loss = list(map(cummean, loss))\n",
        "        data = pd.concat(list(map(lambda x: to_dataframe(*x), list(zip(alg_name, loss, avg_loss)))))\n",
        "        all_data.append(data)\n",
        "    all_data = pd.concat(all_data)\n",
        "    return all_data[all_data['Instantaneous Cost'].notnull()]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZV8YReDN0cZ8"
      },
      "outputs": [],
      "source": [
        "def plot(title, data, scale = 'linear'):\n",
        "    fig, axs = plt.subplots(ncols=1, figsize=(15,4))\n",
        "    axs.set_yscale(scale)\n",
        "    sns.lineplot(x = 'Time', y = 'Average Cost', hue = 'Algorithm', \n",
        "                 data = data, ax = axs, ci = 'sd', palette = color_code).set_title(title)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tdZe37z8AHcn"
      },
      "source": [
        "# Experiments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CBt5gfSWryrB"
      },
      "outputs": [],
      "source": [
        "# Sine perturbations\n",
        "Wgen = lambda: (np.sin(np.arange(T*m)/(20*np.pi)).reshape(T,m) @ np.ones((m, n))).reshape(T, n, 1)\n",
        "quad_cost = lambda x, u: np.sum(x.T @ Q @ x + u.T @ R @ u)\n",
        "\n",
        "# Time steps & Number of seeds/repetitions to test each method on!\n",
        "T = 2000\n",
        "rep = 25\n",
        "for M in [5]:\n",
        "    for lr in [5e-4]:\n",
        "        for sd in [0.5]:\n",
        "            print(\"running M = {}, lr = {}, delta = {}\".format(M, lr, sd))\n",
        "            data = repeat_benchmark(M, Wgen, rep=rep, cost_fn=quad_cost, lr = 0.007, delta = 0.05, rbpc_lr=lr, rbpc_noise_sd=sd)\n",
        "            plot('Sinusoidal Perturbations', data)\n",
        "            specs = str(T) + \"_\" + str(M) + \"_\" + str(lr) + \"_\" + str(sd)\n",
        "            plt.savefig(\"sin_quad_rbpc\" + specs + \".pdf\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4c4TS8KUnV20"
      },
      "outputs": [],
      "source": [
        "\n",
        "# geometric random walk\n",
        "def Wgen():\n",
        "    W = onp.zeros(shape = (T, n, 1))\n",
        "    W[0] = random.normal(size=(n, 1))\n",
        "    for i in range(1, T):\n",
        "        W[i] = random.normal(size=(n, 1)) + 0.7 *W[i-1]\n",
        "    return W\n",
        "rep = 25   \n",
        "T = 1000\n",
        "for M in [3]:\n",
        "  for lr in [1e-3]:\n",
        "    for sd in [0.3]: # gaussian random walk requires smaller deltas\n",
        "        data = repeat_benchmark(M, Wgen, rep, lr = 0.001, delta = 0.05, rbpc_lr=lr, rbpc_noise_sd=sd)\n",
        "        plot('Gaussian Random Walk Perturbations', data)\n",
        "        specs = str(T) + \"_\" + str(M) + \"_\" + str(lr) + \"_\" + str(sd)\n",
        "        plt.savefig(\"random_walk_quad_\" + specs + \".pdf\") \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OGLgD7przV67"
      },
      "outputs": [],
      "source": [
        "# Defining non-quadratic hinge loss with sine noise\n",
        "Wgen = lambda: (np.sin(np.arange(T*m)/(2*np.pi)).reshape(T,m) @ np.ones((m, n))).reshape(T, n, 1)\n",
        "hinge_loss = lambda x, u: np.sum(np.abs(x)) + np.sum(np.abs(u))\n",
        "\n",
        "T = 1000\n",
        "rep = 25\n",
        "for M in [3, 6, 10]:\n",
        "    for lr in [0.007, 0.003, 0.001]:\n",
        "        for delta in [0.5, 0.3, 0.1, 0.05, 0.01]:\n",
        "            data = repeat_benchmark(M, Wgen, rep=rep, cost_fn=hinge_loss, lr = lr, delta = delta)\n",
        "            plot('Sinusoidal Perturbations - Hinge Loss', data)\n",
        "            specs = str(T) + \"_\" + str(M) + \"_\" + str(lr) + \"_\" + str(delta)\n",
        "            plt.savefig(\"sin_hinge_\" + specs + \".pdf\") "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DexFe05I-cTP"
      },
      "outputs": [],
      "source": [
        "\n",
        "# constant\n",
        "def Wgen():\n",
        "    W = onp.zeros(shape = (T, n, 1))\n",
        "    z = np.array([1.0, 1.0]).reshape(2,1)\n",
        "    for i in range(0, T):\n",
        "        W[i] = z\n",
        "    return W\n",
        "rep = 25   \n",
        "T = 2000\n",
        "for M in [3]:\n",
        "  for lr in [1e-4]:\n",
        "    for sd in [1.0]: \n",
        "        data = repeat_benchmark(M, Wgen, rep, lr = 1e-4, delta = 1.0, rbpc_lr=lr, rbpc_noise_sd=sd)\n",
        "        plot('Constant Perturbations', data)\n",
        "        specs = str(T) + \"_\" + str(M) + \"_\" + str(lr) + \"_\" + str(sd)\n",
        "        plt.savefig(\"constant_2\" + specs + \".pdf\") "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XHOvdTzHmOZ2"
      },
      "outputs": [],
      "source": [
        "def f(x):\n",
        "  return min(x, 50)\n",
        "\n",
        "data['f'] = \n",
        "plot('Constant Perturbations', data[data['Average Cost'] < 50])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "jBKCA1Xamgfn"
      },
      "outputs": [],
      "source": [
        "# Sine perturbations\n",
        "Wgen = lambda: (np.sin(np.arange(T*m)/(20*np.pi)).reshape(T,m) @ np.ones((m, n))).reshape(T, n, 1)\n",
        "quad_cost = lambda x, u: np.sum(x.T @ Q @ x + u.T @ R @ u)\n",
        "\n",
        "# Time steps & Number of seeds/repetitions to test each method on!\n",
        "T = 10000\n",
        "rep = 3\n",
        "for M in [5]:\n",
        "    for lr in [5e-8]:\n",
        "        for sd in [1.0]:\n",
        "            print(\"running M = {}, lr = {}, delta = {}\".format(M, lr, sd))\n",
        "            data = repeat_benchmark(M, Wgen, rep=rep, cost_fn=quad_cost, lr = 1e-6, delta = sd, rbpc_lr=lr, rbpc_noise_sd=sd)\n",
        "            plot('Sinusoidal Perturbations', data)\n",
        "            specs = str(T) + \"_\" + str(M) + \"_\" + str(lr) + \"_\" + str(sd)\n",
        "            plt.savefig(\"sin_quad_rbpc\" + specs + \".pdf\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-3Z99tKapNpi"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}