
%% Mirror Langevin with Different Discretizations

clear all; close all;

%% Specify potential and parameters

% box constraint for potential
a = 0.01;
b = 1.0;

h = 1e-5;  % step size
% smoothness param of log barrier (this is the potential we aim to sample from)
beta = 1e-4;  

iter = 500;  % length of each run

% for numerically inverting and prox solver
inner_iter = 50;
stepsize = 1e-2;

repeat = 200;  % number of trials to plot

num = 10;   % number of EM interval (forward/backward)

%% Uniform draws from box

x_axis = unifrnd(-a,a,[1,200]);
y_axis = unifrnd(-b,b,[1,200]);
figure;
plot(x_axis,y_axis,'r.')
title(['\fontsize{20}Uniform draws from box'])
set(gca,'FontSize',16)

%% EM discretization, target approximate distribution 

x_save = zeros(repeat,2);

for j = 1:repeat
    
    x = zeros(iter+1,2);  % initialize

    for i = 1:iter

        temp1 = b^2-x(i,1)^2;
        temp2 = a^2-x(i,2)^2;
        grad = [2*x(i,1)/temp1; 2*x(i,2)/temp2];
        hess = [(2*temp1+4*x(i,1)^2)/temp1^2, 0; 
            0, (2*temp2+4*x(i,2)^2)/temp2^2];
        z = normrnd(0,1,[2,1]);
        next_y = beta*grad-h*beta*grad+sqrt(2*h)*sqrtm(beta*hess)*z;

        % numerically invert
        solver_x = x(i,:)'; % initialize
        for k = 1:inner_iter
            tmp1 = b^2-solver_x(1)^2;
            tmp2 = a^2-solver_x(2)^2;
            nabla_V_x = [2*solver_x(1)/tmp1; 2*solver_x(2)/tmp2];
            % gradient ascent step for finding next x
            solver_x = solver_x + stepsize*(next_y- beta*nabla_V_x);
        end
        % check norm of gradient
        %norm(next_y-beta*nabla_V_x)
        x(i+1,:) = solver_x';

    end
    
    x_save(j,:) = x(iter,:);
end

figure();
%plot(x(1000:end,1),x(1000:end,2),'r.')
plot(x_save(:,1),x_save(:,2),'r.')
ylim([-b,b])
xlim([-a,a])
title(['\fontsize{20}Discretize objective and geometry (EM)'])
set(gca,'FontSize',16)

%% Forward Discretization

x_save = zeros(repeat,2);

for j = 1:repeat
    
    x = zeros(iter+1,2);  % initialize with (0,0)

    for i = 1:iter

        temp1 = b^2-x(i,1)^2;
        temp2 = a^2-x(i,2)^2;
        grad = [2*x(i,1)/temp1; 2*x(i,2)/temp2];
        next_y = beta*grad-h*beta*grad;  % mirror descent step y_0
        
        for m = 1:num   % 10 steps of EM
            
            % initialize with last iterate
            if m == 1
               solver_x = x(i,:)'; 
            end
            
            % invert to next_y get next x for which we use to compute
            % hessian
            for k = 1:inner_iter
                tmp1 = b^2-solver_x(1)^2;
                tmp2 = a^2-solver_x(2)^2;
                nabla_V_x = [2*solver_x(1)/tmp1; 2*solver_x(2)/tmp2];
                % gradient ascent step for finding next x
                solver_x = solver_x + stepsize*(next_y- beta*nabla_V_x);
            end

            % diffusion part solved by $num$ steps of EM
            temp1 = b^2-solver_x(1)^2;
            temp2 = a^2-solver_x(2)^2;
            hess = [(2*temp1+4*solver_x(1)^2)/temp1^2, 0; 
                0, (2*temp2+4*solver_x(2)^2)/temp2^2];
            z = normrnd(0,1,[2,1]);
            next_y = next_y + sqrt(2*h/num)*sqrtm(beta*hess)*z;  
            
        end
        
        % invert one last time to get x_{k+1}
        for k = 1:inner_iter
            tmp1 = b^2-solver_x(1)^2;
            tmp2 = a^2-solver_x(2)^2;
            nabla_V_x = [2*solver_x(1)/tmp1; 2*solver_x(2)/tmp2];
            % gradient ascent step 
            solver_x = solver_x + stepsize*(next_y- beta*nabla_V_x);
        end
        
        x(i+1,:) = solver_x';

    end
    
    x_save(j,:) = x(iter,:);
