% context_generalization

function [] = teacher_gd_ndependence_sample(i,seed)

beta = 1e4;

eta = 0.01;
N_0 = 100;
M = 50;
N_1 = 1000;
n = 20;
m = 5;

P = 200; 
N_s = 1000;
iteration = 6e6;
start = 4e6;

% alpha_all = [0.2:0.2:2,3:10,15,20];
% alpha = alpha_all(i);
% N = round(P/alpha);
N_all = [20:20:80,100:100:3000];
N = N_all(i);

ind = 1;

rng(1,'twister');
center = normrnd(0,1,N_0/m,n);
sigma = 0.1;
% clustered statistics of x_0
index = randi(n,1,P*m);
x_0 = sqrt(1-eta)*center(:,index)+sqrt(eta)*normrnd(0,1,N_0/m,P*m);
x_0 = reshape(x_0,N_0,P);

rng(3,'twister');
% clustered statistics of x_0
index_t= randi(n,1,N_s*m);
x_t = sqrt(1-eta)*center(:,index_t)+sqrt(eta)*normrnd(0,1,N_0/m,N_s*m);
x_t = reshape(x_t,N_0,N_s);

% random g
u0 = normrnd(0,1,M/m,N_0/m);
u = kron(eye(m,m),u0);

g = (1/sqrt(N_0)*u*x_0>0);
gp = (1/sqrt(N_0)*u*x_t>0);

% y (noisy relu teacher, different dependence on different parts of x)
w0 = [0.01*normrnd(0,1,N_1,N_0/m*(ind-1)),normrnd(0,1,N_1,N_0/m),0.01*normrnd(0,1,N_1,N_0-N_0/m*ind)];
a0 = normrnd(0,1,1,N_1);
y = 1/sqrt(N_1)*a0*((w0*x_0/mean(sqrt(sum(w0'.^2)))).*(w0*x_0>0)) + 0.01*normrnd(0,1,1,P);
y_t = 1/sqrt(N_1)*a0*((w0*x_t/mean(sqrt(sum(w0'.^2)))).*(w0*x_t>0)) + 0.01*normrnd(0,1,1,N_s);


 th = 0;
 [error,H] = numeric_saddlepoint_nonlinear(N_0,P,M,g,y,x_0,N,sigma,beta);




    K_1 = sigma^2/N_0*(x_0'*x_0);
    K_1p = sigma^2/N_0*(x_t'*x_0);
    k_1p = sigma^2/N_0*(x_t'*x_t);
    
    K2 = (g'*H*g).*K_1 + 1/beta*eye(P,P);
    K2p = (gp'*H*g).*K_1p;
    k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fp =  K2p*pinv(K2)*y';
    fvar_th = diag(k2p - K2p*pinv(K2)*K2p');
    
    K2gp = (sigma^2*g'*g/M).*K_1 + 1/beta*eye(P,P);
    K2pgp = (sigma^2*gp'*g/M).*K_1p;
    k2pgp = (sigma^2*gp'*gp/M).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fpgp =  K2pgp*pinv(K2gp)*y';
    fvar_gp = diag(k2pgp - K2pgp*pinv(K2gp)*K2pgp');
    
    ge = sum((fp-y_t').^2)/N_s + mean(fvar_th);
    gegp = sum((fpgp-y_t').^2)/N_s + mean(fvar_gp);
    
    
    rng(seed,'twister');
w = normrnd(0, 1, N_0+M, N);
f_average = zeros(1,N_s);
f_var = zeros(1,N_s);
eta = 1e-4;
for j0 = 1:iteration
    
    w_1 = reshape(w(1:N_0, :), N_0, N);
    w_2 = reshape(w(N_0+1:end, :), M, N)';
    
%     for m0 = 1:N_1
%         x_1(m0,:) = 1/sqrt(N_0*M)*diag(x_0'*w_1(:,:,m0)*g);
%     end
    x_1 = 1/sqrt(N_0)*w_1'*x_0;
    x_1t = 1/sqrt(N_0)*w_1'*x_t;
    f = (1/sqrt(N*M)*diag(g'*w_2'*x_1))';
    f_t = (1/sqrt(N*M)*diag(gp'*w_2'*x_1t))';
    E(j0) = sum((y - f).^2);
   

    
    w_2 = w_2 + (eta*1/sqrt(N*M)*g*diag((y-f))*x_1')' - eta/(beta*sigma^2)*w_2 + sqrt(2*eta/(beta))*normrnd(0,1,M,N)';
    w_1 = w_1 + (eta/sqrt(N*M*N_0)*(w_2*g)*diag(y-f)*x_0')' - eta/(beta*sigma^2)*w_1 + sqrt(2*eta/beta)*normrnd(0,1,N_0,N);
    
    if j0 > start
            f_old = f_average;
            f_average = (f_average*(j0-1-start)+f_t)/(j0-start);
            f_var = (j0-start)/(j0-start+1)*f_var+(j0-start-1)/(j0-start+1)*(f_average-f_old).^2+1/(j0-start+1)*(f_t-f_average).^2;
    end
    bias = mean((f_average - y_t).^2);
    bias_th = mean((fp' - y_t).^2);
    fvar_th = mean(fvar_th);
    
    
    w(1:N_0,:) = reshape(w_1, N_0, N);
    w(N_0+1:end,:) = reshape(w_2', M, N);
       if (j0 > start & mod(j0,1e3) == 1)
        disp(['bias' num2str(bias),'theory',num2str(bias_th)]);
        disp(['train' num2str(E(end))]);
        disp(['variance' num2str(mean(f_var)),'theory',num2str(fvar_th)]);
       end
end

    
    save(['../example_h_generalization/increase_example_H_ind',num2str(ind),'_seed',num2str(seed),'_N',num2str(N),'_P',num2str(P),'.mat'],'fvar_th','f_var','bias','bias_th','N_all','ge','gegp','y_t','fp','fpgp','fvar_gp','fvar_th','H');
    
end