{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bdb5952",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.linalg import multi_dot\n",
    "from numpy import linalg as LA\n",
    "from six.moves import xrange\n",
    "from scipy import linalg as la\n",
    "import scipy.stats as st\n",
    "import random\n",
    "from numpy.linalg import multi_dot\n",
    "import math\n",
    "from math import pow\n",
    "import seaborn as sns\n",
    "import pandas as pd "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a5550b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TTS_LQR_Solver:\n",
    "    def __init__(self,A,B,Q,R,d,k,numstep,alpha,beta,gamma,sigma):\n",
    "        self.A = A\n",
    "        self.B = B\n",
    "        self.Q = Q\n",
    "        self.R = R\n",
    "        self.d = d\n",
    "        self.k = k\n",
    "        #self.thres = thres\n",
    "        self.numstep = numstep\n",
    "        self.alpha = alpha\n",
    "        self.beta = beta\n",
    "        self.gamma = gamma\n",
    "        self.sigma = sigma\n",
    "        \n",
    "    def Choose_K(self):\n",
    "        for j in range(100000):\n",
    "            K_11 = 0.8*np.random.rand(self.k,self.d)\n",
    "            K_22 = -0.8*np.random.rand(self.k,self.d)\n",
    "            K=(K_11+K_22)/2\n",
    "            #K=np.eye(self.d)\n",
    "            O_1=self.A-np.dot(self.B,K)\n",
    "            results1=np.linalg.eig(O_1)\n",
    "            #print(O_1)\n",
    "            #if 0<results[0][0]<1 and 0<results[0][1]<1:\n",
    "            if abs(results1[0][0])<0.8 and abs(results1[0][1])<0.8: #Smallier eigenvalue may convergence fast\n",
    "                #print(K)\n",
    "                break\n",
    "            if j>99998:\n",
    "                print('Choose_K wrong')\n",
    "                break\n",
    "        return K\n",
    "    \n",
    "    def svec(self,X): # I have tested that this is right\n",
    "        n=len(X)\n",
    "        c=int(n*(n+1)/2)\n",
    "        x=np.zeros(c)\n",
    "        for j in range(n):\n",
    "            for i in range(j+1):\n",
    "                l=int((j+1)*j/2+i)\n",
    "                if i!=j:\n",
    "                    x[l]=math.sqrt(2)*X[i][j]\n",
    "                else:\n",
    "                    x[l]=X[i][j]\n",
    "        return x\n",
    "    \n",
    "    def smat(self,x): #I have tested that this is right\n",
    "        k=len(x)\n",
    "        n=int((math.sqrt(1+8*k)-1)/2)\n",
    "        X=np.zeros((n,n))\n",
    "        for j in range(n):\n",
    "            for i in range(j+1):\n",
    "                t=int((j+1)*j/2+i)\n",
    "                if i!=j:\n",
    "                    X[i][j]=x[t]/math.sqrt(2)\n",
    "                else:\n",
    "                    X[i][j]=x[t]\n",
    "        X_T=np.transpose(X)\n",
    "        X=X+X_T\n",
    "        for i in range(n):\n",
    "            X[i][i]=X[i][i]/2\n",
    "        return X\n",
    "    \n",
    "    def block_mat(self,O): #I have tested that this is right\n",
    "        O_21=np.zeros((self.k,self.d))\n",
    "        O_22=np.zeros((self.k,self.k))\n",
    "        O_11=np.zeros((self.d,self.d))\n",
    "        O_12=np.zeros((self.d,self.k))\n",
    "        #for i in range(self.d):\n",
    "            #for j in range(self.d):\n",
    "                #O_11[i][j]=O[i,j]\n",
    "            #for l in range(self.d,self.d+self.k):\n",
    "                #O_12[i][l-self.d]=O[i][l]\n",
    "        for i in range(self.d,self.d+self.k):\n",
    "            for j in range(self.d):\n",
    "                O_21[i-self.d][j]=O[i,j]\n",
    "            for l in range(self.d,self.d+self.k):\n",
    "                O_22[i-self.d][l-self.d]=O[i][l]\n",
    "        return O_22,O_21\n",
    "    \n",
    "    \n",
    "    def Get_phi(self,x,u): # I have tested that this is right\n",
    "        aa=np.zeros(self.d+self.k)\n",
    "        for i in range(self.d):\n",
    "            aa[i]=x[i]\n",
    "        for j in range(self.k):\n",
    "            aa[j+self.d]=u[j]\n",
    "        pp_1=np.dot(aa[:,None],aa[None,:])\n",
    "        return self.svec(pp_1)\n",
    "    \n",
    "    def Project(self,omega):\n",
    "        omega=omega/np.linalg.norm(omega)\n",
    "        return omega\n",
    "        \n",
    "    def Stationary(self,K): #I have tested that this is right\n",
    "        I=np.eye((self.d))\n",
    "        phi=I\n",
    "        phi_si=phi+self.sigma*self.sigma*np.dot(B,B.T)\n",
    "        C_k=phi_si\n",
    "        for i in range(10000):\n",
    "            C_next=phi_si+np.dot(np.dot((A-np.dot(B,K)),C_k),(A-np.dot(B,K)).T)\n",
    "            C_next+=np.transpose(C_next)\n",
    "            C_next*=0.5\n",
    "            if np.abs(C_k - C_next).max() < 1e-10:\n",
    "                #print('convergence')\n",
    "                #print(phi_si+np.dot(np.dot((A-np.dot(B,K)),C_next),(A-np.dot(B,K)).T))\n",
    "                #print(C_next)\n",
    "                break\n",
    "            if i>99998:\n",
    "                print('Not Stabile')\n",
    "                break\n",
    "            C_k=C_next\n",
    "        return C_next\n",
    "    \n",
    "    def Sample_x(self,K):\n",
    "        mean=np.zeros(self.d)\n",
    "        sigma=self.Stationary(K)\n",
    "        x=np.random.multivariate_normal(mean,sigma)\n",
    "        return x\n",
    "    \n",
    "    def Rollout(self,K):\n",
    "        x=self.Sample_x(K)\n",
    "        u=-np.dot(K,x)+ self.sigma*np.random.multivariate_normal(np.zeros(self.k),np.eye((self.k)))\n",
    "        r=np.dot(np.dot(x,self.Q),x)+np.dot(np.dot(u,self.R),u)\n",
    "        x_next=np.dot(self.A,x)+np.dot(B,u)+np.random.multivariate_normal(np.zeros(self.d),np.eye((self.d)))\n",
    "        u_next=-np.dot(K,x_next)+ self.sigma*np.random.multivariate_normal(np.zeros(self.k),np.eye((self.k)))\n",
    "        return x,u,r,x_next,u_next\n",
    "    \n",
    "    def Optimal_K(self):\n",
    "        P = la.solve_discrete_are(self.A, self.B, self.Q, self.R)\n",
    "        K_opt = multi_dot([np.linalg.inv(self.R + multi_dot([self.B.T, P, self.B])), self.B.T, P, self.A])\n",
    "        return K_opt\n",
    "    \n",
    "    def J_K(self,K):\n",
    "        P_k=self.solve_P_k(K)\n",
    "        phi_sigma=np.eye(self.d)+self.sigma*self.sigma*(np.dot(self.B,self.B.T))\n",
    "        J=np.dot(P_k,phi_sigma)\n",
    "        return J.trace()+self.sigma*self.sigma*self.R.trace()\n",
    "    \n",
    "    def solve_P_k(self, K): # I have tested that this is right\n",
    "        \"\"\"Solves the Bellman equation by iteration.\n",
    "        Bellman Equation\n",
    "        if rho(A-BK) < 1\n",
    "        ```none\n",
    "        P_{t+1} = Q + K'RK + (A-BK)'P_{t}(A-BK)\n",
    "        ```\n",
    "        Returns:\n",
    "            A numpy array, a postive definite matrix P which is the solution to the Bellman equation.\n",
    "        Raises:\n",
    "            RuntimeError: If the computed P matrix is not symmetric and\n",
    "            positive-definite.\n",
    "        \"\"\"\n",
    "        a_b_k = self.A - np.dot(self.B, K)\n",
    "        if False in (np.absolute(LA.eigvals(a_b_k)) < 1):\n",
    "            return False\n",
    "\n",
    "        p = self.Q \n",
    "        p_list = []\n",
    "        for step in xrange(10000):\n",
    "            # p_next = self.Q + np.dot(np.transpose(K), np.dot(self.R, K)) + np.dot(np.transpose(a_b_k), np.dot(p, a_b_k))\n",
    "            p_next = self.Q + multi_dot([K.T, self.R, K]) + multi_dot([a_b_k.T, p, a_b_k]) # P_K=Q+K.TRK+(A-BK).TP_K(A-BK)\n",
    "            p_next += np.transpose(p_next)\n",
    "            p_next *= .5 # to make it symmetric\n",
    "            if np.abs(p - p_next).max() < 1e-9: #close enough\n",
    "                # print(\"step of pk = \" + str(step))\n",
    "                break\n",
    "            p = p_next\n",
    "            p_list += [p]\n",
    "        \n",
    "        if np.abs(p - p_next).max() > 1e-9:\n",
    "            logging.warn('DARE solver did not converge')\n",
    "        try:\n",
    "            # Check that the result is symmetric and positive-definite.\n",
    "            np.linalg.cholesky(p_next)\n",
    "        except np.linalg.LinAlgError:\n",
    "            raise RuntimeError('ARE solver failed: P matrix is not symmetric and '\n",
    "                            'positive-definite.')\n",
    "        return p_next\n",
    "    \n",
    "    def Omega_K(self,K):\n",
    "        P_K=self.solve_P_k(K)\n",
    "        #Omega_21=np.zeros((self.k,self.d))\n",
    "        #Omega_22=np.zeros((self.k,self.k))\n",
    "        #Omega_11=np.zeros((self.d,self.d))\n",
    "        #Omega_12=np.zeros((self.d,self.k))\n",
    "        Omega=np.zeros((self.d+self.k,self.d+self.k))\n",
    "        Omega_21=multi_dot([self.B.T,P_K,self.A])\n",
    "        Omega_22=self.R+multi_dot([self.B.T,P_K,self.B])\n",
    "        Omega_11=self.Q+multi_dot([self.A.T, P_K, self.A])\n",
    "        Omega_12=multi_dot([self.A.T, P_K, self.B])\n",
    "        for i in range(self.d):\n",
    "            for j in range(self.d):\n",
    "                Omega[i,j]=Omega_11[i,j]\n",
    "            for j in range(self.d,self.d+self.k):\n",
    "                Omega[i,j]=Omega_12[i,j-self.d]\n",
    "        for i in range(self.d,self.d+self.k):\n",
    "            for j in range(self.d):\n",
    "                Omega[i,j]=Omega_21[i-self.d,j]\n",
    "            for j in range(self.d,self.d+self.k):\n",
    "                Omega[i,j]=Omega_22[i-self.d,j-self.d]\n",
    "        return Omega       \n",
    "    \n",
    "    def Update(self):\n",
    "        K=self.Choose_K()\n",
    "        #data=np.array([])\n",
    "        data_aver=np.array([])\n",
    "        data_aver_crit=np.array([])\n",
    "        data_aver_cost=np.array([])\n",
    "        total=0\n",
    "        total_crit=0\n",
    "        total_cost=0\n",
    "        error=0\n",
    "        error_crit=0\n",
    "        error_cost=0\n",
    "        diff=0\n",
    "        diff_cost=0\n",
    "        diff_crit=0\n",
    "        eta=0\n",
    "        omega=np.zeros(int((self.d+self.k)*(self.d+self.k+1)/2))\n",
    "        K_opt=self.Optimal_K()\n",
    "        JK=self.J_K(K_opt)\n",
    "        #print(JK)\n",
    "        for i in range(self.numstep):\n",
    "            x,u,r,x_next,u_next=self.Rollout(K)\n",
    "            phi=self.Get_phi(x,u)\n",
    "            phi_next=self.Get_phi(x_next,u_next)\n",
    "            delta=r-eta+np.dot(phi_next,omega)-np.dot(phi,omega)\n",
    "            eta_next=eta+(self.gamma/pow(i+1,0.5))*(r-eta)\n",
    "            omega_next=omega+(self.beta/pow((i+10000),0.5))*delta*phi\n",
    "            #omega_next=omega+(self.beta)*delta*phi\n",
    "            #omega_next=self.Project(omega_next)\n",
    "            mat_omega=self.smat(omega)\n",
    "            ture_omega=self.Omega_K(K)\n",
    "            diff_crit=np.linalg.norm(mat_omega-ture_omega)**2\n",
    "            total_crit+=diff_crit\n",
    "            error_crit=total_crit/(i+1)\n",
    "            data_aver_crit = np.concatenate((data_aver_crit, [error_crit]), axis=0)\n",
    "            mat_22,mat_21=self.block_mat(mat_omega)\n",
    "            K_next=K-(self.alpha/pow(i+1000,0.5))*(np.dot(mat_22,K)-mat_21)\n",
    "            #K_next=K-(self.alpha)*(np.dot(mat_22,K)-mat_21)\n",
    "            #J_K=self.J_K(K_next)\n",
    "            J_K=self.J_K(K)\n",
    "            diff=J_K-JK\n",
    "            diff_cost=(eta-J_K)**2\n",
    "            #data = np.concatenate((data, [diff]), axis=0)\n",
    "            #data_cost= np.concatenate((data_cost, [diff_cost]), axis=0)\n",
    "            #print(eta_next)\n",
    "            total+=J_K-JK\n",
    "            total_cost+=diff_cost\n",
    "            error=total/(i+1)\n",
    "            error_cost=total_cost/(i+1)\n",
    "            data_aver = np.concatenate((data_aver, [error]), axis=0)\n",
    "            data_aver_cost = np.concatenate((data_aver_cost, [error_cost]), axis=0)\n",
    "            #print('average error:',error)\n",
    "            K=K_next\n",
    "            #print(np.linalg.norm(K-K_opt))\n",
    "            omega=omega_next\n",
    "            eta=eta_next\n",
    "        return K_next,data_aver_cost,data_aver,data_aver_crit"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dc67d8a",
   "metadata": {},
   "source": [
    "# Example 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f21c3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "A=np.array([[0,1],[1,0]])\n",
    "B=np.array([[0,1],[1,0]])\n",
    "Q=np.array([[9,2],[2,1]])\n",
    "R=np.array([[1,2],[2,8]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701f3ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_cost=np.zeros((10,1000000))\n",
    "data_crit=np.zeros((10,1000000))\n",
    "data_actor=np.zeros((10,1000000))\n",
    "for i in range(10):\n",
    "    print(i)\n",
    "    TTS_LQR=TTS_LQR_Solver(A,B,Q,R,2,2,1000000,0.005,0.01,0.1,1)\n",
    "    K_next,data_cost[i],data_actor[i],data_crit[i]=TTS_LQR.Update()\n",
    "np.savetxt('cost_ex1.txt',data_cost)\n",
    "np.savetxt('critic_ex1.txt',data_crit)\n",
    "np.savetxt('actor_ex1.txt',data_actor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12d4a2b1",
   "metadata": {},
   "source": [
    "# Example 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daa017c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "A= np.array([[0.2,0.1,1.0,0],[0.2,0.1,0.1,0],[0,0.1,0.5,0],[0,0,0,0.5]])\n",
    "B= np.array([[0.3,0,0],[0.2,0,0.3],[1.0,1.0,0.3],[0.3,0.1,0.1]])\n",
    "Q= np.array([[1.0,0,0.2,0], [0,1.0,0.1,0], [0.2,0.1,1.0,0.1],[0,0,0.1,1]])\n",
    "R= np.array([[1.0,0.1,1.0], [0.1,1.0,0.5],[1.0,0.5,2]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3d5b97",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_cost=np.zeros((10,1000000))\n",
    "data_crit=np.zeros((10,1000000))\n",
    "data_actor=np.zeros((10,1000000))\n",
    "for i in range(10):\n",
    "    print(i)\n",
    "    TTS_LQR=TTS_LQR_Solver(A,B,Q,R,4,3,1000000,0.005,0.01,0.1,1)\n",
    "    K_next,data_cost[i],data_actor[i],data_crit[i]=TTS_LQR.Update()\n",
    "np.savetxt('cost_ex2.txt',data_cost)\n",
    "np.savetxt('critic_ex2.txt',data_crit)\n",
    "np.savetxt('actor_ex2.txt',data_actor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f473d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
