function [info] = hspg_fp_mixed_l1l2_solver(pname,fun,lambda,parms)
%=========================================================================
% Purpose : Half-Space-Based Proximal Stochastic Gradient solver for 
%           minimizing logistic function plus mixed l1/l2 regularization.
%=========================================================================
% Input:
%   fun     object of type LogRegCost (see function LogRegCost.m)
%   lambda  weighting parameter in objective function: f(x) + lambda|x|_1 
%   parms   structure of control parameters (see function fPlusL1_spec.m)
%
% Output:
%   info    structure holding information parameters
%=========================================================================

% Get control parameters.
max_epoch     = parms.max_epoch;
batch_size    = parms.batch_size;
num_groups    = parms.num_groups;
epislon       = parms.epislon;

% Size of the problem (the number of optimizaton variables)
n             = fun.num_features;
m             = fun.num_samples;
batch_size    = floor(max(batch_size, 1));
group_indexes = utils.split_groups(n, num_groups);

% Initial estimate of a solution along with function values.
x             = zeros(n,1);

indexes       = 1:m;
epoch         = 1;
alpha         = parms.alpha;
gamma         = parms.gamma;
alg_start_time = tic;
info.Fval     = [];
info.fval     = [];
info.Omegaval = [];
info.nnz      = [];
info.group_sparsity = [];
num_prox_sg   = parms.num_prox_sg;
num_proj_sg   = parms.num_proj_sg;
    
%========================
% Begin: main while loop.
%========================
prox_sg_tracker = 1;
proj_sg_tracker = 1;
do_prox_sg = false;
do_proj_sg = false;
    
while(1)  
    % Check progess FOR least square
    f = fun.func(x, indexes); 
    norm_l1l2_x = utils.mixed_l1_l2_norm(x, group_indexes);
    F = f + lambda * norm_l1l2_x;
    grad_f = fun.grad(x, indexes);
    nnz  = sum(x~=0);
    [g_sparsity, idx_zero_groups] = utils.compute_group_sparsity(x, group_indexes);
    fprintf('Epoch: %d, F value: %.4f, f value: %.4f, ||x_g||_2: %f, norm of grad_f: %.4f, nnz: %d out of %d, g_sparsity: %f.\n', epoch, F, f, norm_l1l2_x, norm(grad_f), nnz, n, g_sparsity);    
    
    if epoch <= 2
        info.Fval = [info.Fval F];
        info.fval = [info.fval f];
        info.Omegaval = [info.Omegaval norm_l1l2_x];
        info.nnz  = [info.nnz nnz];
        info.group_sparsity = [info.group_sparsity g_sparsity];
    else
        info.Fval = [info.Fval(2:2) F];
        info.fval = [info.fval(2:2) f];
        info.Omegaval = [info.Omegaval(2:2) norm_l1l2_x];
        info.nnz  = [info.nnz(2:2) nnz];
        info.group_sparsity = [info.group_sparsity(2:2) g_sparsity];
    end
    
    % Termination condition:
    if epoch > max_epoch
        info.status = 0;
        info.Fval = mean(info.Fval);
        info.fval = mean(info.fval);
        info.Omegaval = mean(info.Omegaval);
        info.density = mean(info.nnz) / n;
        info.sparsity = 1.0 - info.density;
        info.group_sparsity = mean(info.group_sparsity);
        info.idx_zero_groups = idx_zero_groups;
        alg_end_time = toc(alg_start_time);
        info.runtime = alg_end_time;
        fprintf('Maximum epoch has been reached. Run time %f\n',alg_end_time);
        break
    end
    
    % switch
    if epoch == 1
        do_prox_sg = true;
        prox_sg_tracker = 0;
        proj_sg_tracker = 0;
    else
        if do_prox_sg && prox_sg_tracker >= num_prox_sg
            do_prox_sg = false;
            do_proj_sg = true;
            prox_sg_tracker = 0;
            proj_sg_tracker = 0;
        elseif do_proj_sg && proj_sg_tracker >= num_proj_sg
            do_prox_sg = true;
            do_proj_sg = false;
            prox_sg_tracker = 0;
            proj_sg_tracker = 0;            
        end
    end
    
    if do_prox_sg
        prox_sg_tracker = prox_sg_tracker + 1;
        
        shuffled_indexes = utils.indexes_shuffle(indexes);
        num_batches = length(shuffled_indexes) / batch_size;
        
        for i = 1 : num_batches
            start_idx = 1 + (i-1) * batch_size;
            end_idx = min( i * batch_size, length(shuffled_indexes) );
            minibatch_idxes = shuffled_indexes(start_idx:end_idx);   

            % Calculate grad_f_i
            grad_f = fun.grad(x, minibatch_idxes);
            x = prox_mapping_group(x, grad_f, lambda, alpha, group_indexes);                           
        end   
    elseif do_proj_sg
        proj_sg_tracker = proj_sg_tracker + 1;
        
        shuffled_indexes = utils.indexes_shuffle(indexes);
        num_batches = length(shuffled_indexes) / batch_size;
        
        % Exploit Half-Space Step:
        for i = 1 : num_batches
            start_idx = 1 + (i-1) * batch_size;
            end_idx = min( i * batch_size, length(shuffled_indexes) );
            minibatch_idxes = shuffled_indexes(start_idx:end_idx);
            
            % update group free to move
            idxes_group_free = [];
            for j = 1 : length(group_indexes)
                group = group_indexes{j};
                if norm(x(group)) > 0 
                    idxes_group_free = [idxes_group_free j];
                end
            end 
        
            % Calculate grad_f_i
            grad_f = fun.grad(x, minibatch_idxes);
            
            hat_x = subprobsolvers.gradientdescent(x, grad_f, alpha, lambda, idxes_group_free, group_indexes);
            
            proj_x = project(hat_x, x, idxes_group_free, group_indexes, epislon);
            
            x = proj_x;
        
        end    

    end
    
    alpha = alpha * gamma;
    epoch = epoch + 1;
end
%========================
% END: main while loop.
%========================


end

%% Proximal Mapping for mixed l1/l2
function new_x = prox_mapping_group(x, grad_f, lambda, alpha, group_indexes)

new_x = zeros(size(x));

numer = alpha * lambda;
trial_x = x - alpha * grad_f;
for i = 1 : length(group_indexes)
    group = group_indexes{i};
    denom = norm(trial_x(group));
    coeff = max(0.0, 1.0 - numer/(denom+1e-6));
    new_x(group) = coeff * trial_x(group);
end

end

%% Project trial_x based on the half-space
function proj_x = project(trial_x, x, idxes_group_free, group_indexes, epislon)

proj_x = trial_x;
for i = 1 : length(idxes_group_free)
    group = group_indexes{idxes_group_free(i)};
    if dot(trial_x(group), x(group)) < epislon * norm(x(group))
        proj_x(group) = 0;
    end
end 


end

