# -*- coding: utf-8 -*-
"""

@author: Anonymous Author
"""

import torch
import numpy as np
from sklearn.cluster import MiniBatchKMeans

def cal_similarity_matrix(x, y):#labeled_embeddings:x ; unlabeled_embeddings: y
    
    yt = torch.transpose(y,1,0)
    x2 = torch.mul(x,x)
    y2 = torch.mul(y,y)
    xy = torch.mm(x,yt)
    
    onex = torch.ones(1,y.size(0))
    oney = torch.ones(x.size(0),1)
    
    x2s = torch.sum(x2,dim=1).unsqueeze(1)
    x22 = torch.mm(x2s,onex)
    
    y2s = torch.sum(y2,dim=1).unsqueeze(0)
    y22 = torch.mm(oney,y2s)
    
    dist = x22 + y22 - 2*xy
    
    return dist
    
def acquire_new_sample(budget, candnamelist, emb_lab, emb_un):#select diverse sampels based on embeddings
    
    selectedid = []
    new_sample = []
    
    emb_lab = [i.numpy() for i in emb_lab]
    emb_lab = np.array(emb_lab)
    emb_lab = torch.from_numpy(emb_lab)
    
    emb_un = [i.numpy() for i in emb_un]
    emb_un = np.array(emb_un)
    emb_un = torch.from_numpy(emb_un)
    
    v = cal_similarity_matrix(emb_lab, emb_un)
    for tsample in range(budget):
        v,_ = torch.min(v,dim=0)
        selectedid += [np.argmax((v.numpy()))]
        #print(np.max((v.numpy())))
        new_sample += [candnamelist[selectedid[-1]]]
        del candnamelist[selectedid[-1]]
        # similarity = torch.cat([similarity[:,:selectedid[-1]],similarity[:,selectedid[-1]+1:]],dim=1)
        emb_new = emb_un[selectedid[-1],:].unsqueeze(0)
        emb_un = torch.cat([ emb_un[:selectedid[-1],:], emb_un[selectedid[-1]+1:,:] ], dim=0)
        addsimrow = cal_similarity_matrix(emb_new, emb_un)
        v = torch.cat((v[:selectedid[-1]], v[(selectedid[-1]+1):]))
        v  = torch.cat([v.unsqueeze(0),addsimrow])
    
    return new_sample

# label_feas = f1m0[train_lbl_idx,:]
# unlabel_feas = f1m0[train_unlbl_idx,:]
# new_sample = acquire_new_sample(50, train_unlbl_idx, torch.from_numpy(label_feas), torch.from_numpy(unlabel_feas))
# label_feas = totfeas[coreset,:]
# uidx = list( set([i for i in range(50000)]) - set(coreset.tolist()) )
# unlabel_feas = totfeas[uidx,:]
# new_sample = acquire_new_sample(5000, uidx, torch.from_numpy(label_feas), torch.from_numpy(unlabel_feas))


### random selection
def random_select(budget, unlbl_idx):
    
    lbl_idx = np.random.randint(0,len(unlbl_idx) - 1, budget).tolist()
    
    return lbl_idx

def random_per_class(labels, budget, n_class):
    lbl_per_class = budget // n_class
    lbl_idx = []
    # unlbl_idx = []
    for i in range(n_class):
        idx = np.where(labels == i)[0]
        np.random.shuffle(idx)
        lbl_idx.extend(idx[:lbl_per_class])
        # unlbl_idx.extend(idx[lbl_per_class:])
    return lbl_idx#, unlbl_idx

