import argparse
import os
import shelve
import torch
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description='Compute sentence embeddings from shelve database')
    parser.add_argument('--db_dir', type=str, required=True, 
                        help='Directory containing the shelve database')
    parser.add_argument('--model_name', type=str, default='all-mpnet-base-v2',
                        help='Sentence embedding model name from sentence-transformers')
    parser.add_argument('--db_name', type=str, default='results_db',
                        help='Name of the shelve database file (without extension)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Batch size for encoding')
    parser.add_argument('--limit', type=int, default=None,
                        help='Number of keys to process')

    return parser.parse_args()

def compute_embeddings(model, texts, batch_size=32):
    """
    Compute embeddings for a list of texts using SentenceTransformer
    """
    # Replace empty strings with a space
    texts = [" " if text == "" else text for text in texts]
    
    # Compute embeddings
    embeddings = model.encode(texts, batch_size=batch_size, 
                             show_progress_bar=False, 
                             convert_to_tensor=True)
    
    return embeddings

def main():
    args = parse_args()
    
    # Check if directory exists
    if not os.path.isdir(args.db_dir):
        raise ValueError(f"Directory {args.db_dir} does not exist")
    
    # Full path to the shelve database
    db_path = os.path.join(args.db_dir, args.db_name)
    
    # Load the SentenceTransformer model
    print(f"Loading model {args.model_name}...")
    model = SentenceTransformer(args.model_name)
    
    # Use GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print(f"Using device: {device}")
    
    # Open the shelve database
    print(f"Opening shelve database at {db_path}...")
    with shelve.open(db_path) as db:
        # Create a dictionary to store the embeddings
        embeddings_dict = {}
        
        # Get all keys from the shelve database
        keys = list(db.keys())[:args.limit]
        print(f"Found {len(keys)} keys in the database")
        
        # Process each key
        for key in tqdm(keys, desc="Computing embeddings"):
            responses = db[key]
            
            # Compute embeddings for all responses for this key
            embeddings = compute_embeddings(model, responses, batch_size=args.batch_size)
            
            # Store the embeddings in the dictionary
            embeddings_dict[key] = embeddings
    
    # Save the embeddings to a PyTorch file
    embeddings_path = os.path.join(args.db_dir, 'embeddings.pt')
    print(f"Saving embeddings to {embeddings_path}...")
    torch.save(embeddings_dict, embeddings_path)
    print("Done!")

if __name__ == "__main__":
    main()
