import pickle
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm
import os
import torch
from sentence_transformers import SentenceTransformer, util
from datasets import load_from_disk
import torch
import os
import numpy as np
import random
import gc
import argparse
import json

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='/path/to/config', help='Config Path')
args = parser.parse_args()
config = json.load(open(args.config))


seed = 23
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)


if __name__ == "__main__":
    model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    # Compute the mean embedding for each cluster
    for i in tqdm(range(config["clustering_k"])):
        cluster_data = load_from_disk(os.path.join(config["working_dir"], "dataset_split", str(i)))[config["dataset_data_column_name"]]
        embeddings = model.encode(cluster_data)
        embeddingsTensor = torch.tensor(embeddings).cuda()
        embeddingsNormalized = util.normalize_embeddings(embeddingsTensor)
        meanEmbedding = embeddingsNormalized.mean(axis=0).cpu()
        meanNorm = util.normalize_embeddings(meanEmbedding.unsqueeze(0))
        torch.save(meanNorm, config["temp_dir"] + "/" + str(i) + ".pt")
        del embeddingsTensor, embeddingsNormalized, cluster_data, embeddings
        torch.cuda.empty_cache()
        gc.collect()

    # Combine all the embeddings
    allEmbeddings = []
    for i in tqdm(range(config["clustering_k"])):
        data = torch.load(config["temp_dir"] + "/" + str(i) + ".pt", map_location=torch.device('cpu'))
        tensor_data = torch.tensor(data)
        allEmbeddings.append(tensor_data)


    res = torch.cat(allEmbeddings)

    print(res.shape)    
    torch.save(res, config["working_dir"] + "/clusterEmbeddings.pt")




