import argparse
from collections import defaultdict, Counter
import itertools
import json
import logging
import os
import sys

os.environ["MLFLOW_EXPERIMENT_NAME"] = "seal_codebooks_eval"
os.environ["MLFLOW_TRACKING_URI"] = "http://10.128.0.103:8080"

sys.path.append("../discrete_autoencoder/")
from more_itertools import batched
import ir_datasets
import ir_measures
from ir_measures import nDCG, MRR, Success
from pyserini.search.lucene import LuceneSearcher
import torch
from tqdm import tqdm
import pickle
import mlflow
import jsonlines

from modeling_autoencoder_v2 import Autoencoder
from train_utils import extract_latent_embeddings

logger = logging.getLogger()
logger.setLevel(logging.ERROR)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--autoencoder_path",
        # required=True,
    )
    parser.add_argument(
        "--dataset",
        # default="msmarco-passage/dev/small",
        # required=True,
    )
    parser.add_argument(
        "--index_path",
    )
    parser.add_argument(
        "--output_path",
        default="",
    )
    # parser.add_argument(
    #     "--scores_path",
    #     required=False,
    # )
    parser.add_argument(
        "--collection_path",
    )
    parser.add_argument(
        "--limit",
        default=-1,
        type=int,
    )
    parser.add_argument(
        "--query_formatting",
        default="{query}",
    )
    parser.add_argument(
        "--search_type",
        choices=[
            "tokenize_only",
            "encoding",
            "encoding_and_tokenisation",
            "unique_encoding",
            "hungarian_encoding",
        ],
    )
    parser.add_argument(
        "--rerank",
        action="store_true",
    )
    parser.add_argument(
        "--max_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--length",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--k",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--quantisation",
        default=None,
        type=int,
    )
    return parser.parse_args()


def get_quantized_indices(
    autoencoder,
    batch_autoencoder,
    queries,
    length: int | None = None,
):
    # num_non_padded = batch_autoencoder["attention_mask"].sum(dim=1)
    output_q2q = autoencoder(
        batch_autoencoder,
    )
    quantizer_indices = output_q2q.quantizer_outputs[1].long().tolist()
    # for i, num_tokens in enumerate(num_non_padded):
    #     quantizer_indices[i] = quantizer_indices[i][:num_tokens]

    if length is not None:
        quantizer_indices = [
            quantizer_index[:length] for quantizer_index in quantizer_indices
        ]

    return quantizer_indices


def get_encoding_batch(
    autoencoder,
    batch_autoencoder,
    queries,
    unique=False,
    length: int | None = None,
):
    quantizer_indices = get_quantized_indices(
        autoencoder,
        batch_autoencoder,
        queries,
        length=length,
    )
    if unique:
        autoencoded_batch = [
            " ".join(
                [
                    str(i) + "_" + str(t + autoencoder.tokenizer.vocab_size)
                    for i, t in enumerate(quantizer_index)
                ]
            )
            for quantizer_index in quantizer_indices
        ]
    else:
        autoencoded_batch = [
            " ".join(
                [str(i + autoencoder.tokenizer.vocab_size) for i in quantizer_index]
            )
            for quantizer_index in quantizer_indices
        ]
    return autoencoded_batch


@torch.inference_mode()
def encode(
    autoencoder: Autoencoder,
    unprocessed_queries: dict[str, str],  # query_id  # query
    batch_size: int,
    query_formatting: str = "{query}",
    search_type: str = "encoding",
    length: int | None = None,
) -> dict[str, str]:  # query_id

    encoded = {}

    with tqdm(desc="Encoding queries", total=len(unprocessed_queries)) as pbar:

        for batch in batched(unprocessed_queries.items(), batch_size):
            ids, queries = zip(*batch)
            queries = [query_formatting.format(query=query) for query in queries]
            batch_autoencoder = autoencoder.tokenize(queries, max_length=200)
            if search_type == "tokenize_only":
                tokenized_queries = [
                    " ".join([str(i) for i in tokenized_query])
                    for tokenized_query in batch_autoencoder["input_ids"].tolist()
                ]
                for id_, tokenized_query in zip(ids, tokenized_queries):
                    pbar.update(1)
                    encoded[id_] = tokenized_query
            elif search_type == "encoding":
                autoencoded_batch = get_encoding_batch(
                    autoencoder,
                    batch_autoencoder,
                    queries,
                    unique=False,
                    length=length,
                )

                for id_, autoencoded in zip(ids, autoencoded_batch):
                    encoded[id_] = autoencoded
                    pbar.update(1)
            elif search_type == "unique_encoding":
                autoencoded_batch = get_encoding_batch(
                    autoencoder,
                    batch_autoencoder,
                    queries,
                    unique=True,
                    length=length,
                )
                for id_, autoencoded in zip(ids, autoencoded_batch):
                    encoded[id_] = autoencoded
                    pbar.update(1)
            elif search_type == "hungarian_encoding":
                raise NotImplementedError()

    return encoded


