import math
import numpy as np

from scipy.spatial.distance import squareform
from partitions._nn_chain_improved import nn_chain_improved


def improved_modified_hierarchical_clustering(corrs, beta=0.1):
    p = corrs.shape[0]
    dist_vec = 1 - np.abs(squareform(corrs, checks=False))

    max_size = math.floor(p * beta)
    Z = nn_chain_improved(dist_vec, p, max_size)

    clusters = [[i] for i in range(p)]
    for i in range(p-1):
        x, y, _, size = Z[i]
        if size == 0: continue
        clusters[int(y)].extend(clusters[int(x)])
        clusters[int(x)].clear()
        assert len(clusters[int(y)]) == size
        
    clusters = [cluster for cluster in clusters if len(cluster) != 0]
    clusters.sort(key=lambda x:len(x))

    return clusters
