{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4515e863",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.linalg import multi_dot\n",
    "from numpy import linalg as LA\n",
    "from six.moves import xrange\n",
    "from scipy import linalg as la\n",
    "import scipy.stats as st\n",
    "import random\n",
    "from numpy.linalg import multi_dot\n",
    "import math\n",
    "from math import pow\n",
    "import seaborn as sns\n",
    "import pandas as pd "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b9b1491",
   "metadata": {},
   "source": [
    "# Zeroth-order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45cd89d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DL_LQR_Solver:\n",
    "    def __init__(self,A,B,Q,R,k,d,m,l,r,eta,T):\n",
    "        self.A = A\n",
    "        self.B = B\n",
    "        self.Q = Q\n",
    "        self.R = R\n",
    "        self.k = k\n",
    "        self.d = d\n",
    "        self.m=m\n",
    "        self.l=l\n",
    "        self.r=r\n",
    "        self.eta=eta\n",
    "        self.T = T\n",
    "        \n",
    "    def Choose_K(self):\n",
    "        for j in range(100000):\n",
    "            K_11 = 0.8*np.random.rand(self.k,self.d)\n",
    "            K_22 = -0.8*np.random.rand(self.k,self.d)\n",
    "            K=(K_11+K_22)/2\n",
    "            O_1=self.A-np.dot(self.B,K)\n",
    "            results1=np.linalg.eig(O_1)\n",
    "            if abs(results1[0][0])<0.8 and abs(results1[0][1])<0.8: #Smallier eigenvalue may convergence fast\n",
    "                print(K)\n",
    "                break\n",
    "            if j>99998:\n",
    "                print('choose K wrong')\n",
    "                break\n",
    "        return K\n",
    "    \n",
    "    def Sample_Policy(self,K):\n",
    "        U_11=np.random.rand(self.k,self.d)\n",
    "        U_22=-np.random.rand(self.k,self.d)\n",
    "        U=(U_11+U_22)/2\n",
    "        n=np.linalg.norm(U)\n",
    "        U=(U/n)*self.r\n",
    "        K_2=K+U\n",
    "        return K_2,U\n",
    "    \n",
    "    def Reward(self,K,x):\n",
    "        C=self.Q+np.dot(np.dot(K.T,self.R),K)\n",
    "        return np.dot(np.dot(x,C),x)\n",
    "    \n",
    "    def State(self,K,x):\n",
    "        return np.dot(self.A-np.dot(self.B,K),x)\n",
    "    \n",
    "    def Trajectory(self,K):\n",
    "        Mean=np.zeros(self.d)\n",
    "        Cov=np.eye(self.d)\n",
    "        x=np.random.multivariate_normal(Mean,Cov)\n",
    "        norm=np.linalg.norm(x)\n",
    "        x=x/norm\n",
    "        C=0\n",
    "        Sigma=np.zeros((self.d,self.d))\n",
    "        for i in range(self.l):\n",
    "            C+=self.Reward(K,x)\n",
    "            Sigma+=np.dot(x[:,None],x[None,:])\n",
    "            x=self.State(K,x)\n",
    "        return C,Sigma\n",
    "    \n",
    "    def Estimator(self,K):\n",
    "        C_total=np.zeros((self.k,self.d))\n",
    "        Sigma_total=np.zeros((self.d,self.d))\n",
    "        c=0\n",
    "        for i in range(self.m):\n",
    "            K_1,U_1=self.Sample_Policy(K)# do not occur K, use K_1 instead\n",
    "            C_1,Sigma=self.Trajectory(K_1)\n",
    "            C_total+=np.dot(C_1,U_1)\n",
    "            Sigma_total+=Sigma\n",
    "        C_est=(C_total*self.d*self.k)/(self.m*self.r*self.r)\n",
    "        Sigma_est=Sigma_total/self.m\n",
    "        return C_est,Sigma_est\n",
    "    \n",
    "    def ModelBased_PolicyGradient(self,K):\n",
    "        P_k = self.Q \n",
    "        K = self.Choose_K() # random sample a k\n",
    "        sigmak = np.zeros((self.d,self.d)) #using store the covariance\n",
    "        num_rollout = 5\n",
    "        \n",
    "        for t in range(0,10000):\n",
    "            P_k = self.Q + multi_dot([K.T, self.R ,K]) + multi_dot( [(self.A-np.dot(self.B,K)).T, P_k ,self.A-np.dot(self.B, K)] )\n",
    "            x1=np.zeros(self.d)\n",
    "            x2=np.zeros(self.d)\n",
    "            Mean=np.zeros(self.d)\n",
    "            Cov=np.eye(self.d)\n",
    "            x1=np.random.multivariate_normal(Mean,Cov)\n",
    "            sigmak=np.zeros((self.d,self.d))\n",
    "            for i in range(10):\n",
    "                x2 = self.State(K,x1)\n",
    "                sigmak+=np.dot(x2[:,None],x2[None,:])\n",
    "                x1=x2\n",
    "            grad_C = 2*np.dot( np.dot(self.R+ multi_dot([self.B.T, P_k, self.B]),K)  - multi_dot([self.B.T,P_k,self.A]) ,sigmak) \n",
    "            K = K - self.eta * grad_C\n",
    "            O=self.A-np.dot(self.B,K)\n",
    "            results=np.linalg.eig(O)\n",
    "            if abs(results[0][0])>1 or abs(results[0][1])>1:\n",
    "                print('model_based_wrong')\n",
    "                break\n",
    "\n",
    "        return K\n",
    "    \n",
    "    def Optimal_K(self):\n",
    "        P = la.solve_discrete_are(self.A, self.B, self.Q, self.R)\n",
    "        K_opt = multi_dot([np.linalg.inv(self.R + multi_dot([self.B.T, P, self.B])), self.B.T, P, self.A])\n",
    "        return K_opt\n",
    "    \n",
    "    def Policy_gradient(self):\n",
    "        K_opt=self.Optimal_K()\n",
    "        K_2=self.Choose_K()\n",
    "        data=np.array([])\n",
    "        diff=0\n",
    "        for i in range(self.T):\n",
    "            C_set_1=np.zeros((self.d,self.d))\n",
    "            C_est_1,Sigma_est_1=self.Estimator(K_2)\n",
    "            #K_3=K_2-self.eta*C_est_1 #gradient\n",
    "            K_3=K_2-self.eta*np.dot(C_est_1,np.linalg.inv(Sigma_est_1)) #natural gradient\n",
    "            O = self.A-np.dot(self.B,K_3)\n",
    "            results=np.linalg.eig(O)\n",
    "            if abs(results[0][0])>1 or abs(results[0][1])>1:\n",
    "                print('-wrong-')\n",
    "                break\n",
    "            K_2=K_3\n",
    "            diff=np.linalg.norm(K_2-K_opt)\n",
    "            data = np.concatenate((data, [diff]), axis=0)\n",
    "            print(np.linalg.norm(K_2-K_opt))\n",
    "        return K_2,data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db941225",
   "metadata": {},
   "source": [
    "# Example 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66343623",
   "metadata": {},
   "outputs": [],
   "source": [
    "A=np.array([[0,1],[1,0]])\n",
    "B=np.array([[0,1],[1,0]])\n",
    "Q=np.array([[9,2],[2,1]])\n",
    "R=np.array([[1,2],[2,8]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de1b9e82",
   "metadata": {},
   "outputs": [],
   "source": [
    "data=np.zeros((10,1000))\n",
    "for i in range(10):\n",
    "    DL_LQR=DL_LQR_Solver(A,B,Q,R,k=2,d=2,m=5000,l=20,r=0.1,eta=0.01,T=1000)\n",
    "    K_next,data[i]=DL_LQR.Policy_gradient()\n",
    "np.savetxt('Zeroth-order.txt',data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f680c6e",
   "metadata": {},
   "source": [
    "# Example 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c67d6fd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "A= np.array([[0.2,0.1,1.0,0],[0.2,0.1,0.1,0],[0,0.1,0.5,0],[0,0,0,0.5]])\n",
    "B= np.array([[0.3,0,0],[0.2,0,0.3],[1.0,1.0,0.3],[0.3,0.1,0.1]])\n",
    "Q= np.array([[1.0,0,0.2,0], [0,1.0,0.1,0], [0.2,0.1,1.0,0.1],[0,0,0.1,1]])\n",
    "R= np.array([[1.0,0.1,1.0], [0.1,1.0,0.5],[1.0,0.5,2]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b10adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "data=np.zeros((10,1000))\n",
    "for i in range(10):\n",
    "    DL_LQR=DL_LQR_Solver(A,B,Q,R,k=3,d=4,m=20000,l=50,r=0.1,eta=0.01,T=1000)\n",
    "    K_next,data[i]=DL_LQR.Policy_gradient()\n",
    "np.savetxt('Zeroth-order2.txt',data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c468ab8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd90332e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