def retrieve(
    bm25_index: LuceneSearcher,
    processed_queries: dict[str, str],  # query_id  # query
    k: int,  # num of hits to return
) -> dict[str, dict[str, float]]:  # query_id  # doc_id  # score
    queries = list(processed_queries.values())
    qids = list(processed_queries.keys())
    predicted_qrels_batched = {}
    output_file = "predicted_qrels.json"
    with open(output_file, "w") as f_out:
        for i in range(0, len(queries), 100):
            print("recovering", i)
            queries_batch = queries[i : i + 100]
            qids_batch = qids[i : i + 100]
            predicted_qrels_b = bm25_index.batch_search(
                queries=queries_batch,
                qids=qids_batch,
                k=k,
                threads=20,
            )
            for q_id, hits in predicted_qrels_b.items():
                qrel_scores = {hit.docid: hit.score for hit in hits}
                f_out.write(json.dumps({q_id: qrel_scores}) + "\n")

    # read from file
    with open(output_file, "r") as f:
        for line in f:
            qrel = json.loads(line)
            for q_id, hits in qrel.items():
                predicted_qrels_batched[q_id] = hits

    predicted_qrels_scores = {}
    for q_id, hits in predicted_qrels_batched.items():
        for doc_id, score in hits.items():
            predicted_qrels_scores.setdefault(q_id, {})[doc_id] = score

    return predicted_qrels_scores


def score(
    gold_qrels: dict[str, dict[str, int]],  # q_id, doc_id, relevant/not
    predicted_qrels: dict[str, dict[str, float]],
    append_string: str = "",
) -> dict[str, float]:
    scores = ir_measures.calc_aggregate(
        [
            nDCG @ 10,
            MRR @ 10,
            Success @ 3,
            Success @ 10,
            Success @ 20,
            Success @ 100,
            Success @ 1000,
        ],
        gold_qrels,
        predicted_qrels,
    )
    for measure, score in scores.items():
        measure = str(measure).replace("@", "-")
        logger.info(f"{measure}: {score}")
    scores = {append_string + str(k).replace("@", "-"): v for k, v in scores.items()}
    return scores


