close all; clear; clc
%% data
pp = [125,250,500,1000,2000];
error_result = zeros(length(pp),1);
for j= 1:length(pp)
    p = pp(j);
    n = p *4/5; % p/n=5:4
    a1 = 1; a2 = 2; % two classes
    mu1 = [zeros(8*(a1-1),1); 8; zeros(p-8*a1+7,1)];
    mu2 = [zeros(8*(a2-1),1); 8; zeros(p-8*a2+7,1)]; % two means
    c1 = 1+8*(a1-1)/sqrt(p); c2 = 1+8*(a2-1)/sqrt(p); % two variances
    % c1 = 1; c2 = 1;
    Ip = eye(p);
    C1 = c1*Ip; C2 = c2*Ip;
    Z1 = randn(p,n/2) * sqrt(c1); Z2 = randn(p,n/2) * sqrt(c2);
    X1 = Z1 /sqrt(p) + repmat(mu1 /sqrt(p),1,n/2);
    X2 = Z2 /sqrt(p) + repmat(mu2 /sqrt(p),1,n/2); % two-classes data
    X = [X1,X2];

    C0 = 1/2*C1+1/2*C2;
    C10 = C1-C0; C20= C2-C0;
    t = [trace(C10)/sqrt(p); trace(C20)/sqrt(p)];
    T = [trace(C10*C10),trace(C10*C20);  trace(C20*C10), trace(C20*C20)]/p;
    VZN1 = vecnorm(Z1)/p; VZN2 =  vecnorm(Z2)/p;
    CC1 = repmat(trace(C1)/p,1,n/2); CC2 = repmat(trace(C2)/p,1,n/2);
    Psi1 = VZN1 - CC1; Psi2 = VZN2 - CC2;
    Psi = [Psi1, Psi2];
    J1 = repmat([1,0],n/2,1); J2 = repmat([0,1],n/2,1);
    J = [J1; J2];
    tau0 = sqrt(trace(C0)/p);

    %% tau
    s = 0.1; % sigma^2

    tau=0;
    phi_t = @(x) max(0,x);
    for i=1:20
        tau = sqrt(s*  Ef2(tau)+(1-s)*tau0^2);
        %         tau_ = sqrt(s*Ef2(tau)+(1-s)*tau0^2);
        %         abs(tau-tau_)
        %         tau = tau_;
    end

    %% CK:G
    phi = @(x) max(0,x)-tau/sqrt(2*pi);
    G =0;
    K =400;
    for k=1:K
        m = 1000;
        A = randn(m,m);
        B = randn(m,p);
        Z = zeros(m,n);
        for i = 1:15
            Z_ = phi(sqrt(s)*A*Z + sqrt(1-s)*B*X)/sqrt(m);
            if norm(Z-Z_)>1e-5
                Z = Z_;
            else
                %             k
                %             i
                %             norm(Z-Z_)
                break
            end
        end
        G = G + (s*(Z'*Z)+ (1-s)*(X'*X))/K;
    end

    %% CK_:G_
    g = @(x) exp(-x.^2/2)/sqrt(2*pi);
    ed1f1 = integral(@(x) 1.*g(x), 0, +Inf);
    ed2f2 = integral(@(x) 2.*g(x), 0, +Inf);
    g4 = ed2f2^2;
    g1 = ed1f1^2;
    g2 = 0;
    alpha1 =(1-s*g1)^(-1)*(1-s);
    alpha4 = (1-s/2*g4)^(-1)*(1-s);
    alpha2 = s/4 * (1-s*g1)^(-1) * g2 * alpha4^2;
    alpha3 = s/2 * (1-s*g1)^(-1) * g2 * alpha1^2;
    V = [J/sqrt(p), Psi'];
    C = [alpha2*(t*t')+alpha3*T, alpha2*t; alpha2*t', alpha2];
    G_ = alpha1*(X'*X)+ V*C*V'+ (tau^2-tau0^2*alpha1-tau0^4*alpha3)*eye(n,n);
    error_result(j) = norm(G-G_,2);
%     norm(G-G_,2)
end
% norm(G-G_,2)
%% functions
function y = Ef2(t) % E[phi^2(tx)]
g = @(x) exp(-x.^2/2)/sqrt(2*pi);
y= integral(@(x) (max(t.*x, 0)-t/sqrt(2*pi)).^2.*g(x), -Inf, +Inf);
end