from datasets import load_from_disk
import numpy as np
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
import argparse
import json


BATCH_SIZE = 4096

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Trainer Script')
    parser.add_argument('--config', type=str, default='/path/to/config', help='Config Path')
    args = parser.parse_args()
    config = json.load(open(args.config))
    ds = load_from_disk(config["dataset_path"])
    inference_data = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)

    model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    SAVE_PATH = os.path.join(config["temp_dir"], "temp_embeddings.npy")
    # Create directory if it doesn't exist
    os.makedirs(config["temp_dir"], exist_ok=True)
    os.makedirs(config["working_dir"], exist_ok=True)
    
    # Initialize variables for incremental saving
    total_samples = len(ds)
    embedding_dim = None
    processed_samples = 0
    chunk_size = 10000  # Adjust based on your memory constraints
    
    # Check if we can resume from a previous run
    if os.path.exists(SAVE_PATH):
        print(f"Found existing embeddings at {SAVE_PATH}, resuming from there...")
        existing_embeddings = np.load(SAVE_PATH, mmap_mode='r')
        embedding_dim = existing_embeddings.shape[1]
        processed_samples = existing_embeddings.shape[0]
        print(f"Already processed {processed_samples}/{total_samples} samples")
    
    # Create or extend the embeddings file as needed
    if processed_samples == 0:
        # First run, we'll create the file after processing the first batch
        first_run = True
    else:
        # Resuming, copy existing embeddings to final destination
        embeddings_array = np.memmap(
            os.path.join(config["working_dir"], "full_dataset_embeddings.npy"),
            dtype='float32',
            mode='w+',
            shape=(total_samples, embedding_dim)
        )
        embeddings_array[:processed_samples] = existing_embeddings[:processed_samples]
        first_run = False
    
    # Process remaining batches
    batch_start_idx = processed_samples
    for batch_idx, batch in enumerate(tqdm(inference_data, desc="Generating embeddings")):
        # Skip already processed batches
        if batch_idx * BATCH_SIZE < processed_samples:
            continue
        
        text = batch[config["dataset_data_column_name"]]
        batch_embeddings = model.encode(text, batch_size=64)
        
        if first_run:
            embedding_dim = batch_embeddings.shape[1]
            # Initialize the memory-mapped array file
            embeddings_array = np.memmap(
                os.path.join(config["working_dir"], "full_dataset_embeddings.npy"),
                dtype='float32',
                mode='w+',
                shape=(total_samples, embedding_dim)
            )
            first_run = False
        
        # Calculate the actual indices for this batch
        end_idx = min(batch_start_idx + len(batch_embeddings), total_samples)
        curr_batch_size = end_idx - batch_start_idx
        
        # Store in the memory-mapped array
        embeddings_array[batch_start_idx:end_idx] = batch_embeddings[:curr_batch_size]
        
        # Update position for next batch
        batch_start_idx = end_idx
        
        # Periodically flush to disk
        if batch_idx % 10 == 0:
            embeddings_array.flush()
            
        # Save checkpoint every chunk_size samples
        if batch_start_idx % chunk_size == 0 or batch_start_idx == total_samples:
            # Save checkpoint by creating a copy of current progress
            np.save(SAVE_PATH, embeddings_array[:batch_start_idx], allow_pickle=False)
            print(f"Checkpoint saved: {batch_start_idx}/{total_samples} embeddings")
    
    # Ensure all data is written to disk
    embeddings_array.flush()
    print(f"Generated {batch_start_idx} embeddings with dimension {embedding_dim}")
    # Final save is redundant since we're using memmap, but useful for confirmation
    print(f"Embeddings saved to {os.path.join(config['working_dir'], 'full_dataset_embeddings.npy')}")
    
