clearvars;
close all;

c_sgd = [0, 0.45, 0.75];
c_tos = [0.85, 0.32, 0.10];
c_adagrad = [0.93, 0.70, 0.13];
c_adam = [0.50, 0.18, 0.55];
c_adaptos = [0,0,0];
% c_tos = [0.6,0.08,0.2];

ADAPTOS.tr_acc = [];
ADAPTOS.te_acc = [];
ADAPTOS.weights = [];
ADAPTOS.groupweights = [];
for T = 0:19
    load(['runs/mnist_adaptos_1.0_',num2str(T),'.pkl.mat']);
    groupweights = [];
    weights = [];
    for t = 1:length(pickle_data.params)
        if mod(t,2) == 0
            pickle_data.params{t} = pickle_data.params{t}';
        end
        groupweights = [groupweights; sum(abs(pickle_data.params{t}),2)];
        weights = [weights; abs(pickle_data.params{t}(:))];
    end

    ADAPTOS.tr_acc(:,end+1) = pickle_data.tr_acc;
    ADAPTOS.te_acc(:,end+1) = pickle_data.te_acc;
    ADAPTOS.weights(:,end+1) = sort(weights,'descend');
    ADAPTOS.groupweights(:,end+1) = sort(groupweights,'descend');
end
ADAPTOS.mean_tr_acc = mean(ADAPTOS.tr_acc,2);
ADAPTOS.mean_te_acc = mean(ADAPTOS.te_acc,2);
ADAPTOS.mean_weights = mean(ADAPTOS.weights,2);
ADAPTOS.mean_groupweights = mean(ADAPTOS.groupweights,2);
ADAPTOS.std_tr_acc = std(ADAPTOS.tr_acc,0,2);
ADAPTOS.std_te_acc = std(ADAPTOS.te_acc,0,2);
ADAPTOS.std_weights = std(ADAPTOS.weights,0,2);
ADAPTOS.std_groupweights = std(ADAPTOS.groupweights,0,2);


ADAGRAD.tr_acc = [];
ADAGRAD.te_acc = [];
ADAGRAD.weights = [];
ADAGRAD.groupweights = [];
for T = 0:19
    load(['runs/mnist_adagrad_0.01_',num2str(T),'.pkl.mat']);
    groupweights = [];
    weights = [];
    for t = 1:length(pickle_data.params)
        if mod(t,2) == 0
            pickle_data.params{t} = pickle_data.params{t}';
        end
        groupweights = [groupweights; sum(abs(pickle_data.params{t}),2)];
        weights = [weights; abs(pickle_data.params{t}(:))];
    end
 
    ADAGRAD.tr_acc(:,end+1) = pickle_data.tr_acc;
    ADAGRAD.te_acc(:,end+1) = pickle_data.te_acc;
    ADAGRAD.weights(:,end+1) = sort(weights,'descend');
    ADAGRAD.groupweights(:,end+1) = sort(groupweights,'descend');
end
ADAGRAD.mean_tr_acc = mean(ADAGRAD.tr_acc,2);
ADAGRAD.mean_te_acc = mean(ADAGRAD.te_acc,2);
ADAGRAD.mean_weights = mean(ADAGRAD.weights,2);
ADAGRAD.mean_groupweights = mean(ADAGRAD.groupweights,2);
ADAGRAD.std_tr_acc = std(ADAGRAD.tr_acc,0,2);
ADAGRAD.std_te_acc = std(ADAGRAD.te_acc,0,2);
ADAGRAD.std_weights = std(ADAGRAD.weights,0,2);
ADAGRAD.std_groupweights = std(ADAGRAD.groupweights,0,2);


TOS.tr_acc = [];
TOS.te_acc = [];
TOS.weights = [];
TOS.groupweights = [];
for T = 0:19
    load(['runs/mnist_TOS_1.0_',num2str(T),'.pkl.mat']);
    groupweights = [];
    weights = [];
    for t = 1:length(pickle_data.params)
        if mod(t,2) == 0
            pickle_data.params{t} = pickle_data.params{t}';
        end
        groupweights = [groupweights; sum(abs(pickle_data.params{t}),2)];
        weights = [weights; abs(pickle_data.params{t}(:))];
    end
 
    TOS.tr_acc(:,end+1) = pickle_data.tr_acc;
    TOS.te_acc(:,end+1) = pickle_data.te_acc;
    TOS.weights(:,end+1) = sort(weights,'descend');
    TOS.groupweights(:,end+1) = sort(groupweights,'descend');
