function [D,X,loss_vec,rec_vec,sample_vec] = SPIDER_init(A_train,Dstar,param)
% CG on the lower-level problem
% A_train: the dataset in the lower level
% Dstar: the true dictionary
% m: dimension of data
% p: number of atoms
% delta: the l1-norm constraint
% epsilon: prescribed accuracy
% maxiter: the max. number of iterations
[m,n_train] = size(A_train);
p = param.p;
delta = param.delta;
epsilon = param.eps;
epsilon2 = param.eps2;
maxiter = param.maxiter;
thres = param.thres;

D_init = randn(m,p);
D_init = D_init./vecnorm(D_init);
X_init = zeros(p,n_train);
loss_train = @(D,X) norm(A_train-D*X,'fro')^2/2;
gD_train= @(D,X) (D*X-A_train)*X';
gX_train = @(D,X) D'*(D*X-A_train);

loss_vec1 = zeros(maxiter+1,1);
loss_vec1(1) = loss_train(D_init,X_init);
rec_vec1 = zeros(maxiter+1,1);
rec_vec1(1) = recovery(D_init,Dstar,thres);
sample_vec1 = zeros(maxiter+1,1);
sample_vec1(1) = 0;

D = D_init;
X = X_init;
D_prev = D_init;
X_prev = X_init;

% stogD = zeros(25,40); 
% stogX = zeros(40,n_train);
s_train = ceil(sqrt(n_train));
for iter = 1:maxiter
%     eta = 1/sqrt(iter);
    eta = 5e-3;

    if mod(iter-1,s_train)==0
        stogD = gD_train(D,X);
        stogX = gX_train(D,X);
        
    else
        idx = randsample(n_train,s_train);
    
        gD_train_i= @(D,X) n_train/s_train*(D*X(:,idx)-A_train(:,idx))*X(:,idx)';
        gX_train_i = @(D,X) n_train/s_train*D'*(D*X(:,idx)-A_train(:,idx));
        
%         disp(size(D));
        gX_train_cur = zeros(40,n_train);
        gX_train_cur(:,idx) = gX_train_i(D,X);
        gX_train_prev = zeros(40,n_train);
        gX_train_prev(:,idx) = gX_train_i(D_prev,X_prev);
    
        stogD = stogD + gD_train_i(D,X) - gD_train_i(D_prev,X_prev);
        stogX = stogX + gX_train_cur - gX_train_prev;
    
    end

%     gD = gD_train(D,X);
%     gX = gX_train(D,X);
    D_atom = -stogD./(vecnorm(stogD)+eps);

    [~,max_idx] = max(abs(stogX));
    X_atom = zeros(p,n_train);
    for j=1:n_train
        X_atom(max_idx(j),j) = -delta*sign(stogX(max_idx(j),j));
    end
    
%     FW_gap = trace(gD'*(D-D_atom))+trace(gX'*(X-X_atom));
    D_prev = D;
    D = (1-eta)*D+eta*D_atom;
    X_prev = X;
    X = (1-eta)*X+eta*X_atom;

    loss_vec1(iter+1) = loss_train(D,X);
    rec_vec1(iter+1) = recovery(D,Dstar,thres);
    sample_vec1(iter+1) = iter*16;
%     if FW_gap<epsilon
%         break
%     end
end
loss_vec1 = loss_vec1(1:iter+1);
rec_vec1 = rec_vec1(1:iter+1);
sample_vec1 = sample_vec1(1:iter+1);

% We fix X and further run FW on D
loss_vec2 = zeros(maxiter,1);
rec_vec2 = zeros(maxiter,1);
sample_vec2 = zeros(maxiter,1);

D_prev = D;
for iter = 1:maxiter
    if mod(iter-1,s_train)==0
        stogD = gD_train(D,X);
        stogX = gX_train(D,X);
        
    else
        idx = randsample(n_train,s_train);
    
        gD_train_i= @(D,X) n_train/s_train*(D*X(:,idx)-A_train(:,idx))*X(:,idx)';
        gX_train_i = @(D,X) n_train/s_train*D'*(D*X(:,idx)-A_train(:,idx));
        
        gX_train_cur = zeros(40,n_train);
        gX_train_cur(:,idx) = gX_train_i(D,X);
        gX_train_prev = zeros(40,n_train);
        gX_train_prev(:,idx) = gX_train_i(D_prev,X_prev);
    
        stogD = stogD + gD_train_i(D,X) - gD_train_i(D_prev,X_prev);
        stogX = stogX + gX_train_cur - gX_train_prev;
    
    end
%     gD = gD_train(D,X);
    D_atom = -stogD./(vecnorm(stogD)+eps);
%     D_dir = D-D_atom;
%     eta = trace(stogD'*D_dir)/(trace(X'*(D_dir'*D_dir)*X));
%     eta = min([eta,1]);
    eta = 5e-4;
    
%     FW_gap = trace(gD'*(D-D_atom));
    D = (1-eta)*D+eta*D_atom;
    loss_vec2(iter) = loss_train(D,X);
    rec_vec2(iter) = recovery(D,Dstar,thres);
    sample_vec2(iter) = 16*(maxiter+iter);
%     if FW_gap<epsilon2
%         break
%     end
end
loss_vec2 = loss_vec2(1:iter);
rec_vec2 = rec_vec2(1:iter);
sample_vec2 = sample_vec2(1:iter);
loss_vec = [loss_vec1;loss_vec2];
rec_vec = [rec_vec1;rec_vec2];
sample_vec = [sample_vec1;sample_vec2];
end

function rec = recovery(D,Dstar,thres)
D = D./vecnorm(D);
[~,num_dict] = size(Dstar);
corr_mat = D'*Dstar;
num = sum(max(abs(corr_mat))>thres);
rec = num/num_dict;
end