import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn

def initialize(X, num_clusters):
    num_samples = len(X)
    indices = np.random.choice(num_samples, num_clusters, replace=False)
    initial_state = X[indices]
    return initial_state

def Bidding(
        X,
        num_clusters,
        upper,
        tol=1e-4,
        iol = 30
):
    '''
    X: the clustered feature
    num_clusters: the number of clusters
    '''
    # convert to float
    X = X.float()

    # transfer to device
    X = X.cuda()

    # initialize

    initial_state = initialize(X, num_clusters)

    iteration = 0

    while True:
        k = initial_state.size()[0]
        feature = torch.cat([initial_state,X],dim=0)
        A_normalized = feature / feature.norm(dim=-1, keepdim=True)
        cosine =  torch.mm(A_normalized,A_normalized.T)
        simility = torch.exp(cosine)
        simility = simility[:k,k:]
        simility = simility/simility.sum(1)[:,None]
        K = simility.size()[1]//simility.size()[0]
        choice_cluster = torch.zeros(simility.size()[1])-1
        for _ in range(len(X)):    
            labels = torch.argmax(simility)
            label = labels.item()//len(X)
            des = labels.item() %len(X)
            while torch.sum(choice_cluster==label)>=(K+upper):
                  simility[label]=-2
                  labels = torch.argmax(simility)
                  label = labels.item()//len(X)
                  des = labels.item() %len(X)
            choice_cluster[des]=label
            simility[:,des]=-2
        initial_state_pre = initial_state.clone()
        for index in range(simility.size()[0]):
            selected = torch.nonzero(choice_cluster == index).squeeze().cuda()
            selected = torch.index_select(X, 0, selected)
            initial_state[index] = selected.mean(dim=0)
        center_shift = torch.sum(
            torch.sqrt(
                torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
            ))

        # increment iteration
        iteration = iteration + 1

        # update tqdm meter
        if center_shift ** 2 < tol or iteration >=iol:
            break

    return choice_cluster.cpu(), initial_state.cpu()

