import argparse
import json
import logging
import os

import ir_datasets
import mlflow
import torch
from more_itertools import batched
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from scripts.baseline_kmeans.discretizers.factory import (
    get_clustering_discretizer,
    clustering_model_choices,
)
from scripts.baseline_kmeans.utils import (
    discretize_text_batch,
)
import ir_measures
from ir_measures import nDCG, MRR, Success
from pyserini.search.lucene import LuceneSearcher

logger = logging.getLogger()
logger.setLevel(logging.ERROR)


def retrieve(
    bm25_index: LuceneSearcher,
    processed_queries: dict[str, str],  # query_id  # query
    k: int,  # num of hits to return
) -> dict[str, dict[str, float]]:  # query_id  # doc_id  # score
    predicted_qrels_b = bm25_index.batch_search(
        queries=list(processed_queries.values()),
        qids=list(processed_queries.keys()),
        k=k,
        threads=10,
    )
    predicted_qrels_batched = {}
    for q_id, hits in predicted_qrels_b.items():
        for j in range(len(hits)):
            predicted_qrels_batched.setdefault(q_id, {})[hits[j].docid] = hits[j].score

    return predicted_qrels_batched


def score(
    gold_qrels: dict[str, dict[str, int]],  # q_id, doc_id, relevant/not
    predicted_qrels: dict[str, dict[str, float]],
) -> dict[str, float]:
    scores = ir_measures.calc_aggregate(
        [nDCG @ 10, MRR @ 10, Success @ 3, Success @ 10, Success @ 20, Success @ 100],
        gold_qrels,
        predicted_qrels,
    )
    for measure, score in scores.items():
        measure = str(measure).replace("@", "-")
        logger.info(f"{measure}: {score}")
    scores = {str(k).replace("@", "-"): v for k, v in scores.items()}
    return scores


def main(args):
    mlflow.start_run(log_system_metrics=True)
    mlflow.log_params(
        {
            "index_name": args.index_path,
        }
    )
    logger.info(f"Loading dataset {args.dataset}")
    logger.setLevel(logging.INFO)

    # Load the dataset and qrels
    dataset = ir_datasets.load(args.dataset)
    gold_qrels = {k: dict(v) for k, v in dataset.qrels_dict().items()}
    unprocessed_queries = {q.query_id: q.text for q in dataset.queries_iter()}

    if args.limit > -1:
        logger.info(f"Limiting the number of queries to {args.limit}")
        unprocessed_queries = {
            k: v for k, v in list(unprocessed_queries.items())[: args.limit]
        }
        gold_qrels = {k: gold_qrels[k] for k in unprocessed_queries}

    # Load the embedding model and clustering model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device {device}")

    logger.info(f"Loading embedding model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModel.from_pretrained(args.model).to(device).eval()

    logger.info(f"Loading clustering model from: {args.clustering_model_path}")
    clustering_model = get_clustering_discretizer(model=args.clustering_model).load(
        args.clustering_model_path
    )

    logger.info(f"Loading BM25 index from: {args.index_path}")
    searcher = LuceneSearcher(args.index_path)

    logger.info("Encoding the queries")

    encoded_queries = {}

    with tqdm(desc="Encoding queries", total=len(unprocessed_queries)) as pbar:
        for batch in batched(unprocessed_queries.items(), args.batch_size):
            ids, queries = zip(*batch)
            queries = [args.query_formatting.format(query=query) for query in queries]

            # Discretize the current batch
            discretized_batch, tokenized_batch = discretize_text_batch(
                tokenizer=tokenizer,
                model=model,
                clustering_model=clustering_model,
                sentences_batch=queries,
                device=args.device,
                max_length=args.max_length,
            )

            for id_, codes, tokens in zip(ids, discretized_batch, tokenized_batch):
                if args.encoding == "code":
                    discretized_query = " ".join(map(str, codes))
                elif args.encoding == "code_plus_text":
                    discretized_query = (
                        " ".join(map(str, codes)) + " " + " ".join(tokens)
                    )
                elif args.encoding == "code_unique":
                    discretized_query = " ".join(map(str, set(codes)))

                encoded_queries[id_] = discretized_query
                pbar.update(1)

    logger.info("Retrieving the documents")
    predicted_qrels = {}
    with tqdm(desc="Retrieving documents", total=len(encoded_queries)) as pbar:
        for batch_keys in batched(encoded_queries.keys(), args.batch_size):
            encoded_queries_batch = {
                k: encoded_queries[k] for k in batch_keys if k in encoded_queries
            }
            predicted_qrels_batch = retrieve(searcher, encoded_queries_batch, 100)
            predicted_qrels.update(predicted_qrels_batch)
            pbar.update(len(encoded_queries_batch))

    logger.info("Scoring the results")
    scores = score(gold_qrels, predicted_qrels)

    mlflow.log_metrics(scores, step=0)

    if args.output_path:
        os.makedirs(args.output_path, exist_ok=True)
        output_path = os.path.abspath(args.output_path)

        encoded_queries_path = os.path.join(output_path, "encoded_queries.json")
        with open(encoded_queries_path, "w") as f:
            logger.info(f"Saving encoded queries to {encoded_queries_path}")
            json.dump(encoded_queries, f)

        qrels_path = os.path.join(output_path, "qrels.json")
        with open(qrels_path, "w") as f:
            logger.info(f"Saving qrels to {qrels_path}")
            json.dump(predicted_qrels, f)

        scores_path = os.path.join(output_path, "scores.json")
        with open(scores_path, "w") as f:
            logger.info(f"Saving scores to {scores_path}")
            json.dump(scores, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        # default="facebook/contriever",
        default="bert-base-cased",
        help="Pretrained embedding model name or path.",
    )
    parser.add_argument(
        "--clustering_model",
        default="faiss",
        choices=clustering_model_choices,
        type=str,
        # required=True,
        help="Pretrained clustering model name.",
    )
    parser.add_argument(
        "--clustering_model_path",
        type=str,
        # required=True,
        help="Path to the pretrained clustering model.",
    )
    parser.add_argument(
        "--dataset",
        # required=True,
        default="msmarco-passage/dev/small",
        help="MSMARCO dataset split to evaluate on, e.g., 'msmarco-passage/dev'.",
    )
    parser.add_argument(
        "--index_path",
        # required=True,
        help="Path to the BM25 index created from discretized documents.",
    )
    parser.add_argument(
        "--output_path",
        help="Directory to save the encoded queries, qrels, and scores.",
    )
    parser.add_argument(
        "--limit",
        default=120,
        type=int,
        help="Limit the number of queries to process.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=32, help="Number of sentences per batch."
    )
    parser.add_argument(
        "--query_formatting",
        default="{query}",
        help="Formatting applied to each query before tokenization.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to use for computation ('cuda' or 'cpu').",
    )
    parser.add_argument(
        "--max_length", type=int, default=200, help="Maximum length of input sequences"
    )
    # add argument with three choices for encoding
    parser.add_argument(
        "--encoding",
        default="code_unique",
        choices=["code", "code_plus_text", "code_unique"],
        help="Encoding method for the queries.",
    )

    args = parser.parse_args()
    # preety pring the arguments
    for arg in vars(args):
        print(f"{arg}: {getattr(args, arg)}")
    main(args)
