import argparse
from collections import defaultdict, Counter
import itertools
import json
import logging
import os
import sys

os.environ["MLFLOW_EXPERIMENT_NAME"] = "seal_codebooks_eval"
os.environ["MLFLOW_TRACKING_URI"] = "http://10.128.0.103:8080"

sys.path.append("../discrete_autoencoder/")
from more_itertools import batched
import ir_datasets
import ir_measures
from ir_measures import nDCG, MRR, Success, Recall
from pyserini.search.lucene import LuceneSearcher
import torch
from tqdm import tqdm
import pickle
import mlflow
import jsonlines

from modeling_autoencoder_v2 import Autoencoder
from train_utils import extract_latent_embeddings

logger = logging.getLogger()
logger.setLevel(logging.ERROR)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--autoencoder_path",
        # required=True,
    )
    parser.add_argument(
        "--dataset",
        # default="msmarco-passage/dev/small",
        # required=True,
    )
    parser.add_argument(
        "--index_path",
    )
    parser.add_argument(
        "--output_path",
        default="",
    )
    parser.add_argument(
        "--token_only_output_path",
        default="",
    )
    # parser.add_argument(
    #     "--scores_path",
    #     required=False,
    # )
    parser.add_argument(
        "--collection_path",
    )
    parser.add_argument(
        "--limit",
        default=-1,
        type=int,
    )
    parser.add_argument(
        "--query_formatting",
        default="{query}",
    )
    parser.add_argument(
        "--search_type",
        choices=[
            "tokenize_only",
            "encoding",
            "encoding_and_words",
            "unique_encoding",
            "hungarian_encoding",
        ],
    )
    parser.add_argument(
        "--rerank",
        action="store_true",
    )
    parser.add_argument(
        "--max_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len_perc",
        default=None,
        type=float,
    )
    parser.add_argument(
        "--min_len",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--length",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--k",
        default=None,
        type=int,
    )
    parser.add_argument(
        "--quantisation",
        default=None,
        type=int,
    )
    return parser.parse_args()


def get_quantized_indices(
    autoencoder,
    batch_autoencoder,
    queries,
    length: int | None = None,
):
    # num_non_padded = batch_autoencoder["attention_mask"].sum(dim=1)
    output_q2q = autoencoder(
        batch_autoencoder,
    )
    quantizer_indices = output_q2q.quantizer_outputs[1].long().tolist()
    # for i, num_tokens in enumerate(num_non_padded):
    #     quantizer_indices[i] = quantizer_indices[i][:num_tokens]

    if length is not None:
        quantizer_indices = [
            quantizer_index[:length] for quantizer_index in quantizer_indices
        ]

    return quantizer_indices


def get_encoding_batch(
    autoencoder,
    batch_autoencoder,
    queries,
    unique=False,
    length: int | None = None,
):
    quantizer_indices = get_quantized_indices(
        autoencoder,
        batch_autoencoder,
        queries,
        length=length,
    )

    if unique:
        autoencoded_batch = [
            " ".join(
                [
                    str(i) + "_" + str(t + autoencoder.tokenizer.vocab_size)
                    for i, t in enumerate(quantizer_index)
                ]
            )
            for quantizer_index in quantizer_indices
        ]
    else:
        autoencoded_batch = [
            " ".join(
                [str(i + autoencoder.tokenizer.vocab_size) for i in quantizer_index]
            )
            for quantizer_index in quantizer_indices
        ]
    return autoencoded_batch


def rescore_tokens_and_words(
    predicted_qrels_tokens_only: dict[str, dict[str, float]],
    predicted_qrels_words_only: dict[str, dict[str, float]],
    coeff: float = 0.5,
    k: int = 1000,
):
    doc_id_scores_token = defaultdict(dict)
    doc_id_scores_words = defaultdict(dict)
    for q_id in predicted_qrels_tokens_only.keys():
        rank = 1
        for doc_id in predicted_qrels_tokens_only[q_id].keys():
            doc_id_scores_token[q_id][doc_id] = 1 / rank
            rank += 1
        rank = 1
        for doc_id in predicted_qrels_words_only[q_id].keys():
            doc_id_scores_words[q_id][doc_id] = 1 / rank
            rank += 1

    # we weight the token score so it is less important the longer the
    # token representation is
    new_scores = defaultdict(dict)
    for q_id in predicted_qrels_tokens_only.keys():
        for doc_id in predicted_qrels_tokens_only[q_id].keys():
            token_score = doc_id_scores_token[q_id][doc_id]
            if doc_id in predicted_qrels_words_only[q_id]:
                word_score = doc_id_scores_words[q_id][doc_id]
            else:
                word_score = 1 / k

            new_scores[q_id][doc_id] = token_score * coeff + word_score * (1 - coeff)

        for doc_id in predicted_qrels_words_only[q_id].keys():
            if doc_id not in new_scores[q_id]:
                word_score = doc_id_scores_words[q_id][doc_id]
                if doc_id in predicted_qrels_tokens_only[q_id]:
                    token_score = doc_id_scores_token[q_id][doc_id]
                else:
                    token_score = 1 / k
                new_scores[q_id][doc_id] = token_score * coeff + word_score * (
                    1 - coeff
                )

    return new_scores


