from sklearn.cluster import KMeans
from matplotlib import pyplot as plt
import numpy as np
from anchor import utils
dataset_folder = './dataset'
dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder, discretize=True)
print(np.shape(dataset.train))


cluster_num=150
clf1=KMeans(cluster_num,max_iter=1000)
alldata=np.vstack((dataset.train,dataset.test))
print(np.shape(alldata))
clf1.fit(alldata)
#print(pca2d.fit_transform(clf1.cluster_centers_))# 数据降维至两维便于可视化)
#print(clf2.cluster_centers_)

label_pred = clf1.labels_
centroids = clf1.cluster_centers_
inertia = clf1.inertia_

'''
fig, ax = plt.subplots(figsize=(12,6))
plt.scatter(alldata[:, 0], alldata[:, 1], c=label_pred)
plt.show()
'''

print(np.shape(centroids))
print(np.shape(label_pred))


L=512
N = len(newsgroups_train)
bert_train,mask_train,seg_ids_train = [], [],[]
all_sents = newsgroups_train+newsgroups_test
tokenizer=BertTokenizer.from_pretrained('distilbert-base-cased')
for sent in tqdm(all_sents):
    tokens = tokenizer.tokenize(sent)
    tokens = ['[CLS]'] + tokens + ['[SEP]']
    padded_tokens = tokens[:L] + ['[PAD]' for _ in range(L - len(tokens))]
    attn_mask = [1 if token != '[PAD]' else 0 for token in padded_tokens]
    sent_ids = tokenizer.convert_tokens_to_ids(padded_tokens)
    seg_ids = [0 for _ in range(len(padded_tokens))]
    bert_train.append(sent_ids)
    mask_train.append(attn_mask)
    seg_ids_train.append(seg_ids)
 

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
data1 = NewDataset(bert_train,mask_train=mask_train,seg_ids_train=seg_ids_train)
bert_model = BertModel.from_pretrained('distilbert-base-cased').to(device)
 
reps = []
batchsize = 5
for batch in tqdm(DataLoader(data1, shuffle=False, batch_size=batchsize)):
    bert_train, mask_train, seg_ids_train = batch
    #hidden_reps, cls_head = bert_model(bert_train.cuda(), attention_mask=mask_train.cuda(), token_type_ids=seg_ids_train.cuda())
    #reps+=list(cls_head.detach().cpu().numpy())
    output = bert_model(bert_train.cuda(device), attention_mask=mask_train.cuda(device), token_type_ids=seg_ids_train.cuda(device))
    reps+=list(output.pooler_output.detach().cpu().numpy())
 
reps_train = reps[:N]
reps_test = reps[N:]