### kMedoids
def kMedoids(D, k, tmax=100):
    # determine dimensions of distance matrix D
    m, n = D.shape

    if k > n:
        raise Exception('too many medoids')
    # randomly initialize an array of k medoid indices
    M = np.arange(n)
    np.random.shuffle(M)
    M = np.sort(M[:k])

    # create a copy of the array of medoid indices
    Mnew = np.copy(M)

    # initialize a dictionary to represent clusters
    C = {}
    for t in range(tmax):
        # determine clusters, i. e. arrays of data indices
        J = np.argmin(D[:,M], axis=1)
        for kappa in range(k):
            C[kappa] = np.where(J==kappa)[0]
        # update cluster medoids
        for kappa in range(k):
            J = np.mean(D[np.ix_(C[kappa],C[kappa])],axis=1)
            j = np.argmin(J)
            Mnew[kappa] = C[kappa][j]
        np.sort(Mnew)
        # check for convergence
        if np.array_equal(M, Mnew):
            break
        M = np.copy(Mnew)
    else:
        # final update of cluster memberships
        J = np.argmin(D[:,M], axis=1)
        for kappa in range(k):
            C[kappa] = np.where(J==kappa)[0]

    # return results
    return M.tolist()#, C #M indicates selected sampels, C indicates corresponding cluster idx


###GMM
def GMM(budget, totfeas, dist):
    
    from sklearn.mixture import GaussianMixture
    gmm = GaussianMixture(n_components=budget).fit(totfeas)
    labels = gmm.predict(totfeas)
    
    new_sample = []
    for ic in range(budget):
        sidx = np.argwhere(labels == ic)[:,0]
        sdist = dist[sidx,:]
        sdist = sdist[:,sidx]
        sumdist = sdist.sum(axis=0)
        new_sample += [ sidx[sumdist.argmin()] ]
    
    return new_sample


###TypiClust_rp

def euclidean_distances(x, y, squared=True):
    """Compute pairwise (squared) Euclidean distances.
    """
    assert isinstance(x, np.ndarray) and x.ndim == 2
    assert isinstance(y, np.ndarray) and y.ndim == 2
    assert x.shape[1] == y.shape[1]

    x_square = np.sum(x*x, axis=1, keepdims=True)
    if x is y:
        y_square = x_square.T
    else:
        y_square = np.sum(y*y, axis=1, keepdims=True).T
    distances = np.dot(x, y.T)
    # use inplace operation to accelerate
    distances *= -2
    distances += x_square
    distances += y_square
    # result maybe less than 0 due to floating point rounding errors.
    np.maximum(distances, 0, distances)
    if x is y:
        # Ensure that distances between vectors and themselves are set to 0.0.
        # This may not be the case due to floating point rounding errors.
        distances.flat[::distances.shape[0] + 1] = 0.0
    if not squared:
        np.sqrt(distances, distances)
    return distances

def constrain_kmeans(totfeas, lbl_idx, lbl_target, num_cluster, max_itr = 1000):

    mu = []
    unlbl_idx = list( set([i for i in range(len(totfeas))]) - set(lbl_idx) )
    lbl_idx = np.array(lbl_idx)
    for i in range(num_cluster):
        if i in lbl_target:
            idx = lbl_idx[np.argwhere(lbl_target == i)[:,0].astype('int')]
            mu += [ totfeas[idx,:].copy().mean(axis=0)  ]
        else:
            idx = unlbl_idx[ np.random.randint(len(unlbl_idx)) ]
            unlbl_idx.remove(idx)
            mu += [ totfeas[idx,:].copy() ]

    mu = np.array(mu)
    tdist  = euclidean_distances(mu, totfeas, squared=False)#metrics.pairwise.pairwise_distances(mu, totfeas, metric='euclidean')
    tlabel = tdist.argmin(axis=0)
    oldlabel = tlabel.copy()

    # lbl_idx = lbl_idx.tolist()

    for itr in range(max_itr):
        for ic in range(num_cluster):
            idx = np.argwhere(tlabel == ic)[:,0].tolist()
            idx = list( set(idx) - set(lbl_idx.tolist()) )
            if ic in lbl_target:
                idx += lbl_idx[np.argwhere(lbl_target == ic)[:,0].astype('int')].tolist()
            mu[ic,:] = totfeas[idx,:].mean(axis=0)

        tdist  = euclidean_distances(mu, totfeas, squared=False)#metrics.pairwise.pairwise_distances(mu, totfeas, metric='euclidean')
        tlabel = tdist.argmin(axis=0)

        same_r = (tlabel == oldlabel).sum() / len(tlabel)
        if same_r == 1:
            break

        oldlabel = tlabel.copy()

    return tlabel, same_r