@torch.inference_mode()
def encode(
    autoencoder: Autoencoder,
    unprocessed_queries: dict[str, str],  # query_id  # query
    batch_size: int,
    query_formatting: str = "{query}",
    search_type: str = "encoding",
    length: int | None = None,
) -> dict[str, str]:  # query_id

    encoded = {}

    with tqdm(desc="Encoding queries", total=len(unprocessed_queries)) as pbar:

        for batch in batched(unprocessed_queries.items(), batch_size):
            ids, queries = zip(*batch)
            queries = [query_formatting.format(query=query) for query in queries]
            batch_autoencoder = autoencoder.tokenize(queries, max_length=200)
            if search_type == "tokenize_only":
                tokenized_queries = [
                    " ".join([str(i) for i in tokenized_query])
                    for tokenized_query in batch_autoencoder["input_ids"].tolist()
                ]
                for id_, tokenized_query in zip(ids, tokenized_queries):
                    pbar.update(1)
                    encoded[id_] = tokenized_query
            elif search_type == "encoding":
                autoencoded_batch = get_encoding_batch(
                    autoencoder,
                    batch_autoencoder,
                    queries,
                    unique=False,
                    length=length,
                )

                for id_, autoencoded in zip(ids, autoencoded_batch):
                    encoded[id_] = autoencoded
                    pbar.update(1)
            elif search_type == "encoding_and_words":
                autoencoded_batch = get_encoding_batch(
                    autoencoder,
                    batch_autoencoder,
                    queries,
                    unique=True,
                    length=length,
                )
                for id_, autoencoded, query in zip(ids, autoencoded_batch, queries):
                    encoded[id_] = autoencoded + " " + query
                    pbar.update(1)
            elif search_type == "unique_encoding":
                autoencoded_batch = get_encoding_batch(
                    autoencoder,
                    batch_autoencoder,
                    queries,
                    unique=True,
                    length=length,
                )
                for id_, autoencoded in zip(ids, autoencoded_batch):
                    encoded[id_] = autoencoded
                    pbar.update(1)
            elif search_type == "hungarian_encoding":
                raise NotImplementedError()

    return encoded


def retrieve(
    bm25_index: LuceneSearcher,
    processed_queries: dict[str, str],  # query_id  # query
    k: int,  # num of hits to return
) -> dict[str, dict[str, float]]:  # query_id  # doc_id  # score
    queries = list(processed_queries.values())
    qids = list(processed_queries.keys())
    predicted_qrels_batched = {}
    output_file = "predicted_qrels.json"
    with open(output_file, "w") as f_out:
        for i in range(0, len(queries), 100):
            print("recovering", i)
            queries_batch = queries[i : i + 100]
            qids_batch = qids[i : i + 100]
            predicted_qrels_b = bm25_index.batch_search(
                queries=queries_batch,
                qids=qids_batch,
                k=k,
                threads=20,
            )
            for q_id, hits in predicted_qrels_b.items():
                qrel_scores = {hit.docid: hit.score for hit in hits}
                f_out.write(json.dumps({q_id: qrel_scores}) + "\n")

    # read from file
    with open(output_file, "r") as f:
        for line in f:
            qrel = json.loads(line)
            for q_id, hits in qrel.items():
                predicted_qrels_batched[q_id] = hits

    predicted_qrels_scores = {}
    for q_id, hits in predicted_qrels_batched.items():
        for doc_id, score in hits.items():
            predicted_qrels_scores.setdefault(q_id, {})[doc_id] = score

    return predicted_qrels_scores


