% example multiple tasks

% gradient descent ucl model
function [] = compute_theory_twotask(i,seed)
nall = [20,50:50:500,800,1000];
N = nall(i);

load('mnist.mat')

n = 2;
iteration = 1e6;
start = 3e4;

if n < 3
eta = 0.002;
else 
    eta = 0.002;
end
N_0 = 400;

P = 300;
N_s = 500;
M = 20;
sigma = 1;
beta = 1e4;
rng(2,'twister');
trainX = double(trainX);
testX = double(testX);

w0 = normrnd(0,1,N_0,784);


trainX = (trainX'-mean(trainX'))./std(trainX');
testX = (testX'-mean(testX'))./std(testX');

trainX = (w0*trainX>0).*(1/sqrt(N_0)*w0*trainX);
testX = (w0*testX>0).*(1/sqrt(N_0)*w0*testX);

trainX = (trainX-mean(trainX))./std(trainX);
testX = (testX-mean(testX))./std(testX);

indtrain = find(trainY==0);
indtest = find(testY==0);
indtrain1 = find(trainY==1);
indtest1 = find(testY==1);
index = [indtrain(1:P/2),indtrain1(1:P/2)];
index_t = [indtest(1:N_s/2),indtest1(1:N_s/2)];

x_00 = trainX(:,index);
x_t0 = testX(:,index_t);

y = (trainY(index)==1);
y_t = (testY(index_t)==1);
y_1 = ([1:n]<=round(n/2));
yt_1 = ([1:n]<=round(n/2));

th = 0;
for k = 1:n
    ind(k,:) = 1:N_0;
    ind2 = randperm(N_0,N_0);
    ind(k,ind2) = sort(ind2);
    x_0(:,((k-1)*P+1):(k*P)) = double(x_00(ind(k,:),:));
    x_t(:,((k-1)*N_s+1):(k*N_s)) = double(x_t0(ind(k,:),:));
end

rng(1,'twister');
th = 1.5;
g1 = (normrnd(0,1,M,2)<th);
gp1 = g1;

th2 = 0;
v = normrnd(0,1,M,N_0);
g = (v*x_0>th2);


gp = (v*x_t>th2);


% iteration = 100;
% beta0 = 0.05;
% [g0,xc{1}] = softkmeans(sum(g1(:,1)), x_0, beta0, iteration, seed2);
% [g01,xc{2}] = softkmeans(sum(g1(:,2)), x_0, beta0, iteration, seed2);
% %g = (g>1/N);
% % 
% for i0 = 1:sum(g1(:,1))
%    d(i0,:) = sum((x_t - xc{1}(:, i0)).^2)';
%    expd(i0,:) = exp(-beta0*d(i0,:));
% end
% gp0 = expd./sum(expd);
% for i0 = 1:sum(g1(:,2))
%    d1(i0,:) = sum((x_t - xc{2}(:, i0)).^2)';
%    expd1(i0,:) = exp(-beta0*d1(i0,:));
% end
% gp01 = expd1./sum(expd1);
% %gp = (gp>1/N);
% 
% g = zeros(N,2*n*P);
% gp = zeros(N,2*n*P);
% 
% ...

    
x_0 = kron(ones(1,2),x_0);
x_t = kron(ones(1,2),x_t);



y = kron(ones(1,n),double(y));
y_t = kron(ones(1,n),double(y_t));
y_1 = kron(double(y_1),ones(1,P));
yt_1 = kron(double(yt_1),ones(1,N_s));



y = 2*y-1;
y_t = 2*y_t-1;
y_1 = 2*y_1-1;
yt_1 = 2*yt_1-1;

y = [y,y_1];
y_t = [y_t,yt_1];

g = kron(ones(1,2),g).*kron(g1,ones(1,n*P));
gp = kron(ones(1,2),gp).*kron(gp1,ones(1,n*N_s));

[~,H] = numeric_saddlepoint_nonlinear(N_0,P*n*2,M,g,y,x_0,N,sigma,beta);

K_1 = sigma^2/N_0*(x_0'*x_0);
K_1p = sigma^2/N_0*(x_t'*x_0);
k_1p = sigma^2/N_0*(x_t'*x_t);

K2 = (g'*H*g).*K_1 + 1/beta*eye(P*n*2,P*n*2);
K2gp = (sigma^2*g'*g/M).*K_1 + 1/beta*eye(P*n*2,P*n*2);


K2p = (gp'*H*g).*K_1p;
k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
fp =  (K2p*pinv(K2)*y')';
ft = ((g'*H*g).*K_1)*pinv(K2)*y';
gt = mean((ft-y').^2);
fvarp = diag(k2p - K2p*pinv(K2)*K2p');
k = 1/N*y*pinv(K2)*y';

K2pgp = (sigma^2*gp'*g/M).*K_1p;
k2pgp = (sigma^2*gp'*gp/M).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
fpgp =  (K2pgp*pinv(K2gp)*y')';
fvarpgp = diag(k2pgp - K2pgp*pinv(K2gp)*K2pgp');

ge = mean((fp-y_t).^2) + mean(fvarp);
bias = mean((fp-y_t).^2);
bias1 = mean((fp(1:n*N_s)-y_t(1:n*N_s)).^2);
bias2 = mean((fp((n*N_s+1):2*n*N_s)-y_t((n*N_s+1):2*n*N_s)).^2);
fvar1 = mean(fvarp(1:n*N_s));
fvar2 = mean(fvarp((n*N_s+1):2*n*N_s));

bias1gp = mean((fpgp(1:n*N_s)-y_t(1:n*N_s)).^2);
bias2gp = mean((fpgp((n*N_s+1):2*n*N_s)-y_t((n*N_s+1):2*n*N_s)).^2);
fvar1gp = mean(fvarpgp(1:n*N_s));
fvar2gp = mean(fvarpgp((n*N_s+1):2*n*N_s));
%gb = mean((y_t+1)/2+(-y_t).*erfc(-fp./sqrt(fvarp'*2))/2);

% % relu approximation 
% 
% x_train = x_0(:,1:n*P);
% x = x_t(:,1:n*N_s);
% alpha = n*P/N;
%     % K for the training sample
% t = x_train'*x_train./(sqrt(sum(x_train.^2))'*sqrt(sum(x_train.^2)));
% theta = real(acos(t));
% k0 = 1/2/pi*sqrt(sum(x_train.^2))'*sqrt(sum(x_train.^2)).*(sin(theta)+(pi-theta).*cos(theta));
%     
%     % K for the test sample
% t_k = x'*x./(sqrt(sum(x.^2))'*sqrt(sum(x.^2)));
% theta_k = real(acos(t_k));
% k = 1/2/pi*sqrt(sum(x.^2))'*sqrt(sum(x.^2)).*(sin(theta_k)+(pi-theta_k).*cos(theta_k));
%     
%     % K for training/testing sample
% t_kp = x'*x_train./(sqrt(sum(x.^2))'*sqrt(sum(x_train.^2)));
% theta_kp = real(acos(t_kp));
% k_p = 1/2/pi*sqrt(sum(x.^2))'*sqrt(sum(x_train.^2)).*(sin(theta_kp)+(pi-theta_kp).*cos(theta_kp));
%     
% y = reshape(y,n*P,2)';
% b =  N_0/N*sigma^(-4)*y*pinv(k0)*y';
% [b1,b0] = eig(b);
% b0 = diag(b0);
% h = (-(1-alpha)+sqrt((1-alpha)^2+4*b0))./2./b0;
%     % average f in the theory
% fp= (k_p*pinv(k0)*y')';
% H = b1*diag(h.^(-1))*b1';
%         
%     % variance of f in the theory
% fvarp1 = H(1,1)*mean(1/N_0*diag(k-k_p*pinv(k0)*k_p'));
% fvarp2 = H(2,2)*mean(1/N_0*diag(k-k_p*pinv(k0)*k_p'));
% biasp1 = mean((fp(1,:)-y_t(1:n*N_s)).^2);
% biasp2 = mean((fp(2,:)-y_t((n*N_s+1):2*n*N_s)).^2);
% gep = mean([biasp1,biasp2]) + mean([fvarp1,fvarp2]);

rng(seed,'twister');
w = normrnd(0, sigma, N_0+M, N);
eta1 = 0.008;
for j0 = 1:iteration
    eta = eta1;
    w_1 = reshape(w(1:N_0, :), N_0, N);
    w_2 = reshape(w(N_0+1:end, :), M, N)';
    
%     for m0 = 1:N_1
%         x_1(m0,:) = 1/sqrt(N_0*M)*diag(x_0'*w_1(:,:,m0)*g);
%     end
    x_1 = 1/sqrt(N_0)*w_1'*x_0;
    x_1t = 1/sqrt(N_0)*w_1'*x_t;
    f = (1/sqrt(N*M)*diag(g'*w_2'*x_1))';
    f_t = (1/sqrt(N*M)*diag(gp'*w_2'*x_1t))';
    E(j0) = sum((y - f).^2);
   

    
    w_2 = w_2 + (eta*1/sqrt(N*M)*g*diag((y-f))*x_1')';
    w_1 = w_1 + (eta/sqrt(N*M*N_0)*(w_2*g)*diag(y-f)*x_0')';
    
    bias = mean((f_t - y_t).^2);
    bias_th = mean((fp - y_t).^2);
    
    
    w(1:N_0,:) = reshape(w_1, N_0, N);
    w(N_0+1:end,:) = reshape(w_2', M, N);
            if mod(j0,1e1) == 1
    disp(['train',num2str(E(end))]);
%    disp(['trainr',num2str(Er(end))]);
    end
       if (E(end)<0.1)
           h = w_2'*w_2/N;
           k2 = (g'*h*g).*K_1 + 1/beta*eye(P*n*2,P*n*2);
           break
   
       end
    
end

save(['../kernel_decorrelate/kernelstructure','seed',num2str(seed),'N',num2str(N),'.mat'],'f_t','y_t','bias','bias_th','bias1gp','bias2gp','fvar1gp','fvar2gp','bias1','bias2','fvar1','fvar2','ge','fp','fvarp','K2','K2gp','k2')

end