function [info] = rdagl_fp_mixed_l1l2_solver(pname,fun,lambda,parms)
%=========================================================================
% Purpose : Dual averaging (group lasso) for minimizing logistic function plus mixed L1/L2.
%=========================================================================
% Input:
%   fun     object of type LogRegCost (see function LogRegCost.m)
%   lambda  weighting parameter in objective function: f(x) +
%   lambda sum_{g in G}||x_g||_2
%   parms   structure of control parameters (see function proxsg_fp_mixed_l1l2_spec.m)
%
% Output:
%   info    structure holding information parameters
%=========================================================================

% Get control parameters.
max_epoch     = parms.max_epoch;
gamma         = parms.gamma;
rda_gamma     = parms.rda_gamma;
alpha         = parms.alpha;
batch_size    = parms.batch_size;
num_groups    = parms.num_groups;

% Size of the problem (the number of optimizaton variables)
n             = fun.num_features;
m             = fun.num_samples;

batch_size    = floor(max(batch_size, 1));
group_indexes = utils.split_groups(n-1, num_groups);  % bias term does not take part in grouping

% Initial estimate of a solution along with function values.
x             = zeros(n,1);
indexes       = 1: m;
bar_g         = zeros(n,1);

epoch         = 1;
info.Fval     = [];
info.fval     = [];
info.Omegaval = [];
info.nnz      = [];
info.group_sparsity = [];
alg_start_time = tic;
%========================
% Begin: main while loop.
%========================
while(1)
    
    % Check progess
    fun.setExpterm(x, indexes);
    fun.setSigmoid();
    f = fun.func(indexes); 
    norm_l1l2_x = utils.mixed_l1_l2_norm(x, group_indexes);
    F = f + lambda * norm_l1l2_x;
    grad_f = fun.grad(indexes);
    nnz  = sum(x~=0);
    g_sparsity = utils.compute_group_sparsity(x, group_indexes);
    fprintf('Epoch: %d, F value: %.4f, f value: %.4f, ||x_g||_2: %f, norm of grad_f: %.4f, nnz: %d out of %d, g_sparsity: %f.\n', epoch, F, f, norm_l1l2_x, norm(grad_f), nnz, n, g_sparsity);    
 
    if epoch <= 2
        info.Fval = [info.Fval F];
        info.fval = [info.fval f];
        info.Omegaval = [info.Omegaval norm_l1l2_x];
        info.nnz  = [info.nnz nnz];
        info.group_sparsity = [info.group_sparsity g_sparsity];
    else
        info.Fval = [info.Fval(2:2) F];
        info.fval = [info.fval(2:2) f];
        info.Omegaval = [info.Omegaval(2:2) norm_l1l2_x];
        info.nnz  = [info.nnz(2:2) nnz];
        info.group_sparsity = [info.group_sparsity(2:2) g_sparsity];
    end
    
    % Termination condition:
    if epoch > max_epoch
        info.status = 0;
        info.Fval = mean(info.Fval);
        info.fval = mean(info.fval);
        info.Omegaval = mean(info.Omegaval);
        info.density = mean(info.nnz) / n;
        info.sparsity = 1.0 - info.density;
        info.group_sparsity = mean(info.group_sparsity);
        alg_end_time = toc(alg_start_time);
        info.runtime = alg_end_time;
        fprintf('Maximum epoch has been reached. Run time %f\n',alg_end_time);
        break
    end
    
    % Calculate group indexes
    shuffled_indexes = utils.indexes_shuffle(indexes);

    num_batches = length(shuffled_indexes) / batch_size;
     
    for i = 1 : ceil(num_batches)
        start_idx = 1 + (i-1) * batch_size;
        end_idx = min( i * batch_size, length(shuffled_indexes) );
        minibatch_idxes = shuffled_indexes(start_idx:end_idx);
        
        % Calculate grad_f_i
        fun.setExpterm(x, minibatch_idxes);
        fun.setSigmoid();
        grad_f = fun.grad(minibatch_idxes);
        
        % RDA update
        bar_g = (i-1)/i*bar_g + 1/i*grad_f;
        x = rdagl_update(i, bar_g, lambda, rda_gamma, group_indexes);
    end
    
    alpha = alpha * gamma;
    epoch = epoch + 1;
    % Epoch runtime
    
end
%========================
% END: main while loop.
%========================


end


function x = rdagl_update(i, bar_g, lambda, rda_gamma, group_indexes)
%
% RDA group lasso update
%
x = zeros(length(bar_g), 1);

for j = 1 : length(group_indexes)
    group = group_indexes{j};
    nm = norm(bar_g(group));
    if lambda >= nm
        x(group) = 0;
    else
        x(group) = -sqrt(i)/rda_gamma*(1-lambda/nm)*bar_g(group);
    end
end

x(end) = -bar_g(end)*sqrt(i)/rda_gamma;  % shrinkage does not apply to bias term
end

