kappa = 10;
dim = 3;
mu = ones(dim, 1);
mu(2, 1) = 0.1;
mu(1, 1) = 10;
mu(3, 1) = 2;
eta = 0.05;
t_proposal = 0.06;
count_G_03 = 0;
mean_gau = zeros(dim, 1);
var_gau = eye(dim);
store = 0;
total = 1000;
inner = 50;
store_x_RS_tmp = zeros(inner, total, dim);
check_nan = 0;

count_rej = 0;

for i = 1 : total
    i
    start_pt = mvnrnd(mean_gau, var_gau);
    start_pt = start_pt / norm(start_pt);
    start_pt = start_pt';
    x_k = start_pt;
    store_x_RS_tmp(1, i, :) = x_k;
    for k = 1 : inner-1
        y = BM(x_k, eta, dim);
        x = mvnrnd(mean_gau, var_gau);
        x = x/norm(x);
        x = x';
        [x, final] = Find_mini3(x, kappa, mu, y, eta, dim);
        while 1
            proposal2 = R_Gauss(x, t_proposal, dim);
            rho = eval_rho2(proposal2, kappa, mu, y, eta, x, dim, t_proposal, final);

            u = rand();
            count_rej = count_rej + 1;
            if u < rho
                x_k = proposal2;
                break 
            end 
        end 
        store_x_RS_tmp(k+1, i, :) = x_k;
    end 
end 