end
TOS.mean_tr_acc = mean(TOS.tr_acc,2);
TOS.mean_te_acc = mean(TOS.te_acc,2);
TOS.mean_weights = mean(TOS.weights,2);
TOS.mean_groupweights = mean(TOS.groupweights,2);
TOS.std_tr_acc = std(TOS.tr_acc,0,2);
TOS.std_te_acc = std(TOS.te_acc,0,2);
TOS.std_weights = std(TOS.weights,0,2);
TOS.std_groupweights = std(TOS.groupweights,0,2);

SGD.tr_acc = [];
SGD.te_acc = [];
SGD.weights = [];
SGD.groupweights = [];
for T = 0:19
    load(['runs/mnist_SGD_0.01_',num2str(T),'.pkl.mat']);
    groupweights = [];
    weights = [];
    for t = 1:length(pickle_data.params)
        if mod(t,2) == 0
            pickle_data.params{t} = pickle_data.params{t}';
        end
        groupweights = [groupweights; sum(abs(pickle_data.params{t}),2)];
        weights = [weights; abs(pickle_data.params{t}(:))];
    end
 
    SGD.tr_acc(:,end+1) = pickle_data.tr_acc;
    SGD.te_acc(:,end+1) = pickle_data.te_acc;
    SGD.weights(:,end+1) = sort(weights,'descend');
    SGD.groupweights(:,end+1) = sort(groupweights,'descend');
end
SGD.mean_tr_acc = mean(SGD.tr_acc,2);
SGD.mean_te_acc = mean(SGD.te_acc,2);
SGD.mean_weights = mean(SGD.weights,2);
SGD.mean_groupweights = mean(SGD.groupweights,2);
SGD.std_tr_acc = std(SGD.tr_acc,0,2);
SGD.std_te_acc = std(SGD.te_acc,0,2);
SGD.std_weights = std(SGD.weights,0,2);
SGD.std_groupweights = std(SGD.groupweights,0,2);
 

ADAM.tr_acc = [];
ADAM.te_acc = [];
ADAM.weights = [];
ADAM.groupweights = [];
for T = 0:19
    load(['runs/mnist_ADAM_0.001_',num2str(T),'.pkl.mat']);
    groupweights = [];
    weights = [];
    for t = 1:length(pickle_data.params)
        if mod(t,2) == 0
            pickle_data.params{t} = pickle_data.params{t}';
        end
        groupweights = [groupweights; sum(abs(pickle_data.params{t}),2)];
        weights = [weights; abs(pickle_data.params{t}(:))];
    end
 
    ADAM.tr_acc(:,end+1) = pickle_data.tr_acc;
    ADAM.te_acc(:,end+1) = pickle_data.te_acc;
    ADAM.weights(:,end+1) = sort(weights,'descend');
    ADAM.groupweights(:,end+1) = sort(groupweights,'descend');
end
ADAM.mean_tr_acc = mean(ADAM.tr_acc,2);
ADAM.mean_te_acc = mean(ADAM.te_acc,2);
ADAM.mean_weights = mean(ADAM.weights,2);
ADAM.mean_groupweights = mean(ADAM.groupweights,2);
ADAM.std_tr_acc = std(ADAM.tr_acc,0,2);
ADAM.std_te_acc = std(ADAM.te_acc,0,2);
ADAM.std_weights = std(ADAM.weights,0,2);
ADAM.std_groupweights = std(ADAM.groupweights,0,2);

%% 

hfig = figure('Position',[100,100,1100,250]);
set(hfig,'name','numerics-nonconvex','numbertitle','off');

h(1)=subplot(1,4,1);
sparsity_weights = (1:length(SGD.groupweights))/length(SGD.groupweights);

x = sparsity_weights';

x = sparsity_weights';
y = SGD.mean_groupweights;
dy = SGD.std_groupweights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_sgd,'linestyle','none','facealpha',0.4); hold on;
plot(x,y+eps,'Color',c_sgd);hold on
ax = gca;
ax.YScale = 'log';
ax.XScale = 'linear';

y = ADAGRAD.mean_groupweights;
dy = ADAGRAD.std_groupweights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adagrad,'linestyle','none','facealpha',0.4);
plot(x,y+eps,'Color',c_adagrad);hold on

y = ADAM.mean_groupweights;
dy = ADAM.std_groupweights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adam,'linestyle','none','facealpha',0.4);
plot(x,y+eps,'Color',c_adam);hold on

y = TOS.mean_groupweights;
dy = TOS.std_groupweights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_tos,'linestyle','none','facealpha',0.4);
plot(x,y+eps,'Color',c_tos);hold on

y = ADAPTOS.mean_groupweights;
dy = ADAPTOS.std_groupweights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adaptos,'linestyle','none','facealpha',0.4);
plot(x,y+eps,'Color',c_adaptos);hold on



