# -*- coding: utf-8 -*-
"""


@author: Anonymous Author
"""

import numpy as np

def euclidean_distances(x, y, squared=True):
    """Compute pairwise (squared) Euclidean distances.
    """
    assert isinstance(x, np.ndarray) and x.ndim == 2
    assert isinstance(y, np.ndarray) and y.ndim == 2
    assert x.shape[1] == y.shape[1]

    x_square = np.sum(x*x, axis=1, keepdims=True)
    if x is y:
        y_square = x_square.T
    else:
        y_square = np.sum(y*y, axis=1, keepdims=True).T
    distances = np.dot(x, y.T)
    # use inplace operation to accelerate
    distances *= -2
    distances += x_square
    distances += y_square
    # result maybe less than 0 due to floating point rounding errors.
    np.maximum(distances, 0, distances)
    if x is y:
        # Ensure that distances between vectors and themselves are set to 0.0.
        # This may not be the case due to floating point rounding errors.
        distances.flat[::distances.shape[0] + 1] = 0.0
    if not squared:
        np.sqrt(distances, distances)
    return distances
    
def constrain_kmeans(totfeas, lbl_idx, lbl_target, num_cluster, max_itr = 1000):

    mu = []
    unlbl_idx = list( set([i for i in range(len(totfeas))]) - set(lbl_idx) )
    lbl_idx = np.array(lbl_idx)
    for i in range(num_cluster):
        if i in lbl_target:
            idx = lbl_idx[np.argwhere(lbl_target == i)[:,0].astype('int')]
            mu += [ totfeas[idx,:].copy().mean(axis=0)  ]
        else:
            idx = unlbl_idx[ np.random.randint(len(unlbl_idx)) ]
            unlbl_idx.remove(idx)
            mu += [ totfeas[idx,:].copy() ]

    mu = np.array(mu)
    tdist  = euclidean_distances(mu, totfeas, squared=False)#metrics.pairwise.pairwise_distances(mu, totfeas, metric='euclidean')
    tlabel = tdist.argmin(axis=0)
    oldlabel = tlabel.copy()

    # lbl_idx = lbl_idx.tolist()

    for itr in range(max_itr):
        for ic in range(num_cluster):
            idx = np.argwhere(tlabel == ic)[:,0].tolist()
            idx = list( set(idx) - set(lbl_idx.tolist()) )
            if ic in lbl_target:
                idx += lbl_idx[np.argwhere(lbl_target == ic)[:,0].astype('int')].tolist()
            mu[ic,:] = totfeas[idx,:].mean(axis=0)

        tdist  = euclidean_distances(mu, totfeas, squared=False)#metrics.pairwise.pairwise_distances(mu, totfeas, metric='euclidean')
        tlabel = tdist.argmin(axis=0)

        same_r = (tlabel == oldlabel).sum() / len(tlabel)
        if same_r == 1:
            break

        oldlabel = tlabel.copy()

    return tlabel, same_r

# alfeas = np.load('C:\\document\\code\\AS4L\\pretrain_encoder\\totfeas.npy')
# alidx = np.load()
# totlabel = np.load('C:\\document\\code\\AS4L\\pretrain_encoder\\totlabel.npy')
# # prediction for checking purity
# prediction = np.load()

def generate_clusters(alfeas, alidx, altarget, num_level = 13, num_class = 10):
    
    cluster10,_ = constrain_kmeans(alfeas, alidx[:], altarget, num_class, max_itr = 1000)
    
    clusters = [cluster10]
    for i in range(1,num_level):
        num_cluster = 2 ** i * num_class
        cluster1 = np.zeros(len(alfeas))
        for j in range(num_cluster):
            if j in clusters[-1]:
                idx = np.argwhere(clusters[-1] == j)[:,0]
                if len(idx) >= 2:
                    c2,_ = constrain_kmeans(alfeas[idx,:], [], [], 2, max_itr = 1000)
                    idx1 = np.argwhere(c2 == 0)[:,0]
                    cluster1[idx[idx1]] = 2*j        
                    idx1 = np.argwhere(c2 == 1)[:,0]
                    cluster1[idx[idx1]] = 2*j + 1
                elif len(idx) == 1:
                    cluster1[idx] = 2*j
        clusters += [cluster1.copy()]   
    
    return clusters

def est_purity(clusters, prediction, num_class = 10):
    
    sampleid = {}
    for i in range(len(clusters)):#
        hsampleid = {}
        for j in clusters[i]:
            idx = np.argwhere(clusters[i] == j)[:,0]
            hsampleid[j] = idx
        sampleid[len(sampleid)] = hsampleid
        print(i)
        
    ### have sampleid dict
    totestp = []
    tpre = np.argmax(prediction, axis=1)
    for i in range(len(clusters)):
        numcluster = num_class * 2**i 
        estpurity = np.zeros(numcluster)
    
        for j in range(numcluster):
            #idx = np.argwhere(clusters[i] == j)[:,0]
            if j not in sampleid[i]:
                estpurity[j] = 1
            else:
                idx = sampleid[i][j]
                estpurity[j] = np.bincount(tpre[idx]).max() / len(idx)
        totestp += [estpurity]    
        
    totnum = []
    for i in range(len(clusters)):
        numcluster = num_class * 2**i 
        tnum = np.zeros(numcluster)
        for j in range(numcluster):
            if j in sampleid[i]:
                tnum[j] = len(sampleid[i][j])
            else:
                tnum[j] =  0
        totnum += [tnum]    
    
    ###
    totwnum = []
    for ih in range(len(clusters)):
        num_ih = 2 ** ih * num_class
        wnum = [ totnum[ih][i] - totestp[ih][i] * totnum[ih][i] for i in range(num_ih) ]
        totwnum += [wnum.copy()]
    
    return totwnum, totnum

