warning('off','all')
dim = 3;
sigma = 0.03;

eta = 0.001;
t_proposal = 0.001001;
count_G_03 = 0;

store = 0;
total = 1000;
inner = 100;
store_x_RS = zeros(inner, total, dim*dim);
check_nan = 0;
count_cpu = 0;
start_cpu = cputime;

count_rej = 0;

for i = 1 : total
    i
    start_pt = eye(3, 3);
    start_pt = start_pt * start_pt';
    x_k = 2*start_pt;
    tostore = reshape(x_k, dim * dim, 1);
    store_x_RS(1, i, :) = tostore;
    for k = 2 : inner
        y = BM(x_k, eta, dim);
        x = rand(dim, dim);
        x = x * x';
        [x, final] = Find_mini(x, y, dim, sigma, eta);        
        while 1
            proposal = BM(x, t_proposal, dim);
            rho = eval_rho(proposal, y, dim, sigma, eta, x, final, t_proposal);
            u = rand();
            count_rej = count_rej + 1;
            if u < rho
                x_k = proposal;
                break 
            end 
        end 
        tostore = reshape(x_k, dim * dim, 1);
        store_x_RS(k, i, :) = tostore;
    end 
end 

function nor_result = norm_square_manifold(X, S)
    nor_result = trace(inv(X) * S * inv(X) * S);
end

function y = exp_map(X, S)
    y = sqrtm(X) * expm( 1 * inv(sqrtm(X)) * S * inv(sqrtm(X))) * sqrtm(X);
end 

function y = exp_inv(X1, X2)
    y = sqrtm(X1) * logm( inv(sqrtm(X1)) * X2 * inv(sqrtm(X1))) * sqrtm(X1);
end 

function y = distance_manif(X1, X2)
    X = inv(X1) * X2;
    eigs_X = eig(X);
    s = size(eigs_X);
    s = s(1);
    y = 0;
    for i = 1 : s 
        y = y + log(eigs_X(i, 1))^2;
    end
    y = sqrt(y);
end 

function grad = Grad_of_distance(X1, X2)
    grad = -2 * exp_inv(X1, X2);
end 
function grad = Rie_Grad(X, Y, dim, sigma, eta)
    X2 = eye(dim);
    grad = (1/(2*sigma * sigma)) * Grad_of_distance(X, X2) * 2 * distance_manif(X, X2)^2 + (1/(2*eta)) * Grad_of_distance(X, Y);
end 

function result = f_d_eval(X, Y, dim, sigma, eta)
    X2 = eye(dim);
    result = (1/(2*sigma * sigma)) * distance_manif(X, X2)^4 + (1/(2*eta)) * distance_manif(X, Y)^2;
end 
function [X, final] = Find_mini(X, Y, dim, sigma, eta)
    X = rand(3, 3);
    X = X * X';
    X = X / norm(X);
    X = Y + 0.01*X;
    initial_stepsize = 0.0001;
    stepsize = initial_stepsize;
    c1 = 0.00001;
    for iter = 1 : 60
        ct = 0;
        Grad = Rie_Grad(X, Y, dim, sigma, eta);
        if norm_square_manifold(X, Grad) < 1000
            stepsize = 0.0001;
        else 
            stepsize = 0.00002;
        end 
        trial_pt = exp_map(X, - Grad * stepsize);
        if isreal(trial_pt) == 0
            X = rand(3, 3);
            X = X * X';
            X = X / norm(X);
            iter = 1;
            continue
        end
        if norm_square_manifold(X, Grad) < 0.001
            X = trial_pt;
            break 
        end 
        X = trial_pt;
    end 
    final = f_d_eval(X, Y, dim, sigma, eta);
end

function nor_result = inner_prod(X, U, V, dim)
    U = reshape(U, dim, dim);
    V = reshape(V, dim, dim);
    nor_result = trace(inv(X) * U * inv(X) * V);
end

function final_pt = BM(X, eta, dim)
    X2 = X;
    matrixdim = dim;
    dim = dim * (dim + 1)/2;
    tangent = zeros(matrixdim*matrixdim, dim);
    parti = 1;
    for i = 1 : parti
        for d = 1 : dim
            tmp = rand(matrixdim, matrixdim);
            tmp = tmp * tmp';
            ou = reshape(tmp, matrixdim*matrixdim, 1);
            for dd = 1 : d-1 
                u = inner_prod(X, ou, tangent(:, dd), matrixdim) / inner_prod(X, tangent(:, dd), tangent(:, dd), matrixdim) * tangent(:, dd);
                ou = ou - u;
            end 
            ou = ou / sqrt(inner_prod(X, ou, ou, matrixdim));
            tangent(:, d) = ou;
        end 
        mean_rv = zeros(dim, 1);
        var_rv = eye(dim);
        r_v = mvnrnd(mean_rv, var_rv)';

        r = zeros(matrixdim*matrixdim, 1);
        for d = 1 : dim - 1
            r = r + tangent(:, d) * r_v(d, 1);
        end 
        r = reshape(r, matrixdim, matrixdim);
        X = exp_map(X, (eta/parti)^(0.5) * r);
    end 
    final_pt = X;
end 

function result = eval_rho(X, Y, dim, sigma, eta, X_star, final, t)
    K = 1/(2*t);
    currect = f_d_eval(X, Y, dim, sigma, eta);
    result = exp( - currect + final) *exp( K* distance_manif(X, X_star)^2) ;
end 