close all
clear all

%graph = 'star'
%graph = 'doubleStar'
%graph = 'hmm'
%graph = 'regular'
graph = '3cayley'
dist = 'gaussian';

m = 20;
n = 100000;
adjmat = makeModel(graph, m);
adjmat = [[0; adjmat(1:end-1,end)], [adjmat(end,1:end-1); adjmat(1:end-1,1:end-1)]]; 
m = m+1;
[distance_org,R] = sampleTreeParameters(adjmat,dist,m);
[sample_distance] = sampleFromTree(R, dist, n);
distance = sample_distance;
distance2 = distance_org(1:m,1:m);

%%

edgeD_min = -log(0.1);
%relD_thres = 2*edgeD_min;  % For reliable statistics, ignore distances below this threshold
relD_thres = 0.5*log(n);
diff_log_ratio = inf*ones(m);

for i=1:m
    for j=i+1:m
        if(distance(i,j) > 2*edgeD_min)
            diff_log_ratio(i,j) = 10;
            continue;
        end
        if(m > 5)
            other_nodes = (distance(i,:) < relD_thres) & (distance(j,:) < relD_thres);
            dt = relD_thres;
            while(sum(other_nodes) <= 5)  % Need at least 2 other nodes to identify siblings
                dt = dt + log(2);
                other_nodes = (distance(i,:) < dt) & (distance(j,:) < dt);
            end
        else
            other_nodes = true(1,m);
        end
        other_nodes([i,j]) = false;
        log_ratio = distance(i,other_nodes) - distance(j,other_nodes);
        diff_log_ratio(i,j) = max(log_ratio) - min(log_ratio);
        %diff_log_ratio(i,j) = std(log_ratio);
    end
end



D = min(diff_log_ratio,diff_log_ratio');
minD = min(D);
[foo,sort_ind_minD] = sort(minD,'descend');
for i=1:m
    D(i,i) = 0;
end
%%
for k = 1%2:m
    if(k==1) 
        silh = max(minD)/max(D(:));
        fprintf('k = %d, mean silhouette = %f\n',k,silh);
        continue;
    end
    for init_ite=1:5
        if (init_ite==1)
            centers = sort_ind_minD([1:k-1,end])';            
        elseif(init_ite==2)
            centers = sort_ind_minD(1:k)';
        else
            randpermm = randperm(m);
            centers = randpermm(1:k)';
        end
        
        prev_centers = centers;

        for ite=1:5
            clusters = mat2cell(centers,ones(k,1));
            for i=1:m
                if(ismember(i,centers))
                    continue;
                end
                [foo, assignC] = min(D(i,centers));
                clusters{assignC}(end+1) = i;
            end

            for c=1:k
                minmaxD = inf;
                for j=1:length(clusters{c})
                    i = clusters{c}(j);
                    maxD = max(D(i,clusters{c}));
                    if(maxD < minmaxD)
                        minmaxD = maxD;
                        center = i;
                    end
                end
                centers(c) = center;
            end
           if(isempty(setdiff(centers,prev_centers)))
                break;
            else
                prev_centers = centers;
            end
        end
        sumDcluster = zeros(m,k);
        for c=1:k
            sumDcluster(:,c) = mean(D(:,clusters{c}),2);
            disp(clusters{c})
        end
        silh = zeros(m,1);
        for c=1:k
            numMembers = length(clusters{c});
            otherClusterMembers = true(1,m);
            otherClusterMembers(clusters{c}) = false;
            for j=1:numMembers;
                i = clusters{c}(j);
                if(numMembers > 1)
                    a = sumDcluster(i,c)*numMembers/(numMembers-1);
                    %b = min(sumDcluster(i,[1:c-1,c+1:end]));
                    b = min(D(i,otherClusterMembers));
                    silh(i) = (b-a)/max(a,b);
                else
                    silh(i) = 0;
               end
            end
        end
        mean_silh = mean(silh(silh~=0));
        fprintf('k = %d, mean silhouette = %f\n',k,mean_silh);
    end
    
end