"""
We use the pre-trained language model (e.g. BERT) to encode the text data, and save the embeddings to the disk.
This step can faster the training process of the model.

"""

import os
import torch
import argparse
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer


class BertTextEmbHandler():
    def __init__(self, device, model_id='bert-base-uncased'):
        self.model_id = model_id
        self.tokenizer = BertTokenizer.from_pretrained(model_id)
        self.model = BertModel.from_pretrained(model_id).to(device)
    
    def get_sentence_embedding(self, sentences, max_length=256):
        inputs = self.tokenizer(sentences, return_tensors='pt', truncation=True, padding=True, max_length=max_length).to(self.model.device)
        # print("Your tokenized input: ", inputs)
        outputs = self.model(**inputs)
        # sentence_embedding = outputs.last_hidden_state[:, 0, :]
        # We need to remove padding, cls, and sep tokens, and average pooling or max pooling
        # We can use the attention mask to remove padding tokens
        return outputs.last_hidden_state[:, 0, :].detach().to(self.model.device)

# parser = argparse.ArgumentParser()
# parser.add_argument('--model', type=str, default='bert-base-uncased', help='pre-trained language model')
# parser.add_argument('--source_dir', type=str, default='data', help='source directory')

# args = parser.parse_args()

class SentenceTransformerHandler():
    def __init__(self, device, model_id='paraphrase-MiniLM-L12-v2'):
        self.model_id = model_id
        self.model = SentenceTransformer(model_id).to(device)
        # self.model.eval()
    
    def get_sentence_embedding(self, sentences):
        return self.model.encode(sentences, convert_to_tensor=True).to(self.model.device)

def emb_all_txt_files(handler, source_dir, cache_dir):
    for file in os.listdir(source_dir):
        if file.endswith('.txt') and (not ("furni" in file)):
            with open(os.path.join(source_dir, file), 'r') as f:
                sentences = f.readlines()
            embeddings = handler.get_sentence_embedding(sentences)
            # print("embedding size: ", embeddings.size())
            torch.save(embeddings.cpu(), os.path.join(cache_dir, file.replace('.txt', '.pt')))

if __name__ == '__main__':
    sentences = ["Circle, color shuffled.", "Line, color shuffled.", "Circle, color ordered.", "Line, color ordered."]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # handler = SentenceTransformerHandler(device)
    # outputs = handler.get_sentence_embedding(sentences)
    # print(outputs.shape)
    # sim = handler.model.similarity(outputs, outputs)
    # print(sim)
    handler = BertTextEmbHandler(device, model_id="bert-large-uncased")
    outputs = handler.get_sentence_embedding("You are so good!")
    print(outputs.requires_grad)

