function [gc, gG] = backpropZonotopeBatch(obj, c, G, gc, gG, options, updateWeights)
% backpropZonotopeBatch - compute the backpropagation for the previous input
%    with batches of zonotopes
%
% Syntax:
%    [gc,gG] = layeri.backpropZonotopeBatch(c,G,gc,gG,options,updateWeights);
%
% Inputs:
%    c, G - batch of input zonotopes; [n,q+1,b] = size([gc gG]),
%    gc, gG - batch of zonotope gradients; [n,q+1,b] = size([gc gG]),
%       where n is the number of dims, q the number of generators, and b the batch size
%    options - training parameters
%    updateWeights - only relevent for layer with learnable parameters
%
% Outputs:
%    gc, gG - zonotope gradients w.r.t the input
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork/backpropZonotopeBatch

% Authors:       Lukas Koller
% Written:       12-August-2025
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% Obtain indices of active generators.
genIds = obj.backprop.store.genIds;
% Obtain stored slope.
m = obj.backprop.store.coeffs;

% Check if there is an exact backpropagation through the image enclosure.
if options.nn.train.exact_backprop
    % Obtain the stored gradients.
    m_l = obj.backprop.store.m_l;
    m_u = obj.backprop.store.m_u;
    % Obtain the stored gradients.
    dl_l = obj.backprop.store.dl_l;
    dl_u = obj.backprop.store.dl_u;
    du_l = obj.backprop.store.du_l;
    du_u = obj.backprop.store.du_u;
    % Compute the sigh of the active generators.
    r_G = sign(G(:,genIds,:));
    % Compute the gradient of the slope w.r.t. the generators.
    m_G = permute(m_u - m_l,[1 3 2]).*r_G;
    % Obtain the number of dimensions and the batch size.
    [n,~,bs] = size(G);

    if options.nn.interval_center
        % Obtain the bounds of the interval center.
        cl = reshape(c(:,1,:),[n bs]);
        cu = reshape(c(:,2,:),[n bs]);
        gl = reshape(gc(:,1,:),[n bs]);
        gu = reshape(gc(:,2,:),[n bs]);

        % Precompute outer product of gradients and inputs.
        hadProd_l = permute(gl.*cl,[1 3 2]);
        hadProd_u = permute(gu.*cu,[1 3 2]);
        hadProd = hadProd_l + hadProd_u + sum(gG(:,genIds,:).*G,2);

        % Backprop gradients.
        rgc_l = gl.*m + m_l.*reshape(hadProd,[n bs]);
        rgc_u = gu.*m + m_u.*reshape(hadProd,[n bs]);
        rgG = gG(:,genIds,:).*permute(m,[1 3 2]) + m_G.*hadProd; 

        if options.nn.use_approx_error
            % Obtain indices of the approximation errors in the generator
            % matrix.
            dDimsIdx = obj.backprop.store.dDimsIdx;
            notdDimsIdx = obj.backprop.store.notdDimsIdx;
            % Compute gradients w.r.t. center and generators.
            du_G = permute(du_u - du_l,[3 1 2]);
            dl_G = permute(dl_u - dl_l,[3 1 2]);

            % Backprop gradients.
            rgc_l(dDimsIdx) = rgc_l(dDimsIdx) ...
                + du_l.*gu(dDimsIdx) + dl_l.*gl(dDimsIdx);
            rgc_u(dDimsIdx) = rgc_u(dDimsIdx) ...
                + du_u.*gu(dDimsIdx) + dl_u.*gl(dDimsIdx);

            rgG = permute(rgG,[2 1 3]);
            r_G = permute(r_G,[2 1 3]);
            rgG(:,dDimsIdx) = rgG(:,dDimsIdx) ... 
                + (du_G(:,:).*gu(dDimsIdx(:))' ...
                    + dl_G(:,:).*gl(dDimsIdx(:))').*r_G(:,dDimsIdx);
            rgG = permute(rgG,[2 1 3]);     
            % Assign results.
            gc = permute(cat(3,rgc_l,rgc_u),[1 3 2]);
            gG = rgG;
        else
            % TODO.
        end
    else
        % Comput the gradient of the slope w.r.t. the center.
        m_c = (m_u + m_l);
        % Precompute outer product of gradients and inputs.
        hadProd = permute(gc.*c,[1 3 2]) + sum(gG(:,genIds,:).*G,2);

        % Backprop gradients.
        rgc = gc.*m + m_c.*reshape(hadProd,[n bs]);
        rgG = gG(:,genIds,:).*permute(m,[1 3 2]) + m_G.*hadProd; 

        if options.nn.use_approx_error
            % Obtain indices of the approximation errors in the generator
            % matrix.
            GdIdx = obj.backprop.store.GdIdx;
            dDimsIdx = obj.backprop.store.dDimsIdx;
            notdDimsIdx = obj.backprop.store.notdDimsIdx;
            % Compute gradients w.r.t. center and generators.
            dc_c = 1/2*(du_u + du_l + dl_u + dl_l);
            dc_G = 1/2*permute(du_u - du_l + dl_u - dl_l,[3 1 2]);

            d_c = 1/2*(du_u + du_l - dl_u - dl_l);
            d_G = 1/2*permute(du_u - du_l - dl_u + dl_l,[3 1 2]);

            % Backprop gradients.
            rgc(dDimsIdx) = rgc(dDimsIdx) ...
                + dc_c.*gc(dDimsIdx) + d_c.*gG(GdIdx);
            rgc(notdDimsIdx) = obj.df(c(notdDimsIdx)).*gc(notdDimsIdx);

            rgG = permute(rgG,[2 1 3]);
            r_G = permute(r_G,[2 1 3]);
            rgG(:,dDimsIdx) = rgG(:,dDimsIdx) ... 
                + dc_G(:,:).*r_G(:,dDimsIdx).*gc(dDimsIdx(:))' ...
                + d_G(:,:).*r_G(:,dDimsIdx).*gG(GdIdx(:))';

            rgG = permute(rgG,[2 1 3]);     
            % Assign results.
            gc = rgc;
            gG = rgG;
        else
            % TODO.
        end
    end
else
    % Consider the linear approximation as fixed. Use the slope of the
    % approximation for backpropagation.
    if options.nn.interval_center
        gc = permute(m,[1 3 2]).*gc;
    else
        gc = gc.*m;
    end
    gG = gG(:,genIds,:).*permute(m,[1 3 2]);
end

end

% ------------------------------ END OF CODE ------------------------------
