%% 1. Parameter Initialization (Ground Truth)
K_true = 3;
A_all = binary(0:(2^K_true-1), K_true);
A_true = A_all;


%% setting 1 - Double Triangular Gamma
G_true = [[1 0 0]; [1 1 0]; [1 0 1]; [1 1 1]; [1 0 1]; [0 1 0]; [0 1 1]; [0 0 0]];
[J, ~] = size(G_true);

theta_true = [
    [0.1 0.1 0.1 0.1 0.9 0.9 0.9 0.9];
    [0.05 0.05 0.4 0.4 0.7 0.7 0.95 0.95];
    [0.05 0.7 0.05 0.7 0.4 0.95 0.4 0.95];
    1-[0.1 1 2 3 4 5 6 6.9]/7;
    [-2 -0.5 -2 -0.5 2 0.5 2 0.5];          % Normal mean
    [-1.5 -1.5 1.5 1.5 -1.5 -1.5 1.5 1.5];  % Normal mean
    [0.5 -0.5 2 -2 0.5 -0.5 2 -2];          % Cauchy mean
    [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5];      % Cauchy mean
];

%% setting 2 - Non-Double Triangular Gamma
G_true = [[1 0 0]; [1 1 0]; [1 0 1]; [1 1 0]; [1 0 0]; [0 1 0]; [0 1 0]; [0 0 0]];
[J, ~] = size(G_true);

theta_true = [
    [0.1 0.1 0.1 0.1 0.9 0.9 0.9 0.9];
    [0.05 0.05 0.4 0.4 0.7 0.7 0.95 0.95];
    [0.05 0.7 0.05 0.7 0.4 0.95 0.4 0.95];
    [0.85 0.85 0.6 0.6 0.35 0.35 0.1 0.1];
    [-2 -2 -2 -2 2 2 2 2]; 
    [-1.5 -1.5 1.5 1.5 -1.5 -1.5 1.5 1.5];
    [-1 -1 2 2 -1 -1 2 2];
    [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5];
];

% Proportion parameters for the linear (chain) latent graph
p1 = 0.66; p0 = 1-p1;
type = "linear";
nu_true = [p0*p1^2, p0^2*p1, p0^3, p0^2*p1, p1^2*p0, p1*p0^2, p1^2*p0, p1^3];

%% 2. Simulation Settings
N = 5000;           % Sample size per replication; CHANGE accordingly
Nrep = 300;         % Number of replications
K_cand = [2, 3, 4]; % Candidate K values to test
num_inits = 20;     % Number of random initializations

lambda3_list = 1000;
tau2_list = 0.05;
rho2_list = 1;

selected_K = zeros(Nrep, 1);
time_est = zeros(Nrep, 1);
eps_tol = 5 * 1e-3;
MAXITERS = 10;

%% 3. Main Parallel Loop
parfor n = 1:Nrep
    rng(n);
    
    % Generate True Data
    X = generate_X_general(N, nu_true, G_true, theta_true, A_true);
    
    best_bic_overall = Inf; best_pi_overall = 1;
    best_K_rep = -1;
    
    tic;
    % Iterate over different K
    for k_idx = 1:length(K_cand)
        K_fit = K_cand(k_idx);
        M_fit = 2^K_fit;
        
        best_bic_for_K = Inf; best_pi_K = 1;
        
        % Iterate over hyperparameters
        for j = 1:length(lambda3_list)
            for k = 1:length(tau2_list)
                for l = 1:length(rho2_list)
                    lambda3 = lambda3_list(j);
                    rho2 = rho2_list(l);
                    tau2 = tau2_list(k);
                    
                    for init = 1:num_inits
                        rng(n * 10000 + k_idx * 1000 + init); 
                        
                        theta = rand(J, M_fit);
                        pi_var = ones(1, M_fit)/M_fit;
                        d = get_d(theta);
                        
                        ite = 1;
                        e = 10;
                        
                        while e > eps_tol && ite < MAXITERS
                            pi_old = pi_var;
                            [s, pi_var] = update_s_pi(X, pi_var, theta, 0);
                            [theta, d] = ADMM(X, d, theta, s, lambda3, rho2, tau2);
                            e = norm(pi_old - pi_var);
                            ite = ite + 1;
                        end          
                        
                        current_bic = BIC(X, pi_var, theta);
                        
                        % Update best fit for the current K
                        if current_bic <= best_bic_for_K
                            best_bic_for_K = current_bic;
                            best_pi_K = pi_var;
                        end
                    end
                end
            end
        end
        
        % Compare against other K values in this replication
        if best_bic_for_K < best_bic_overall
            best_bic_overall = best_bic_for_K;
            best_K_rep = K_fit;
            best_pi_overall = best_pi_K;
        end
    end
    
    time_est(n) = toc;
    selected_K(n) = best_K_rep;
    selected_pi{n} = best_pi_overall;
    
    fprintf('Rep %d finished. Selected K = %d\n', n, best_K_rep);
end

%% 4. Reporting Accuracy
accuracy = sum(selected_K == K_true) / Nrep;
for k = K_cand
    fprintf('  K = %d: %d times\n', k, sum(selected_K == k));
end

save("K_selection_accuracy_ablation_5k.mat", "selected_K", "time_est", "accuracy", "K_true");