function [c,G,id] = quadMapZonotope(c1,G1,id1, c2,G2,id2, n, k, m)
% quadMapZonotope - compute quadratic map of uncertain matrices
%
% Syntax:
%    res = nnHelper.quadMapZonotope(c1,G1,id1,c2,G2,id2, n, k, m)
%
% Inputs:
%    c1,G1,id1 - first zonotope
%    c2,G2,id2 - first zonotope
%    n - first dimension of first matrix polyZontope 
%    k - second dimension of first matrix polyZonotope = first dimension of second matrix polyZontope
%    m - second dimension of second matrix polyZontope
%    
% Outputs:
%    [c,G,id] - resulting zonotope according to [1, Thm. 5]
%
% References:
%    [1] Bonaert et al. "Fast and Precise Certification of Transformers", 2021
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Tobias Ladner
% Written:       22-September-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

if nargin < 5
    throw(CORAerror('CORA:notEnoughInputArgs',5));
end

% transform given sets to polynomial matrix zonotopes

% reshape to respective matric
c1 = reshape(c1,n,k); 
h1 = size(G1,2);
G1 = reshape(G1,n,k,h1,1);
c2 = reshape(c2,k,m); 
h2 = size(G2,2);
G2 = reshape(G2,k,m,1,h2);

% compute respective matrix multiplications
c1c2 = c1 * c2;
G1c2 = pagemtimes(G1, c2);
c1G2 = pagemtimes(c1, G2);
G1G2 = pagemtimes(G1, G2);

% reshape back to vector
c1c2 = reshape(c1c2,n*m,[]);
G1c2 = reshape(G1c2,n*m,[]);
c1G2 = reshape(c1G2,n*m,[]);
G1G2 = reshape(G1G2,n*m,[]);

% merge according to ids
[cG,id] = nnHelper.mergeLabelledGenerators(G1c2,id1,c1G2,id2);

% determine quadratic terms
% assign new ids for unlabelled generators
id_ = max(id);
id1_ = id_ + (1:h1);
id2_ = (id_+h1) + (1:h2);
id1_(1:numel(id1)) = id1;
id2_(1:numel(id2)) = id2;
isQuadratic = reshape(id1_' == id2_, 1, []);
if isempty(isQuadratic)
    isQuadratic = false(1,size(G1G2,2));
end

% apply [1, Thm. 5] halfing quadratic terms
c = c1c2 + 0.5 * sum(abs(G1G2(:,isQuadratic)),2);
I_rad = sum(abs(G1G2(:,~isQuadratic)),2) + 0.5*sum(abs(G1G2(:,isQuadratic)),2);

% add interval bounds
G = [cG diag(I_rad)];

end

% ------------------------------ END OF CODE ------------------------------
