% This is MATLAB code for the ICLR submission titled
% "Minibatch vs Local SGD with Shuffling: Tight Convergence Bounds and Beyond".
% This was used to draw plots in Figures 1, 2, and 3.

clear
close all 

% define problem parameters
L = 100;
mu = 1;
nu = 1;
M = 16;
N = 768;
Bval = [1;4;16;64;256];
Kval = [1;3;5;7;10;30;50;70;100;300;500;700;1000];
%Kval = [1;3];
T = 20;

% arrays to store mean and quartiles, for plots
obj_mean_mb = zeros(length(Kval),1);
obj_mean_lc = zeros(length(Kval),1);
obj_mean_mb_ss = zeros(length(Kval),1);
obj_mean_lc_ss = zeros(length(Kval),1);
obj_mean_sing = zeros(length(Kval),1);
obj_mean_mb_w = zeros(length(Kval),1);
obj_mean_lc_w = zeros(length(Kval),1);

obj_errbar_mb = zeros(length(Kval),2);
obj_errbar_lc = zeros(length(Kval),2);
obj_errbar_mb_ss = zeros(length(Kval),2);
obj_errbar_lc_ss = zeros(length(Kval),2);
obj_errbar_sing = zeros(length(Kval),2);
obj_errbar_mb_w = zeros(length(Kval),2);
obj_errbar_lc_w = zeros(length(Kval),2);