def cal_typi(dist, candidx, clusteridx, K = 20):
    
    tdist = dist[candidx,:]
    tdist = tdist[:,clusteridx]
    idx = tdist.argsort()[0,:K]
    tdist = tdist[0,idx]
    typi = 1 / np.mean(tdist)
    
    return typi

def query_typiclust_first(cluster, dist, B = 20):

    new_sample = []
    for ib in range(B):
        clusteridx = np.argwhere(cluster == ib)[:,0]
        tottypi = []
        for icand in clusteridx:
            ttypi = cal_typi(dist, [icand], clusteridx, K = 20)
            tottypi += [ttypi]

        new_sample += [ clusteridx[np.argmax(tottypi)] ]

    return new_sample

def query_typiclust_first2(dist, cidx):#dist within one cluster

    new_sample = []
    tottypi = []
    clusteridx = [i for i in range(len(cidx))]
    for icand in range(len(cidx)):
        ttypi = cal_typi(dist, [icand], clusteridx, K = 20)
        tottypi += [ttypi]

    new_sample += [ cidx[np.argmax(tottypi)] ]

    return new_sample

def query_typiclust(cluster, candcluster, dist, B = 20):
    
    new_sample = []
    assert B == len(candcluster), 'wrong cand cluster / budget'
    for ib in candcluster:
        clusteridx = np.argwhere(cluster == ib)[:,0]
        tottypi = []
        for icand in clusteridx:
            ttypi = cal_typi(dist, [icand], clusteridx, K = 20)
            tottypi += [ttypi]
        
        new_sample += [ clusteridx[np.argmax(tottypi)] ]
        
    return new_sample

def find_uncover_cluster(lblidx, cluster, B = 20): #增补当新增空闲聚类数不够的处理:找重叠最少的类
    
    uncover_class =  list( set(cluster.tolist()) - set(cluster[lblidx].tolist()) )
    
    if len(uncover_class) >= B:
        totnum = []
        for i in uncover_class:
            totnum += [(cluster == i).sum()]
        
        candcluster = np.argsort(totnum)[-B:]
        uncover_class = np.array(uncover_class)
        candcluster = uncover_class[candcluster]
    else:
        num_overlap = np.bincount(cluster[lblidx], minlength = len(np.unique(cluster)))
        candcluster = np.argsort(num_overlap)[:B]
    
    return candcluster

# dist = euclidean_distances(totfeas, totfeas, squared=False)
# #bootstrap
# cluster20,_ = constrain_kmeans(totfeas, [], [], 20, max_itr = 1000)
# lblidx = query_typiclust_first(cluster20, dist, B = 20)
# #t+1 cycle
# num_budget_gap = 20
# num_cluster = 20
# B = 20
# for ial in range(4):
#     num_cluster += num_budget_gap
#     cluster,_ = constrain_kmeans(totfeas, [], [], num_cluster, max_itr = 1000)
#     candcluster = find_uncover_cluster(lblidx, cluster, B)
#     new = query_typiclust(cluster, candcluster, dist, B)
#     lblidx += new

# def sampling_typicluster(totfeas, alidx, num_budget):
    
#     dist = euclidean_distances(totfeas, totfeas, squared=False)
#     if len(alidx) == 0:
#         cluster,_ = constrain_kmeans(totfeas, [], [], num_budget, max_itr = 1000)
#         newidx = query_typiclust_first(cluster, dist, B = num_budget)
#     else:
#         num_cluster = len(alidx) + num_budget
#         if num_cluster > 500:
#             num_cluster = 500
#         cluster,_ = constrain_kmeans(totfeas, [], [], num_cluster, max_itr = 1000)
#         candcluster = find_uncover_cluster(alidx, cluster, num_budget)
#         newidx = query_typiclust(cluster, candcluster, dist, num_budget)

#     return newidx

