rng(11)
NN=5*1e+5;
lambda=0.8;
regret_total=zeros(NN,1);



%system 1
%A=[1.01,0.01,0;0.01,1.01,0.01;0,0.01,1.01];
%B=eye(3);
%Q=0.5*eye(3);
%R=eye(3);
%Q_T=0.5*eye(size(Q,1),size(Q,2));


%system 2
A=rand(7,7);
B=rand(7,4);
Q=zeros(size(A,2),size(A,2));
R=zeros(size(B,2),size(B,2));
M1=diag(rand(size(A,2),1));
Z1=orth(rand(size(A,2),size(A,2)));
Q(1:size(A,2),1:size(A,2))=Z1*M1*Z1';
M2=diag(rand(size(B,2),1));
Z2=orth(rand(size(B,2),size(B,2)));
R(1:size(B,2),1:size(B,2))=Z2*M2*Z2';
Q_T=zeros(size(Q,1),size(Q,2));



parfor sample=1:50
    rng(sample)       
    T=3;
    sigma=1;
    gamma=0.1;
    gamma1=0.05;
    x_initial=zeros(size(A,2),1);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %The optimal control and Riccati equation
    K=zeros(size(B,2),size(A,2),T);
    P=zeros(size(A,2),size(A,2),T+1);
    P(1:size(A,2),1:size(A,2),T+1)=Q_T;
    Ptilde=zeros(size(A,2),size(A,2),T+1);
    
    for t=T:-1:1
        Ptilde(1:size(A,2),1:size(A,2),t+1)=eye(size(A,2))/(eye(size(A,2))-gamma*sigma^2*P(1:size(A,2),1:size(A,2),t+1))*P(1:size(A,2),1:size(A,2),t+1);
        K(1:size(B,2),1:size(A,2),t)=-eye(size(B,2))/(B'*Ptilde(1:size(A,2),1:size(A,2),t+1)*B+R)*B'*Ptilde(1:size(A,2),1:size(A,2),t+1)*A;
        P(1:size(A,2),1:size(A,2),t)=Q+K(1:size(B,2),1:size(A,2),t)'*R*K(1:size(B,2),1:size(A,2),t)+(A+B*K(1:size(B,2),1:size(A,2),t))'*Ptilde(1:size(A,2),1:size(A,2),t+1)*(A+B*K(1:size(B,2),1:size(A,2),t));
    end
    %optimal value function
    Jstar=0.5*(x_initial'*P(1:size(A,2),1:size(A,2),1)*x_initial);
    for t=1:T
        Jstar=Jstar-0.5/gamma*log(det(eye(size(A,2))-gamma*sigma^2*P(1:size(A,2),1:size(A,2),t+1)));
    end
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %calculate the constants in the regret bound
    Gamma=zeros(2+3*T,1);
    Gamma(1)=norm(A,2);
    Gamma(2)=norm(B,2);
    for t=1:T
        Gamma(2+t)=norm(P(1:size(A,1),1:size(A,2),t),2);
        Gamma(2+T+t)=norm(Ptilde(1:size(A,1),1:size(A,2),t),2);
        Gamma(2+2*T+t)=norm(K(1:size(B,2),1:size(A,2),t),2);
    end
    for t=1:T
        Gamma(2+T+t)=norm(Ptilde(1:size(A,1),1:size(A,2),t),2);
    end
    Gatilde=max(Gamma)+1;
    mL=1/(1-gamma*sigma^2*Gatilde);
    mV=2*(mL+1)*Gatilde^3;
    psi=zeros(T,1);
    psi(T)=2*Gatilde^3;
    for t=T-1:-1:1
        psi(t)=2*Gatilde^3*(10*mV^2*mL*Gatilde^4)^(2*(T-t-1))+12*Gatilde^4*psi(t+1);
    end
    reconst=sum(sigma^2*size(A,2)*mV^2*psi);
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    A0=rand(size(A,1),size(A,2));
    B0=rand(size(B,1),size(B,2));

    
    
    %estimated control and riccati equation
    K0=zeros(size(B,2),size(A,2),T);
    P0=zeros(size(A,2),size(A,2),T+1);
    P0(1:size(A,2),1:size(A,2),T+1)=Q_T;
    Ptilde0=zeros(size(A,2),size(A,2),T+1);
    %Ricatti equation under estimated control but true system matrices
    P1=zeros(size(A,2),size(A,2),T+1);
    P1(1:size(A,2),1:size(A,2),T+1)=Q_T;
    Ptilde1=zeros(size(A,2),size(A,2),T+1);
    %record the estimated cost function
    J=zeros(NN,1);
    num=0;
    %save the data for the estimation of system matrices
    zz=zeros(size(A,2)+size(B,2));
    zx=zeros(size(A,2)+size(B,2),size(A,2));
    xx=zeros(size(A,2));

    for l=1:NN
        %update estimated control and riccati equations
        for t=T:-1:1
            Ptilde0(1:size(A,2),1:size(A,2),t+1)=eye(size(A,2))/(eye(size(A,2))-gamma1*sigma^2*P0(1:size(A,2),1:size(A,2),t+1))*P0(1:size(A,2),1:size(A,2),t+1);
            Ptilde1(1:size(A,2),1:size(A,2),t+1)=eye(size(A,2))/(eye(size(A,2))-gamma*sigma^2*P1(1:size(A,2),1:size(A,2),t+1))*P1(1:size(A,2),1:size(A,2),t+1);
            K0(1:size(B,2),1:size(A,2),t)=-eye(size(B,2))/(B0'*Ptilde0(1:size(A,2),1:size(A,2),t+1)*B0+R)*B0'*Ptilde0(1:size(A,2),1:size(A,2),t+1)*A0;
            P0(1:size(A,2),1:size(A,2),t)=Q+K0(1:size(B,2),1:size(A,2),t)'*R*K0(1:size(B,2),1:size(A,2),t)+(A0+B0*K0(1:size(B,2),1:size(A,2),t))'*Ptilde0(1:size(A,2),1:size(A,2),t+1)*(A0+B0*K0(1:size(B,2),1:size(A,2),t));
            P1(1:size(A,2),1:size(A,2),t)=Q+K0(1:size(B,2),1:size(A,2),t)'*R*K0(1:size(B,2),1:size(A,2),t)+(A+B*K0(1:size(B,2),1:size(A,2),t))'*Ptilde1(1:size(A,2),1:size(A,2),t+1)*(A+B*K0(1:size(B,2),1:size(A,2),t));
        end
        Jhat=0.5*(x_initial'*P1(1:size(A,2),1:size(A,2),1)*x_initial);
        for t=1:T
            Jhat=Jhat-0.5/gamma*log(det(eye(size(A,2))-gamma*sigma^2*P1(1:size(A,2),1:size(A,2),t+1)));
        end
        Z=zeros(size(A,2)+size(B,2),T+1);
        X=zeros(size(A,2),T+1);
        %evolve the system dynamics under the estimated control
    
        num=num+1;
        J(num)=Jhat;
        if mod(num,1000)==0
            num
            Jstar
            Jhat
        end
        x0=x_initial;
        ik=zeros(size(A,2)+size(B,2));
        kk=zeros(size(B,2));
        for t=1:T
            u=K0(1:size(B,2),1:size(A,2),t)*x0+1/(l^(1/4))*mvnrnd(zeros(size(B,2),1),sigma^2*eye(size(B,2)))';
            %u=K0(1:size(B,2),1:size(A,2),t)*x0+mvnrnd(zeros(size(B,2),1),sigma^2*eye(size(B,2)))';
            x1=A*x0+B*u+mvnrnd(zeros(size(A,2),1),sigma^2*eye(size(A,2)))';
            ik=ik+[eye(size(A,2));K0(1:size(B,2),1:size(A,2),t)]*[eye(size(A,2));K0(1:size(B,2),1:size(A,2),t)]';
            kk=kk+K0(1:size(B,2),1:size(A,2),t)*K0(1:size(B,2),1:size(A,2),t)';
            X(1:size(A,2),t)=x0;
            Z(1:size(A,2)+size(B,2),t)=[x0;u];
            x0=x1;
        end
        Z(1:size(A,2)+size(B,2),T+1)=[x0;zeros(size(B,2),1)];
        X(1:size(A,2),T+1)=x0;
    
        %estimate the new system matrices
    
        for t=1:T
            zz=zz+Z(1:size(A,2)+size(B,2),t)*Z(1:size(A,2)+size(B,2),t)';
            zx=zx+Z(1:size(A,2)+size(B,2),t)*Z(1:size(A,2),t+1)';
            xx=xx+X(1:size(A,2),t)*X(1:size(A,2),t)';
        end
        theta=eye(size(A,2)+size(B,2))/(zz+lambda*eye(size(A,2)+size(B,2)))*(zx);
        A0=theta(1:size(A,2),1:size(A,2))';
        B0=theta(size(A,2)+1:size(A,2)+size(B,2),1:size(B,1))';
    end
    Regret=cumsum(J-Jstar);
    regret_total=[regret_total,Regret];
end
dim=size(regret_total);
regret_final=regret_total(1:NN,2:dim(2));
regret_mean=mean(regret_final,2);
regret_std=std(regret_final,0,2);
regret_upper=regret_mean+regret_std/sqrt(dim(2)-1)*1.96;
regret_lower=regret_mean-regret_std/sqrt(dim(2)-1)*1.96;