import argparse
from collections import defaultdict
import json
import logging
import os
import sys
import typing

sys.path.append("../discrete_autoencoder/")
from more_itertools import batched
import ir_datasets
import torch
from tqdm import tqdm

from modeling_autoencoder_v2 import Autoencoder

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("create_bm25_index.py")
logger.setLevel(logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--autoencoder_path",
        # required=True,
    )
    parser.add_argument(
        "--dataset",
        default="msmarco-passage/dev/small",
        # required=True,
    )
    parser.add_argument(
        "--index_path",
    )
    parser.add_argument(
        "--document_formatting",
        default="{document}",
    )
    parser.add_argument(
        "--num_docs_per_query",
        default=50,
    )

    parser.add_argument(
        "--max_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--quantisation",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--length",
        default=None,
        type=int,
    )

    return parser.parse_args()


def get_quantized_indices(autoencoder, batch_autoencoder, length):
    output_d2d = autoencoder(
        batch_autoencoder,
    )
    # assert len(non_padded) == len(output_d2d.quantizer_outputs[1])
    # quantizer_indices = [None] * len(non_padded)
    # for i, num_tokens in enumerate(non_padded):
    #     quantizer_indices[i] = (
    #         output_d2d.quantizer_outputs[1][i][:num_tokens].long().tolist()
    #     )
    quantizer_indices = output_d2d.quantizer_outputs[1].long().tolist()
    if length is not None:
        quantizer_indices = [
            quantizer_index[:length] for quantizer_index in quantizer_indices
        ]
    return quantizer_indices


@torch.inference_mode()
def encode(
    autoencoder: Autoencoder,
    unprocessed_docs: typing.Iterable[tuple[str, str]],  # doc_id  # doc_text
    batch_size: int,
    document_formatting: str = "{document}",
    num_docs: None | int = None,
    length: int = 100,
) -> typing.Generator[tuple[str, str], None, None]:  # doc_id  # encoded_doc

    print("Autoencoder tokenizer vocab size: ", autoencoder.tokenizer.vocab_size)
    if num_docs is None and hasattr(unprocessed_docs, "__len__"):
        num_docs = len(unprocessed_docs)

    with tqdm(
        desc="Encoding documents",
        total=num_docs,
    ) as pbar:

        for batch in batched(unprocessed_docs, batch_size):
            ids, documents = zip(*batch)
            documents = [
                document_formatting.format(document=document) for document in documents
            ]
            batch_autoencoder = autoencoder.tokenize(documents, max_length=256)
            # set length
            # non_padded = batch_autoencoder.attention_mask.sum(dim=-1)
            quantizer_indices = get_quantized_indices(
                autoencoder, batch_autoencoder, length
            )
            autoencoded_batch = [
                " ".join(
                    [str(i + autoencoder.tokenizer.vocab_size) for i in quantizer_index]
                )
                for quantizer_index in quantizer_indices
            ]
            autoencoded_batch_unique = [
                " ".join(
                    [
                        str(i) + "_" + str(t + autoencoder.tokenizer.vocab_size)
                        for i, t in enumerate(quantizer_index)
                    ]
                )
                for quantizer_index in quantizer_indices
            ]

            for id_, autoencoded, autoencoded_unique, document in zip(
                ids, autoencoded_batch, autoencoded_batch_unique, documents
            ):
                pbar.update(1)
                yield (id_, autoencoded, autoencoded_unique, document)


def main():
    global logger
    args = parse_args()
    # load the dataset

    logger.info(f"Loading dataset {args.dataset}")
    dataset = ir_datasets.load(args.dataset)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device {device}")

    logger.info(f"Loading the autoencoder from {args.autoencoder_path}")
    # init the autoencoder
    autoencoder = Autoencoder.load_checkpoint(
        path=args.autoencoder_path,
    ).to(device)
    if args.quantisation is not None:
        autoencoder.levels = [args.quantisation] * len(autoencoder.levels)
    for param in autoencoder.parameters():
        param.requires_grad = False
    autoencoder.model.bfloat16()
    autoencoder.add_special_tokens()
    autoencoder.quantizer.float()
    autoencoder.eval()

    output_folder = os.path.join(os.path.abspath(args.index_path), "data")
    output_folder_documents_token_and_code = os.path.join(
        os.path.abspath(args.index_path), "data_documents_token_and_code"
    )
    output_folder_documents_unique_code = os.path.join(
        os.path.abspath(args.index_path), "data_documents_unique_code"
    )
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(output_folder_documents_token_and_code, exist_ok=True)
    os.makedirs(output_folder_documents_unique_code, exist_ok=True)

    logger.info(f"Saving the index to {output_folder}")
    # clean up scoredocs so that each query only has 100 corresponding documents.
    # if a query has less than 100 documents, we discard it.
    # if a query has more than 100 documents, we keep the top 100.
    queries_to_neg_docs = defaultdict(list)

    docs_store = dataset.docs_store()
    for scoredoc in dataset.scoreddocs_iter():
        if len(queries_to_neg_docs[scoredoc.query_id]) < 99:
            queries_to_neg_docs[scoredoc.query_id].append(
                (scoredoc.doc_id, docs_store.get(scoredoc.doc_id).text)
            )
    # add 1 positive document
    for qrel in dataset.qrels_iter():
        if len(queries_to_neg_docs[qrel.query_id]) < 100:
            queries_to_neg_docs[qrel.query_id].append(
                (qrel.doc_id, docs_store.get(qrel.doc_id).text)
            )

    considered_query_ids_and_doc_ids = {
        k: v for k, v in queries_to_neg_docs.items() if len(v) == 100
    }
    documents_text = queries_to_neg_docs.values()
    documents_text = [(doc[0], doc[1]) for docs in list(documents_text) for doc in docs]
    # query_id_to_text = {q.query_id: q.text for q in dataset.queries_iter()}

    print("Scoring on", len(considered_query_ids_and_doc_ids), "queries")

    with open(
        os.path.join(output_folder, "documents_codebook.jsonl"), "w"
    ) as out, open(
        os.path.join(
            output_folder_documents_token_and_code, "documents_codebook.jsonl"
        ),
        "w",
    ) as out_code_and_token, open(
        os.path.join(output_folder_documents_unique_code, "documents_codebook.jsonl"),
        "w",
    ) as out_unique_code:
        encode_iter = encode(
            autoencoder=autoencoder,
            unprocessed_docs=documents_text,
            batch_size=200,
            document_formatting=args.document_formatting,
            length=args.length,
        )
        for doc_id, encoded_doc, encoded_unique_doc, doc in encode_iter:
            out.write(
                json.dumps(
                    {
                        "id": doc_id,
                        "contents": encoded_doc,
                    }
                )
            )
            out_code_and_token.write(
                json.dumps(
                    {
                        "id": doc_id,
                        "contents": encoded_unique_doc + " " + doc,
                    }
                )
            )
            out_unique_code.write(
                json.dumps(
                    {
                        "id": doc_id,
                        "contents": encoded_unique_doc,
                    }
                )
            )
            out.write("\n")
            out_code_and_token.write("\n")
            out_unique_code.write("\n")
    logger.info("Processing finished!")


if __name__ == "__main__":
    main()