def score(
    gold_qrels: dict[str, dict[str, int]],  # q_id, doc_id, relevant/not
    predicted_qrels: dict[str, dict[str, float]],
    append_string: str = "",
) -> dict[str, float]:
    scores = ir_measures.calc_aggregate(
        [
            nDCG @ 10,
            MRR @ 10,
            Recall @ 3,
            Recall @ 10,
            Recall @ 20,
            Recall @ 100,
            Recall @ 1000,
            Success @ 3,
            Success @ 10,
            Success @ 20,
            Success @ 100,
            Success @ 1000,
        ],
        gold_qrels,
        predicted_qrels,
    )
    for measure, score in scores.items():
        measure = str(measure).replace("@", "-")
        logger.info(f"{measure}: {score}")
    scores = {append_string + str(k).replace("@", "-"): v for k, v in scores.items()}
    return scores


def get_sim_scores(
    autoencoder: Autoencoder,
    processed_queries: dict[str, str],
    predicted_qrels: dict[str, dict[str, float]],
    collection_path: str,
    reduce: str = "mean",
    min_score: float = -1.0,
    which_vectors: str = "fsq",
    use_idf: bool = False,
    search_type: str = "encoding_and_words",
):
    encodings_for_all_indices = {}
    predicted_docs = set()
    for q_id, doc_ids in predicted_qrels.items():
        predicted_docs.update(doc_ids.keys())

    import jsonlines

    counts = Counter()
    count_docs = 0
    for json_line in jsonlines.open(collection_path):
        if search_type == "encoding_and_words":
            encodings_for_all_indices[json_line["id"]] = [
                int(x.split("_")[1]) - autoencoder.tokenizer.vocab_size
                for x in json_line["contents"].split(" ")
            ]
        else:
            encodings_for_all_indices[json_line["id"]] = [
                int(x) - autoencoder.tokenizer.vocab_size
                for x in json_line["contents"].split(" ")
            ]
        counts.update(encodings_for_all_indices[json_line["id"]])
        count_docs += 1

    # pre-compute codebook scores
    quantizer = autoencoder.quantizer
    vectors = extract_latent_embeddings(
        which=which_vectors,
        autoencoder=autoencoder,
    )
    with torch.inference_mode():
        vectors = vectors / vectors.norm(dim=-1, keepdim=True)

        if use_idf:
            counts_tensor = (
                torch.tensor(
                    [counts[i] for i in range(autoencoder.quantizer.codebook_size)],
                    device=vectors.device,
                    dtype=torch.float,
                )
                + 1e-6
            )
            idf_tensor = torch.log(count_docs / counts_tensor) + 1.0
            scores = torch.einsum(
                "bh,Bh->bB",
                vectors,
                vectors,
            )
            scores = scores * idf_tensor.unsqueeze(0)

        else:
            scores = torch.einsum("bh,Bh->bB", vectors, vectors)

        scores = scores.clamp_min(min_score)
        scores = scores.flatten()

        for q_id, doc_ids in tqdm(
            predicted_qrels.items(),
            desc="Computing sim scores",
            total=len(predicted_qrels),
        ):

            query = [
                int(q) - autoencoder.tokenizer.vocab_size
                for q in processed_queries[q_id].split(" ")
            ]
            query = list(set(query))
            query_tensor = torch.tensor(query, device=scores.device, dtype=torch.long)

            documents = []
            for doc_id in doc_ids:
                document = encodings_for_all_indices[doc_id]
                document = list(set(document))
                documents.append(document)
            max_document_size = max([len(doc) for doc in documents])
            documents_tensor = torch.tensor(
                [
                    document + [-1] * (max_document_size - len(document))
                    for document in documents
                ],
                dtype=torch.long,
                device=scores.device,
            )

            qbsz = 1
            qtsz = query_tensor.size(0)
            dbsz = documents_tensor.size(0)
            dtsz = documents_tensor.size(1)
            query_document_indices = quantizer.codebook_size * query_tensor.view(
                qbsz, qtsz, 1, 1
            ) + documents_tensor.view(qbsz, 1, dbsz, dtsz)
            query_document_indices_mask = (query_tensor.view(qbsz, qtsz, 1, 1) < 0) | (
                documents_tensor.view(qbsz, 1, dbsz, dtsz) < 0
            )
            query_document_indices = query_document_indices.masked_fill(
                query_document_indices_mask,
                0,
            )

            if reduce == "max":
                # fill non latent tokens with -inf
                query_document_scores = scores[query_document_indices]
                query_document_scores = query_document_scores.masked_fill(
                    query_document_indices_mask, float("-inf")
                )
                # aggregate over document tokens
                query_document_scores = query_document_scores.max(dim=3).values
                # fill non latent query tokens with 0
                query_document_scores = query_document_scores.masked_fill(
                    (query_tensor < 0).unsqueeze(-1), 0.0
                )
                # aggregate over query tokens
                query_document_scores = query_document_scores.sum(1)
            elif reduce == "mean":
                query_document_scores = scores[query_document_indices]
                query_document_scores = query_document_scores.masked_fill(
                    query_document_indices < 0, 0.0
                )
                query_document_scores = query_document_scores.sum(
                    dim=3
                )  # [qbsz, qtsz, dbsz]
                query_document_scores = query_document_scores / (
                    documents_tensor >= 0
                ).long().sum(-1).view(qbsz, 1, dbsz)
                query_document_scores = query_document_scores.sum(dim=1)

            for new_score, doc_id in zip(query_document_scores[0].tolist(), doc_ids):
                # predicted_qrels[q_id][doc_id] = predicted_qrels[q_id][doc_id]
                # predicted_qrels[q_id][doc_id] = new_score + predicted_qrels[q_id][doc_id]
                predicted_qrels[q_id][doc_id] = new_score

    # now sort predicted qrels by sim score
    predicted_qrels = {
        # TODO what is this key here?
        q_id: dict(sorted(doc_ids.items(), key=lambda x: x[1], reverse=True))
        for q_id, doc_ids in predicted_qrels.items()
    }
    return predicted_qrels