def get_sim_scores(
    autoencoder: Autoencoder,
    processed_queries: dict[str, str],
    predicted_qrels: dict[str, dict[str, float]],
    collection_path: str,
    reduce: str = "mean",
    min_score: float = -1.0,
    which_vectors: str = "fsq",
    use_idf: bool = False,
    same_position: bool = False,
    search_type: str = "unique_encoding",
):
    encodings_for_all_indices = {}
    predicted_docs = set()
    for q_id, doc_ids in predicted_qrels.items():
        predicted_docs.update(doc_ids.keys())

    import jsonlines

    counts = Counter()
    count_docs = 0
    for json_line in jsonlines.open(collection_path):
        if search_type == "unique_encoding":
            encodings_for_all_indices[json_line["id"]] = [
                int(x.split("_")[1]) - autoencoder.tokenizer.vocab_size
                for x in json_line["contents"].split(" ")
            ]
        else:
            encodings_for_all_indices[json_line["id"]] = [
                int(x) - autoencoder.tokenizer.vocab_size
                for x in json_line["contents"].split(" ")
            ]
        counts.update(encodings_for_all_indices[json_line["id"]])
        count_docs += 1

    # pre-compute codebook scores
    quantizer = autoencoder.quantizer
    vectors = extract_latent_embeddings(
        which=which_vectors,
        autoencoder=autoencoder,
    )
    with torch.inference_mode():
        # vectors = vectors / vectors.norm(dim=-1, keepdim=True)

        if use_idf:
            counts_tensor = (
                torch.tensor(
                    [counts[i] for i in range(autoencoder.quantizer.codebook_size)],
                    device=vectors.device,
                    dtype=torch.float,
                )
                + 1e-6
            )
            idf_tensor = torch.log(count_docs / counts_tensor) + 1.0
            # scores = torch.einsum(
            #     "bh,Bh->bB",
            #     vectors,
            #     vectors,
            # )

            scores = 1.0 - torch.cdist(
                vectors[None, :, :],
                vectors[None, :, :],
            ).squeeze(0)
            scores = scores * idf_tensor.unsqueeze(0)

        else:
            scores = 1.0 - torch.cdist(
                vectors[None, :, :],
                vectors[None, :, :],
            ).squeeze(0)

        scores = scores.clamp_min(min_score)
        scores = scores.flatten()

        for q_id, doc_ids in tqdm(
            predicted_qrels.items(),
            desc="Computing sim scores",
            total=len(predicted_qrels),
        ):

            query = [
                int(q) - autoencoder.tokenizer.vocab_size
                for q in processed_queries[q_id].split(" ")
            ]
            # query = list(set(query))
            query_tensor = torch.tensor(query, device=scores.device, dtype=torch.long)

            documents = []
            for doc_id in doc_ids:
                document = encodings_for_all_indices[doc_id]
                # document = list(set(document))
                documents.append(document)
            max_document_size = max([len(doc) for doc in documents])
            documents_tensor = torch.tensor(
                [
                    document + [-1] * (max_document_size - len(document))
                    for document in documents
                ],
                dtype=torch.long,
                device=scores.device,
            )

            qbsz = 1
            qtsz = query_tensor.size(0)
            dbsz = documents_tensor.size(0)
            dtsz = documents_tensor.size(1)
            query_document_indices = quantizer.codebook_size * query_tensor.view(
                qbsz, qtsz, 1, 1
            ) + documents_tensor.view(qbsz, 1, dbsz, dtsz)
            query_document_indices_mask = (query_tensor.view(qbsz, qtsz, 1, 1) < 0) | (
                documents_tensor.view(qbsz, 1, dbsz, dtsz) < 0
            )
            query_document_indices = query_document_indices.masked_fill(
                query_document_indices_mask,
                0,
            )

            if same_position:
                assert qtsz == dtsz
                same_position_mask = 1.0 - torch.eye(
                    qtsz, device=scores.device, dtype=scores.dtype
                ).unsqueeze(0).unsqueeze(2)
                same_position_mask *= -1e6

            if reduce == "max":
                # fill non latent tokens with -inf
                query_document_scores = scores[query_document_indices]
                if same_position:
                    query_document_scores += same_position_mask
                query_document_scores = query_document_scores.masked_fill(
                    query_document_indices_mask, float("-inf")
                )
                # aggregate over document tokens
                query_document_scores = query_document_scores.max(dim=3).values
                # fill non latent query tokens with 0
                query_document_scores = query_document_scores.masked_fill(
                    (query_tensor < 0).unsqueeze(-1), 0.0
                )
                # aggregate over query tokens
                query_document_scores = query_document_scores.sum(1)
            elif reduce == "mean":
                query_document_scores = scores[query_document_indices]
                query_document_scores = query_document_scores.masked_fill(
                    query_document_indices < 0, 0.0
                )
                query_document_scores = query_document_scores.sum(
                    dim=3
                )  # [qbsz, qtsz, dbsz]
                query_document_scores = query_document_scores / (
                    documents_tensor >= 0
                ).long().sum(-1).view(qbsz, 1, dbsz)
                query_document_scores = query_document_scores.sum(dim=1)

            for new_score, doc_id in zip(query_document_scores[0].tolist(), doc_ids):
                # predicted_qrels[q_id][doc_id] = predicted_qrels[q_id][doc_id]
                # predicted_qrels[q_id][doc_id] = new_score + predicted_qrels[q_id][doc_id]
                predicted_qrels[q_id][doc_id] = new_score

    # now sort predicted qrels by sim score
    predicted_qrels = {
        # TODO what is this key here?
        q_id: dict(sorted(doc_ids.items(), key=lambda x: x[1], reverse=True))
        for q_id, doc_ids in predicted_qrels.items()
    }
    return predicted_qrels


