from torch.utils.data import DataLoader
from data.sampler import SubsetSequentialSampler
from sklearn.cluster import KMeans
from functions import *
import torch
import numpy as np

def Kmeans_Selection(argTuple, data_train, BATCH, models, n, **kwargs):
    arg, arg_preds, arg_label, arg_scores = argTuple
    arg_preds, arg_label = torch.tensor(arg_preds), torch.tensor(arg_label)
    if 'labels' in dir(data_train):
        assert (arg_label == torch.tensor(data_train.labels)[arg]).all()
    else:
        assert (arg_label == torch.tensor(data_train.targets)[arg]).all()
    nC, clsm = arg_scores.size(-1), kwargs['clsm']
    clsSelectNum = int(n/nC*clsm)
    mod_arg_preds = arg_preds.clone()

    for C in range(nC):
        clsPredArg = (mod_arg_preds == C).nonzero()[:, 0]
        clsImgArg = arg[clsPredArg]
        if len(clsImgArg) < clsSelectNum:
            topkIdces = arg_scores[:, C].argsort()
            for topIdx in topkIdces:
                oriClass = mod_arg_preds[topIdx]
                if len((mod_arg_preds == oriClass).nonzero()) > clsSelectNum:
                    mod_arg_preds[topIdx] = C
                if len((mod_arg_preds == C).nonzero()) >= clsSelectNum:
                    break

    IID_idces = []
    for C in range(nC):
        clsPredArg = (mod_arg_preds==C).nonzero()[:,0]
        clsImgArg = arg[clsPredArg]
        clsSelectArg = clsImgArg[:clsSelectNum]
        IID_idces.append(clsSelectArg)
    IID_idces = torch.stack(IID_idces).reshape(-1)
    assert len(IID_idces.unique()) == clsSelectNum * nC
    IID_idces = IID_idces[torch.randperm(len(IID_idces))]

    UL_selection_loader = DataLoader(data_train, batch_size=BATCH, sampler=SubsetSequentialSampler(IID_idces), pin_memory=True)
    Pre_features = []
    with torch.no_grad():
        for i, data in enumerate(UL_selection_loader):
            inputs = data[0].cuda(); label = data[1].cuda()
            idxFromIndex = IID_idces[i * BATCH:(i + 1) * BATCH]
            if 'labels' in dir(data_train):
                assert (torch.tensor(data_train.labels)[idxFromIndex].cuda() == label).all()
            else:
                assert (torch.tensor(data_train.targets)[idxFromIndex].cuda() == label).all()
            scores, embedding, features = models['backbone'](inputs)
            Pre_features.append(embedding)
    Pre_features = torch.stack(Pre_features).reshape(-1, 512).detach().cpu().numpy()
    cluster_learner = KMeans(n_clusters=n)
    cluster_learner.fit(Pre_features)

    cluster_idxs = cluster_learner.predict(Pre_features)
    centers = cluster_learner.cluster_centers_[cluster_idxs]
    dis = (Pre_features - centers) ** 2
    dis = dis.sum(axis=1)
    Pre_Kmeans_Idx = np.array([np.arange(Pre_features.shape[0])[cluster_idxs == i][dis[cluster_idxs == i].argmin()]
                               for i in range(n) if (cluster_idxs == i).sum() > 0])
    Kmeans_Idx = IID_idces[Pre_Kmeans_Idx]

    if len(Kmeans_Idx) < n:
        print(f'len(Pre_Kmeans_Idx) {len(Pre_Kmeans_Idx)} is less than n {n}')
        r_nums = n - len(Pre_Kmeans_Idx)
        r_Idxs = np.array(list(set(range(n)) - set(Pre_Kmeans_Idx)))
        chosen = np.random.choice(r_Idxs, r_nums, replace=False)
        Pre_Kmeans_Idx = np.concatenate((Pre_Kmeans_Idx, chosen), axis=0)
        print(f'{len(chosen)} idxes are sampled to construct set whose length is {len(Pre_Kmeans_Idx)}')
        Kmeans_Idx = IID_idces[Pre_Kmeans_Idx]

    return Kmeans_Idx.tolist()
