%% implement Prior Model 1 on data simulated from warpMix model
% input: 
% main_basis: type of basis ('Fourier' or 'BSpline') for fixed effect
% basisSize_m: number of basis for fixed effect
% random_basis: type of basis ('BSpline') for random effect
% basisSize_v: number of basis for random effect
% dataidx: 1, 2 or 3, corresponds to three different fixed effect functions
% output: posterior samples of parameters of inrerest
function FMMPriorModel1warpMix(main_basis, basisSize_m, random_basis, basisSize_v, dataidx)
rng(1)
addpath("./Functions")
% load in data and pre-processing
load(sprintf("./data/warpMixData%d.csv",dataidx))
oldName = sprintf('warpMixData%d', dataidx);
newName = 'warpMixData';
if exist(oldName, 'var')
    newVars.(newName) = eval(oldName);
    clear(oldName);
end
data = newVars.("warpMixData");
N = size(data,2);
T = size(data,1);
for ii = 1:N
    qt{ii} = data(:,ii)';
end
datatype = sprintf("warpMix%d",dataidx);
t = linspace(0,1,T);

% basis for fixed effect
clear U1
switch main_basis
    case "Fourier"
        U1(1:T,1) = sqrt(3)*t;
        U1(1:T,2) = sqrt(3)*(1-t);
        for i = 1:2:basisSize_m-2
            n = ceil(i/2);
            U1(:,i+2) = sqrt(2)*cos(2*n*pi*t);
            U1(:,i+3) = sqrt(2)*sin(2*n*pi*t);
        end
    case "BSpline"
        U1 = spcol([0,0,0,linspace(0,1,basisSize_m-2),1,1,1],4, t);
        U1 = U1./(sqrt(diag(U1'*U1))*ones(1,T))';
        for i = 1:basisSize_m
            U1(:,i) = U1(:,i)/sqrt(trapz((U1(:,i)).^2)/(T-1));
        end
end

% basis for random effects
switch random_basis
    case "Fourier"
        U2(1:T,1) = sqrt(3)*t;
        U2(1:T,2) = sqrt(3)*(1-t);
        for i = 1:2:basisSize_v-2
            n = ceil(i/2);
            U2(:,i+2) = sqrt(2)*cos(2*n*pi*t);
            U2(:,i+3) = sqrt(2)*sin(2*n*pi*t);
        end
    case "BSpline"
        U2 = spcol([0,0,0,linspace(0,1,basisSize_v-2),1,1,1],4, t);
        U2 = U2./(sqrt(diag(U2'*U2))*ones(1,T))';
        for i = 1:basisSize_v
            U2(:,i) = U2(:,i)/sqrt(trapz((U2(:,i)).^2)/(T-1));
        end
end
U1 = gram_schmidt_orthonormalization(U1);
U2 = gram_schmidt_orthonormalization(U2);

% set up for MCMC
IterUpdate = 1000;
StartTuning = 10000;
% hyperparameters
muA = 1e-4; tauA = 1e4;
alphaSc = 0.01; betaSc = 0.01;
alphaS = 0.01; betaS = 0.01;
% initial proposal variance
covProposalmu = 1e-8*eye(basisSize_m);
deltaProposal = 0.05*ones(N,1);
corrmu = eye(basisSize_m);
def_by = 10;
varProposalSigma = 1e-6;
varProposalSigmac = 1e-4;
% store posterior samples
burn = 200000;
iterations = 100000 + burn;
sigma_samples = zeros(iterations,1);
sigmac_samples = zeros(iterations,1);
a_samples = zeros(iterations, basisSize_m);
alpha_samples = zeros(iterations, N);
% store acceptance rate for variance tuning
accepted_alpha = zeros(N,1);
accepted_a = 0;
accepted_sigma = 0;
accepted_sigmac = 0;
acceptance_a = zeros(iterations/IterUpdate,1);
acceptance_sigma = zeros(iterations/IterUpdate,1);
acceptance_sigmac = zeros(iterations/IterUpdate,1);
acceptance_alpha = zeros(N,iterations/IterUpdate,1);
deltaProposal_record = zeros(N,iterations/IterUpdate);
accrate_a = zeros(iterations/IterUpdate,1);
accrate_sigma = zeros(iterations/IterUpdate,1);
accrate_sigmac = zeros(iterations/IterUpdate,1);
accrate_alpha = zeros(N,iterations/IterUpdate,1);
% initial values for parameters
sigma_cur = 0.01;
a_cur = zeros(basisSize_m,1);
sigmac_cur = 0.01;
alpha_cur = zeros(N,1);
gt_cur = repelem(t',1,N);
clear U2_i
for ii = 1:N
    gt_cur_dot(:,ii) = 1 - alpha_cur(ii) + 2*alpha_cur(ii)*t;
    for bn = 1:basisSize_v
        U2_i(:,bn,ii) = qComposeGamma2(U2(:,bn),t,gt_cur(:,ii),gt_cur_dot(:,ii));
    end
end
muTilde_cur = zeros(T,N);
loglik_pr_alpha = zeros(N,1);
log_postDensity = zeros(iterations,N,1);
log_jointDensity = zeros(iterations,1);
% set seed
rng(2023);
%% MCMC algorithm
tic
for i = 1:iterations
    % MH step, coefficients for fixed effect
    % proposal covariance matrix tuning
    if((mod(i,IterUpdate) == 0) && i > StartTuning && (i < burn))
        corrtemp = corr(a_samples(i-IterUpdate+1:(i-1),:));
        corrtemp(isnan(corrtemp)) = 0;
        for rr = 1:basisSize_m
            for cc = 1:basisSize_m
                if (rr==cc)
                    corrmu(rr,cc) = 1;
                elseif (abs(corrtemp(rr,cc))<0.6)
                    corrmu(rr,cc)=0;
                else
                    corrmu(rr,cc) = sign(corrtemp(rr,cc))*(abs(corrtemp(rr,cc))-0.4);
                end
            end
        end
        acc_a_temp = accepted_a/i;
        if acc_a_temp < 0.3
            def_by = 20;
        end
        vartemp = diag(cov(a_samples(i-IterUpdate+1:(i-1),:)))/def_by;
        if sum(vartemp == 0) > 0
            var_cur = diag(covProposalmu);
            vartemp(vartemp == 0) = max(var_cur(vartemp == 0)/def_by,1e-10);
        end
        for rr = 1:basisSize_m
            for cc = 1:basisSize_m
                covProposalmu(rr,cc) = corrmu(rr,cc)*sqrt(vartemp(rr)*vartemp(cc));
            end
        end
    end
    % proposal
    a_can =  mvnrnd(a_cur,covProposalmu)';
    muTilde_can = U1*a_can;
    % evaluate acceptance probability
    log_pr_can = log_lik_normal(a_can, muA, diag(tauA*ones(basisSize_m,1)));
    log_pr_cur = log_lik_normal(a_cur, muA, diag(tauA*ones(basisSize_m,1)));
    loglik_can = 0;
    loglik_cur = 0;
    for ii = 1:N
        muTilde_temp(:,ii) = qComposeGamma2(muTilde_can, t, gt_cur(:,ii),gt_cur_dot(:,ii));
        loglik_can = loglik_can + log_lik_normal(qt{ii}', muTilde_temp(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
        loglik_cur = loglik_cur + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
    end
    loglik_pr_a = log_pr_cur;
    logaccprob_a = (loglik_can + log_pr_can)-(loglik_cur + log_pr_cur);
    % accept/reject
    if(log(rand()) < logaccprob_a)
        a_cur = a_can;
        muTilde_cur =  muTilde_temp;
        accepted_a = accepted_a + 1;
        loglik_pr_a = log_pr_can;
    end

    % MH step, random effect variance
    % proposal variance tuning
    if((mod(i,IterUpdate) == 0) && i > StartTuning && (i < burn))
        acc_sigmac_temp = (accepted_sigmac - acceptance_sigmac(updatenum))/IterUpdate;
        if(acc_sigmac_temp>.44 || acc_sigmac_temp<.34) 
            varProposalSigmac = max(min(1e10,varProposalSigmac*acc_sigmac_temp/.39 + 1e-10), 1e-10); 
        end
    end
    % proposal
    pd = makedist('Normal', 'mu', sigmac_cur, 'sigma', sqrt(varProposalSigmac));
    pd_trunc = truncate(pd, 0, Inf);
    sigmac_can = random(pd_trunc,1);
    % evaluate acceptance probability
    log_pr_can = log_lik_invgamma(sigmac_can, alphaSc,betaSc);
    log_pr_cur = log_lik_invgamma(sigmac_cur, alphaSc, betaSc);
    loglik_can = 0;
    loglik_cur = 0;
    for ii = 1:N
        loglik_can = loglik_can + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_can*U2_i(:,:,ii)*U2_i(:,:,ii)');
        loglik_cur = loglik_cur + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
    end
    loglikprop_cur_given_can = -log(normcdf(sigmac_can, 0, sqrt(varProposalSigmac)));
    loglikprop_can_given_cur = -log(normcdf(sigmac_cur, 0, sqrt(varProposalSigmac)));
    loglik_pr_sigmac = log_pr_cur;
    logaccprob_sigmac = (loglik_can + log_pr_can + loglikprop_cur_given_can)-(loglik_cur + log_pr_cur + loglikprop_can_given_cur);
    % accept/reject
    if(log(rand()) < logaccprob_sigmac)
        sigmac_cur = sigmac_can;
        accepted_sigmac = accepted_sigmac + 1;
        loglik_pr_sigmac = log_pr_can;
    end

    % MH step, variance of error process
    % proposal variance tuning
    if((mod(i,IterUpdate) == 0) && i > StartTuning && (i < burn))
        acc_sigma_temp = (accepted_sigma - acceptance_sigma(updatenum))/IterUpdate;
        if(acc_sigma_temp>.44 || acc_sigma_temp<.34)
            varProposalSigma = max(min(1e10,varProposalSigma*acc_sigma_temp/.39 + 1e-10), 1e-10); 
        end
    end
    % proposal
    pd = makedist('Normal', 'mu', sigma_cur, 'sigma', sqrt(varProposalSigma));
    pd_trunc = truncate(pd, 0, Inf);
    sigma_can = random(pd_trunc,1);
    % evaluate acceptance probability
    log_pr_can = log_lik_invgamma(sigma_can, alphaS,betaS);
    log_pr_cur = log_lik_invgamma(sigma_cur, alphaS, betaS);
    loglik_can = 0;
    loglik_cur = 0;
    for ii = 1:N
        loglik_can = loglik_can + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_can*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
        loglik_cur = loglik_cur + log_lik_normal(qt{ii}', muTilde_cur(:,ii), sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
    end
    loglikprop_cur_given_can = -log(normcdf(sigma_can, 0, sqrt(varProposalSigma)));
    loglikprop_can_given_cur = -log(normcdf(sigma_cur, 0, sqrt(varProposalSigma)));
    loglik_pr_sigma = log_pr_cur;
    logaccprob_sigma = (loglik_can + log_pr_can + loglikprop_cur_given_can)-(loglik_cur + log_pr_cur + loglikprop_can_given_cur);
    % accept/reject
    if(log(rand()) < logaccprob_sigma)
        sigma_cur = sigma_can;
        accepted_sigma = accepted_sigma + 1;
        loglik_pr_sigma = log_pr_can;
    end

    % MH step, phase function
    for ii = 1:N
        % proposal variance tuning
        if((mod(i,IterUpdate) == 0) && i > StartTuning && (i < burn))
            acc_alpha_temp = (accepted_alpha(ii) - acceptance_alpha(ii,updatenum))/IterUpdate;
            if((accepted_alpha(ii)/(i))>.34 || (accepted_alpha(ii)/(i))<.24)
                deltaProposal(ii) = max(min(0.99,deltaProposal(ii)/.29*(acc_alpha_temp + 1e-10)),1e-10); 
            end
        end
        % proposal
        if i <= StartTuning
            alpha_can = unifrnd(-1,1,1);
        else
            alpha_can = unifrnd(alpha_cur(ii)-deltaProposal(ii), alpha_cur(ii)+deltaProposal(ii),1);
        end
        if (alpha_can >=1) || (alpha_can <= -1)
            logaccprob_alpha = -Inf;
        else
            % evaluate acceptance probability
            gt_can = t + alpha_can*t.*(t-1);
            gt_can_dot = 1 - alpha_can + 2*alpha_can*t;
            muTilde_nowarp = U1*a_cur;
            for bn = 1:basisSize_v
                U2_can(:,bn) = qComposeGamma2(U2(:,bn),t,gt_can, gt_can_dot);
            end
            loglik_can = log_lik_normal(qt{ii}', ...
                qComposeGamma2(muTilde_nowarp, t, gt_can, gt_can_dot), ...
                sigma_cur*diag(gt_can_dot) + sigmac_cur*U2_can*U2_can');
            loglik_cur = log_lik_normal(qt{ii}', ...
                qComposeGamma2(muTilde_nowarp, t, gt_cur(:,ii),gt_cur_dot(:,ii))', ...
                sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)');
            logaccprob_alpha = loglik_can-loglik_cur;
        end
        loglik_pr_alpha(ii) = log(1/2);
        % accept/reject
        if(log(rand()) < logaccprob_alpha)
            alpha_cur(ii) = alpha_can;
            gt_cur(:,ii) = gt_can;
            gt_cur_dot(:,ii) = gt_can_dot;
            muTilde_cur(:,ii) = qComposeGamma2(muTilde_nowarp, t, gt_can, gt_can_dot);
            U2_i(:,:,ii) = U2_can;
            accepted_alpha(ii) = accepted_alpha(ii) + 1;
        end
    end

    % posterior density
    for ii = 1:N
            log_postDensity(i,ii) = log_lik_normal(qt{ii}', ...
            muTilde_cur(:,ii), ...
            sigma_cur*diag(gt_cur_dot(:,ii)) + sigmac_cur*U2_i(:,:,ii)*U2_i(:,:,ii)') + ...
            loglik_pr_alpha(ii);
    end
    log_jointDensity(i) = sum(log_postDensity(i,:)) + loglik_pr_a+ loglik_pr_sigma + loglik_pr_sigmac;

    % save samples
    a_samples(i,:) = a_cur;
    alpha_samples(i,:) = alpha_cur;
    sigmac_samples(i) = sigmac_cur;
    sigma_samples(i) = sigma_cur;
    % save acceptance rate
    if mod(i,IterUpdate)==0
        i
        updatenum = i/IterUpdate;
        acceptance_a(updatenum) = accepted_a;
        acceptance_sigma(updatenum) = accepted_sigma;
        acceptance_sigmac(updatenum) = accepted_sigmac;
        acceptance_alpha(:,updatenum) = accepted_alpha;
        deltaProposal_record(:,updatenum) = deltaProposal;
        if updatenum > 1
            accrate_a(updatenum) = (acceptance_a(updatenum) - acceptance_a(updatenum-1))/IterUpdate;
            accrate_sigma(updatenum) = (acceptance_sigma(updatenum) - acceptance_sigma(updatenum-1))/IterUpdate;
            accrate_sigmac(updatenum) = (acceptance_sigmac(updatenum) - acceptance_sigmac(updatenum-1))/IterUpdate;
            accrate_alpha(:,updatenum) = (acceptance_alpha(:,updatenum) - acceptance_alpha(:,updatenum-1))./IterUpdate;
            [accrate_a(updatenum),accrate_sigma(updatenum),accrate_sigmac(updatenum)]
        end
         save(sprintf('./result/Data_%s_PriorModel1MCMC_main_%s_random_%s_basisnum_%d_%d.mat', ...
           datatype, main_basis, random_basis, basisSize_m, basisSize_v)) 
    end
end

timespend = toc;
save(sprintf('./result/Data_%s_PriorModel1MCMC_main_%s_random_%s_basisnum_%d_%d.mat', ...
       datatype, main_basis, random_basis, basisSize_m, basisSize_v))
end
