%Optimization On Stiefel Manifolds
%contains various functions for operating optimization calculus and related geometries on Stiefel Manifold St(p, n)


classdef Stiefel_Optimization
   
%class open variables 
properties  
    omega %the weight sequence
    Seq   %the sequence of pointes on St(p, n)
    threshold_gradnorm   %the threshold for gradient norm when using GD
    threshold_fixedpoint %the threshold for fixed-point iteration for average
    threshold_checkonStiefel  %the threshold for checking if iteration is still on St(p, n)
    threshold_logStiefel %the threshold for calculating the Stiefel logarithmic map via iterative method
end  

   
%functions in the class
methods

    
function self = Stiefel_Optimization(omega, Seq, threshold_gradnorm, threshold_fixedpoint, threshold_checkonStiefel, threshold_logStiefel)           
%class constructor function
    if nargin > 0  
        self.omega = omega;  
        self.Seq = Seq;  
        self.threshold_gradnorm = threshold_gradnorm;
        self.threshold_fixedpoint = threshold_fixedpoint;
        self.threshold_checkonStiefel = threshold_checkonStiefel;
        self.threshold_logStiefel = threshold_logStiefel;
    end  
end


function [Q] = Complete_SpecialOrthogonal(self, A)
%given the matrix A in St(p, n), complete it into Q = [A B] in SO(n)
   n = size(A, 1);
   p = size(A, 2);
   [O1, D, O2] = svd(A);
   O2_ext = [O2 zeros(p, n-p); zeros(n-p, p) eye(n-p)]; 
   Q = O1 * O2_ext';
   if det(Q) < 0
       Q(:, p+1) = -Q(:, p+1);
   end    
end    


function [f, gradf] = Center_Mass_function_gradient_Euclid(self, Y)
%calculate the function value and the gradient on Stiefel manifold St(p, n) of the Euclidean center of mass function f_F(A)=\sum_{k=1}^m w_k \|A-A_k\|_F^2
    m = length(self.omega);
    f = 0;
    for i = 1:m
        f = f + self.omega(i)*(norm(Y-self.Seq(:,:,i), 'fro')^2);
    end
    gradf = 0;
    for i = 1:m
        gradf = gradf + 2*self.omega(i)*((Y-self.Seq(:,:,i))-Y*(Y-self.Seq(:,:,i))'*Y);
    end
end


function [Euclid_Center, value, gradnorm] = Center_Mass_Euclid(self)
%directly calculate the Euclidean center of mass that is the St(p, n) minimizer of f_F(A)=\sum_{k=1}^m w_k\|A-A_k\|_F^2, according to our elegant lemma based on SVD
    m = length(self.omega);
    n = size(self.Seq, 1);
    p = size(self.Seq, 2);
    B = zeros(n, p);
    for i=1:m
        B = B + self.omega(i) * self.Seq(:, :, i);
    end
    [O1, D, O2] = svd(B);
    O = zeros(p, n-p);
    Mtx = [eye(p) O];
    Mtx = Mtx';
    Euclid_Center = O1 * Mtx * O2';
    [value, grad] = self.Center_Mass_function_gradient_Euclid(Euclid_Center);
    gradnorm = norm(grad, 'fro');
end


     
function [ifStiefel, distance] = CheckOnStiefel(self, Y)
%test if the given matrix Y is on the Stiefel manifold St(p, n)
%Y is the matrix to be tested, threshold is a threshold value, if \|Y^TY-I_p\|_F < threshold then return true
    n = size(Y, 1);
    p = size(Y, 2);
    Mtx = Y'*Y - eye(p);
    distance = norm(Mtx, 'fro');
    if distance <= self.threshold_checkonStiefel
        ifStiefel = true;
    else
        ifStiefel = false;
    end
end
        

function [ifTangentStiefel] = CheckTangentStiefel(self, Y, H)
%test if the given matrix H is on the tangent space of Stiefel manifold T_Y St(p, n)
%H is the matrix to be tested, threshold is a threshold value, if \|Y^TH+H^TY\| < threshold then return true
    n = size(Y, 1);
    p = size(Y, 2);
    n_H = size(H, 1);
    p_H = size(H, 2);
    if (n == n_H) && (p == p_H)
        Mtx = Y' * H + H' * Y;
        distance = norm(Mtx + Mtx', 'fro');
        if distance <= self.threshold_checkonStiefel
            ifTangentStiefel = true;
        else
            ifTangentStiefel = false;
        end
    else
        ifTangentStiefel = false;
    end
end



function [prj_tg] = projection_tangent(self, Y, Z)
%calculate the projection onto tangent space of Stiefel manifold St(p, n)
%Pi_{T, Y}(Z) projects matrix Z of size n by p onto the tangent space of St(p, n) at point Y\in St(p, n)
%returns the tangent vector prj_tg on T_Y(St(p, n))
    n = size(Y, 1);
    p = size(Y, 2);
    skew = (Y' * Z - Z' * Y)/2;
    prj_tg = Y * skew + (eye(n) - Y * Y') * Z;
end
       

end %end of class methods
  
end %end of class