def main():
    args = parse_args()
    logger.info(args)
    # load the dataset
    mlf_context_manager = mlflow.start_run(log_system_metrics=True)
    COEFF = 0.5
    mlflow.log_params(
        {
            "index_name": args.index_path,
            "search_type": args.search_type,
            "search_speed": "slow",
            "length": args.length,
            "k": args.k,
            "coeff": COEFF,
        }
    )
    logger.info(f"Loading dataset {args.dataset}")
    logger.setLevel(logging.INFO)

    logger.info(f"Loading dataset {args.dataset}")
    dataset = ir_datasets.load(args.dataset)
    gold_qrels = {k: dict(v) for k, v in dataset.qrels_dict().items()}
    unprocessed_queries = {q.query_id: q.text for q in dataset.queries_iter()}
    # encode the queries

    if args.limit > -1:
        logger.info(f"Limiting the number of queries to {args.limit}")
        unprocessed_queries = {
            k: v for k, v in list(unprocessed_queries.items())[: args.limit]
        }
        gold_qrels = {k: gold_qrels[k] for k in unprocessed_queries}

    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device {device}")

    logger.info(f"Loading the autoencoder from {args.autoencoder_path}")
    # init the autoencoder
    autoencoder = Autoencoder.load_checkpoint(
        path=args.autoencoder_path,
    ).to(device)
    if args.quantisation is not None:
        autoencoder.levels = [args.quantisation] * len(autoencoder.levels)
    for param in autoencoder.parameters():
        param.requires_grad = False
    autoencoder.model.bfloat16()
    autoencoder.add_special_tokens()
    autoencoder.quantizer.float()
    autoencoder.eval()

    logger.info(f"Loading the index from {args.index_path}")
    # init the searcher class
    searcher = LuceneSearcher(args.index_path)

    logger.info("Encoding the queries")
    processed_queries = encode(
        autoencoder,
        unprocessed_queries,
        32,
        args.query_formatting,
        args.search_type,
        length=args.length,
    )
    logger.info("Retrieving the documents words+tokens")
    if args.k:
        k = int(args.k)
    else:
        k = 1000

    os.makedirs(args.output_path, exist_ok=True)
    output_path = os.path.abspath(args.output_path)
    token_only_output_path = os.path.abspath(args.token_only_output_path)
    predicted_qrels_tokens_only_path = os.path.join(
        token_only_output_path, "predicted_qrels.json"
    )
    rerank_predicted_qrels_token_only_path = os.path.join(
        token_only_output_path, "rerank_predicted_qrels.json"
    )
    predicted_qrels_path = os.path.join(output_path, "predicted_qrels.json")
    encoded_queries_path = os.path.join(output_path, "encoded_queries.json")
    qrels_path = os.path.join(output_path, "qrels.json")
    rerank_predicted_qrels_path = os.path.join(
        output_path, "rerank_predicted_qrels.json"
    )
    scores_path = os.path.join(output_path, "scores.json")
    if os.path.exists(predicted_qrels_path):
        with open(predicted_qrels_path, "r") as f:
            predicted_qrels_no_rerank = json.load(f)
    else:
        predicted_qrels_no_rerank = retrieve(searcher, processed_queries, k)
        with open(predicted_qrels_path, "w") as f:
            logger.info(f"Saving predicted qrels to {predicted_qrels_path}")
            json.dump(predicted_qrels_no_rerank, f)

    logger.info("Scoring the results words+tokens")
    no_rerank_scores = score(gold_qrels, predicted_qrels_no_rerank, "no_rerank_")

    logger.info("Retrieving the documents words and tokens separately")

    processed_queries_tokens_only = {
        query_id: " ".join(query_tokens.split(" ")[: autoencoder.length])
        for query_id, query_tokens in processed_queries.items()
    }
    processed_queries_words_only = {
        query_id: " ".join(query_tokens.split(" ")[autoencoder.length :])
        for query_id, query_tokens in processed_queries.items()
    }
    if os.path.exists(predicted_qrels_tokens_only_path):
        with open(predicted_qrels_tokens_only_path, "r") as f:
            predicted_qrels_tokens_only = json.load(f)
    else:
        predicted_qrels_tokens_only = retrieve(
            searcher, processed_queries_tokens_only, k
        )
        with open(predicted_qrels_tokens_only_path, "w") as f:
            logger.info(
                f"Saving predicted qrels token only to {predicted_qrels_tokens_only_path}"
            )
            json.dump(predicted_qrels_tokens_only, f)
    token_only_scores = score(gold_qrels, predicted_qrels_tokens_only, "token_only_")
    predicted_qrels_words_only = retrieve(searcher, processed_queries_words_only, k)
    logger.info("Scoring the results words only")
    words_only_scores = score(gold_qrels, predicted_qrels_words_only, "words_only_")

    logger.info("Rerank token only results")
    if args.rerank:
        for key, v in processed_queries_tokens_only.items():
            processed_queries_tokens_only[key] = []
            for x in v.split(" "):
                processed_queries_tokens_only[key].append(x.split("_")[1])
            processed_queries_tokens_only[key] = " ".join(
                processed_queries_tokens_only[key]
            )

        if os.path.exists(rerank_predicted_qrels_token_only_path):
            with open(rerank_predicted_qrels_token_only_path, "r") as f:
                predicted_qrels_rerank = json.load(f)
        else:
            predicted_qrels_rerank = get_sim_scores(
                autoencoder,
                processed_queries_tokens_only,
                predicted_qrels_tokens_only,
                collection_path=args.collection_path,
                reduce="max",
                min_score=-1.00,
                which_vectors="fsq",
                use_idf=False,
                search_type=args.search_type,
            )
            with open(rerank_predicted_qrels_token_only_path, "w") as f:
                logger.info(
                    f"Saving rerank predicted qrels token only to {rerank_predicted_qrels_token_only_path}"
                )
                json.dump(predicted_qrels_rerank, f)

        # retrieve the documents
        logger.info("Scoring the results rerank")
        scores = score(gold_qrels, predicted_qrels_rerank)

    for coeff in range(4, 9):
        COEFF = coeff / 10
        predicted_qrels_tokens_and_words_rescore = rescore_tokens_and_words(
            predicted_qrels_rerank,
            predicted_qrels_words_only,
            coeff=COEFF,
            k=k,
        )
        logger.info(f"Scoring the results rescore token weight: {COEFF}\n")
        rescored_scores = score(
            gold_qrels, predicted_qrels_tokens_and_words_rescore, f"rescored_{COEFF}_"
        )

    if args.rerank:
        with mlf_context_manager:
            mlflow.log_metrics(
                token_only_scores
                | scores
                | no_rerank_scores
                | rescored_scores
                | words_only_scores,
                step=0,
            )
    else:
        with mlf_context_manager:
            mlflow.log_metrics(
                token_only_scores
                | no_rerank_scores
                | rescored_scores
                | words_only_scores,
                step=0,
            )

    if args.output_path:
        os.makedirs(args.output_path, exist_ok=True)
        output_path = os.path.abspath(args.output_path)

        with open(encoded_queries_path, "w") as f:
            logger.info(f"Saving encoded queries to {encoded_queries_path}")
            json.dump(processed_queries, f)

        with open(qrels_path, "w") as f:
            logger.info(f"Saving qrels to {qrels_path}")
            json.dump(gold_qrels, f)

        with open(predicted_qrels_path, "w") as f:
            logger.info(f"Saving predicted qrels to {predicted_qrels_path}")
            json.dump(predicted_qrels_no_rerank, f)

        with open(rerank_predicted_qrels_path, "w") as f:
            logger.info(
                f"Saving rerank predicted qrels to {rerank_predicted_qrels_path}"
            )
            json.dump(predicted_qrels_rerank, f)

        if args.rerank:
            with open(scores_path, "w") as f:
                logger.info(f"Saving scores to {scores_path}")
                json.dump(scores, f)


if __name__ == "__main__":
    main()
