%Optimization On Grassmann Manifolds
%contains various functions for operating optimization calculus and related geometries on Grassmann manifold G_{n,p}

%author: Wenqing Hu (Missouri S&T)

classdef Grassmann_Optimization
   
%class open variables 
properties  
    omega %the weight sequence
    Seq   %the sequence of points on St(p, n) that are identified as points on G_{n,p} 
    threshold_gradnorm   %the threshold for gradient norm when using GD
    threshold_fixedpoint %the threshold for fixed-point iteration for average
    threshold_checkonGrassmann  %the threshold for checking if iteration is still on the Grassmann manifold (actually St(p,n))
end  

   
%functions in the class
methods

    
function self = Grassmann_Optimization(omega, Seq, threshold_gradnorm, threshold_fixedpoint, threshold_checkonGrassmann)           
%class constructor function
    if nargin > 0  
        self.omega = omega;  
        self.Seq = Seq;  
        self.threshold_gradnorm = threshold_gradnorm;
        self.threshold_fixedpoint = threshold_fixedpoint;
        self.threshold_checkonGrassmann = threshold_checkonGrassmann;
    end  
end


function [value, grad] = Center_Mass_function_gradient_pFrobenius(self, Y)
%find the value and grad of the projected Frobenius distance center of mass function f(A)=\sum_{k=1}^m w_k |AA^T-A_kA_k^T|_F^2 on G_{n,p}
    A = Y;
    m = length(self.omega);
    n = size(A, 1);
    p = size(A, 2);
    value = 0;
    for k = 1:m
        Mtx = A * A' - self.Seq(:,:,k) * self.Seq(:,:,k)';
        value = value + self.omega(k) * (norm(Mtx, 'fro')^2);
    end
    grad = zeros(n, p);
    for k = 1:m
        M1 = A .* (2 * self.omega(k));
        M2 = self.Seq(:,:,k) * self.Seq(:,:,k)' * A .* (4 * self.omega(k));
        grad = grad + M1 - M2;
    end
    grad = grad - A * A' * grad;
end    


function [pF_Center, value, grad] = Center_Mass_pFrobenius(self)
%directly calculate the center of mass on G_{n,p} with respect to projected Frobenius norm
    m  = length(self.omega);
    n = size(self.Seq, 1);
    p = size(self.Seq, 2);
    total_weight = sum(self.omega);
    Mtx = zeros(n, n);
    for k = 1:m
        Mtx = Mtx + (self.Seq(:,:,k) * self.Seq(:,:,k)').*(self.omega(k)/total_weight);
    end
    [Q, D, Q1] = svd(Mtx);
    I = [diag(ones(p, 1)); zeros(n-p, p)];
    pF_Center = Q * I;
    [value, grad] = self.Center_Mass_function_gradient_pFrobenius(pF_Center);
end


     
function [ifGrassmann, distance] = CheckOnGrassmann(self, Y)
%test if the given matrix Y is on the Grassmann manifold G_{n,p}
%same as tesing that Y is on St(p, n)
%Y is the matrix to be tested, threshold is a threshold value for returning true
    n = size(Y, 1);
    p = size(Y, 2);
    Mtx = Y'*Y - eye(p);
    distance = norm(Mtx, 'fro');
    if distance <= self.threshold_checkonGrassmann
        ifGrassmann = true;
    else
        ifGrassmann = false;
    end
end
        

function [ifTangentGrassmann, distance] = CheckTangentGrassmann(self, Y, H)
%test if the given matrix H is on the tangent space of Grassmann manifold T_Y G_{n,p}
%H is the matrix to be tested, threshold is a threshold value for returning true
    n = size(Y, 1);
    p = size(Y, 2);
    n_H = size(H, 1);
    p_H = size(H, 2);
    if (n == n_H) && (p == p_H)
        Mtx = Y' * H;
        distance = norm(Mtx, 'fro');
        if distance <= self.threshold_checkonGrassmann
            ifTangentGrassmann = true;
        else
            ifTangentGrassmann = false;
        end
    else
        ifTangentGrassmann = false;
    end
end



function [prj_tg] = projection_tangent(self, Y, Z)
%calculate the projection onto tangent space of Grassmann manifold G_{n,p}
%Pi_{T, Y}(Z) projects matrix Z of size n times p onto the tangent space of G_{n,p} at point Y\in St(p, n)
%returns the tangent vector prj_tg on T_Y(G(n,p))
    n = size(Y, 1);
    p = size(Y, 2);
    prj_tg = (eye(n) - Y * Y') * Z;
end
       


end %end of class methods

end %end of class Grassmann_Optimization