%% Test: PAGD & ANCGD on non-quadratic triangle function
n = 512;
n1 = n+1;
a = 6/n1;
xaxis = -1.5+a:a:1.5-a;
yaxis = -1.5+a:a:1.5-a;
[X,Y]=meshgrid(xaxis,yaxis);
func1 = @(x,y) cos(pi.*x)./2.+(y+cos(2.*pi.*x)./2.-1./2.).^2./2.-1./2.;
Z1 = func1(X,Y);
%%
eta = .01; % descent step size
r = 0.1; % radius of perturbation ball
M = 300; % # of samples

outcome_pagd = zeros(1,M);
outcome_ancgd = zeros(1,M);
outcome_pagdx = zeros(1,M);
outcome_ancgdx = zeros(1,M);
outcome_pagdy = zeros(1,M);
outcome_ancgdy = zeros(1,M);
outcome_pagdh = zeros(1,M);
outcome_ancgdh = zeros(1,M);
outcome_pagdxh = zeros(1,M);
outcome_ancgdxh = zeros(1,M);
outcome_pagdyh = zeros(1,M);
outcome_ancgdyh = zeros(1,M);
T1 = 80;
T2 = 20;
    
% sample
    seed_pagd = randsphere(M,2,r);
    seed_ancgd = randsphere(M,2,r);

tic
    for t = 1:M
        % pgd for type 0
        x0 = seed_pagd(t,:)';
        xT = agd(x0,eta,T1);
        outcome_pagdx(t) = xT(1);
        outcome_pagdy(t) = xT(2);
        outcome_pagd(t) = f(xT);
        xT = agd(x0,eta,T1/2);
        outcome_pagdxh(t) = xT(1);
        outcome_pagdyh(t) = xT(2);
        outcome_pagdh(t) = f(xT);
        
        %ncgd for type 0
        x0 = seed_ancgd(t,:)';
        xT = anc(x0,eta,T2,r);
        outcome_ancgdx(t) = xT(1);
        outcome_ancgdy(t) = xT(2);
        outcome_ancgd(t) = f(xT);
        xT = anc(x0,eta,T2/2,r);
        outcome_ancgdxh(t) = xT(1);
        outcome_ancgdyh(t) = xT(2);
        outcome_ancgdh(t) = f(xT);
    end
tEnd=toc;
fprintf('Test cosine_decrease, running time = %d\n',tEnd);

%% plot 
figure(1)
subplot(1,3,1)
scatter(outcome_pagdxh,outcome_pagdyh,50,'.','r');
hold on
scatter(outcome_ancgdxh,outcome_ancgdyh,50,'.','b');
hold on
contour(X,Y,Z1, [-0.8,0,0.6,1.5],'ShowText','on');
plot([1 -1],[0 0],'k*');
set(gca,...
    'TickDir','out','TickLength',[0.02 0.02],...
    'FontSize',12,'FontName','Times');
set(gcf, 'units','points','position',[0 0 1000 400]);
xticks([-1.5,-1,-0.5,0,0.5,1,1.5]);
yticks([-1.5,-1,-0.5,0,0.5,1,1.5]);
xlabel('X')
ylabel('Y')
legend('PAGD-sampling','ANCGD-sampling','landscape','minimizer')
title('t_{PAGD} = 40,   t_{ANCGD} = 10');

subplot(1,3,2)
scatter(outcome_pagdx,outcome_pagdy,50,'.','r');
hold on
scatter(outcome_ancgdx,outcome_ancgdy,50,'.','b');
hold on
contour(X,Y,Z1, [-0.8,0,0.6,1.5],'ShowText','on');
plot([1 -1],[0 0],'k*');
set(gca,...
    'TickDir','out','TickLength',[0.02 0.02],...
    'FontSize',12,'FontName','Times');
set(gcf, 'units','points','position',[0 0 1000 400]);
xticks([-1.5,-1,-0.5,0,0.5,1,1.5]);
yticks([-1.5,-1,-0.5,0,0.5,1,1.5]);
xlabel('X')
ylabel('Y')
legend('PAGD-sampling','ANCGD-sampling','landscape','minimizer')
title('t_{PAGD} = 80,   t_{ANCGD} = 20');

subplot(1,3,3)
binRange = -1:0.1:-0.6;
hcx = histcounts(outcome_ancgd(:),[binRange Inf]);
hcy = histcounts(outcome_pagd(:),[binRange Inf]);
b = bar(binRange,[hcx;hcy]);
b(1).FaceColor = [0 .3 .9];
b(2).FaceColor = [.9 .2 .1];
xtips1 = b(1).XEndPoints;
ytips1 = b(1).YEndPoints;
labels1 = string(b(1).YData);
text(xtips1,ytips1,labels1,'HorizontalAlignment','center',...
    'VerticalAlignment','bottom')
xtips2 = b(2).XEndPoints;
ytips2 = b(2).YEndPoints;
labels2 = string(b(2).YData);
text(xtips2,ytips2,labels2,'HorizontalAlignment','center',...
    'VerticalAlignment','bottom')
set(gca,...
    'XDir','reverse',...
        'FontSize',12,'FontName','Times');
xticks([-1 -0.9 -0.8 -0.7 -0.6]);
xticklabels({'[-1,-0.9]','(-0.9,-0.8]','(-0.8,-0.7]','(-0.7,-0.6]','>-0.6'});
xlabel('Descent value')
ylabel('Frequency')
legend('ANCGD','PAGD','Location','northwest');
title('descent value sampling at t_{PAGD} = 80, t_{ANCGD} = 20')

%%


function x = agd(x0,eta,T)
v = 0;
theta = 0.3;
for k = 1:T 
    y = x0 + (1-theta)*v;
    x = y - eta * grad(y);
    v = x - x0;
    x0 = x;
end
end

function x = anc(x0,eta,T,r)
v0 = 0;
for k = 1:T
    v = v0 - eta.*grad(x0);
    x = x0 + v;
    v0 = v * r / norm(x);
    x0 = x * r / norm(x);
end
x = x / norm(x);
end

function y = f(x)
y = cos(pi*x(1))/2+(x(2)+cos(2*pi*x(1))/2-1/2)^2/2-1/2;
end

function v = grad(x)
v = [-pi*sin(pi*x(1))/2-pi*sin(2*pi*x(1))*(x(2)+cos(2*pi*x(1))/2-1/2) ; x(2)+cos(2*pi*x(1))/2-1/2];
end

