%% Escape from saddle points on quantum computers
% This script runs the 4 numerical experiments presented in the paper
% "Escape from Saddle Points on Quantum Computers"
% submitted to NeurIPS 2020 with paper ID 2851. Do not distribute.
% The computing environment is MATLAB 2019a. 

% In test 0 and 1, we use the finite difference method (FDM) and the 
% leapfrog scheme (LFS) to solve time-dependent Schrodinger equation. 
% The integrator 'leapfrogsolver' (available on MathWorks) was developed 
% by Mauger Fran?ois. 


% In test 2, we use a sampler ('randsphe?re') to generate random points
% in an n-dim hypersphere. 
% It was developed by Roger Stafford, also availale in MathWorks.

% The initialization PDE parameters are the same as follows.

%% Test 2: PGD & NCGD on non-quadratic function
n = 512;
n1 = n+1;
a = 6/n1;
xaxis = -3+a:a:3-a;
yaxis = -3+a:a:3-a;
[X,Y]=meshgrid(xaxis,yaxis);
func1 = @(x,y) x.^3./2 - y.^3./2 - 3.* x.* y + (x.^2+y.^2).^2./2;
Z1 = func1(X,Y);
%%
eta = .02; % descent step size
r = 0.01; % radius of perturbation ball
M = 300; % # of samples

outcome_psgd = zeros(1,M);
outcome_ncgd = zeros(1,M);
outcome_psgdx = zeros(1,M);
outcome_ncgdx = zeros(1,M);
outcome_psgdy = zeros(1,M);
outcome_ncgdy = zeros(1,M);
outcome_psgdh = zeros(1,M);
outcome_ncgdh = zeros(1,M);
outcome_psgdxh = zeros(1,M);
outcome_ncgdxh = zeros(1,M);
outcome_psgdyh = zeros(1,M);
outcome_ncgdyh = zeros(1,M);
T1 = 60;
T2 = 30;
    
% sample
    seed_psgd = randsphere(M,2,r);
    seed_ncgd = randsphere(M,2,r);

% PGD & qPGD
tic
    for t = 1:M
        % psgd for type 0
        x0 = [0,0];
        xT = sgd(x0,eta,T1,r,5);
        outcome_psgdx(t) = xT(1);
        outcome_psgdy(t) = xT(2);
        outcome_psgd(t) = f(xT);
        xT = sgd(x0,eta,T1/2,r,5);
        outcome_psgdxh(t) = xT(1);
        outcome_psgdyh(t) = xT(2);
        outcome_psgdh(t) = f(xT);
        
        %ncgd for type 0
        x0 = [0,0];
        xT = nc(x0,eta,T2,r,5);
        outcome_ncgdx(t) = xT(1);
        outcome_ncgdy(t) = xT(2);
        outcome_ncgd(t) = f(xT);
        xT = nc(x0,eta,T2/2,r,5);
        outcome_ncgdxh(t) = xT(1);
        outcome_ncgdyh(t) = xT(2);
        outcome_ncgdh(t) = f(xT);
    end
tEnd=toc;
fprintf('Test quartic_decrease, running time = %d\n',tEnd);

%% plot 
figure(1)
subplot(1,3,1)
scatter(outcome_psgdxh,outcome_psgdyh,50,'.','r');
hold on
scatter(outcome_ncgdxh,outcome_ncgdyh,50,'.','b');
hold on
contour(X,Y,Z1, [-1.4,-0.9,0,2,7,15],'ShowText','on');
set(gca,...
    'TickDir','out','TickLength',[0.02 0.02],...
    'FontSize',12,'FontName','Times');
set(gcf, 'units','points','position',[0 0 1000 400]);
xticks([-2,-1,0,1,2]);
yticks([-2,-1,0,1,2]);
xlabel('X')
ylabel('Y')
legend('PSGD-sampling','SNCGD-sampling','landscape','minimizer')
title('t_{PSGD} = 30, t_{SNCGD} = 15,  batchsize = 5');

