{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a57e42bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.linalg import multi_dot\n",
    "from numpy import linalg as LA\n",
    "from six.moves import xrange\n",
    "from scipy import linalg as la\n",
    "import scipy.stats as st\n",
    "import random\n",
    "from numpy.linalg import multi_dot\n",
    "import math\n",
    "from math import pow\n",
    "import seaborn as sns\n",
    "import pandas as pd "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "217637c8",
   "metadata": {},
   "source": [
    "# Double-loop AC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "755d5888",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.linalg import multi_dot\n",
    "from numpy import linalg as LA\n",
    "from six.moves import xrange\n",
    "from scipy import linalg as la\n",
    "import random\n",
    "from numpy.linalg import multi_dot\n",
    "import math\n",
    "from math import pow\n",
    "\n",
    "class TTS_LQR_Solver:\n",
    "    def __init__(self,A,B,Q,R,d,k,numstep,alpha,beta,gamma,sigma,outer_numstep,outer_lr):\n",
    "        self.A = A\n",
    "        self.B = B\n",
    "        self.Q = Q\n",
    "        self.R = R\n",
    "        self.d = d\n",
    "        self.k = k\n",
    "        self.numstep = numstep\n",
    "        self.alpha = alpha\n",
    "        self.beta = beta\n",
    "        self.gamma = gamma\n",
    "        self.sigma = sigma\n",
    "        self.outer_numstep=outer_numstep\n",
    "        self.outer_lr=outer_lr\n",
    "        \n",
    "    def Choose_K(self):\n",
    "        for j in range(100000):\n",
    "            K_11 = 0.8*np.random.rand(self.k,self.d)\n",
    "            K_22 = -0.8*np.random.rand(self.k,self.d)\n",
    "            K=(K_11+K_22)/2\n",
    "            O_1=self.A-np.dot(self.B,K)\n",
    "            results1=np.linalg.eig(O_1)\n",
    "            if abs(results1[0][0])<0.8 and abs(results1[0][1])<0.8: #Smallier eigenvalue may convergence fast\n",
    "                #print(K)\n",
    "                break\n",
    "            if j>99998:\n",
    "                print('Choose_K wrong')\n",
    "                break\n",
    "        return K\n",
    "    \n",
    "    def svec(self,X):\n",
    "        n=len(X)\n",
    "        c=int(n*(n+1)/2)\n",
    "        x=np.zeros(c)\n",
    "        for j in range(n):\n",
    "            for i in range(j+1):\n",
    "                l=int((j+1)*j/2+i)\n",
    "                if i!=j:\n",
    "                    x[l]=math.sqrt(2)*X[i][j]\n",
    "                else:\n",
    "                    x[l]=X[i][j]\n",
    "        return x\n",
    "    \n",
    "    \n",
    "    def smat(self,x): \n",
    "        k=len(x)\n",
    "        n=int((math.sqrt(1+8*k)-1)/2)\n",
    "        X=np.zeros((n,n))\n",
    "        for j in range(n):\n",
    "            for i in range(j+1):\n",
    "                t=int((j+1)*j/2+i)\n",
    "                if i!=j:\n",
    "                    X[i][j]=x[t]/math.sqrt(2)\n",
    "                else:\n",
    "                    X[i][j]=x[t]\n",
    "        X_T=np.transpose(X)\n",
    "        X=X+X_T\n",
    "        for i in range(n):\n",
    "            X[i][i]=X[i][i]/2\n",
    "        return X\n",
    "    \n",
    "    def block_mat(self,O): \n",
    "        O_21=np.zeros((self.k,self.d))\n",
    "        O_22=np.zeros((self.k,self.k))\n",
    "        O_11=np.zeros((self.d,self.d))\n",
    "        O_12=np.zeros((self.d,self.k))\n",
    "        for i in range(self.d,self.d+self.k):\n",
    "            for j in range(self.d):\n",
    "                O_21[i-self.d][j]=O[i,j]\n",
    "            for l in range(self.d,self.d+self.k):\n",
    "                O_22[i-self.d][l-self.d]=O[i][l]\n",
    "        return O_22,O_21\n",
    "    \n",
    "    \n",
    "    def Get_phi(self,x,u): \n",
    "        aa=np.zeros(self.d+self.k)\n",
    "        for i in range(self.d):\n",
    "            aa[i]=x[i]\n",
    "        for j in range(self.k):\n",
    "            aa[j+self.d]=u[j]\n",
    "        pp_1=np.dot(aa[:,None],aa[None,:])\n",
    "        return self.svec(pp_1)\n",
    "    \n",
    "    def Project(self,omega):\n",
    "        if np.linalg.norm(omega)>1000:\n",
    "            omega=(omega/np.linalg.norm(omega))*1000\n",
    "        return omega\n",
    "        \n",
    "    def Stationary(self,K):\n",
    "        I=np.eye((self.d))\n",
    "        phi=I\n",
    "        phi_si=phi+self.sigma*self.sigma*np.dot(B,B.T)\n",
    "        C_k=phi_si\n",
    "        for i in range(10000):\n",
    "            C_next=phi_si+np.dot(np.dot((A-np.dot(B,K)),C_k),(A-np.dot(B,K)).T)\n",
    "            C_next+=np.transpose(C_next)\n",
    "            C_next*=0.5\n",
    "            if np.abs(C_k - C_next).max() < 1e-10:\n",
    "                break\n",
    "            if i>99998:\n",
    "                print('Not Stabile')\n",
    "                break\n",
    "            C_k=C_next\n",
    "        return C_next\n",
    "    \n",
    "    def Sample_x(self,K):\n",
    "        mean=np.zeros(self.d)\n",
    "        sigma=self.Stationary(K)\n",
    "        x=np.random.multivariate_normal(mean,sigma)\n",
    "        return x\n",
    "    \n",
    "    def Rollout(self,K):\n",
    "        x=self.Sample_x(K)\n",
    "        u=-np.dot(K,x)+ self.sigma*np.random.multivariate_normal(np.zeros(self.k),np.eye((self.k)))\n",
    "        r=np.dot(np.dot(x,self.Q),x)+np.dot(np.dot(u,self.R),u)\n",
    "        x_next=np.dot(self.A,x)+np.dot(B,u)+np.random.multivariate_normal(np.zeros(self.d),np.eye((self.d)))\n",
    "        u_next=-np.dot(K,x_next)+ self.sigma*np.random.multivariate_normal(np.zeros(self.k),np.eye((self.k)))\n",
    "        return x,u,r,x_next,u_next\n",
    "    \n",
    "    def Optimal_K(self):\n",
    "        P = la.solve_discrete_are(self.A, self.B, self.Q, self.R)\n",
    "        K_opt = multi_dot([np.linalg.inv(self.R + multi_dot([self.B.T, P, self.B])), self.B.T, P, self.A])\n",
    "        return K_opt\n",
    "    \n",
    "    def J_K(self,K):\n",
    "        P_k=self.solve_P_k(K)\n",
    "        phi_sigma=np.eye(self.d)+self.sigma*self.sigma*(np.dot(self.B,self.B.T))\n",
    "        J=np.dot(P_k,phi_sigma)\n",
    "        return J.trace()+self.sigma*self.sigma*self.R.trace()\n",
    "    \n",
    "    def solve_P_k(self, K): # I have tested that this is right\n",
    "        \"\"\"Solves the Bellman equation by iteration.\n",
    "        Bellman Equation\n",
    "        if rho(A-BK) < 1\n",
    "        ```none\n",
    "        P_{t+1} = Q + K'RK + (A-BK)'P_{t}(A-BK)\n",
    "        ```\n",
    "        Returns:\n",
    "            A numpy array, a postive definite matrix P which is the solution to the Bellman equation.\n",
    "        Raises:\n",
    "            RuntimeError: If the computed P matrix is not symmetric and\n",
    "            positive-definite.\n",
    "        \"\"\"\n",
    "        a_b_k = self.A - np.dot(self.B, K)\n",
    "        if False in (np.absolute(LA.eigvals(a_b_k)) < 1):\n",
    "            return False\n",
    "\n",
    "        p = self.Q \n",
    "        p_list = []\n",
    "        for step in xrange(10000):\n",
    "            # p_next = self.Q + np.dot(np.transpose(K), np.dot(self.R, K)) + np.dot(np.transpose(a_b_k), np.dot(p, a_b_k))\n",
    "            p_next = self.Q + multi_dot([K.T, self.R, K]) + multi_dot([a_b_k.T, p, a_b_k]) # P_K=Q+K.TRK+(A-BK).TP_K(A-BK)\n",
    "            p_next += np.transpose(p_next)\n",
    "            p_next *= .5 # to make it symmetric\n",
    "            if np.abs(p - p_next).max() < 1e-9: #close enough\n",
    "                # print(\"step of pk = \" + str(step))\n",
    "                break\n",
    "            p = p_next\n",
    "            p_list += [p]\n",
    "        \n",
    "        if np.abs(p - p_next).max() > 1e-9:\n",
    "            logging.warn('DARE solver did not converge')\n",
    "        try:\n",
    "            # Check that the result is symmetric and positive-definite.\n",
    "            np.linalg.cholesky(p_next)\n",
    "        except np.linalg.LinAlgError:\n",
    "            raise RuntimeError('ARE solver failed: P matrix is not symmetric and '\n",
    "                            'positive-definite.')\n",
    "        return p_next\n",
    "    \n",
    "    def Omega_K(self,K):\n",
    "        P_K=self.solve_P_k(K)\n",
    "        Omega=np.zeros((self.d+self.k,self.d+self.k))\n",
    "        Omega_21=multi_dot([self.B.T,P_K,self.A])\n",
    "        Omega_22=self.R+multi_dot([self.B.T,P_K,self.B])\n",
    "        Omega_11=self.Q+multi_dot([self.A.T, P_K, self.A])\n",
    "        Omega_12=multi_dot([self.A.T, P_K, self.B])\n",
    "        for i in range(self.d):\n",
    "            for j in range(self.d):\n",
    "                Omega[i,j]=Omega_11[i,j]\n",
    "            for j in range(self.d,self.d+self.k):\n",
    "                Omega[i,j]=Omega_12[i,j-self.d]\n",
    "        for i in range(self.d,self.d+self.k):\n",
    "            for j in range(self.d):\n",
    "                Omega[i,j]=Omega_21[i-self.d,j]\n",
    "            for j in range(self.d,self.d+self.k):\n",
    "                Omega[i,j]=Omega_22[i-self.d,j-self.d]\n",
    "        return Omega\n",
    "    \n",
    "    def Update(self):\n",
    "        K=self.Choose_K()\n",
    "        total=0\n",
    "        error=0\n",
    "        eta=0\n",
    "        diff=0\n",
    "        data=np.array([])\n",
    "        data_aver=np.array([])\n",
    "        omega=np.zeros(int((self.d+self.k)*(self.d+self.k+1)/2))\n",
    "        K_opt=self.Optimal_K()\n",
    "        JK=self.J_K(K_opt)\n",
    "        print(JK)\n",
    "        for i in range(self.numstep):\n",
    "            x,u,r,x_next,u_next=self.Rollout(K)\n",
    "            phi=self.Get_phi(x,u)\n",
    "            phi_next=self.Get_phi(x_next,u_next)\n",
    "            delta=r-eta+np.dot(phi_next,omega)-np.dot(phi,omega)\n",
    "            eta_next=eta+(self.gamma/pow(i+1,0.4))*(r-eta)\n",
    "            omega_next=omega+(self.beta/pow((i+1000),0.4))*delta*phi\n",
    "            mat_omega=self.smat(omega)\n",
    "            mat_22,mat_21=self.block_mat(mat_omega)\n",
    "            K_next=K-(self.alpha/pow(i+1000,0.6))*(np.dot(mat_22,K)-mat_21)\n",
    "            #K_next=K-(self.alpha)*(np.dot(mat_22,K)-mat_21)\n",
    "            J_K=self.J_K(K_next)\n",
    "            diff=J_K-JK\n",
    "            data = np.concatenate((data, [diff]), axis=0)\n",
    "            print(eta_next)\n",
    "            total+=J_K-JK\n",
    "            error=total/(i+1)\n",
    "            data_aver = np.concatenate((data_aver, [error]), axis=0)\n",
    "            print('average error:',error)\n",
    "            K=K_next\n",
    "            omega=omega_next\n",
    "            eta=eta_next\n",
    "        return K_next,data,data_aver\n",
    "    def Yang_Rollout(self,K,v1,v2,omega1,omega2):\n",
    "        x=self.Sample_x(K)\n",
    "        print(self.J_K(K))\n",
    "        data1=np.array([])\n",
    "        data2=v2[np.newaxis,:]\n",
    "        u=-np.dot(K,x)+ self.sigma*np.random.multivariate_normal(np.zeros(self.k),np.eye((self.k)))\n",
    "        r=np.dot(np.dot(x,self.Q),x)+np.dot(np.dot(u,self.R),u)\n",
    "        x_next=np.dot(self.A,x)+np.dot(B,u)+np.random.multivariate_normal(np.zeros(self.d),np.eye((self.d)))\n",
    "        a_tot=0\n",
    "        b_tot=0\n",
    "        a_tot_n=0\n",
    "        b_tot_n=0\n",
    "        for i in range(self.numstep):\n",
    "            u_next=-np.dot(K,x_next)+ self.sigma*np.random.multivariate_normal(np.zeros(self.k),np.eye((self.k)))\n",
    "            r_next=np.dot(np.dot(x_next,self.Q),x_next)+np.dot(np.dot(u_next,self.R),u_next)\n",
    "            x_next2=np.dot(self.A,x_next)+np.dot(B,u_next)+np.random.multivariate_normal(np.zeros(self.d),np.eye((self.d)))\n",
    "            phi=self.Get_phi(x,u)\n",
    "            phi_next=self.Get_phi(x_next,u_next)\n",
    "            delta=v1-r+np.dot(phi,v2)-np.dot(phi_next,v2)\n",
    "            v1_next=v1-(self.alpha/pow(i+1,0.5))*(omega1+np.dot(phi,omega2))\n",
    "            v2_next=v2-(self.alpha/pow(i+1,0.5))*np.dot(phi,omega2)*(phi-phi_next)\n",
    "            omega1_next=(1-self.alpha/pow(i+1,0.5))*omega1+(self.alpha/pow(i+1,0.5))*(v1-r)\n",
    "            omega2_next=(1-self.alpha/pow(i+1,0.5))*omega2+(self.alpha/pow(i+1,0.5))*delta*phi\n",
    "            v1_next=self.Project(v1_next)\n",
    "            v2_next=self.Project(v2_next)\n",
    "            omega1_next=self.Project(omega1_next)\n",
    "            omega2_next=self.Project(omega2_next)\n",
    "            x=x_next\n",
    "            x_next=x_next2\n",
    "            r=r_next\n",
    "            u=u_next\n",
    "            v1=v1_next\n",
    "            v2=v2_next\n",
    "            omega1=omega1_next\n",
    "            omega2=omega2_next\n",
    "            a_tot=a_tot_n\n",
    "            b_tot=b_tot_n\n",
    "            data1=np.concatenate((data1, [v1]), axis=0)\n",
    "            data2=np.concatenate((data2, v2[np.newaxis,:]), axis=0)\n",
    "        v1_total=0\n",
    "        v2_total=np.zeros(int((self.d+self.k)*(self.d+self.k+1)/2))\n",
    "        step_total=0\n",
    "        for i in range(int(99*self.numstep/100),self.numstep):\n",
    "            v1_total=v1_total+(self.alpha/pow(i+1,0.5))*data1[i]\n",
    "            step_total=step_total+self.alpha/pow(i+1,0.5)\n",
    "            v2_total=v2_total+(self.alpha/pow(i+1,0.5))*data2[i+1]\n",
    "        v1_final=v1_total/step_total\n",
    "        v2_final=v2_total/step_total\n",
    "        ture_omega=self.Omega_K(K)\n",
    "        mat_omega=self.smat(v2_final)\n",
    "        mat_v2=self.smat(v2_final)\n",
    "        mat_22,mat_21=self.block_mat(mat_v2)\n",
    "        \n",
    "        return mat_22,mat_21,v1_final,v2_final,omega1,omega2\n",
    "    def Yang_Update(self):\n",
    "        K=self.Choose_K()\n",
    "        data=np.array([])\n",
    "        diff=0\n",
    "        v1=20\n",
    "        v2=np.zeros(int((self.d+self.k)*(self.d+self.k+1)/2))\n",
    "        K_opt=self.Optimal_K()\n",
    "        for i in range(self.outer_numstep):\n",
    "            omega1=0\n",
    "            omega2=np.zeros(int((self.d+self.k)*(self.d+self.k+1)/2))\n",
    "            O_22,O_21,v1,v2,omega1,omega2=self.Yang_Rollout(K,v1,v2,omega1,omega2)\n",
    "            K_next=K-self.outer_lr*(np.dot(O_22,K)-O_21)\n",
    "            diff=np.linalg.norm(K_next-K_opt)\n",
    "            print(diff)\n",
    "            data=np.concatenate((data,[diff]),axis=0)\n",
    "            K=K_next\n",
    "        return K_next,data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9052088a",
   "metadata": {},
   "source": [
    "# Example 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "152ffe61",
   "metadata": {},
   "outputs": [],
   "source": [
    "A= np.array([[0,1],[1,0]])\n",
    "B= np.array([[0,1],[1,0]])\n",
    "Q= np.array([[9,2],[2,1]])\n",
    "R= np.array([[1,2],[2,8]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b6d21ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "TTS_LQR=TTS_LQR_Solver(A,B,Q,R,2,2,500000,0.01,0.01,0.1,0.2,100,0.05)\n",
    "K,data=TTS_LQR.Yang_Update()\n",
    "np.savetxt('Double-loop AC.txt',data) #Repeated only once. We run it separately for 10 times since it may crack. You can set a larger inner-loop numstep to avoid crack but it's really time-consuming."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "289c2feb",
   "metadata": {},
   "source": [
    "# Example 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50aca415",
   "metadata": {},
   "outputs": [],
   "source": [
    "A= np.array([[0.2,0.1,1.0,0],[0.2,0.1,0.1,0],[0,0.1,0.5,0],[0,0,0,0.5]])\n",
    "B= np.array([[0.3,0,0],[0.2,0,0.3],[1.0,1.0,0.3],[0.3,0.1,0.1]])\n",
    "Q= np.array([[1.0,0,0.2,0], [0,1.0,0.1,0], [0.2,0.1,1.0,0.1],[0,0,0.1,1]])\n",
    "R= np.array([[1.0,0.1,1.0], [0.1,1.0,0.5],[1.0,0.5,2]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eab32715",
   "metadata": {},
   "outputs": [],
   "source": [
    "TTS_LQR=TTS_LQR_Solver(A,B,Q,R,2,2,500000,0.01,0.01,0.1,0.2,100,0.05)\n",
    "K,data=TTS_LQR.Yang_Update()\n",
    "np.savetxt('Double-loop AC2.txt',data) #Repeated only once. We run it separately for 10 times since it may crack. You can set a larger inner-loop numstep to avoid crack but it's really time-consuming."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e11c78a3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
