import argparse
import json
import logging
import os
import sys
import typing

sys.path.append("../discrete_autoencoder/")
import encoding_utils
from more_itertools import batched
import ir_datasets
import torch
from tqdm import tqdm

from modeling_autoencoder_v2 import Autoencoder

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("create_bm25_index.py")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--autoencoder_path",
        # required=True,
    )
    parser.add_argument(
        "--dataset",
        default="msmarco-passage",
        # required=True,
    )
    parser.add_argument(
        "--index_path",
    )
    parser.add_argument(
        "--document_formatting",
        default="{document}",
    )
    parser.add_argument(
        "--assignment_size",
        default=-1,
        type=int,
    )
    parser.add_argument(
        "--max_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--length",
        default=None,
        type=int,
    )

    return parser.parse_args()


def get_quantized_indices(autoencoder, batch_autoencoder, length):
    output_d2d = autoencoder(
        batch_autoencoder,
    )
    # assert len(non_padded) == len(output_d2d.quantizer_outputs[1])
    # quantizer_indices = [None] * len(non_padded)
    # for i, num_tokens in enumerate(non_padded):
    #     quantizer_indices[i] = (
    #         output_d2d.quantizer_outputs[1][i][:num_tokens].long().tolist()
    #     )
    quantizer_indices = output_d2d.quantizer_outputs[1].long().tolist()
    if length is not None:
        quantizer_indices = [
            quantizer_index[:length] for quantizer_index in quantizer_indices
        ]
    return quantizer_indices


@torch.inference_mode()
def encode(
    autoencoder: Autoencoder,
    unprocessed_docs: typing.Iterable[tuple[str, str]],  # query_id  # query
    batch_size: int,
    document_formatting: str = "{document}",
    num_docs: None | int = None,
    length: int = 256,
) -> typing.Generator[tuple[str, str], None, None]:  # query_id  # encoded_query

    if num_docs is None and hasattr(unprocessed_docs, "__len__"):
        num_docs = len(unprocessed_docs)

    with tqdm(
        desc="Encoding documents",
        total=num_docs,
    ) as pbar:

        for batch in batched(unprocessed_docs, batch_size):
            ids = [doc.doc_id for doc in batch]
            documents = [doc.text for doc in batch]
            # ids, documents = zip(*batch)
            documents = [
                document_formatting.format(document=document) for document in documents
            ]
            batch_autoencoder = autoencoder.tokenize(documents, max_length=256)
            # set length
            # non_padded = batch_autoencoder.attention_mask.sum(dim=-1)
            quantizer_indices = get_quantized_indices(
                autoencoder, batch_autoencoder, length
            )
            autoencoded_batch = [
                " ".join(
                    [str(i + autoencoder.tokenizer.vocab_size) for i in quantizer_index]
                )
                for quantizer_index in quantizer_indices
            ]
            autoencoded_batch_unique = [
                " ".join(
                    [
                        str(i) + "_" + str(t + autoencoder.tokenizer.vocab_size)
                        for i, t in enumerate(quantizer_index)
                    ]
                )
                for quantizer_index in quantizer_indices
            ]

            for id_, autoencoded, autoencoded_unique, document in zip(
                ids, autoencoded_batch, autoencoded_batch_unique, documents
            ):
                pbar.update(1)
                yield (id_, autoencoded, autoencoded_unique, document)


def main():
    global logger
    args = parse_args()
    # load the dataset

    logger.info(f"Loading dataset {args.dataset}")
    dataset = ir_datasets.load(args.dataset)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device {device}")

    logger.info(f"Loading the autoencoder from {args.autoencoder_path}")
    # init the autoencoder
    autoencoder = Autoencoder.load_checkpoint(
        path=args.autoencoder_path,
    ).to(device)
    for param in autoencoder.parameters():
        param.requires_grad = False
    autoencoder.model.bfloat16()
    autoencoder.quantizer.float()
    autoencoder.add_special_tokens()
    autoencoder.eval()

    output_folder = os.path.join(os.path.abspath(args.index_path), "data")
    output_folder_documents_token_and_code = os.path.join(
        os.path.abspath(args.index_path), "data_documents_token_and_code"
    )
    output_folder_documents_unique_code = os.path.join(
        os.path.abspath(args.index_path), "data_documents_unique_code"
    )

    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(output_folder_documents_token_and_code, exist_ok=True)
    os.makedirs(output_folder_documents_unique_code, exist_ok=True)

    logger.info(f"Saving the index to {output_folder}")
    with open(
        os.path.join(output_folder, "documents_codebook.jsonl"), "w"
    ) as out, open(
        os.path.join(
            output_folder_documents_token_and_code, "documents_codebook.jsonl"
        ),
        "w",
    ) as out_code_and_token, open(
        os.path.join(output_folder_documents_unique_code, "documents_codebook.jsonl"),
        "w",
    ) as out_unique_code:
        encode_iter = encode(
            autoencoder=autoencoder,
            unprocessed_docs=dataset.docs_iter(),
            batch_size=200,
            document_formatting=args.document_formatting,
            length=args.length,
        )
        for doc_id, encoded_doc, encoded_unique_doc, document_text in encode_iter:
            if int(doc_id) % 100000 == 0:
                logger.info(f"Processing document {doc_id}")
                logger.info(f"Document text: {document_text}")
                logger.info(f"Encoded document: {encoded_doc}")
                logger.info(f"Encoded unique document: {encoded_unique_doc}")
            out.write(
                json.dumps(
                    {
                        "id": doc_id,
                        "contents": encoded_doc,
                    }
                )
            )
            out_code_and_token.write(
                json.dumps(
                    {
                        "id": doc_id,
                        "contents": encoded_unique_doc + " " + document_text,
                    }
                )
            )
            out_unique_code.write(
                json.dumps(
                    {
                        "id": doc_id,
                        "contents": encoded_unique_doc,
                    }
                )
            )
            out.write("\n")
            out_code_and_token.write("\n")
            out_unique_code.write("\n")
    logger.info("Processing finished!")


if __name__ == "__main__":
    main()