subplot(1,3,2)
scatter(outcome_psgdx,outcome_psgdy,50,'.','r');
hold on
scatter(outcome_ncgdx,outcome_ncgdy,50,'.','b');
hold on
contour(X,Y,Z1, [-1.4,-0.9,0,2,7,15],'ShowText','on');
set(gca,...
    'TickDir','out','TickLength',[0.02 0.02],...
    'FontSize',12,'FontName','Times');
set(gcf, 'units','points','position',[0 0 1000 400]);
xticks([-2,-1,0,1,2]);
yticks([-2,-1,0,1,2]);
xlabel('X')
ylabel('Y')
legend('PSGD-sampling','SNCGD-sampling','landscape','minimizer')
title('t_{PSGD} = 60, t_{SNCGD} = 30, batchsize = 5');

subplot(1,3,3)
binRange = -1.5:0.3:0;
hcx = histcounts(outcome_ncgd(:),[binRange Inf]);
hcy = histcounts(outcome_psgd(:),[binRange Inf]);
b = bar(binRange,[hcx;hcy]);
b(1).FaceColor = [0 .3 .9];
b(2).FaceColor = [.9 .2 .1];
xtips1 = b(1).XEndPoints;
ytips1 = b(1).YEndPoints;
labels1 = string(b(1).YData);
text(xtips1,ytips1,labels1,'HorizontalAlignment','center',...
    'VerticalAlignment','bottom')
xtips2 = b(2).XEndPoints;
ytips2 = b(2).YEndPoints;
labels2 = string(b(2).YData);
text(xtips2,ytips2,labels2,'HorizontalAlignment','center',...
    'VerticalAlignment','bottom')
set(gca,...
    'XDir','reverse',...
        'FontSize',12,'FontName','Times');
xticks([-1.5 -1.2 -0.9 -0.6 -0.3 0]);
xticklabels({'(-1.5,-1.2]','(-1.2,-0.9]','(-0.9,-0.6]','(-0.6,-0.3]','(-0.3,0]','>0'});xlabel('Descent value')
ylabel('Frequency')
legend('SNCGD','PSGD','Location','northwest');
title('descent value sampling at t_{PSGD} = 60, t_{SNCGD} = 30')



%%
function x = sgd(x0,eta,T,r,b)
for k = 1:T
    sigma = [norm(x0)^2/(2*b^2),norm(x0)^2/(2*b^2)];
    Sigma = diag(sigma);
    sgrad = mvnrnd(grad(x0),Sigma);
    x = x0 - eta.*sgrad;
    x1 = x;
    sigma0 = [r^2/2,r^2/2];
    Sigma0 = diag(sigma0);
    x = mvnrnd(x1,Sigma0);
    x0 = x;
end
end

function x = nc(x0,eta,T,r,b)
l = 1;
for k = 1:T
    sigma = [r^2/(2*b^2),r^2/(2*b^2)];
    Sigma = diag(sigma);
    sgrad = mvnrnd(grad(x0),Sigma);
    x = x0 - eta.*sgrad;
    x1 = x;
    sigma0 = [r^2/(2*l^2),r^2/(2*l^2)];
    Sigma0 = diag(sigma0);
    x = mvnrnd(x1,Sigma0,1);
    l = l * norm(x)/norm(x0);
    x = x * r / norm(x);
    x0 = x;
end
x = 1.3 * x / norm(x);
end

function y = f(x)
y = x(1)^3/2 - x(2)^3/2 - 3* x(1)*x(2) + (x(1)^2+x(2)^2)^2/2;
end

function v = grad(x)
v = [3*x(1)^2/2 - 3*x(2) + 2*(x(1)^2+x(2)^2)*x(1) ; -3*x(2)^2/2 - 3*x(1) + 2*(x(1)^2+x(2)^2)*x(2)];
end


