import argparse
import json
import os
from scripts.baseline_kmeans.discretizers.factory import (
    get_clustering_discretizer,
    clustering_model_choices,
)
from datasets import load_dataset
from more_itertools import batched
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from scripts.baseline_kmeans.utils import discretize_text_batch


def write_standard_discretized(
    doc_id: int, quantizer_indices: list[int], out_file
) -> None:
    """
    Write the standard discretized document representation to the output file.

    Args:
    - doc_id (int): The document ID.
    - quantizer_indices (list[int]): The list of quantizer indices (discretized tokens).
    - out_file: The file object to write the output to.
    """
    discretized_doc = " ".join([str(i) for i in quantizer_indices])
    out_file.write(
        json.dumps(
            {
                "id": doc_id,
                "contents": discretized_doc,
            }
        )
        + "\n"
    )


def write_token_and_code(
    doc_id: int, quantizer_indices: list[int], tokenizer_tokens: list[str], out_file
) -> None:
    """
    Write the combined token and code representation to the output file.

    Args:
    - doc_id (int): The document ID.
    - quantizer_indices (list[int]): The list of quantizer indices (discretized tokens).
    - out_file: The file object to write the output to.
    """
    discretized_tokens = " ".join([str(i) for i in quantizer_indices])
    combined_representation = discretized_tokens + " " + " ".join(tokenizer_tokens)

    out_file.write(
        json.dumps(
            {
                "id": doc_id,
                "contents": combined_representation,
            }
        )
        + "\n"
    )


def write_unique_code(doc_id: int, quantizer_indices: list[int], out_file) -> None:
    """
    Write the unique code representation to the output file.

    Args:
    - doc_id (int): The document ID.
    - quantizer_indices (list[int]): The list of quantizer indices (discretized tokens).
    - out_file: The file object to write the output to.
    """
    unique_codes = " ".join([str(i) for i in set(quantizer_indices)])
    out_file.write(
        json.dumps(
            {
                "id": doc_id,
                "contents": unique_codes,
            }
        )
        + "\n"
    )


def main(args):
    # Load the pretrained tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModel.from_pretrained(args.model_name)
    model.to(args.device)
    model.eval()

    # Load the clustering model
    clustering_model = get_clustering_discretizer(model=args.clustering_model).load(
        args.clustering_model_path
    )

    # Load the dataset
    dataset = load_dataset("irds/msmarco-passage", "docs")

    # Set up output directories and files
    output_folder = os.path.join(os.path.abspath(args.index_path), "data")
    output_folder_token_and_code = os.path.join(
        os.path.abspath(args.index_path), "data_documents_token_and_code"
    )
    output_folder_unique_code = os.path.join(
        os.path.abspath(args.index_path), "data_documents_unique_code"
    )
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(output_folder_token_and_code, exist_ok=True)
    os.makedirs(output_folder_unique_code, exist_ok=True)

    # Open the output files
    f_code = os.path.join(output_folder, "documents_codebook.jsonl")
    f_code_token = os.path.join(
        output_folder_token_and_code, "documents_codebook.jsonl"
    )
    f_code_unique = os.path.join(output_folder_unique_code, "documents_codebook.jsonl")

    with (
        open(f_code, "w") as out_code,
        open(f_code_token, "w") as out_code_token,
        open(f_code_unique, "w") as out_code_unique,
    ):
        # Calculate the total number of batches (if possible)
        dataset_size = len(dataset)
        num_batches = (dataset_size + args.batch_size - 1) // args.batch_size

        # Adding tqdm to the loop for progress tracking
        for batch_samples in tqdm(
            batched(dataset, args.batch_size),
            total=num_batches,
            desc="Processing batches",
        ):
            current_batch_size = len(batch_samples)
            doc_ids = [doc["doc_id"] for doc in batch_samples]
            documents = [doc["text"] for doc in batch_samples]

            # Discretize the current batch
            discretized_batch, tokenized_batch = discretize_text_batch(
                tokenizer=tokenizer,
                model=model,
                clustering_model=clustering_model,
                sentences_batch=documents,
                device=args.device,
                max_length=args.max_length,
            )

            # Write to each output file using the appropriate function
            for doc_id, quantizer_indices, tokens in zip(
                doc_ids, discretized_batch, tokenized_batch
            ):
                write_standard_discretized(doc_id, quantizer_indices, out_code)
                write_token_and_code(doc_id, quantizer_indices, tokens, out_code_token)
                write_unique_code(doc_id, quantizer_indices, out_code_unique)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Discretize text and create a BM25 index using a pretrained model and clustering in batches."
    )
    parser.add_argument(
        "--model_name",
        type=str,
        # required=True,
        default="facebook/contriever",
        help="Pretrained HF model name or path.",
    )
    parser.add_argument(
        "--clustering_model",
        # default="scikit",
        default="faiss",
        choices=clustering_model_choices,
        type=str,
        # required=True,
        help="Pretrained clustering model name.",
    )
    parser.add_argument(
        "--clustering_model_path",
        type=str,
        # required=True,
        help="Path to the pretrained clustering model.",
    )
    parser.add_argument(
        "--index_path",
        type=str,
        # required=True,
        help="Path to save the BM25 index files.",
    )
    parser.add_argument(
        "--max_length", type=int, default=200, help="Maximum length of input sequences"
    )

    parser.add_argument(
        "--batch_size", type=int, default=128, help="Number of sentences per batch."
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to use for computation ('cuda' or 'cpu').",
    )

    args = parser.parse_args()
    main(args)
