%% implement Prior Model 2 on data simulated from warpMix model
% input: 
% main_basis: type of basis ('Fourier' or 'BSpline') for fixed effect
% basisSize_m: number of basis for fixed effect
% random_basis: type of basis ('BSpline') for random effect
% basisSize_v: number of basis for random effect
% nG: number of time point for modeling phase function in Prior Model 2
% dataidx: 1, 2 or 3, corresponds to three different fixed effect functions
% output: posterior samples of parameters of inrerest
function FMMPriorModel2warpMix(main_basis,basisSize_m, random_basis, basisSize_v, nG, dataidx)
rng(1)
addpath("./Functions")
% load in data
load(sprintf("./data/warpMixData%d.csv",dataidx))
oldName = sprintf('warpMixData%d', dataidx);
newName = 'warpMixData';
if exist(oldName, 'var')
    newVars.(newName) = eval(oldName);
    clear(oldName);
end
data = newVars.("warpMixData");
N = size(data,2);
T = size(data,1);
for ii = 1:N
    qt{ii} = data(:,ii)';
end
datatype = sprintf("warpMix%d",dataidx);

t = linspace(0,1,T);
tG = linspace(0,1,nG)';
mat = (diag(ones(nG-1,1),1)+diag(-ones(nG,1)));
mat = mat(1:(end-1),:);