function final_pt = BM(x_k, eta, dim)
    while(1)
        u = rand();
        y = R_Gauss(x_k, eta, dim);
        di = acos(x_k' * y);
        rhoB = 0.08 * heat_l_eval(x_k, y, eta) / exp(-di^2 / (2*eta));
        if u < rhoB
            final_pt = y;
            break 
        end 
    end 
end 

function result = grad_heat(x, y, t)
    inner_pro = x' * y;
    T = zeros(18, 1);
    T(1, 1) = 0; %l = 0
    T(2, 1) = exp(-1*2*t/2) * (2*1 + 1)/(4*pi) ;
    T(3, 1) = exp(-2*3*t/2) * (2*2 + 1)/(4*pi) * (3*inner_pro^1 * 2 )/2;
    T(4, 1) = exp(-3*4*t/2) * (2*3 + 1)/(4*pi) * (5*inner_pro^2 * 3 - 3)/2; % l = 3
    T(5, 1) = exp(-4*5*t/2) * (2*4 + 1)/(4*pi) * (35*inner_pro^3 * 4 - 30*inner_pro^1 * 2)/8; 
    T(6, 1) = exp(-5*6*t/2) * (2*5 + 1)/(4*pi) * (63*inner_pro^4 * 5 - 70*inner_pro^2 * 3 + 15)/8; 
    T(7, 1) = exp(-6*7*t/2) * (2*6 + 1)/(4*pi) * (231*inner_pro^5 * 6 - 315*inner_pro^3 * 4 + 105*inner_pro^1 * 2)/16; 
    T(8, 1) = exp(-7*8*t/2) * (2*7 + 1)/(4*pi) * (429*inner_pro^6 * 7 - 693*inner_pro^4 * 5 + 315*inner_pro^2 * 3 - 35)/16;
    T(9, 1) = exp(-8*9*t/2) * (2*8 + 1)/(4*pi) * (6435*inner_pro^7 * 8 - 12012*inner_pro^5 * 6 + 6930*inner_pro^3 * 4 - 1260*inner_pro^1 * 2)/128; 
    T(10, 1) = exp(-9*10*t/2) * (2*9 + 1)/(4*pi) * (12155*inner_pro^8 * 9 - 25740*inner_pro^6 * 7 + 18018*inner_pro^4 * 5 - 4620*inner_pro^2 * 3 + 315)/128; 
    T(11, 1) = exp(-10*11*t/2) * (2*10 + 1)/(4*pi) * (46189*inner_pro^9 * 10 - 109395*inner_pro^7 * 8 + 90090*inner_pro^5 * 6 - 30030*inner_pro^3 * 4 + 3465*inner_pro^1 * 2 )/256; 
    T(12, 1) = exp(-11*12*t/2) * (2*11 + 1)/(4*pi) * (88179 *inner_pro^10 * 11 - 230945 *inner_pro^8 * 9 + 218790 *inner_pro^6 * 7 - 90090 *inner_pro^4 * 5 + 15015 *inner_pro^2 * 3 - 693)/256; 

    T(13, 1) = exp(-12*13*t/2) * (2*12 + 1)/(4*pi) * (676039*inner_pro^11 * 12 - 1939938*inner_pro^9 * 10 + 2078505*inner_pro^7 * 8 - 1021020*inner_pro^5 * 6 + 225225*inner_pro^3 * 4 - 18018*inner_pro^1 * 2)/1024; 
    T(14, 1) = exp(-13*14*t/2) * (2*13 + 1)/(4*pi) * (1300075*inner_pro^12 * 13 - 4056234*inner_pro^10 * 11 + 4849845*inner_pro^8 * 9 - 2771340*inner_pro^6 * 7 + 765765*inner_pro^4 * 5 - 90090*inner_pro^2 * 3 + 3003)/1024; 
    T(15, 1) = exp(-14*15*t/2) * (2*14 + 1)/(4*pi) * (5014575 *inner_pro^13 * 14 - 16900975 *inner_pro^11 * 12 + 22309287 *inner_pro^9 * 10 - 14549535 *inner_pro^7 * 8 + 4849845 *inner_pro^5 * 6 - 765765 *inner_pro^3 * 4 + 45045 *inner_pro^1 * 2)/2048; 
    T(16, 1) = exp(-15*16*t/2) * (2*15 + 1)/(4*pi) * (9694845 *inner_pro^14 * 15 - 35102025 *inner_pro^12 * 13 + 50702925 *inner_pro^10 * 11 - 37182145 *inner_pro^8 * 9 + 14549535 *inner_pro^6 * 7 - 2909907 *inner_pro^4 * 5 + 255255 *inner_pro^2 * 3 - 6435)/2048; 
    T(17, 1) = exp(-16*17*t/2) * (2*16 + 1)/(4*pi) * (300540195 *inner_pro^15 * 16 - 1163381400 *inner_pro^13 * 14 + 1825305300 *inner_pro^11 * 12 - 1487285800 *inner_pro^9 * 10 + 669278610 *inner_pro^7 * 8 - 162954792 *inner_pro^5 * 6 + 19399380 *inner_pro^3 * 4 - 875160 *inner_pro^1 * 2)/32768; 
    T(18, 1) = exp(-17*18*t/2) * (2*17 + 1)/(4*pi) * (583401555 *inner_pro^16 * 17 - 2404321560 *inner_pro^14 * 15 + 4071834900 *inner_pro^12 * 13 - 3650610600 *inner_pro^10 * 11 + 1859107250 *inner_pro^8 * 9 - 535422888 *inner_pro^6 * 7 + 81477396 *inner_pro^4 * 5 - 5542680 *inner_pro^2 * 3 + 109395)/32768; 

    result = 0;
    for l = 1 : 18
        result = result + T(l, 1);
    end 
end 

function grad = Rie_Grad2(x, kappa, mu, y, eta, dim)
    tmp = - kappa * mu - grad_heat(x, y, eta)/heat_l_eval(x, y, eta) * y ;
    u = (tmp' * x) / norm(x)^2 * x;
    grad = tmp - u;
end 

function y = exp_map(x, v)
    rad = 1;
    y = x * cos(norm(v) / rad ) / rad + v/norm(v) * sin(norm(v) / rad );
end 

function result = f_d_eval2(x, kappa, mu, y, eta, dim)
    result = - kappa * mu' * x - log(heat_l_eval(x, y, eta));
end 

function result = eval_rho2(x, kappa, mu, y, eta, x_star, dim, t, final)
    K = 1/(2*t);
    currect = f_d_eval2(x, kappa, mu, y, eta, dim);
    result = exp( - currect + final) *exp( K* acos(x' * x_star)^2);
end 

function result = heat_l_eval(x, y, t)
    inner_pro = x' * y;
    T = zeros(18, 1);
    T(1, 1) = 1/(4*pi); %l = 0
    T(2, 1) = exp(-1*2*t/2) * (2*1 + 1)/(4*pi) * inner_pro;
    T(3, 1) = exp(-2*3*t/2) * (2*2 + 1)/(4*pi) * (3*inner_pro^2 - 1)/2;
    T(4, 1) = exp(-3*4*t/2) * (2*3 + 1)/(4*pi) * (5*inner_pro^3 - 3*inner_pro)/2; % l = 3
    T(5, 1) = exp(-4*5*t/2) * (2*4 + 1)/(4*pi) * (35*inner_pro^4 - 30*inner_pro^2 + 3)/8; 
    T(6, 1) = exp(-5*6*t/2) * (2*5 + 1)/(4*pi) * (63*inner_pro^5 - 70*inner_pro^3 + 15*inner_pro)/8; 
    T(7, 1) = exp(-6*7*t/2) * (2*6 + 1)/(4*pi) * (231*inner_pro^6 - 315*inner_pro^4 + 105*inner_pro^2 - 5)/16; 
    T(8, 1) = exp(-7*8*t/2) * (2*7 + 1)/(4*pi) * (429*inner_pro^7 - 693*inner_pro^5 + 315*inner_pro^3 - 35*inner_pro)/16; 
    T(9, 1) = exp(-8*9*t/2) * (2*8 + 1)/(4*pi) * (6435*inner_pro^8 - 12012*inner_pro^6 + 6930*inner_pro^4 - 1260*inner_pro^2 + 35)/128; 
    T(10, 1) = exp(-9*10*t/2) * (2*9 + 1)/(4*pi) * (12155*inner_pro^9 - 25740*inner_pro^7 + 18018*inner_pro^5 - 4620*inner_pro^3 + 315*inner_pro)/128; 
    T(11, 1) = exp(-10*11*t/2) * (2*10 + 1)/(4*pi) * (46189*inner_pro^10 - 109395*inner_pro^8 + 90090*inner_pro^6 - 30030*inner_pro^4 + 3465*inner_pro^2 - 63)/256; 
    T(12, 1) = exp(-11*12*t/2) * (2*11 + 1)/(4*pi) * (88179 *inner_pro^11 - 230945 *inner_pro^9 + 218790 *inner_pro^7 - 90090 *inner_pro^5 + 15015 *inner_pro^3 - 693 *inner_pro)/256; 
    T(13, 1) = exp(-12*13*t/2) * (2*12 + 1)/(4*pi) * (676039*inner_pro^12 - 1939938*inner_pro^10 + 2078505*inner_pro^8 - 1021020*inner_pro^6 + 225225*inner_pro^4 - 18018*inner_pro^2 + 231)/1024; 
    T(14, 1) = exp(-13*14*t/2) * (2*13 + 1)/(4*pi) * (1300075*inner_pro^13 - 4056234*inner_pro^11 + 4849845*inner_pro^9 - 2771340*inner_pro^7 + 765765*inner_pro^5 - 90090*inner_pro^3 + 3003*inner_pro)/1024; 
    T(15, 1) = exp(-14*15*t/2) * (2*14 + 1)/(4*pi) * (5014575 *inner_pro^14 - 16900975 *inner_pro^12 + 22309287 *inner_pro^10 - 14549535 *inner_pro^8 + 4849845 *inner_pro^6 - 765765 *inner_pro^4 + 45045 *inner_pro^2 - 429)/2048; 
    T(16, 1) = exp(-15*16*t/2) * (2*15 + 1)/(4*pi) * (9694845 *inner_pro^15 - 35102025 *inner_pro^13 + 50702925 *inner_pro^11 - 37182145 *inner_pro^9 + 14549535 *inner_pro^7 - 2909907 *inner_pro^5 + 255255 *inner_pro^3 - 6435 *inner_pro)/2048; 
    T(17, 1) = exp(-16*17*t/2) * (2*16 + 1)/(4*pi) * (300540195 *inner_pro^16 - 1163381400 *inner_pro^14 + 1825305300 *inner_pro^12 - 1487285800 *inner_pro^10 + 669278610 *inner_pro^8 - 162954792 *inner_pro^6 + 19399380 *inner_pro^4 - 875160 *inner_pro^2 + 6435)/32768; 
    T(18, 1) = exp(-17*18*t/2) * (2*17 + 1)/(4*pi) * (583401555 *inner_pro^17 - 2404321560 *inner_pro^15 + 4071834900 *inner_pro^13 - 3650610600 *inner_pro^11 + 1859107250 *inner_pro^9 - 535422888 *inner_pro^7 + 81477396 *inner_pro^5 - 5542680 *inner_pro^3 + 109395 *inner_pro)/32768; 

    result = 0.0002;
    for l = 1 : 18
        result = result + T(l, 1);
    end 
end 

function [x, final] = Find_mini3(x, kappa, mu, y, eta, dim)
    %initial_stepsize = 0.05;
    initial_stepsize = 0.01;
    stepsize = initial_stepsize;
    for iter = 1 : 1600
        if iter > 100 
            stepsize = 0.005;
        end 
        ct = 0;            
        Grad = Rie_Grad2(x, kappa, mu, y, eta, dim);
        x = exp_map(x, - Grad * stepsize);
        final = f_d_eval2(x, kappa, mu, y, eta, dim);
        if isreal(final) == 0
            mean_gau = [0;0;0];
            var_gau = eye(3);
            x = mvnrnd(mean_gau, var_gau);
            x = x/norm(x);
            x = x';
        end 
        if norm(Grad) < 0.001
            break 
        end 
    end 
    x = x / norm(x);
    final = f_d_eval2(x, kappa, mu, y, eta, dim);
end

function final_pt = R_Gauss(x_k, eta, dim)
    mean_gau = zeros(dim, 1);
    var_gau = eye(dim);
    tangent = zeros(dim, dim - 1);
        for d = 1 : dim - 1
            tmp = mvnrnd(mean_gau, var_gau);
            tmp = tmp';
            u = (tmp' * x_k) / norm(x_k)^2 * x_k;
            ou = tmp - u;
            for dd = 1 : d-1 
                u = (ou' * tangent(:, dd)) / norm(tangent(:, dd))^2 * tangent(:, dd);
                ou = ou - u;
            end 
            ou = ou / norm(ou);
            tangent(:, d) = ou;
        end 
        mean_rv = zeros(dim-1, 1);
        var_rv = eye(dim-1) ;
        
        while 1
            r_v = mvnrnd(mean_rv, var_rv)';
            while norm((eta)^(0.5) * r_v) > pi 
                r_v = mvnrnd(mean_rv, var_rv)';
            end 
            u = rand();
            if u < ((sin(norm((eta)^(0.5) * r_v)))/(norm((eta)^(0.5) * r_v)))
                break 
            end 
        end 
        r = zeros(dim, 1);
        for d = 1 : dim - 1
            r = r + tangent(:, d) * r_v(d, 1);
        end 
        x_k = exp_map(x_k, (eta)^(0.5) * r);
    final_pt = x_k;
end 

