%% Simulate data from Prior Model 2
% mandatory input:
% main_basis: type of basis for fixed effect ('Fourier' or 'BSpline')
% basisSize_m: number of basis for fixed effect
% random_basis: type of basis for fixed effect ('BSpline')
% basisSize_v: number of basis for random effect
% optional input:
% nG: number of time points for phase function
% theta_g: precision parameter of Dirichlet distribution for prior of phase
% function (default: 30)
% sigma: variance for error process (default: 0.1)
% T: number of time points for discretized functions (default: 50)
% N: number of functions (default: 30)
function SimCurvefromPriorModel2(main_basis,basisSize_m, ...
    random_basis,basisSize_v, nG, theta_g, sigma, T, N)
rng(1)
addpath("./Functions")
% set up parameters
if nargin < 9
    N = 30;
end
if nargin < 8
    T = 50;
end
if nargin < 7
    sigma = 0.1;
end
if nargin < 6
    theta_g = 30;
end
if nargin < 5
    nG = 5;
end
sigmac = 0.25;
t = linspace(0,1,T);
tG = linspace(0,1,nG)';
mat = (diag(ones(nG-1,1),1)+diag(-ones(nG,1)));
mat = mat(1:(end-1),:);

% fixed effect coefficient
a = normrnd(0,1,basisSize_m,1);

% phase function
for i = 1:N
    increment_true(:,i) = gamrnd(mat*tG*theta_g,1);
    increment_true(:,i) = increment_true(:,i)./sum(increment_true(:,i));
    gam_true_tG(:,i) = [0; cumsum(increment_true(:,i))];
    gam_true(:,i) = interp1(tG,gam_true_tG(:,i), t, "linear");
    gam_true(end,i) = 1;
    gam_dot(:,i) = gradient(gam_true(:,i), 1/(T-1));
end

% fixed effect basis
switch main_basis
    case "Fourier"
        % main and random effect
        U1(1:T,1) = sqrt(3)*t;
        U1(1:T,2) = sqrt(3)*(1-t);
        for i = 1:2:basisSize_m-2
            n = ceil(i/2);
            U1(:,i+2) = sqrt(2)*cos(2*n*pi*t);
            U1(:,i+3) = sqrt(2)*sin(2*n*pi*t);
        end
    case "BSpline"
        U1 = spcol([0,0,0,linspace(0,1,basisSize_m-2),1,1,1],4, t);
        U1 = U1./(sqrt(diag(U1'*U1))*ones(1,T))';
        for i = 1:basisSize_m
            U1(:,i) = U1(:,i)/sqrt(trapz((U1(:,i)).^2)/(T-1));
        end
end

% random effect basis
switch random_basis
    case "Fourier"
        U2(1:T,1) = sqrt(3)*t;
        U2(1:T,2) = sqrt(3)*(1-t);
        for i = 1:2:basisSize_v-2
            n = ceil(i/2);
            U2(:,i+2) = sqrt(2)*cos(2*n*pi*t);
            U2(:,i+3) = sqrt(2)*sin(2*n*pi*t);
        end
    case "BSpline"
        U2 = spcol([0,0,0,linspace(0,1,basisSize_v-2),1,1,1],4, t);
        U2 = U2./(sqrt(diag(U2'*U2))*ones(1,T))';
        for i = 1:basisSize_v
            U2(:,i) = U2(:,i)/sqrt(trapz((U2(:,i)).^2)/(T-1));
        end
end

U1 = gram_schmidt_orthonormalization(U1);
U2 = gram_schmidt_orthonormalization(U2);

% generate data
for i = 1:N
    clear U1_i U2_i
    for j = 1:basisSize_m
        U1_i(:,j) = qComposeGamma(U1(:,j),t,gam_true(:,i));
    end
    for j = 1:basisSize_v
        U2_i(:,j) = qComposeGamma(U2(:,j),t,gam_true(:,i));
    end
    var_q = sigma*diag(gam_dot(:,i)) + sigmac*U2_i*U2_i';
    mean_q = (U1_i*a)';
    qt{i} = mvnrnd(mean_q,var_q);
end

save(sprintf("./data/SimCurvePriorModel2_sigma_%.1f_main_%s_random_%s_basisSize_m_%d_basisSize_v_%d_T_%d_nG_%d_thetag_%d.mat", ...
    sigma,main_basis, random_basis, basisSize_m, basisSize_v, T, nG,theta_g));
end