% hSGD = semilogy(sparsity_weights,sort(SGD.mean_groupweights,'descend'));  hold on
% hTOS = semilogy(sparsity_weights,sort(TOS.mean_groupweights,'descend'));
% hADAGRAD = semilogy(sparsity_weights,sort(ADAGRAD.mean_groupweights,'descend'));
% hADAM = semilogy(sparsity_weights,sort(ADAM.mean_groupweights,'descend'));
% hADAPTOS = semilogy(sparsity_weights,sort(ADAPTOS.mean_groupweights,'descend'),'color',c_adaptos);
ylim([1e-15,10])
grid on, grid minor, grid minor
ylabel('magnitude of group','Interpreter','latex','Fontsize',15)
xlabel('groups (ordered)','Interpreter','latex','Fontsize',15)
ax = gca;
set(findall(ax, 'Type', 'line'),'LineWidth',2);
ax.FontSize = 14;
ax.TickLabelInterpreter = 'latex';
ax.TickDir = 'out';
grid on; grid minor; grid minor;
set(gca,'TickDir','out')
set(gca,'LineWidth',0.75,'TickLength',[0.02 0.02]);
ax.Box = 'on';
ax.YTick = 10.^(-99:3:100);
ax.XTick = 0:0.2:1;
ax.XRuler.MinorTick = 'on'; %or 'off'
zx.XRuler.MinorTickValues = 0.1:0.2:0.9; %just like major ticks


h(2)=subplot(1,4,2);
sparsity_weights = (1:length(SGD.weights))/length(SGD.weights);

x = sparsity_weights';
y = sort(SGD.mean_weights,'descend');
dy = SGD.std_weights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_sgd,'linestyle','none','facealpha',0.4); hold on;
hSGD = plot(x,y+eps,'color',c_sgd);hold on
ax = gca;
ax.YScale = 'log';
ax.XScale = 'linear';

y = ADAGRAD.mean_weights;
dy = ADAGRAD.std_weights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adagrad,'linestyle','none','facealpha',0.4);
hADAGRAD = plot(x,y+eps,'color',c_adagrad);hold on

y = ADAM.mean_weights;
dy = ADAM.std_weights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adam,'linestyle','none','facealpha',0.4);
hADAM = plot(x,y+eps,'color',c_adam);hold on

y = TOS.mean_weights;
dy = TOS.std_weights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_tos,'linestyle','none','facealpha',0.4);
hTOS = plot(x,y+eps,'color',c_tos);hold on

y = ADAPTOS.mean_weights;
dy = ADAPTOS.std_weights;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adaptos,'linestyle','none','facealpha',0.4);
hADAPTOS = plot(x,y+eps,'color',c_adaptos);hold on

% semilogy(sparsity_weights,sort(SGD.mean_weights,'descend'),'color',c_sgd);hold on
% semilogy(sparsity_weights,sort(ADAGRAD.mean_weights,'descend'),'color',c_adagrad);
% semilogy(sparsity_weights,sort(ADAM.mean_weights,'descend'),'color',c_adam);
% semilogy(sparsity_weights,sort(TOS.mean_weights,'descend'),'color',c_tos);
% semilogy(sparsity_weights,sort(ADAPTOS.mean_weights,'descend'),'color',c_adaptos);
ylim([1e-15,1])
grid on, grid minor, grid minor
ylabel('magnitude of weight','Interpreter','latex','Fontsize',15)
xlabel('weights (ordered)','Interpreter','latex','Fontsize',15)
ax = gca;
set(findall(ax, 'Type', 'line'),'LineWidth',2);
ax.FontSize = 14;
ax.TickLabelInterpreter = 'latex';
ax.TickDir = 'out';
grid on; grid minor; grid minor;
set(gca,'TickDir','out')
set(gca,'LineWidth',0.75,'TickLength',[0.02 0.02]);
ax.Box = 'on';
% ax.XTick = (0:0.5:10).*1e5;
ax.XTick = 0:0.2:1;
ax.YTick = 10.^(-99:3:100);
ax.XRuler.MinorTick = 'on'; %or 'off'
zx.XRuler.MinorTickValues = 0.1:0.2:0.9; %just like major ticks


h(3)=subplot(1,4,3);

% plot(SGD.mean_tr_acc,'color',c_sgd);hold on
% plot(ADAGRAD.mean_tr_acc,'color',c_adagrad);hold on
% plot(ADAM.mean_tr_acc,'color',c_adam);hold on
% plot(TOS.mean_tr_acc,'color',c_tos);hold on