% for the communication interval B in Bvals...
for l=1:length(Bval)
    B = Bval(l)
    % run the algorithms for # epochs K in Kvals
    for p=1:length(Kval)
        K = Kval(p);
        
        % run minibatch RR for K epochs; repeat T times and take average
        eta = B*log(M*N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_mb = 0; % global iterate
            sigma = zeros(M,N); % array to store permutations for M machines
            z = zeros(M,1); % local iterates for M machines
            for k=1:K
                % independent permutations
                for m=1:M
                    sigma(m,:) = randperm(N);
                end

                % update for N/B communication rounds
                for i=1:(N/B)
                    z = ones(M,1)*x_mb;
                    for m=1:M
                        for j=1:B
                            z(m)=z(m)-eta/B*grad(x_mb,L,mu,nu,N,sigma(m,(i-1)*B+j));
                        end
                    end
                    x_mb = mean(z);
                end
            end
            objvals(t) = objval(x_mb,L,mu);
        end
        obj_mean_mb(p) = mean(objvals);
        obj_errbar_mb(p,:) = quantile(objvals,[0.25 0.75]);

        % run local RR for K epochs; repeat T times and take average
        eta = log(M*N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_lc = 0; % global iterate
            sigma = zeros(M,N); % permutation for M machines
            z = zeros(M,1); % local iterates for M machines 
            for k=1:K
                % independent permutations
                for m=1:M
                    sigma(m,:) = randperm(N);
                end

                % update for N/B communication rounds
                for i=1:(N/B)
                    z = ones(M,1)*x_lc;
                    for m=1:M
                        for j=1:B
                            z(m)=z(m)-eta*grad(z(m),L,mu,nu,N,sigma(m,(i-1)*B+j));
                        end
                    end
                    x_lc = mean(z);
                end
            end
            objvals(t) = objval(x_lc,L,mu);
        end
        obj_mean_lc(p) = mean(objvals);
        obj_errbar_lc(p,:) = quantile(objvals,[0.25 0.75]);
        
        % run minibatch RR + synchronized shuffling for K epochs
        % repeat T times and take average
        eta = B*log(M^2*N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_mb_ss = 0; % global iterate
            sigma = zeros(M,N); % permutation for M machines
            z = zeros(M,1); % local iterates for M machines
            for k=1:K
                % synchronized shuffling
                sigma(1,:) = randperm(N);
                pi = randperm(M);
                for m=2:M
                    for i=1:N
                        sigma(m,i) = sigma(1,mod(i+N/M*(pi(m)-pi(1))-1,N)+1);
                    end
                end

                % update for N/B communication rounds
                for i=1:(N/B)
                    z = ones(M,1)*x_mb_ss;
                    for m=1:M
                        for j=1:B
                            z(m)=z(m)-eta/B*grad(x_mb_ss,L,mu,nu,N,sigma(m,(i-1)*B+j));
                        end
                    end
                    x_mb_ss = mean(z);
                end
            end
            objvals(t) = objval(x_mb_ss,L,mu);
        end
        obj_mean_mb_ss(p) = mean(objvals);
        obj_errbar_mb_ss(p,:) = quantile(objvals,[0.25 0.75]);

        % run local RR + synchronized shuffling for K epochs
        % repeat T times and take average
        eta = log(M^2*N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_lc_ss = 0; % global iterate
            sigma = zeros(M,N); % permutation for M machines
            z = zeros(M,1); % local iterates for M machines
            for k=1:K
                % synchronized shuffling
                sigma(1,:) = randperm(N);
                pi = randperm(M);
                for m=2:M
                    for i=1:N
                        sigma(m,i) = sigma(1,mod(i+N/M*(pi(m)-pi(1))-1,N)+1);
                    end
                end

                % update for N/B communication rounds
                for i=1:(N/B)
                    z = ones(M,1)*x_lc_ss;
                    for m=1:M
                        for j=1:B
                            z(m)=z(m)-eta*grad(z(m),L,mu,nu,N,sigma(m,(i-1)*B+j));
                        end
                    end
                    x_lc_ss = mean(z);
                end
            end
            objvals(t) = objval(x_lc_ss,L,mu);
        end
        obj_mean_lc_ss(p) = mean(objvals);
        obj_errbar_lc_ss(p,:) = quantile(objvals,[0.25 0.75]);
        
        % run single-machine RR for K epochs; repeat T times and take average
        eta = log(N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_sing = 0; % global iterate
            for k=1:K
                sigma = randperm(N);
                
                % update for N/B iterations, with batch size B
                for i=1:(N/B)
                    z = x_sing;
                    for j=1:B
                        z = z-eta*grad(x_sing,L,mu,nu,N,sigma((i-1)*B+j));
                    end
                    x_sing = z;
                end
            end
            objvals(t) = objval(x_sing,L,mu);
        end
        obj_mean_sing(p) = mean(objvals);
        obj_errbar_sing(p,:) = quantile(objvals,[0.25 0.75]);
        
        % run minibatch SGD for K*N/B communication rounds
        % repeat T times and take average
        eta = B*log(M*N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_mb_w = 0; % global iterate
            z = zeros(M,1); % local iterates for M machines
            for k=1:(K*N/B)
                z = ones(M,1)*x_mb_w;
                for m=1:M
                    % sampling with replacement
                    idx = randi(N,[B,1]);
                    for j=1:B
                        z(m)=z(m)-eta/B*grad(x_mb_w,L,mu,nu,N,idx(j));
                    end
                end
                x_mb_w = mean(z);
            end
            objvals(t) = objval(x_mb_w,L,mu);
        end
        obj_mean_mb_w(p) = mean(objvals);
        obj_errbar_mb_w(p,:) = quantile(objvals,[0.25 0.75]);

        % run local SGD for K*N/B communication rounds
        % repeat T times and take average
        eta = log(M*N*K^2)/mu/N/K; % step size
        objvals = zeros(T,1); % array to record final obj values
        for t=1:T
            x_lc_w = 0; % global iterate
            z = zeros(M,1); % local iterates for M machines
            for k=1:(K*N/B)
                z = ones(M,1)*x_lc_w;
                for m=1:M
                    % sampling with replacement
                    idx = randi(N,[B,1]);
                    for j=1:B
                        z(m)=z(m)-eta*grad(z(m),L,mu,nu,N,idx(j));
                    end
                end
                x_lc_w = mean(z);
            end
            objvals(t) = objval(x_lc_w,L,mu);
        end
        obj_mean_lc_w(p) = mean(objvals);
        obj_errbar_lc_w(p,:) = quantile(objvals,[0.25 0.75]);
    end

    % generate and save plots for Fig 1
    figure; hold on;
    errorbar(Kval,obj_mean_mb,obj_errbar_mb(:,1),obj_errbar_mb(:,2), ...
             'r-s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','red');
    errorbar(Kval,obj_mean_lc,obj_errbar_lc(:,1),obj_errbar_lc(:,2), ...
             'b--s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','blue');
    errorbar(Kval,obj_mean_mb_ss,obj_errbar_mb_ss(:,1),obj_errbar_mb_ss(:,2), ...
             'g:s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','green');
    errorbar(Kval,obj_mean_lc_ss,obj_errbar_lc_ss(:,1),obj_errbar_lc_ss(:,2), ...
             'k-.s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','black');
	set(gca, 'XScale','log', 'YScale','log')
    legend('MinibatchRR','LocalRR','MinibatchRR+SyncShuf','LocalRR+SyncShuf','FontSize',18);
    axis([min(Kval) max(Kval) 1e-9 1e-1]); hold off;
    set(gcf,'position',[0,0,800,600]);
    set(gca,'FontSize',15)
    saveas(gcf,"EB-Exp1-N"+N+"-M"+M+"-B"+B,'epsc');
    
    % generate and save plots for Fig 2
    figure; hold on;
    errorbar(Kval,obj_mean_mb,obj_errbar_mb(:,1),obj_errbar_mb(:,2), ...
             'r-s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','red');
    errorbar(Kval,obj_mean_lc,obj_errbar_lc(:,1),obj_errbar_lc(:,2), ...
             'b--s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','blue');
    errorbar(Kval,obj_mean_sing,obj_errbar_sing(:,1),obj_errbar_sing(:,2), ...
             'g:s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','green');
	set(gca, 'XScale','log', 'YScale','log')
    legend('MinibatchRR','LocalRR','SingleMachineRR','FontSize',18);
    axis([min(Kval) max(Kval) 1e-9 1e-1]); hold off;
    set(gcf,'position',[400,0,800,600]);
    set(gca,'FontSize',15)
    saveas(gcf,"EB-Exp2-N"+N+"-M"+M+"-B"+B,'epsc');
    
    % generate and save plots for Fig 3
    figure; hold on;
    errorbar(Kval,obj_mean_mb,obj_errbar_mb(:,1),obj_errbar_mb(:,2), ...
             'r-s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','red');
    errorbar(Kval,obj_mean_lc,obj_errbar_lc(:,1),obj_errbar_lc(:,2), ...
             'b--s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','blue');
    errorbar(Kval,obj_mean_mb_w,obj_errbar_mb_w(:,1),obj_errbar_mb_w(:,2), ...
             'g:s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','green');
	errorbar(Kval,obj_mean_lc_w,obj_errbar_lc_w(:,1),obj_errbar_lc_w(:,2), ...
             'k-.s','LineWidth',1.5,'MarkerSize',5,'MarkerFaceColor','black');
	set(gca, 'XScale','log', 'YScale','log')
    legend('MinibatchRR','LocalRR','MinibatchSGD','LocalSGD','FontSize',18);
    axis([min(Kval) max(Kval) 1e-9 1e-1]); hold off;
    set(gcf,'position',[800,0,800,600]);
    set(gca,'FontSize',15)
    saveas(gcf,"EB-Exp3-N"+N+"-M"+M+"-B"+B,'epsc');
end