rng(11)
N=20*1e+5;



%system 1
%A=[1.01,0.01,0;0.01,1.01,0.01;0,0.01,1.01];
%B=eye(3);
%Q=0.5*eye(3);
%R=eye(3);
%Q_T=0.5*eye(size(Q,1),size(Q,2));

%system 2 
A=rand(7,7);
B=rand(7,4);
Q=zeros(size(A,2),size(A,2));
R=zeros(size(B,2),size(B,2));
M1=diag(rand(size(A,2),1));
Z1=orth(rand(size(A,2),size(A,2)));
Q(1:size(A,2),1:size(A,2))=Z1*M1*Z1';
M2=diag(rand(size(B,2),1));
Z2=orth(rand(size(B,2),size(B,2)));
R(1:size(B,2),1:size(B,2))=Z2*M2*Z2';
Q_T=zeros(size(Q,1),size(Q,2));
regret_total=zeros(N,1);



parfor sample=1:50
    rng(sample)
    T=3;
    sigma=1;
    gamma=0.1;
    gamma1=0.05;
    m0=500;
    minitial=m0;
    L=ceil(log2(N/m0+1))-1;
    x_initial=zeros(size(A,2),1);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %The optimal control and Riccati equation
    K=zeros(size(B,2),size(A,2),T);
    P=zeros(size(A,2),size(A,2),T+1);
    P(1:size(A,2),1:size(A,2),T+1)=Q_T;
    Ptilde=zeros(size(A,2),size(A,2),T+1);
    
    for t=T:-1:1
        Ptilde(1:size(A,2),1:size(A,2),t+1)=eye(size(A,2))/(eye(size(A,2))-gamma*sigma^2*P(1:size(A,2),1:size(A,2),t+1))*P(1:size(A,2),1:size(A,2),t+1);
        K(1:size(B,2),1:size(A,2),t)=-eye(size(B,2))/(B'*Ptilde(1:size(A,2),1:size(A,2),t+1)*B+R)*B'*Ptilde(1:size(A,2),1:size(A,2),t+1)*A;
        P(1:size(A,2),1:size(A,2),t)=Q+K(1:size(B,2),1:size(A,2),t)'*R*K(1:size(B,2),1:size(A,2),t)+(A+B*K(1:size(B,2),1:size(A,2),t))'*Ptilde(1:size(A,2),1:size(A,2),t+1)*(A+B*K(1:size(B,2),1:size(A,2),t));
    end
    %optimal value function
    Jstar=0.5*(x_initial'*P(1:size(A,2),1:size(A,2),1)*x_initial);
    for t=1:T
        Jstar=Jstar-0.5/gamma*log(det(eye(size(A,2))-gamma*sigma^2*P(1:size(A,2),1:size(A,2),t+1)));
    end
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %calculate the constants in the regret bound
    Gamma=zeros(2+3*T,1);
    Gamma(1)=norm(A,2);
    Gamma(2)=norm(B,2);
    for t=1:T
        Gamma(2+t)=norm(P(1:size(A,1),1:size(A,2),t),2);
        Gamma(2+T+t)=norm(Ptilde(1:size(A,1),1:size(A,2),t),2);
        Gamma(2+2*T+t)=norm(K(1:size(B,2),1:size(A,2),t),2);
    end
    for t=1:T
        Gamma(2+T+t)=norm(Ptilde(1:size(A,1),1:size(A,2),t),2);
    end
    Gatilde=max(Gamma)+1;
    mL=1/(1-gamma*sigma^2*Gatilde);
    mV=2*(mL+1)*Gatilde^3;
    psi=zeros(T,1);
    psi(T)=2*Gatilde^3;
    for t=T-1:-1:1
        psi(t)=2*Gatilde^3*(10*mV^2*mL*Gatilde^4)^(2*(T-t-1))+12*Gatilde^4*psi(t+1);
    end
    reconst=sum(sigma^2*size(A,2)*mV^2*psi);
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    A0=rand(size(A,1),size(A,2));
    B0=rand(size(B,1),size(B,2)); 
    
    %estimated control and riccati equation
    K0=zeros(size(B,2),size(A,2),T);
    P0=zeros(size(A,2),size(A,2),T+1);
    P0(1:size(A,2),1:size(A,2),T+1)=Q_T;
    Ptilde0=zeros(size(A,2),size(A,2),T+1);
    %Ricatti equation under estimated control but true system matrices
    P1=zeros(size(A,2),size(A,2),T+1);
    P1(1:size(A,2),1:size(A,2),T+1)=Q_T;
    Ptilde1=zeros(size(A,2),size(A,2),T+1);
    %record the estimated cost function
    J=zeros(N,1);
    num=0;
    for l=1:L+1
        %update estimated control and riccati equations
        for t=T:-1:1
            Ptilde0(1:size(A,2),1:size(A,2),t+1)=eye(size(A,2))/(eye(size(A,2))-gamma1*sigma^2*P0(1:size(A,2),1:size(A,2),t+1))*P0(1:size(A,2),1:size(A,2),t+1);
            Ptilde1(1:size(A,2),1:size(A,2),t+1)=eye(size(A,2))/(eye(size(A,2))-gamma*sigma^2*P1(1:size(A,2),1:size(A,2),t+1))*P1(1:size(A,2),1:size(A,2),t+1);
            K0(1:size(B,2),1:size(A,2),t)=-eye(size(B,2))/(B0'*Ptilde0(1:size(A,2),1:size(A,2),t+1)*B0+R)*B0'*Ptilde0(1:size(A,2),1:size(A,2),t+1)*A0;
            P0(1:size(A,2),1:size(A,2),t)=Q+K0(1:size(B,2),1:size(A,2),t)'*R*K0(1:size(B,2),1:size(A,2),t)+(A0+B0*K0(1:size(B,2),1:size(A,2),t))'*Ptilde0(1:size(A,2),1:size(A,2),t+1)*(A0+B0*K0(1:size(B,2),1:size(A,2),t));
            P1(1:size(A,2),1:size(A,2),t)=Q+K0(1:size(B,2),1:size(A,2),t)'*R*K0(1:size(B,2),1:size(A,2),t)+(A+B*K0(1:size(B,2),1:size(A,2),t))'*Ptilde1(1:size(A,2),1:size(A,2),t+1)*(A+B*K0(1:size(B,2),1:size(A,2),t));
        end
        Jhat=0.5*(x_initial'*P1(1:size(A,2),1:size(A,2),1)*x_initial);
        for t=1:T
            Jhat=Jhat-0.5/gamma*log(det(eye(size(A,2))-gamma*sigma^2*P1(1:size(A,2),1:size(A,2),t+1)));
        end
        Z=zeros(size(A,2)+size(B,2),T+1,m0);
        X=zeros(size(A,2),T+1,m0);
        %evolve the system dynamics under the estimated control
        for k=1:m0
            num=num+1;
            J(num)=Jhat;
            if mod(num,1000)==0
                num
                Jstar
                Jhat
            end
            x0=x_initial;
            ik=zeros(size(A,2)+size(B,2));
            kk=zeros(size(B,2));
            for t=1:T
                u=K0(1:size(B,2),1:size(A,2),t)*x0;
                x1=A*x0+B*u+mvnrnd(zeros(size(A,2),1),sigma^2*eye(size(A,2)))';
                ik=ik+[eye(size(A,2));K0(1:size(B,2),1:size(A,2),t)]*[eye(size(A,2));K0(1:size(B,2),1:size(A,2),t)]';
                kk=kk+K0(1:size(B,2),1:size(A,2),t)*K0(1:size(B,2),1:size(A,2),t)';
                X(1:size(A,2),t,k)=x0;
                Z(1:size(A,2)+size(B,2),t,k)=[x0;u];
                x0=x1;
            end
            Z(1:size(A,2)+size(B,2),T+1,k)=[x0;zeros(size(B,2),1)];
            X(1:size(A,2),T+1,k)=x0;
        end
        %estimate the new system matrices
        zz=zeros(size(A,2)+size(B,2));
        zx=zeros(size(A,2)+size(B,2),size(A,2));
        xx=zeros(size(A,2));
        for k=1:m0
            for t=1:T
                zz=zz+Z(1:size(A,2)+size(B,2),t,k)*Z(1:size(A,2)+size(B,2),t,k)';
                zx=zx+Z(1:size(A,2)+size(B,2),t,k)*Z(1:size(A,2),t+1,k)';
                xx=xx+X(1:size(A,2),t,k)*X(1:size(A,2),t,k)';
            end
        end
        theta=eye(size(A,2)+size(B,2))/(zz+eye(size(A,2)+size(B,2)))*(zx);
        A0=theta(1:size(A,2),1:size(A,2))';
        B0=theta(size(A,2)+1:size(A,2)+size(B,2),1:size(B,1))';
        if l<L
            m0=2*m0;
        elseif l==L
            m0=N-(2^L-1)*minitial;
        end
    end
    Regret=cumsum(J-Jstar);
    regret_total=[regret_total,Regret];
end
dim=size(regret_total);
regret_final=regret_total(1:N,2:dim(2));
regret_mean=mean(regret_final,2);
regret_std=std(regret_final,0,2);
regret_upper=regret_mean+regret_std/sqrt(dim(2)-1)*1.96;
regret_lower=regret_mean-regret_std/sqrt(dim(2)-1)*1.96;