end

figure();
%plot(x(1000:end,1),x(1000:end,2),'r.')
plot(x_save(:,1),x_save(:,2),'r.')
ylim([-b,b])
xlim([-a,a])
title(['\fontsize{20}Discretize objective not geometry (forward)'])
set(gca,'FontSize',16)


%% Backward Discretization

x_save = zeros(repeat,2);

for j = 1:repeat
    
    x = zeros(iter+1,2);  % initialize with (0,0)

    for i = 1:iter
        
        for m = 1:num   % 10 steps of EM
            
            % initialize with last iterate
            if m == 1
               solver_x = x(i,:)'; 
               tmp1 = b^2-solver_x(1)^2;
               tmp2 = a^2-solver_x(2)^2;
               nabla_V_x = [2*solver_x(1)/tmp1; 2*solver_x(2)/tmp2];
               next_y = beta*nabla_V_x;
            end

            % diffusion part solved by $num$ steps of EM
            temp1 = b^2-solver_x(1)^2;
            temp2 = a^2-solver_x(2)^2;
            hess = [(2*temp1+4*solver_x(1)^2)/temp1^2, 0; 
                0, (2*temp2+4*solver_x(2)^2)/temp2^2];
            z = normrnd(0,1,[2,1]);
            next_y = next_y + sqrt(2*h/num)*sqrtm(beta*hess)*z;  
            
            % invert next_y to get x for which we use to compute hessian
            for k = 1:inner_iter
                tmp1 = b^2-solver_x(1)^2;
                tmp2 = a^2-solver_x(2)^2;
                nabla_V_x = [2*solver_x(1)/tmp1; 2*solver_x(2)/tmp2];
                % gradient ascent step for finding next x
                solver_x = solver_x + stepsize*(next_y- beta*nabla_V_x);
            end
            
        end
        
        % prox step for next x (using next_y from diffusion part)
        for k = 1:inner_iter
            temp1 = b^2-solver_x(1)^2;
            temp2 = a^2-solver_x(2)^2;
            grad = [2*solver_x(1)/temp1; 2*solver_x(2)/temp2];
            solver_x = solver_x - stepsize*(h*beta*grad + beta*grad - next_y);
        end
        
        x(i+1,:) = solver_x';

    end
    
    x_save(j,:) = x(iter,:);
end

figure();
%plot(x(1000:end,1),x(1000:end,2),'r.')
plot(x_save(:,1),x_save(:,2),'r.')
ylim([-b,b])
xlim([-a,a])
title(['\fontsize{20}Discretize objective not geometry (backward)'])
set(gca,'FontSize',16)



%% Projected Langevin
% target uniform distribution directly (no bias but slow)

x_save = zeros(repeat,2);

for j = 1:repeat
    
    x = zeros(iter,2);  % initialize

    for i = 1:iter

        %temp1 = b^2-x(i,1)^2;
        %temp2 = a^2-x(i,2)^2;
        %grad = [2*x(i,1)/temp1; 2*x(i,2)/temp2];  % gradient for barrier
        grad = 0;   % gradient for indicator
        z = normrnd(0,1,[2,1]);
        next_x = x(i,:)'-h*beta*grad+sqrt(2*h)*z;
        x(i+1,:) = next_x';

        % projection 
        if next_x(1) >= a
            x(i+1,1) = a;
        elseif next_x(1) <= -a
            x(i+1,1) = -a;
        end
        
        if next_x(2) >= b
            x(i+1,2) = b;
        elseif next_x(2) <= -b
            x(i+1,2) = -b;
        end

    end
    
    x_save(j,:) = x(iter,:);
