warning('off','all')
dim = 3;
sigma = 0.03;

count_G_03 = 0;
eta = 0.0001;
store = 0;
total = 1000;
inner = 100; 
store_x_LMC_ref = zeros(inner, total, dim*dim);
store_x_LMC_0005 = zeros(inner, total, dim*dim);
check_nan = 0;

for i = 1 : total
    i
    start_pt = eye(3, 3);
    start_pt = start_pt * start_pt';
    x_k = start_pt;
    tostore = reshape(x_k, dim * dim, 1);
    store_x_LMC_0005(1, i, :) = tostore;
    for k = 2 : inner
        Grad =  Rie_Grad_f(x_k, dim, sigma);
        x_k = BM2(x_k, eta, - Grad, dim);
        tostore = reshape(x_k, dim * dim, 1);
        store_x_LMC_0005(k, i, :) = tostore;
    end 
end 

for i = 1 : total
    i
    store_x_LMC_ref(1, i, :) = store_x_LMC_0005(100, i, :);
    x_k = reshape(store_x_LMC_0005(100, i, :), dim, dim);
    for k = 2 : inner
        if k < 10 
            eta = 0.0001;
        elseif k < 30 
            eta = 0.00005;
        elseif k < 60 
            eta = 0.000001;
        else 
            eta = 0.0000001;
        end 
        Grad =  Rie_Grad_f(x_k, dim, sigma);
        x_k = BM2(x_k, eta, - Grad, dim);
        tostore = reshape(x_k, dim * dim, 1);
        store_x_LMC_ref(k, i, :) = tostore;
    end 
    
end 

function nor_result = norm_square_manifold(X, S)
    nor_result = trace(inv(X) * S * inv(X) * S);
end

function y = exp_map(X, S)
    y = sqrtm(X) * expm( 1 * inv(sqrtm(X)) * S * inv(sqrtm(X))) * sqrtm(X);
end 

function y = exp_inv(X1, X2)
    y = sqrtm(X1) * logm( inv(sqrtm(X1)) * X2 * inv(sqrtm(X1))) * sqrtm(X1);
end 

function y = distance_manif(X1, X2)
    X = inv(X1) * X2;
    eigs_X = eig(X);
    s = size(eigs_X);
    s = s(1);
    y = 0;
    for i = 1 : s 
        y = y + log(eigs_X(i, 1))^2;
    end
    y = sqrt(y);
end 

function grad = Grad_of_distance(X1, X2)
    %Grad of d(X1, X2)^2 where X1 is variable and X2 is fixed.
    grad = -2 * exp_inv(X1, X2);
end 
function grad = Rie_Grad_f(X, dim, sigma)
    X2 = eye(dim);
    grad =  (1/(2*sigma * sigma)) * Grad_of_distance(X, X2) * 2 * distance_manif(X, X2)^2 ;
end 

function nor_result = inner_prod(X, U, V, dim)
    U = reshape(U, dim, dim);
    V = reshape(V, dim, dim);
    nor_result = trace(inv(X) * U * inv(X) * V);
end

function final_pt = BM(X, eta, dim)
    X0 = X;
    X = (X + X')/2;
    matrixdim = dim;
    dim = dim * (dim + 1)/2;
    tangent = zeros(matrixdim*matrixdim, dim);
    parti = 1;
    for i = 1 : parti
        check = 0;
        while 1
            for d = 1 : dim
                tmp = rand(matrixdim, matrixdim);
                tmp = tmp * tmp';
                ou = reshape(tmp, matrixdim*matrixdim, 1);
                for dd = 1 : d-1 
                    %u = (ou' * tangent(:, dd)) / norm(tangent(:, dd))^2 * tangent(:, dd);
                    u = inner_prod(X, ou, tangent(:, dd), matrixdim) / inner_prod(X, tangent(:, dd), tangent(:, dd), matrixdim) * tangent(:, dd);
                    ou = ou - u;
                end 
                ou = ou / sqrt(inner_prod(X, ou, ou, matrixdim));
                if isreal(ou) == 0
                    check = 1;
                    X
                    break
                end 
                tangent(:, d) = ou;
            end 
            if check == 0 
                break
            end 
        end 
        mean_rv = zeros(dim, 1);
        var_rv = eye(dim);
        r_v = mvnrnd(mean_rv, var_rv)';

        r = zeros(matrixdim*matrixdim, 1);
        for d = 1 : dim - 1
            r = r + tangent(:, d) * r_v(d, 1);
        end 
        r = reshape(r, matrixdim, matrixdim);
        X = exp_map(X, (eta/parti)^(0.5) * r);
    end 
    final_pt = X;
end 
function final_pt = BM2(x_k, eta, grad, dim)
    x_k = exp_map(x_k, eta * grad);
    final_pt = BM(x_k, 2*eta, dim);
end 