x = (1:length(SGD.mean_tr_acc))';
y = SGD.mean_tr_acc;
dy = SGD.std_tr_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_sgd,'linestyle','none','facealpha',0.4); hold on
plot(y,'color',c_sgd);hold on

x = (1:length(ADAGRAD.mean_tr_acc))';
y = ADAGRAD.mean_tr_acc;
dy = ADAGRAD.std_tr_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adagrad,'linestyle','none','facealpha',0.4);
plot(y,'color',c_adagrad);hold on

x = (1:length(ADAM.mean_tr_acc))';
y = ADAM.mean_tr_acc;
dy = ADAM.std_tr_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adam,'linestyle','none','facealpha',0.4);
plot(y,'color',c_adam);hold on


x = (1:length(TOS.mean_tr_acc))';
y = TOS.mean_tr_acc;
dy = TOS.std_tr_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_tos,'linestyle','none','facealpha',0.4);
plot(y,'color',c_tos);hold on

x = (1:length(ADAPTOS.mean_tr_acc))';
y = ADAPTOS.mean_tr_acc;
dy = ADAPTOS.std_tr_acc;
fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adaptos,'linestyle','none','facealpha',0.4);
plot(y,'color',c_adaptos);hold on

xlim([0,500])
ylim([0.94,1])
grid on, grid minor, grid minor
ylabel('training accuracy','Interpreter','latex','Fontsize',15)
xlabel('epoch','Interpreter','latex','Fontsize',15)
ax = gca;
set(findall(ax, 'Type', 'line'),'LineWidth',2);
ax.FontSize = 14;
ax.TickLabelInterpreter = 'latex';
ax.TickDir = 'out';
grid on; grid minor; grid minor;
set(gca,'TickDir','out')
set(gca,'LineWidth',0.75,'TickLength',[0.02 0.02]);
ax.Box = 'on';


ax.YTick = 0.9:0.01:1;
ax.XTick = 0:100:500;

h(4)=subplot(1,4,4);

x = (1:length(ADAM.mean_te_acc))';
y = ADAM.mean_te_acc;
dy = ADAM.std_te_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adam,'linestyle','none','facealpha',0.4); hold on
plot(y,'color',c_adam);hold on

x = (1:length(ADAGRAD.mean_te_acc))';
y = ADAGRAD.mean_te_acc;
dy = ADAGRAD.std_te_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adagrad,'linestyle','none','facealpha',0.4); hold on
plot(y,'color',c_adagrad);hold on

x = (1:length(SGD.mean_te_acc))';
y = SGD.mean_te_acc;
dy = SGD.std_te_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_sgd,'linestyle','none','facealpha',0.4); hold on
plot(y,'color',c_sgd);hold on

x = (1:length(TOS.mean_te_acc))';
y = TOS.mean_te_acc;
dy = TOS.std_te_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_tos,'linestyle','none','facealpha',0.4); hold on
plot(y,'color',c_tos);hold on

x = (1:length(ADAPTOS.mean_te_acc))';
y = ADAPTOS.mean_te_acc;
dy = ADAPTOS.std_te_acc;
hfill = fill([x;flipud(x)],[y-dy;flipud(y+dy)],c_adaptos,'linestyle','none','facealpha',0.4); hold on
plot(y,'color',c_adaptos);hold on

% hADAM = plot(ADAM.mean_te_acc,'color',c_adam);hold on
% hADAGRAD = plot(ADAGRAD.mean_te_acc,'color',c_adagrad);hold on
% hSGD = plot(SGD.mean_te_acc,'color',c_sgd);hold on
% hTOS = plot(TOS.mean_te_acc,'color',c_tos);hold on
% hADAPTOS = plot(ADAPTOS.mean_te_acc,'color',c_adaptos);hold on
xlim([0,500])
ylim([0.94,0.98])
grid on, grid minor, grid minor
ylabel('validation accuracy','Interpreter','latex','Fontsize',15)
xlabel('epoch','Interpreter','latex','Fontsize',15)
ax = gca;
set(findall(ax, 'Type', 'line'),'LineWidth',2);
ax.FontSize = 14;
ax.TickLabelInterpreter = 'latex';
ax.TickDir = 'out';
grid on; grid minor; grid minor;
set(gca,'TickDir','out')
set(gca,'LineWidth',0.75,'TickLength',[0.02 0.02]);
ax.Box = 'on';

subplot(1,4,2)
hl = legend([hSGD,hADAGRAD,hADAM,hTOS,hADAPTOS],'SGD','AdaGrad','Adam','TOS', 'AdapTOS');
hl.Location = 'SouthEast';
hl.Interpreter = 'latex';
hl.FontSize = 14;
ax.XTick = 0:100:500;

