import numpy as np
import json
import torch, os
from transformers import BertTokenizer, BertModel, AutoTokenizer
from tqdm import tqdm
import argparse
from utils.config import MODEL_DIR, DATA_DIR

def cosine_similarity1(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
class SentenceEmbeddings:

    def __init__(self, model_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = BertTokenizer.from_pretrained(os.path.join(MODEL_DIR, model_path))
        self.model = BertModel.from_pretrained(os.path.join(MODEL_DIR, model_path)).to(self.device)
        # self.target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
        self.vocab_size = 1000
        
    def get_embedding(self, sentence):
        """Generate embedding for a sentence."""
        input_ids = self.tokenizer.encode(sentence, return_tensors="pt").to(self.device)
        with torch.no_grad():
            output = self.model(input_ids)
        return output[0][:, 0, :].cpu().numpy()

    def generate_embeddings(self, input_filename, output_filename):
        """Generate embeddings for all sentences in the input file."""
        all_embeddings = []
        with open(os.path.join(DATA_DIR, input_filename), 'r') as f:
            lines = f.readlines()
        pbar = tqdm(total=self.vocab_size, desc="Embeddings generated")

        for line in lines:
            data = json.loads(line)
            all_embeddings.append(self.get_embedding(data['text']))
            pbar.update(1)
            if len(all_embeddings) >= self.vocab_size:
                break

        pbar.close()
        all_embeddings = np.vstack(all_embeddings)
        np.savetxt(os.path.join(DATA_DIR, output_filename), all_embeddings, delimiter=" ")


def main():
    parser = argparse.ArgumentParser(description='Generate embeddings for sentences.')
    parser.add_argument('--input_filename', type=str, required=True, help='Input file path')
    parser.add_argument('--output_filename', type=str, required=True, help='Output file path')
    parser.add_argument('--model_path', type=str, required=True, help='Path to the BERT model')
    args = parser.parse_args()

    sentence_embeddings = SentenceEmbeddings(args.model_path)
    sentence_embeddings.generate_embeddings(args.input_filename, args.output_filename)

if __name__ == '__main__':
    main()
