function func = data_quadratic_on_simplex(n)
% F = f + g,
% f = 0.5 * <x - x^*, Q (x - x^*)>, g = \chi_{simplex}
% \phi = 0.5 * <x - x^*, D (x - x^*)>, D = diag(diag(Q)).

%%
function x = projection(y, ~)
% Projects a vector y onto the probability simplex

    l = length(y);
    y_sorted = sort(y, 'descend');
    y_cumsum = cumsum(y_sorted);
    rho = find(y_sorted + (1 - y_cumsum) ./ (1:l)' > 0, 1, 'last'); 
    theta = (y_cumsum(rho) - 1) / rho;
    x = max(y - theta, 0);
end

function x = projection_D(y, ~)
% Projects a vector y onto the probability simplex under D-norm

    z = dh_inv(y); % then, solve min \|x-z\|_D^2 subject to x\in\Delta_n
    sum_x = @(lambda) sum(max(0, z + lambda ./ d));

    % Bisection to find lambda such that sum(x) = 1
    lambda_low = min(d .* ( -z ));  % lower bound makes first x_i = 0
    lambda_high = max(d .* (1 - z));  % upper bound (not tight, but safe)
    
    tol = 1e-10;
    max_iter = 100;
    iter = 0;

    while lambda_high - lambda_low > tol && iter < max_iter
        lambda = (lambda_low + lambda_high) / 2;
        s = sum_x(lambda);
        if s > 1
            lambda_high = lambda;
        else
            lambda_low = lambda;
        end
        iter = iter + 1;
    end

    % Final projection
    lambda = (lambda_low + lambda_high) / 2;
    x = max(0, z + lambda ./ d);

end

%%
rng(2);
A0 = randn(n);
scale = linspace(1,n,n);
A0 = A0 .* scale;
Q = A0' * A0;
eigQ = eig(Q);
L = max(eigQ);
mu = 0;

diagQ = diag(Q); D_half = diag(diagQ.^(-1/2)); d = diagQ;
% L_relative = eigs(D_half*Q*D_half, 1, 'lm');
eigQrela = eig(D_half*Q*D_half);
L_relative = max(eigQrela);
mu_relative = 0;
x_true = ones(n,1)/n;

clear A0;

func = struct('n',n,'f',@f,'h',@h,'gradf',@grad,'dh',@dh,'dh_inv',@dh_inv, 'prox',@projection,...
    'subprbm',@subproblem_solver, 'L',L,'mu',mu, 'L_relative', L_relative, 'mu_relative', mu_relative, ...
   'x_true',x_true, 'prox_mir', @projection_D, 'f_true', 0, ...
   'mu0_relative', L_relative/100, 'mu0', L/100, 'm0', 9);

function v = grad(x)
    v = Q * (x-x_true);
end

function y = f(x)
    y = 0.5 * (x-x_true)' * Q * (x-x_true);
end

function y = h(x)
    y = 0.5 * sum(diagQ .* (x.^2));
end

function v = dh(x)
    v = diagQ .* x;
end

function v = dh_inv(x)
    v = (1 ./ diagQ) .* x;
end

end