def split_clusters(totwnum, totnum, num_split, clusters, totlabel, num_class = 10):
    
    #initialization
    leaf = [ [0 for i in range(num_class)], [i for i in range(num_class)], totwnum[0].copy() ]
    
    #split
    while True:
        idsplit = np.argmax(leaf[2])
        if totnum[leaf[0][idsplit] + 1][2*leaf[1][idsplit]] >=1 and totnum[leaf[0][idsplit] + 1][2*leaf[1][idsplit]+1] >=1:
            leaf[0] += 2*[leaf[0][idsplit] + 1]#level id
            leaf[1] += [2*leaf[1][idsplit], 2*leaf[1][idsplit]+1] #cluster id
            leaf[2] += [ totwnum[leaf[0][-1]][leaf[1][-2]].copy(), totwnum[leaf[0][-1]][leaf[1][-1]].copy() ]#num wrong pl
        else:
            leaf[0] += [leaf[0][idsplit] + 1]#level id
            leaf[1] += [2*leaf[1][idsplit]] #cluster id
            leaf[2] += [ 0 ]#num wrong pl
        for idel in range(3): # del parent node
            del leaf[idel][idsplit]
        
        if len(leaf[0]) >= num_split:
            break
    
    #generate fuse pl
    fusepl = np.zeros(len(clusters[0]))
    fuseid = 0
    for i in range(len(leaf[0])):
        idx = np.argwhere(clusters[leaf[0][i]] == leaf[1][i])[:,0]
        fusepl[idx] = fuseid
        fuseid += 1
        
    fusepl = fusepl.astype('int')
    
    ###
    totpurity = []
    for i in range(num_split):
        idx = np.argwhere(fusepl == i)[:,0]
        totpurity += [ (np.bincount(totlabel[idx])).max()]
    
    print('cluster pseudo labels purity: ', np.sum(totpurity) / len(clusters[0]))
    
    return fusepl

def split_clusters_minconstr(totwnum, totnum, num_split, clusters, totlabel, num_min_cluster = 50, num_class = 10):
    
    #initialization
    leaf = [ [0 for i in range(num_class)], [i for i in range(num_class)], totwnum[0].copy() ]
    
    #split
    while True:
        idsplit = np.argmax(leaf[2])
        
        # if totnum[leaf[0][idsplit]][leaf[1][idsplit]] < num_min_cluster:
        #     leaf[2][idsplit] = 0
        # else:
        if leaf[0][idsplit] + 1 < len(totnum):
            if totnum[leaf[0][idsplit] + 1][2*leaf[1][idsplit]] >=num_min_cluster and totnum[leaf[0][idsplit] + 1][2*leaf[1][idsplit]+1] >=num_min_cluster:
                leaf[0] += 2*[leaf[0][idsplit] + 1]#level id
                leaf[1] += [2*leaf[1][idsplit], 2*leaf[1][idsplit]+1] #cluster id
                leaf[2] += [ totwnum[leaf[0][-1]][leaf[1][-2]].copy(), totwnum[leaf[0][-1]][leaf[1][-1]].copy() ]#num wrong pl
                for idel in range(3): # del parent node
                    del leaf[idel][idsplit]
            else:
                leaf[2][idsplit] = 0
        else:
            print('require splitting deeper hierachy', len(leaf[0]))
            leaf[2][idsplit] = 0

        if len(leaf[0]) >= num_split:
            break
    
    #generate fuse pl
    fusepl = np.zeros(len(clusters[0]))
    fuseid = 0
    for i in range(len(leaf[0])):
        idx = np.argwhere(clusters[leaf[0][i]] == leaf[1][i])[:,0]
        fusepl[idx] = fuseid
        fuseid += 1
        
    fusepl = fusepl.astype('int')
    
    ###
    totpurity = []
    for i in range(num_split):
        idx = np.argwhere(fusepl == i)[:,0]
        totpurity += [ (np.bincount(totlabel[idx])).max()]
    
    print('cluster pseudo labels purity: ', np.sum(totpurity) / len(clusters[0]))
    
    return fusepl

def generate_cluster_pl(alfeas, alidx, altarget, prediction, num_split, totlabel, num_level = 13, num_min_cluster = 50, num_class = 10):
    
    if alfeas is None:
        raise Exception 
    clusters = generate_clusters(alfeas, alidx, altarget, num_level = num_level, num_class = num_class)
    totwnum, totnum = est_purity(clusters, prediction, num_class)
    # fusepl = split_clusters(totwnum, totnum, num_split, clusters, totlabel)
    fusepl = split_clusters_minconstr(totwnum, totnum, num_split, clusters, totlabel, num_min_cluster, num_class)
    
    return fusepl

# def cal_purity(cluster, totlabel, num_cluster):
    
#     totpurity = []
#     num1 = []
#     for i in range(num_cluster):
#         idx = np.argwhere(cluster == i)[:,0]
#         totpurity += [ (np.bincount(totlabel[idx])).max()]
#         num1 += [len(idx)]
#     print(np.sum(totpurity) / len(totlabel))
    
#     return np.sum(totpurity) /  len(totlabel)