import torch.nn.functional as F
import argparse
import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModel, set_seed
from datasets import load_from_disk
import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=["imdb_goodbad_allpara", 'imdb_goodbad_withpara_final', "imdb_horribleincredible", "imdb_horribleincredible_allpara_final"], default='imdb_goodbad_withpara_final')
parser.add_argument('--seed',type=int, default=0)

args = parser.parse_args()
set_seed(args.seed)

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to('cuda')

dsdict = load_from_disk(args.dataset)
for dsk in dsdict:
    if not (dsk.startswith('S') or dsk.startswith('T')):
        continue
    ds = dsdict[dsk]
    print(ds)

    queries = []
    Tpara_queries = []
    Spara_queries = []
    queries = {'text': [], 'Tpara1': [], 'Spara1': []}
    for example in ds:
        queries['text'].append(example['text'])
        if 'Tpara1' in example:
            queries['Tpara1'].append(example['Tpara1'])
        if 'Spara1' in example:
            queries['Spara1'].append(example['Spara1'])

    batch_size=256
    N = len(ds) // batch_size + 1
    all_embeds = {'text': [], 'Tpara1': [], 'Spara1': []}
    with torch.no_grad():
        for k,l in queries.items():
            if len(l) == 0:
                continue
            embeds = []
            for i in tqdm.tqdm(range(N)):
                input_texts = l[i*batch_size:(i+1)*batch_size]
                batch_dict = tokenizer(input_texts, padding=True, truncation=True, return_tensors='pt')
                batch_dict = {k:v.to('cuda') for (k,v) in batch_dict.items()}
                outputs = model(**batch_dict)
                embeddings = mean_pooling(outputs, batch_dict['attention_mask']).cpu()

                # normalize embeddings
                embeddings = F.normalize(embeddings, p=2, dim=1)
                embeds.extend(embeddings)
            all_embeds[k] = embeds

    for k in all_embeds.keys():
        if len(all_embeds[k]) > 0:
            embeds = [e.numpy().tolist() for e in all_embeds[k]]
            ds = ds.add_column(f"{k}_embeds", embeds)
    dsdict[dsk] = ds

dsdict.save_to_disk(f"{args.dataset}_sentencebert")
