close all
clear all
addpath ./toolbox

%graph = 'star'
%graph = 'doubleStar'
%graph = 'hmm'
%graph = '5cayley'
graph = 'doublebinary'
dist = 'gaussian';

m = 128;  % Number of observed variables
nset = [1, 2, 10, 20, 100, 200]*1000; % Number of samples to test
num_runs = 1; % Change this to 200 to get reliable performance comparisons

save_result = false;
draw_figure = false;
useDistances = true;

% Generate a latent tree
adjmat = makeModel(graph, m);
if(strcmp(graph,'5cayley')) % Observe the root node of the 5-cayley tree
    adjmat = [[0; adjmat(1:end-1,end)], [adjmat(end,1:end-1); adjmat(1:end-1,1:end-1)]]; 
    m = m+1;
end
topo_distance_org = treeDistance(adjmat);
topo_distance_org = topo_distance_org(1:m,1:m);
tree_partition_org = treePartition(adjmat,0,m);
ind = logical(tree_partition_org(:,m));
tree_partition_org(ind,:) = ~tree_partition_org(ind,:);

%parameter setting
d=3;
[p,~]=size(adjmat);
[root,~]=size(adjmat);
rho=0.5;
temp=randn(d,d);
Sigma_r=temp*temp.'+5*eye(d);
[V,D]=eig(Sigma_r);
LambdaA=rho*eye(3);
Lambdan=D-LambdaA*D*LambdaA;
A=V*LambdaA*V.';
Sigma_n=V*Lambdan*V.';
delta_g=exp(0.5*log(det(A*A.')));


num_samples = length(nset);
num_algorithms = 3;
num_recovery = zeros(num_samples,num_algorithms);
num_hidden = zeros(num_samples,num_algorithms);
computation_time = zeros(num_samples,num_algorithms);
RFdistance = zeros(num_samples,num_algorithms);
adjmatT = cell(1,num_algorithms);
edgeD = cell(1,num_algorithms);
    
for np=1:num_samples    
    n = nset(np);
    fprintf('Running for samples %dK\n',n/1000);   
   
    for r = 1:num_runs
        X=samplegeneration(adjmat,root,A,Sigma_r,Sigma_n,m,n);
        sample_distance=distancecomputing(X,m,d,0);

        tic;
        [adjmatT{1},edgeD{1}] = RG(sample_distance, useDistances, n);
        t(1) = toc; tic;
        adjmatT{2} = SNJ(sample_distance, useDistances,n);
        t(2) = toc; tic;
        [adjmatT{3},edgeD{3}] = CLRG(sample_distance, useDistances, n);
        t(3) = toc; %tic;
        %[adjmatT{4},edgeD{4}] = CLNJ(sample_distance, useDistances);
        %t(4) = toc;

        computation_time(np,:) = computation_time(np,:) + t;
        for a=1:num_algorithms
            [is_exact, pes] = isExactRecovery(adjmatT{a}, topo_distance_org);
            tree_partition = treePartition(adjmatT{a},0,m);
            ind = logical(tree_partition(:,m));
            tree_partition(ind,:) = ~tree_partition(ind,:);
            RF_distance1 = sum(~ismember(tree_partition_org, tree_partition, 'rows'));
            RF_distance2 = sum(~ismember(tree_partition, tree_partition_org, 'rows'));
            RFdistance(np,a) = RFdistance(np,a) + RF_distance1+RF_distance2;
            if(is_exact)
                num_recovery(np,a) = num_recovery(np,a)+1;            
            end
            num_hidden(np,a) = num_hidden(np,a) + size(adjmatT{a},1) - m;
            
        end
    end
    print_format1 = '%d %d %d %d\n';
    print_format2 = '%f %f %f %f\n';
    fprintf(['# recovery out of %d: ', print_format1], r, num_recovery(np,:));
    fprintf(['# hidden vars: ', print_format2], num_hidden(np,:)/num_runs);
    fprintf(['Computation time: ', print_format2], computation_time(np,:)/num_runs);
    fprintf(['Robinson-Foulds distance: ', print_format2, '\n'],RFdistance(np,:)/num_runs);    
end
num_recovery = num_recovery/num_runs;
RFdistance = RFdistance/num_runs;
num_hidden = num_hidden/num_runs;
computation_time = computation_time/num_runs;

print_format = '%8.4f %8.4f %8.4f %8.4f\n';
fprintf(['# recovery: ', print_format],sum(num_recovery,1)/num_samples);
fprintf(print_format,num_recovery');
fprintf(['Robinson-Foulds distance:' ,print_format],sum(RFdistance,1)/num_samples); 
fprintf(print_format,RFdistance');
fprintf(['Number of hidden nodes:' ,print_format],sum(num_hidden,1)/num_samples); 
fprintf(print_format,num_hidden');
fprintf(['Computation time: ' print_format], sum(computation_time,1)/num_samples);
fprintf(print_format,computation_time');

if(save_result)
    file_name = ['./results/' graph '_' num2str(m)]
    save(file_name);
end