% basis for fixed effect
clear U1
switch main_basis
    case "Fourier"
        U1(1:T,1) = sqrt(3)*t;
        U1(1:T,2) = sqrt(3)*(1-t);
        for i = 1:2:basisSize_m-2
            n = ceil(i/2);
            U1(:,i+2) = sqrt(2)*cos(2*n*pi*t);
            U1(:,i+3) = sqrt(2)*sin(2*n*pi*t);
        end
    case "BSpline"
        U1 = spcol([0,0,0,linspace(0,1,basisSize_m-2),1,1,1],4, t);
        U1 = U1./(sqrt(diag(U1'*U1))*ones(1,T))';
        for i = 1:basisSize_m
            U1(:,i) = U1(:,i)/sqrt(trapz((U1(:,i)).^2)/(T-1));
        end
end

% basis for random effects
switch random_basis
    case "Fourier"
        U2(1:T,1) = sqrt(3)*t;
        U2(1:T,2) = sqrt(3)*(1-t);
        for i = 1:2:basisSize_v-2
            n = ceil(i/2);
            U2(:,i+2) = sqrt(2)*cos(2*n*pi*t);
            U2(:,i+3) = sqrt(2)*sin(2*n*pi*t);
        end
    case "BSpline"
        U2 = spcol([0,0,0,linspace(0,1,basisSize_v-2),1,1,1],4, t);
        U2 = U2./(sqrt(diag(U2'*U2))*ones(1,T))';
        for i = 1:basisSize_v
            U2(:,i) = U2(:,i)/sqrt(trapz((U2(:,i)).^2)/(T-1));
        end
end
U1 = gram_schmidt_orthonormalization(U1);
U2 = gram_schmidt_orthonormalization(U2);

% set up for MCMC
IterUpdate = 1000;
StartTuning = 10000;
% hyperparameters
muA = 1e-4; tauA = 1e4;
alphaSc = 0.01; betaSc = 0.01;
alphaS = 0.01; betaS = 0.01; 
thetaG = 30;
% initial proposal variance
covProposalmu = 1e-8*eye(basisSize_m);
thetaProposal = 100*ones(N,1);
corrmu = eye(basisSize_m);
def_by = 10;
varProposalSigma = 1e-6;
varProposalSigmac = 1e-4;
% store posterior samples
burn = 200000;
iterations = 100000 + burn;
sigma_samples = zeros(iterations,1); 
sigmac_samples = zeros(iterations,1); 
a_samples = zeros(iterations, basisSize_m);
gamma_samples = zeros(iterations, nG, N);
% store acceptance rate for proposal variance tuning
accepted_gamma = zeros(N,1);
accepted_a = 0;
accepted_sigma = 0;
accepted_sigmac = 0;
acceptance_a = zeros(iterations/IterUpdate,1);
acceptance_sigma = zeros(iterations/IterUpdate,1);
acceptance_sigmac = zeros(iterations/IterUpdate,1);
acceptance_gamma = zeros(N,iterations/IterUpdate,1);
covProposalmu_record = zeros(iterations/IterUpdate,basisSize_m,basisSize_m);
varProposalSigma_record = zeros(iterations/IterUpdate,1);
varProposalSigmac_record = zeros(iterations/IterUpdate,1);
thetaProposal_record = zeros(iterations/IterUpdate,N);
accrate_a = zeros(iterations/IterUpdate,1);
accrate_sigma = zeros(iterations/IterUpdate,1);
accrate_sigmac = zeros(iterations/IterUpdate,1);
accrate_gamma = zeros(N,iterations/IterUpdate,1);
% initial values for parameters
sigma_cur = 0.01;
a_cur = zeros(basisSize_m,1);
sigmac_cur = 0.01;
gamma_cur = tG*ones(N,1)';
gamma_cur_dot = zeros(nG,N);
gt_cur_dot = zeros(T,N);
gt_cur = repelem(t',1,N);
clear U2_i
for ii = 1:N
    gamma_cur_dot(1:nG-1,ii) = (mat*gamma_cur(:,ii))./(mat*tG);
    gamma_cur_dot(nG,ii) = gamma_cur_dot(nG-1,ii);
    gt_cur_dot(:,ii) = interp1(tG,gamma_cur_dot(:,ii),t,'previous');
    for bn = 1:basisSize_v
        U2_i(:,bn,ii) = qComposeGamma2(U2(:,bn),t,gt_cur(:,ii),gt_cur_dot(:,ii));
    end
end
muTilde_cur = zeros(T,N);
loglik_pr_gam = zeros(N,1);
log_postDensity = zeros(iterations,N,1);
log_jointDensity = zeros(iterations,1);
% set seed
rng(2023);
%% MCMC algorithm
tic
for i = 1:iterations
    % MH step, coefficients for fixed effect
    % proposal covariance matrix tuning
    if((mod(i,IterUpdate) == 0) && i >= StartTuning && (i < burn))
        corrtemp = corr(a_samples(i-IterUpdate+1:(i-1),:));
        corrtemp(isnan(corrtemp)) = 0;
        for rr = 1:basisSize_m
            for cc = 1:basisSize_m
                if (rr==cc)
                    corrmu(rr,cc) = 1;
                elseif (abs(corrtemp(rr,cc))<0.6)
                    corrmu(rr,cc)=0;
                else
                    corrmu(rr,cc) = sign(corrtemp(rr,cc))*(abs(corrtemp(rr,cc))-0.4);
                end
            end
        end
        acc_a_temp = accepted_a/i;
        if acc_a_temp < 0.3
            def_by = 20;
        end
        vartemp = diag(cov(a_samples(i-IterUpdate+1:(i-1),:)))/def_by;
        if sum(vartemp == 0) > 0
            var_cur = diag(covProposalmu);
            vartemp(vartemp == 0) = max(var_cur(vartemp == 0)/def_by,1e-10);
        end
        for rr = 1:basisSize_m
            for cc = 1:basisSize_m
                covProposalmu(rr,cc) = corrmu(rr,cc)*sqrt(vartemp(rr)*vartemp(cc));
            end
        end
    end
    % proposal
    a_can =  mvnrnd(a_cur,covProposalmu)';
    muTilde_can = U1*a_can;
    % evaluate acceptance probability
    log_pr_can = log_lik_normal(a_can, muA, diag(tauA*ones(basisSize_m,1)));
    log_pr_cur = log_lik_normal(a_cur, muA, diag(tauA*ones(basisSize_m,1)));
    loglik_can = 0;
    loglik_cur = 0;
    for ii = 1:N
       muTilde_temp(:,ii) = qComposeGamma2(muTilde_can, t, gt_cur(:,ii),gt_cur_dot(:,ii));
        loglik_can = loglik_can + log_lik_normal(qt{ii}', muTilde_temp(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
        loglik_cur = loglik_cur + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
    end
    loglik_pr_a = log_pr_cur;
    logaccprob_a = (loglik_can + log_pr_can)-(loglik_cur + log_pr_cur);
    % accept/reject
    if(log(rand()) < logaccprob_a)
        a_cur = a_can;
        muTilde_cur =  muTilde_temp;
        accepted_a = accepted_a + 1;
        loglik_pr_a = log_pr_can;
    end

    % MH step, random effect variance
    % proposal variance tuning
    if((mod(i,IterUpdate) == 0) && i >= StartTuning && (i < burn))
        acc_sigmac_temp = (accepted_sigmac - acceptance_sigmac(updatenum))/IterUpdate;
        if(acc_sigmac_temp>.44 || acc_sigmac_temp<.34)
            varProposalSigmac = max(min(1e10,varProposalSigmac*acc_sigmac_temp/.39 + 1e-10), 1e-10); 
        end
    end
    % proposal
    pd = makedist('Normal', 'mu', sigmac_cur, 'sigma', sqrt(varProposalSigmac));
    pd_trunc = truncate(pd, 0, Inf);
    sigmac_can = random(pd_trunc,1);
    % evaluate acceptance probability
    log_pr_can = log_lik_invgamma(sigmac_can, alphaSc,betaSc);
    log_pr_cur = log_lik_invgamma(sigmac_cur, alphaSc, betaSc);
    loglik_can = 0;
    loglik_cur = 0;
    for ii = 1:N
        loglik_can = loglik_can + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_can*U2_i(:,:,ii)*U2_i(:,:,ii)');
        loglik_cur = loglik_cur + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
    end
    loglikprop_cur_given_can = -log(normcdf(sigmac_can, 0, sqrt(varProposalSigmac)));
    loglikprop_can_given_cur = -log(normcdf(sigmac_cur, 0, sqrt(varProposalSigmac)));
    loglik_pr_sigmac = log_pr_cur;
    logaccprob_sigmac = (loglik_can + log_pr_can + loglikprop_cur_given_can)-(loglik_cur + log_pr_cur + loglikprop_can_given_cur);
    % accept/reject
    if(log(rand()) < logaccprob_sigmac)
        sigmac_cur = sigmac_can;
        accepted_sigmac = accepted_sigmac + 1;
        loglik_pr_sigmac = log_pr_can;
    end

    % MH step, variance of error process
    % proposal variance tuning
    if((mod(i,IterUpdate) == 0) && i >= StartTuning && (i < burn))
        acc_sigma_temp = (accepted_sigma - acceptance_sigma(updatenum))/IterUpdate;
        if(acc_sigma_temp>.44 || acc_sigma_temp<.34)
            varProposalSigma = max(min(1e10,varProposalSigma*acc_sigma_temp/.39 + 1e-10), 1e-10); 
        end
    end
    % proposal
    pd = makedist('Normal', 'mu', sigma_cur, 'sigma', sqrt(varProposalSigma));
    pd_trunc = truncate(pd, 0, Inf);
    sigma_can = random(pd_trunc,1);
    % evaluate acceptance probability
    log_pr_can = log_lik_invgamma(sigma_can, alphaS,betaS);
    log_pr_cur = log_lik_invgamma(sigma_cur, alphaS, betaS);
    loglik_can = 0;
    loglik_cur = 0;
    for ii = 1:N
        loglik_can = loglik_can + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_can*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
        loglik_cur = loglik_cur + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
    end
    loglikprop_cur_given_can = -log(normcdf(sigma_can, 0, sqrt(varProposalSigma)));
    loglikprop_can_given_cur = -log(normcdf(sigma_cur, 0, sqrt(varProposalSigma)));
    loglik_pr_sigma = log_pr_cur;
    logaccprob_sigma = (loglik_can + log_pr_can + loglikprop_cur_given_can)-(loglik_cur + log_pr_cur + loglikprop_can_given_cur);
    % accept/reject
    if(log(rand()) < logaccprob_sigma)
        sigma_cur = sigma_can;
        accepted_sigma = accepted_sigma + 1;
        loglik_pr_sigma = log_pr_can;
    end

    % MH step, phase function
    for ii = 1:N
        % proposal precision tuning
        if((mod(i,IterUpdate) == 0) && i >= StartTuning && (i < burn))
            acc_gamma_temp = (accepted_gamma(ii) - acceptance_gamma(ii,updatenum))/IterUpdate;
            if(acc_gamma_temp>.25 || acc_gamma_temp<.15)
                thetaProposal(ii) = min(max(1,thetaProposal(ii)*.2/(acc_gamma_temp + 1e-10)),100000); 
            end
        end
        % proposal
        increments_cur = mat*gamma_cur(:,ii);
        increments_tilde = gamrnd(mat*tG*thetaProposal(ii),1);
        increments_tilde = increments_tilde./sum(increments_tilde);
        gamma_tilde = [0;cumsum(increments_tilde)];
        gamma_can = interp1(tG,gamma_cur(:,ii),gamma_tilde,'linear');
        gamma_can(end) = 1;
        increments_can = mat*gamma_can;
        % evaluate prior
        log_pr_can = sum(log(increments_can + eps).*(thetaG*mat*tG-1));
        log_pr_cur = sum(log(increments_cur + eps).*(thetaG*mat*tG-1));
        % evaluate likelihood
        gt_can = interp1(tG,gamma_can,t,'linear')';
        gt_can(end) = 1;
        gamma_can_dot = increments_can./(mat*tG);
        gamma_can_dot(nG) = gamma_can_dot(nG-1);
        gt_can_dot = interp1(tG,gamma_can_dot,t,'previous');
        muTilde_nowarp = U1*a_cur;
        for bn = 1:basisSize_v
            U2_can(:,bn) = qComposeGamma2(U2(:,bn),t,gt_can, gt_can_dot');
        end
        loglik_can = log_lik_normal(qt{ii}', ...
            qComposeGamma2(muTilde_nowarp, t, gt_can, gt_can_dot')', ...
            sigma_cur*diag(gt_can_dot) + sigmac_cur*U2_can*U2_can');
        loglik_cur = log_lik_normal(qt{ii}', ...
            qComposeGamma2(muTilde_nowarp, t, gt_cur(:,ii), gt_cur_dot(:,ii))', ...
            sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
        tilde_inverse = interp1(gamma_tilde,tG,tG,'linear'); 
        % evaluate proposal densities
        tilde_inverse(end) = 1;
        increments_tilde_inverse = mat*tilde_inverse;
        loglikprop_g_cur_given_can = sum(log(increments_tilde_inverse + eps).*(thetaProposal(ii)*mat*tG-1));
        loglikprop_g_can_given_cur = sum(log(increments_tilde + eps).*(thetaProposal(ii)*mat*tG-1));
        % evaluate Jacobian
        gamma_cur_inv = interp1(gamma_cur(:,ii),tG,tG);
        gamma_can_inv = interp1(gamma_can, tG,tG);
        gamma_cur_inv_dot = mat*gamma_cur_inv./(mat*tG);
        gamma_can_inv_dot = mat*gamma_can_inv./(mat*tG);
        gamma_cur_inv_dot(nG) = gamma_cur_inv_dot(nG-1);
        gamma_can_inv_dot(nG) = gamma_can_inv_dot(nG-1);
        grad_gamma_cur_inv = interp1(tG,gamma_cur_inv_dot,t,'previous');
        grad_gamma_can_inv = interp1(tG,gamma_can_inv_dot,t,'previous');
        l_J = sum(log(interp1(t,grad_gamma_can_inv,gamma_cur(2:(end-1),ii))))-sum(log(interp1(t,grad_gamma_cur_inv,gamma_can(2:(end-1)))));
        % evaluate acceptance probability
        logaccprob_gamma = (loglik_can+log_pr_can+loglikprop_g_cur_given_can)-(loglik_cur+log_pr_cur+loglikprop_g_can_given_cur) + l_J;
        loglik_pr_gam(ii) = log_pr_cur;
        % accept/reject
        if(log(rand()) < logaccprob_gamma)
            gamma_cur(:,ii) = gamma_can;
            gt_cur(:,ii) = gt_can;
            gt_cur_dot(:,ii) = gt_can_dot;
            muTilde_cur(:,ii) = qComposeGamma2(muTilde_nowarp, t, gt_can,gt_can_dot');
            U2_i(:,:,ii) = U2_can;
            accepted_gamma(ii) = accepted_gamma(ii) + 1;
            loglik_pr_gam(ii)= log_pr_can;
        end
    end
    
    % posterior density
    for ii = 1:N
            log_postDensity(i,ii) = log_lik_normal(qt{ii}', ...
            muTilde_cur(:,ii), ...
            sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)') + ...
            loglik_pr_gam(ii);
    end
    log_jointDensity(i) = sum(log_postDensity(i,:)) + loglik_pr_a+ loglik_pr_sigma + loglik_pr_sigmac;

    % save samples
    a_samples(i,:) = a_cur;
    gamma_samples(i,:,:) = gamma_cur;
    sigmac_samples(i) = sigmac_cur;
    sigma_samples(i) = sigma_cur;
    % save acceptance rate
    if mod(i,IterUpdate)==0
        i
        updatenum = i/IterUpdate;
        acceptance_a(updatenum) = accepted_a;
        acceptance_sigma(updatenum) = accepted_sigma;
        acceptance_sigmac(updatenum) = accepted_sigmac;
        acceptance_gamma(:,updatenum) = accepted_gamma;
        covProposalmu_record(updatenum,:,:) = covProposalmu;
        varProposalSigmac_record(updatenum) = varProposalSigmac;
        varProposalSigma_record(updatenum) = varProposalSigma;
        thetaProposal_record(updatenum,:) = thetaProposal;
        if updatenum > 1
            accrate_a(updatenum) = (acceptance_a(updatenum) - acceptance_a(updatenum-1))/IterUpdate;
            accrate_sigma(updatenum) = (acceptance_sigma(updatenum) - acceptance_sigma(updatenum-1))/IterUpdate;
            accrate_sigmac(updatenum) = (acceptance_sigmac(updatenum) - acceptance_sigmac(updatenum-1))/IterUpdate;
            accrate_gamma(:,updatenum) = (acceptance_gamma(:,updatenum) - acceptance_gamma(:,updatenum-1))./IterUpdate;
            [accrate_a(updatenum),accrate_sigma(updatenum),accrate_sigmac(updatenum)]
        end
        save(sprintf('./result/Data_%s_PriorModel2MCMC_main_%s_random_%s_basisnum_%d_%d_nG_%d.mat', ...
           datatype, main_basis, random_basis, basisSize_m, basisSize_v,nG),'-v7.3')
    end
end

timespend = toc;
save(sprintf('./result/Data_%s_PriorModel2MCMC_main_%s_random_%s_basisnum_%d_%d_nG_%d.mat', ...
           datatype, main_basis, random_basis, basisSize_m, basisSize_v,nG),'-v7.3')
end