end

figure();
plot(x_save(:,1),x_save(:,2),'r.')
ylim([-b,b])
xlim([-a,a])
title(['\fontsize{20}Projected Langevin'])
set(gca,'FontSize',16)


%% Sample from ill-conditioned gaussian (Newton Langevin with EM)
% no bias but helps with change of basis

% \Sigma = diag(\lambda_max,..., \lambda_min)

dim = 50;
repeat = 200;
iter = 500;  % length of each run
x_save = zeros(repeat,dim,iter+1);   % plus 1 here
h = 1e-3;  % step size

% prox solver and inner step size
inner_iter = 500;  
stepsize = 0.5;

% potential function: f(x) = (x-mu)'V^{-1}(x-mu)/2
%V_inv = diag(unifrnd(1,dim,[dim,1]));
V_inv = diag(1./[1:dim]);
V = inv(V_inv);
mu = ones(dim,1);
%cond(V)

for j = 1:repeat
    
    x = normrnd(0,1,[iter+1,dim]);  % initialize
    %x = ones(iter+1,dim)./dim;

    for i = 1:iter

        grad = V_inv*(x(i,:)'-mu);
        hess = V_inv;
        z = normrnd(0,1,[dim,1]);
        next_y = grad-h*grad+sqrt(2*h)*sqrtm(hess)*z;

        solver_x = x(i,:)'; % initialize
        
        % gradient ascent step for finding next x
        for k = 1:inner_iter
            nabla_V_x = V_inv*(solver_x-mu);            
            solver_x = solver_x + stepsize*(next_y- nabla_V_x);
        end
        % check norm of gradient
        %norm(next_y-nabla_V_x)
        x(i+1,:) = solver_x';

    end
    
    x_save(j,:,:) = x';  
end


% calculate mean differene and relative covariance error, averaged across
% $repeat$ number of trials
mean_err_newton = zeros(iter,1);
covariance_err_newton = zeros(iter,1);

for i = 1:iter
    estimate_mean = mean(x_save(:,:,i));  % size repeat * dim
    mean_err_newton(i) = norm(estimate_mean-mu');  
    covariance_err_newton(i) = norm((x_save(:,:,i)-estimate_mean)'*(x_save(:,:,i)-estimate_mean)/repeat - V)/norm(V);
end


%% Same gaussian setup, with mirror map = I (ULA)

h = 1e-3;  % step size
x_save = zeros(repeat,dim,iter+1);   % plus 1 here

for j = 1:repeat
    
    x = normrnd(0,1,[iter+1,dim]);  % initialize
    %x = ones(iter+1,dim)./dim;

    for i = 1:iter

        grad = V_inv*(x(i,:)'-mu);
        z = normrnd(0,1,[dim,1]);
        next_y = x(i,:)'-h*grad+sqrt(2*h)*z;
        
        x(i+1,:) = next_y';

    end
    
    x_save(j,:,:) = x';  
end


% calculate mean differene and relative covariance error
mean_err = zeros(iter,1);
covariance_err = zeros(iter,1);

for i = 1:iter
    estimate_mean = mean(x_save(:,:,i));  % size repeat * dim
    mean_err(i) = norm(estimate_mean-mu');  
    covariance_err(i) = norm((x_save(:,:,i)-estimate_mean)'*(x_save(:,:,i)-estimate_mean)/repeat - V)/norm(V);
end

% plot comparison
figure();
semilogy(mean_err)
hold on
semilogy(mean_err_newton)
legend('ULA','NLA','Location','southwest')
title(['\fontsize{20}Mean error over iterations'])
set(gca,'FontSize',16)

figure();
semilogy(covariance_err)
hold on
semilogy(covariance_err_newton)
legend('ULA','NLA','Location','southwest')
title(['\fontsize{20}Covariance error over iterations'])
set(gca,'FontSize',16)


