% context_generalization

function [] = teacher_gd_ndependence_decrease(i,seed)

beta = 1e4;

eta = 0.01;
N_0 = 200;
M = 20;
N_1 = 1000;
n = 20;
m = 10;

P = 1000; 
N_s = 1000;
iteration = 5e6;
start = 3e6;

% alpha_all = [0.2:0.2:2,3:10,15,20];
% alpha = alpha_all(i);
% N = round(P/alpha);
N_all = [20:20:80,100:100:3000];
N = N_all(i);

ind = 1;

rng(1,'twister');
center = normrnd(0,1,N_0/m,n);
sigma = 0.1;
% clustered statistics of x_0
index = randi(n,1,P*m);
x_0 = sqrt(1-eta)*center(:,index)+sqrt(eta)*normrnd(0,1,N_0/m,P*m);
x_0 = reshape(x_0,N_0,P);

rng(3,'twister');
% clustered statistics of x_0
index_t= randi(n,1,N_s*m);
x_t = sqrt(1-eta)*center(:,index_t)+sqrt(eta)*normrnd(0,1,N_0/m,N_s*m);
x_t = reshape(x_t,N_0,N_s);

% random g
u0 = normrnd(0,1,M/m,N_0/m);
u = kron(eye(m,m),u0);

g = (1/sqrt(N_0)*u*x_0>0);
gp = (1/sqrt(N_0)*u*x_t>0);

% y (noisy relu teacher, different dependence on different parts of x)
w0 = [normrnd(0,1,N_1,N_0/m*(ind-1)),normrnd(0,1,N_1,N_0/m),normrnd(0,1,N_1,N_0-N_0/m*ind)];
a0 = normrnd(0,1,1,N_1);
y = 1/sqrt(N_1)*a0*((w0*x_0/mean(sqrt(sum(w0'.^2)))).*(w0*x_0>0)) + 0.03*normrnd(0,1,1,P);
y_t = 1/sqrt(N_1)*a0*((w0*x_t/mean(sqrt(sum(w0'.^2)))).*(w0*x_t>0));


 th = 0;
 [error,H] = numeric_saddlepoint_nonlinear(N_0,P,M,g,y,x_0,N,sigma,beta);




    K_1 = sigma^2/N_0*(x_0'*x_0);
    K_1p = sigma^2/N_0*(x_t'*x_0);
    k_1p = sigma^2/N_0*(x_t'*x_t);
    
    K2 = (g'*H*g).*K_1 + 1/beta*eye(P,P);
    K2p = (gp'*H*g).*K_1p;
    k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fp =  K2p*pinv(K2)*y';
    ft = ((g'*H*g).*K_1)*pinv(K2)*y';
    fvar_th = diag(k2p - K2p*pinv(K2)*K2p');
    
    K2gp = (sigma^2*g'*g/M).*K_1 + 1/beta*eye(P,P);
    K2pgp = (sigma^2*gp'*g/M).*K_1p;
    k2pgp = (sigma^2*gp'*gp/M).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fpgp =  K2pgp*pinv(K2gp)*y';
    fvar_gp = diag(k2pgp - K2pgp*pinv(K2gp)*K2pgp');
     gt = mean((ft-y_t').^2);
    ge = sum((fp-y_t').^2)/N_s + mean(fvar_th);
    gegp = sum((fpgp-y_t').^2)/N_s + mean(fvar_gp);
    
    
    rng(seed,'twister');
w = normrnd(0, sigma, N_0+M, N);
f_average = zeros(1,N_s);
f_var = zeros(1,N_s);
eta = 1e-2;
for j0 = 1:iteration
    
    w_1 = reshape(w(1:N_0, :), N_0, N);
    w_2 = reshape(w(N_0+1:end, :), M, N)';
    
%     for m0 = 1:N_1
%         x_1(m0,:) = 1/sqrt(N_0*M)*diag(x_0'*w_1(:,:,m0)*g);
%     end
    x_1 = 1/sqrt(N_0)*w_1'*x_0;
    x_1t = 1/sqrt(N_0)*w_1'*x_t;
    f = (1/sqrt(N*M)*diag(g'*w_2'*x_1))';
    f_t = (1/sqrt(N*M)*diag(gp'*w_2'*x_1t))';
    E(j0) = sum((y - f).^2);
   

    
    w_2 = w_2 + (eta*1/sqrt(N*M)*g*diag((y-f))*x_1')';
    w_1 = w_1 + (eta/sqrt(N*M*N_0)*(w_2*g)*diag(y-f)*x_0')';
    

    bias = mean((f_t - y_t).^2);
    bias_th = mean((fp' - y_t).^2);
    fvar_th = mean(fvar_th);
    
    
    w(1:N_0,:) = reshape(w_1, N_0, N);
    w(N_0+1:end,:) = reshape(w_2', M, N);
       if ( mod(j0,1e3) == 1)
        disp(['train' num2str(E(end))]);

       end
    
    if E(j0) < gt
        break
    end

end

    
    save(['../example_h_generalization_gd/decrease_example_H_ind',num2str(ind),'_seed',num2str(seed),'_N',num2str(N),'_P',num2str(P),'.mat'],'bias','fvar_th','f_t','bias_th','N_all','y_t','fp','fvar_th','H');
    
end