def main():
    args = parse_args()
    logger.info(args)
    # load the dataset
    mlf_context_manager = mlflow.start_run(log_system_metrics=True)
    mlflow.log_params(
        {
            "index_name": args.index_path,
            "search_type": args.search_type,
            "search_speed": "fast",
            "length": args.length,
            "k": args.k,
        }
    )
    logger.info(f"Loading dataset {args.dataset}")
    logger.setLevel(logging.INFO)

    logger.info(f"Loading dataset {args.dataset}")
    dataset = ir_datasets.load(args.dataset)
    gold_qrels = defaultdict(dict)

    for scoredoc in dataset.scoreddocs_iter():
        if len(gold_qrels[scoredoc.query_id]) < 99:
            d = gold_qrels[scoredoc.query_id]
            d[scoredoc.doc_id] = 0
    # add 1 positive document
    for qrel in dataset.qrels_iter():
        if len(gold_qrels[qrel.query_id]) < 100:
            d = gold_qrels[qrel.query_id]
            d[qrel.doc_id] = 1

    considered_query_ids_and_doc_ids = {
        k: v for k, v in gold_qrels.items() if len(v) == 100
    }
    unprocessed_queries = {q.query_id: q.text for q in dataset.queries_iter()}
    # encode the queries

    if args.limit > -1:
        logger.info(f"Limiting the number of queries to {args.limit}")
        unprocessed_queries = {
            k: v for k, v in list(unprocessed_queries.items())[: args.limit]
        }
        gold_qrels = {k: gold_qrels[k] for k in unprocessed_queries}


    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device {device}")

    logger.info(f"Loading the autoencoder from {args.autoencoder_path}")
    # init the autoencoder
    autoencoder = Autoencoder.load_checkpoint(
        path=args.autoencoder_path,
    ).to(device)
    if args.quantisation is not None:
        autoencoder.levels = [args.quantisation] * len(autoencoder.levels)
    for param in autoencoder.parameters():
        param.requires_grad = False
    autoencoder.model.bfloat16()
    autoencoder.quantizer.float()
    autoencoder.add_special_tokens()
    autoencoder.eval()

    logger.info(f"Loading the index from {args.index_path}")
    # init the searcher class
    searcher = LuceneSearcher(args.index_path)

    print("Scoring on", len(considered_query_ids_and_doc_ids), "queries")

    logger.info("Encoding the queries")
    processed_queries = encode(
        autoencoder,
        unprocessed_queries,
        32,
        args.query_formatting,
        args.search_type,
        length=args.length,
    )
    logger.info("Retrieving the documents")
    if args.k:
        k = args.k
    else:
        k = 1000

    # check if retrieved documents file exists
    os.makedirs(args.output_path, exist_ok=True)
    output_path = os.path.abspath(args.output_path)
    predicted_qrels_path = os.path.join(output_path, "predicted_qrels.json")
    # if os.path.exists(predicted_qrels_path):
    #     with open(predicted_qrels_path, "r") as f:
    #         predicted_qrels_no_rerank = json.load(f)
    # else:
    predicted_qrels_no_rerank = retrieve(searcher, processed_queries, k)
    with open(predicted_qrels_path, "w") as f:
        logger.info(f"Saving predicted qrels to {predicted_qrels_path}")
        json.dump(predicted_qrels_no_rerank, f)

    logger.info("Scoring the results no rerank")
    no_rerank_scores = score(gold_qrels, predicted_qrels_no_rerank, "no_rerank_")

    if args.search_type == "unique_encoding":
        for k, v in processed_queries.items():
            processed_queries[k] = []
            for x in v.split(" "):
                processed_queries[k].append(x.split("_")[1])
            processed_queries[k] = " ".join(processed_queries[k])

    if args.rerank:
        predicted_qrels_rerank = get_sim_scores(
            autoencoder,
            processed_queries,
            predicted_qrels_no_rerank,
            collection_path=args.collection_path,
            reduce="max",
            min_score=-1.00,
            which_vectors="fsq",
            use_idf=False,
            same_position=True,
            search_type=args.search_type,
        )

        # retrieve the documents
        logger.info("Scoring the results rerank")
        scores = score(gold_qrels, predicted_qrels_rerank)

        with mlf_context_manager:
            mlflow.log_metrics(scores | no_rerank_scores, step=0)
    else:
        with mlf_context_manager:
            mlflow.log_metrics(no_rerank_scores, step=0)

    if args.output_path:
        os.makedirs(args.output_path, exist_ok=True)
        output_path = os.path.abspath(args.output_path)

        encoded_queries_path = os.path.join(output_path, "encoded_queries.json")
        qrels_path = os.path.join(output_path, "qrels.json")
        predicted_qrels_path = os.path.join(output_path, "predicted_qrels.json")
        rerank_predicted_qrels_path = os.path.join(
            output_path, "rerank_predicted_qrels.json"
        )
        scores_path = os.path.join(output_path, "scores.json")

        with open(encoded_queries_path, "w") as f:
            logger.info(f"Saving encoded queries to {encoded_queries_path}")
            json.dump(processed_queries, f)

        with open(qrels_path, "w") as f:
            logger.info(f"Saving qrels to {qrels_path}")
            json.dump(gold_qrels, f)

        with open(rerank_predicted_qrels_path, "w") as f:
            logger.info(
                f"Saving rerank predicted qrels to {rerank_predicted_qrels_path}"
            )
            json.dump(predicted_qrels_rerank, f)

        with open(scores_path, "w") as f:
            logger.info(f"Saving scores to {scores_path}")
            json.dump(scores, f)


if __name__ == "__main__":
    main()