### using minibatch kmeans for large scale data
def sampling_typicluster(totfeas, alidx, num_budget):
    
    if len(alidx) == 0:
        kmeans = MiniBatchKMeans(n_clusters=num_budget, batch_size=256)
        cluster = kmeans.fit_predict(totfeas)
        newidx = []
        for ic in range(num_budget):
            cidx = np.argwhere(cluster == ic)[:,0]
            dist = euclidean_distances(totfeas[cidx,:], totfeas[cidx,:], squared=False)
            tidx = query_typiclust_first2(dist, cidx )
            newidx += tidx
    else:
        num_cluster = len(alidx) + num_budget
        if num_cluster > 1000:#500:
            num_cluster = 1000# 500
        kmeans = MiniBatchKMeans(n_clusters=num_cluster, batch_size=256)
        cluster = kmeans.fit_predict(totfeas)
        candcluster = find_uncover_cluster(alidx, cluster, num_budget)
        for ic in candcluster:
            cidx = np.argwhere(cluster == ic)[:,0]
            dist = euclidean_distances(totfeas[cidx,:], totfeas[cidx,:], squared=False)
            tidx = query_typiclust_first2(dist, cidx )
            newidx += tidx

    return newidx


#probcover
###TODO: thresh estimation
def estimate_thresh(dist, cluster, thresh = 0.95, num_bins = 20):
    
    totpurity = []
    maxdist = dist.max()
    num_gap = maxdist*0.7 / num_bins
    idist = 1
    while(True):
        tdist = num_gap*idist
        purity,tnum = 0,0
        for i in range(len(cluster)):
            idx = np.argwhere(dist[i,:] <= tdist)[:,0]
            purity += ( (cluster[i] == cluster[idx]).sum() )
            tnum += len(idx) 
        totpurity += [purity / tnum]
        if totpurity[-1] < thresh:
            break
        idist += 1
    
    return tdist - num_gap

def construct_dg(totfeas, num_class, thresh):
    
    dist = euclidean_distances(totfeas, totfeas, squared=False)
    # cluster,_ = constrain_kmeans(totfeas, [], [], num_class)
    
    thresh = 0.17978#estimate_thresh(dist, cluster, thresh = 0.95)
    print('probcover purity 0.95 is equal to dist ', thresh)
    
    dg = {}
    # thresh = 0.02
    for i in range(50000):
        tidx = np.argwhere(dist[i,:] <= thresh)[:,0]
        dg[i] = tidx.tolist()
    
    return dg

def find_new_one(dg):

    num = 0
    tidx = 0
    for i in dg:
        if len(dg[i]) > num:
            num = len(dg[i])
            tidx = i

    return tidx, num

def update_dg(dg, covered_idx):
    tcover = set(dg[covered_idx])
    for ind in dg:
        dg[ind] = list(set(dg[ind]) - tcover)
    del dg[covered_idx]
    return dg

def sampling_probcover(dg, num_budget):
    
    alidx = []
    for i in range(num_budget):
        tidx,_ = find_new_one(dg)
        alidx += [tidx]
        dg = update_dg(dg, tidx)
        
    return alidx

###badge
def kmeans_plus(uidx, emb, num_budget):
    
    newidx = [uidx[np.random.randint(0,len(uidx))]]
    restidx = list( set(uidx) - set(newidx) )
    for i in range(num_budget - 1):
        dist = euclidean_distances(emb[newidx,:], emb[restidx,:], squared=True)###TODO avoid repeated dist comp. for efficiency
        prob = ( dist.T / dist.sum(axis=1) ).T
        newidx += np.random.choice(a=restidx, size=1, replace=False, p=prob[0,:]).tolist()
        restidx = list( set(uidx) - set(newidx) )
    
    return newidx

### uncertainty min max softmax output
def uncertainty_sampling(totpre, uidx, num_budget):
    
    uidx = np.array(uidx)
    prob = totpre[uidx,:].max(axis=1)
    newidx = uidx[np.argsort(prob)[:num_budget]]
    
    return newidx.tolist()

### entropy
def entropy_sampling(totpre, uidx, num_budget):
    
    uidx = np.array(uidx)
    prob = totpre[uidx,:]
    entropy = - prob * np.log(prob) 
    entropy = entropy.sum(axis=1)

    newidx = uidx[np.argsort(entropy)[-num_budget:]]
    
    return newidx.tolist()

