import argparse

import numpy as np
import torch
from datasets import load_dataset
from more_itertools import batched
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from scripts.baseline_kmeans.utils import accumulate_token_embeddings_masked_select


def main(args: argparse.Namespace) -> None:
    """
    Main function to execute the embedding extraction and memmap storage process.

    Args:
    - args (argparse.Namespace): Parsed command-line arguments.
    """
    # Initialize the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModel.from_pretrained(args.model_name)

    # Example sentences if no sentences are passed
    dataset = load_dataset("irds/msmarco-passage", "docs")

    # Create the memory-mapped file
    # perhaps take a look at this:
    # https://github.com/jxhe/efficient-knnlm/blob/e614f1e2e7da9c83994e99c82fd10f054fa08b4b/build_dstore.py#L76-L81
    dtype = np.float16 if args.use_fp16 else np.float32
    memmap = np.memmap(
        args.memmap_file,
        dtype=dtype,
        mode="w+",
        shape=(args.dstore_size, model.config.hidden_size),
    )

    # Initialize the index to keep track of the number of stored embeddings
    current_idx = 0

    model.eval()
    model.to(args.device)

    # GPT insists on using torch.no_grad() even though we're not training
    # I don't care enough to search about it, so I'll just leave it here -- sorry I am lazy
    with torch.no_grad():
        with tqdm(desc="Collecting embeddings", total=args.dstore_size) as pbar:
            for batch_samples in batched(dataset, args.batch_size):
                current_batch_size = len(batch_samples)
                doc_ids = [doc["doc_id"] for doc in batch_samples]
                documents = [doc["text"] for doc in batch_samples]

                # Tokenize the batch of sentences and move tensors to the specified device
                inputs = tokenizer(
                    documents,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=args.max_length,
                ).to(args.device)
                outputs = model(**inputs)
                token_embeddings = outputs.last_hidden_state

                # Accumulate non-padded token embeddings
                attention_mask = inputs["attention_mask"]
                selected_embeddings = accumulate_token_embeddings_masked_select(
                    token_embeddings, attention_mask
                )

                # Get the number of valid token embeddings
                num_embeddings = selected_embeddings.size(0)
                # Calculate the remaining space in the memmap
                available_space = args.dstore_size - current_idx

                # Adjust the number of embeddings to fit in the remaining space
                if num_embeddings > available_space:
                    num_embeddings = available_space

                features = selected_embeddings[:num_embeddings].data.cpu().numpy()
                if args.use_fp16:
                    features = features.astype(np.float16)

                # Store the embeddings in the memmap
                memmap[current_idx : current_idx + num_embeddings] = features

                current_idx += num_embeddings

                pbar.update(num_embeddings)

                if current_idx >= args.dstore_size:
                    print(f"Memmap is full. Stored {current_idx} token embeddings.")
                    break

    memmap.flush()
    # Delete the memory-mapped array to release resources
    del memmap


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Populate a memmap file with token embeddings from sentences."
    )

    # Model and tokenizer arguments
    parser.add_argument(
        "--model_name",
        type=str,
        default="facebook/contriever",
        help="Pretrained model name or path",
    )

    # Memmap arguments
    parser.add_argument(
        "--memmap_file",
        type=str,
        help="File path to store the memmap",
    )
    parser.add_argument(
        "--dstore_size",
        type=int,
        default=10000,
        help="Total number of embeddings to store",
    )
    # Processing arguments
    parser.add_argument(
        "--batch_size", type=int, default=64, help="Batch size for processing sentences"
    )
    parser.add_argument(
        "--max_length", type=int, default=200, help="Maximum length of input sequences"
    )
    parser.add_argument(
        "--use_fp16", action="store_true", help="Use float16 for embedding storage"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "cpu"],
        help="Device to run the model on",
    )

    args = parser.parse_args()
    main(args)
