function [info] = proxsg_fp_mixed_l1l2_solver(pname,fun,lambda,parms)
%=========================================================================
% Purpose : Proximal gradient SVRG for minimizing logistic function plus mixed L1/L2.
%=========================================================================
% Input:
%   fun     object of type LogRegCost (see function LogRegCost.m)
%   lambda  weighting parameter in objective function: f(x) +
%   lambda sum_{g in G}||x_g||_2
%   parms   structure of control parameters (see function proxsg_fp_mixed_l1l2_spec.m)
%
% Output:
%   info    structure holding information parameters
%=========================================================================

% Get control parameters.
max_epoch     = parms.max_epoch;
gamma         = parms.gamma;
alpha         = parms.alpha;
batch_size    = parms.batch_size;
num_groups    = parms.num_groups;

% Size of the problem (the number of optimizaton variables)
n             = fun.num_features;
m             = fun.num_samples;

batch_size    = floor(max(batch_size, 1));
group_indexes = utils.split_groups(n, num_groups);

% Initial estimate of a solution along with function values.
x             = zeros(n,1);
hat_x = x;
indexes       = 1: m;

epoch         = 1;
info.Fval     = [];
info.fval     = [];
info.Omegaval = [];
info.nnz      = [];
info.group_sparsity = [];
alg_start_time = tic;
%========================
% Begin: main while loop.
%========================
while(1)
    
    % Check progess
    fun.setExpterm(hat_x, indexes);
    fun.setSigmoid();
    f = fun.func(indexes); 
    norm_l1l2_x = utils.mixed_l1_l2_norm(hat_x, group_indexes);
    F = f + lambda * norm_l1l2_x;
    hat_v = fun.grad(indexes);
    nnz  = sum(hat_x~=0);
    g_sparsity = utils.compute_group_sparsity(hat_x, group_indexes);
    fprintf('Epoch: %d, F value: %.4f, f value: %.4f, ||x_g||_2: %f, norm of grad_f: %.4f, nnz: %d out of %d, g_sparsity: %f.\n', epoch, F, f, norm_l1l2_x, norm(hat_v), nnz, n, g_sparsity);    
 
    if epoch <= 2
        info.Fval = [info.Fval F];
        info.fval = [info.fval f];
        info.Omegaval = [info.Omegaval norm_l1l2_x];
        info.nnz  = [info.nnz nnz];
        info.group_sparsity = [info.group_sparsity g_sparsity];
    else
        info.Fval = [info.Fval(2:2) F];
        info.fval = [info.fval(2:2) f];
        info.Omegaval = [info.Omegaval(2:2) norm_l1l2_x];
        info.nnz  = [info.nnz(2:2) nnz];
        info.group_sparsity = [info.group_sparsity(2:2) g_sparsity];
    end
    
    % Termination condition:
    if epoch > max_epoch
        info.status = 0;
        info.Fval = mean(info.Fval);
        info.fval = mean(info.fval);
        info.Omegaval = mean(info.Omegaval);
        info.density = mean(info.nnz) / n;
        info.sparsity = 1.0 - info.density;
        info.group_sparsity = mean(info.group_sparsity);
        alg_end_time = toc(alg_start_time);
        info.runtime = alg_end_time;
        fprintf('Maximum epoch has been reached. Run time %f\n',alg_end_time);
        break
    end
    
    % Calculate group indexes
    shuffled_indexes = utils.indexes_shuffle(indexes);
    
    % Predict Orthant Step:
    num_batches = length(shuffled_indexes) / batch_size;

    xs = zeros(n, int64(num_batches) + 1);
    xs(:, 1) = hat_x;

     
    for i = 1 : int64(num_batches)
        start_idx = 1 + (i-1) * batch_size;
        end_idx = min( i * batch_size, length(shuffled_indexes) );
        minibatch_idxes = shuffled_indexes(start_idx:end_idx);
        
        % Calculate grad_f_i
        fun.setExpterm(xs(:, i), minibatch_idxes);
        fun.setSigmoid();
        grad_f = fun.grad(minibatch_idxes);

        % Calculate grad_f_hat_i
        fun.setExpterm(hat_x, minibatch_idxes);
        fun.setSigmoid();
        grad_f_hat = fun.grad(minibatch_idxes);

        v = grad_f - grad_f_hat + hat_v;
        xs(:, i+1) = prox_mapping_group(xs(:, i), v, lambda, alpha, group_indexes);
    end    
    alpha = alpha * gamma;
    epoch = epoch + 1;

    hat_x = mean(xs(:, 2:int64(num_batches) + 1), 2);
    % Epoch runtime
    
end
%========================
% END: main while loop.
%========================


end

function new_x = prox_mapping_group(x, grad_f, lambda, alpha, group_indexes)

new_x = zeros(size(x));

numer = alpha * lambda;
trial_x = x - alpha * grad_f;
for i = 1 : length(group_indexes)
    group = group_indexes{i};
    denom = norm(trial_x(group));
    coeff = max(0.0, 1.0 - numer/(denom+1e-6));
    new_x(group) = coeff * trial_x(group);
end

end

function d = calculate_d(x, grad_f, lambda, alpha)
%
% Calculate d for Omega(x) = ||x||_1
%
trial_x  = zeros(size(x));
pos_shrink = x - alpha * grad_f - alpha * lambda;
neg_shrink = x - alpha * grad_f + alpha * lambda;
pos_shrink_idx = (pos_shrink > 0);
neg_shrink_idx = (neg_shrink < 0);
trial_x(pos_shrink_idx) = pos_shrink(pos_shrink_idx);
trial_x(neg_shrink_idx) = neg_shrink(neg_shrink_idx);
d = trial_x - x;

end

