clc;
clear;
seed = 1234;
rng(seed);
%% toy example input data
 
x0 = [2;2];
X =[0 2 0 -4 ; 3 0 -1 0] ; % X = [x1,x2,x3,x4]
[m,n] = size(X);
mu_g = 1;
L_g = 2;
A = randn(m,m);
[U,S,V] = svd(A'*A);
A = U*diag([sqrt(L_g);sqrt(mu_g)])*V';


theta_init = [0;0];
lambda_init = ones(n,1)/m;
maxiter = 1e2; 

%% function definition

fun_f= @(theta) sum_square(theta - x0)/2; %upper level function
fun_g = @(lambda,theta) sum_square(A*theta - (X*lambda))/2;%lower level function

grad_f_y = @(theta) theta - x0; %gradient of upper level function with respect to y
grad_f_x = zeros(n,1); %gradient of upper level function with respect to x

grad_g_y = @(lambda,theta) A'*(A*theta - (X*lambda)); %gradient of lower level function with respect to y

%second derivative of lower level function
grad_g_yy = A'*A; 
grad_g_yx= -A'*X; 
%% Finding optimal solution
cvx_begin
variables theta_s(m,1) lambda_s(n,1)
minimize sum_square(theta_s - x0)/2
subject to 
    A'*(A*theta_s - (X*lambda_s)) == 0;
    lambda_s >= 0;
    sum(lambda_s) == 1;
cvx_end
fstar = cvx_optval; %optimal value
%% IBCG algorithm

eta = (1-(L_g-mu_g)/(L_g+mu_g))/mu_g; 
gamma = .65*log(maxiter)/(maxiter);
alpha = 2/(mu_g + L_g);

param.eta = eta; % stepsize for w
param.gam = gamma; % stepsize for x (FW)
param.alpha = alpha; % stepsize for y
param.maxiter = maxiter;

[f_vec1,g_vec1,time_vec1, y1] = IBCG(fun_f, grad_f_y,grad_f_x, grad_g_y, grad_g_yx,...
    grad_g_yy,param,theta_init, lambda_init);
disp('IBCG Solution Achieved!');

%% SBFW algorithm

param.lg = L_g;
param.mug = mu_g;
param.maxiter = maxiter;

[f_vec3,g_vec3,time_vec3, y3] = SBFW(fun_f, grad_f_y,grad_f_x, grad_g_y, grad_g_yx,...
    grad_g_yy,param,theta_init, lambda_init);
disp('SFBW Solution Achieved!');

%% SBFW with Hessian inverse algorithm

param.lg = L_g;
param.mug = mu_g;
param.maxiter = maxiter;

[f_vec4,g_vec4,time_vec4, y4] = SBFW_inv(fun_f, grad_f_y,grad_f_x, grad_g_y, grad_g_yx,...
    grad_g_yy,param,theta_init, lambda_init);
disp('SFBW_inv Solution Achieved!');

%% Figures
% trajectory of f - f^*
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',2)
set(gcf,'DefaultLineMarkerSize',5);
In_BiCoG = semilogy(1:maxiter+1,abs(f_vec1-fstar),'DisplayName','IBCG','color','blue');
hold on;
SBFW = semilogy(1:maxiter,abs(f_vec3-fstar),'DisplayName','SBFW');
hold on;
SBFW_inv = semilogy(1:maxiter,abs(f_vec4-fstar), 'DisplayName', 'SBFW/Hessian inverse','Color','green');


ylabel('$|f(\lambda_k,\theta_k)- f^*| $')
xlabel('Iteration')
set(gca,'FontSize',18);
legend('Interpreter','latex','Location','southwest')
grid on;
grid minor;
yticks([0, 1e-10, 1e-6,1e-3, 1e0]);

pbaspect([1 0.8 1])

%% Figure
% trajectory of theta
figure;
set(0,'defaulttextinterpreter','latex')
set(gcf,'DefaultLineLinewidth',2)
set(gcf,'DefaultLineMarkerSize',5);
X = [X,X(:,1)];
plot(X(1,:),X(2,:),'HandleVisibility','off','Color','black');
axis([-4 2.5 -1 3]);
hold on;
In_BiCoG = plot(y1(:,1), y1(:,2),'--*','DisplayName','IBCG','Color','blue');
hold on;
SBFW = plot(y3(:,1), y3(:,2),'-.x','DisplayName', 'SBFW','Color','red');
hold on;
 SBFW_inv = plot(y4(:,1), y4(:,2),'DisplayName','SBFW/Hessian inverse','Color','green');


set(gca,'FontSize',18);
legend('Interpreter','latex','Location','northwest')
plot(x0(1),x0(2),'^','HandleVisibility','off')
text(x0(1)-0.5,x0(2),'$x_0$','FontSize',20)
plot(theta_init(1),theta_init(2),'o','HandleVisibility','off')
text(theta_init(1)-0.3,theta_init(2)-0.2,'start','FontSize',18)
plot(theta_s(1),theta_s(2),'*','HandleVisibility','off','Color','yellow' )
text(theta_s(1) +0.1,theta_s(2),'optimal','FontSize',18)