import pyterrier as pt

if not pt.started():
    pt.init()
import ir_datasets as irds
import pandas as pd
from fire import Fire
import os
import logging
import glob


def load_bi_encoder(
    checkpoint: str = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco",
    batch_size: int = 64,
    **kwargs,
):
    from transformers import AutoModel, AutoTokenizer
    from pyterrier_dr import HgfBiEncoder, BiScorer

    model = AutoModel.from_pretrained(checkpoint).cuda().eval()
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    backbone = HgfBiEncoder(model, tokenizer, {}, device=model.device)
    return BiScorer(backbone, batch_size=batch_size, verbose=True)


def load_cross_encoder(
    checkpoint: str = "cross-encoder/ms-marco-TinyBERT-L-6",
    batch_size: int = 64,
    **kwargs,
):
    from rankers import CatTransformer

    return CatTransformer.from_pretrained(checkpoint, batch_size=batch_size, verbose=True)


def load_run(system_name, dataset_id):
    filename = f"data/beir.run.{system_name}.{dataset_id}.txt"
    return pt.io.read_results(filename)


def get_latest_checkpoint(model_name_or_path):
    checkpoints = [
        d for d in os.listdir(model_name_or_path)
        if d.startswith("checkpoint-") and os.path.isdir(os.path.join(model_name_or_path, d))
    ]

    if not checkpoints:
        return f"Model not found at specified path {model_name_or_path}!"

    # Sort by checkpoint number in descending order (newest first)
    checkpoints.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)

    # Find the latest valid checkpoint
    for checkpoint in checkpoints:
        checkpoint_path = os.path.join(model_name_or_path, checkpoint)
        if os.path.exists(os.path.join(checkpoint_path, "config.json")) and os.path.exists(os.path.join(checkpoint_path, "model.safetensors")):
            return checkpoint_path

    return f"Model not found at specified path {model_name_or_path}!"


def run_topics(
    file: str,
    model_name_or_path: str,
    out_path: str,
    batch_size: int = 256,
    cat: bool = False,
    dont_overwrite: bool = False,
    depth: int = 100,
):
    if not os.path.exists(f"{model_name_or_path}/config.json"):
        # find most recent checkpoint
        model_name_or_path = get_latest_checkpoint(model_name_or_path)
        if 'not found' in model_name_or_path:
            return f"Model not found at specified path {model_name_or_path}!"
        output_directory = os.path.dirname(out_path)
        basename = os.path.basename(out_path)
        # get checkpoint number
        checkpoint_number = model_name_or_path.split("-")[-1]
        basename = f"tmp-{checkpoint_number}-{basename}"
        out_path = os.path.join(output_directory, basename)

    print(f"Processing {file}")
    # format is data/beir.run.{system_name}.{dataset_id}.txt
    system_name = file.split(".")[2]
    dataset_id = file.split(".")[3]
    formatted_dataset = dataset_id.replace("/", "-")
    # List of all BEIR datasets that include a test split
    _beir_with_test = {
        "dbpedia-entity",
        "fever",
        "fiqa",
        "hotpotqa",
        "nfcorpus",
        "quora",
        "scifact"
    }

    if dataset_id in _beir_with_test:
        dataset_id = f"{dataset_id}/test"
    
    model_name = os.path.basename(model_name_or_path)
    out_file = os.path.join(out_path, f"beir_{formatted_dataset}_{model_name}_{system_name}.res.gz")
    print(f"Output file: {out_file}")
    if 'cqadupstack' in dataset_id:
        dataset_id = '/'.join(dataset_id.split('-'))

    try:
        dataset = irds.load(f"beir/{dataset_id}")
    except Exception as e:
        logging.info(f"Error loading dataset {dataset_id}: {e}")
        return

    if os.path.exists(out_file) and not dont_overwrite:
        logging.info(f"File already exists at {out_file}")
        return

    run = pt.io.read_results(file)
    run = run.sort_values(["qid", "score"], ascending=[True, False]).groupby("qid").head(depth)
    try:
        queries = (
            pd.DataFrame(dataset.queries_iter()).set_index("query_id").text.to_dict()
        )
    except Exception as e:
        logging.info(f"Error loading queries for dataset {dataset_id}: {e}")
        return
    run["query"] = run["qid"].map(lambda qid: queries[qid])
    docstore = dataset.docs_store()
    run['text'] = run['docno'].map(lambda docno: docstore.get(docno).text)

    try:
        model = (
            load_bi_encoder(model_name_or_path, batch_size=batch_size)
            if not cat
            else load_cross_encoder(model_name_or_path, batch_size=batch_size)
        )
    except Exception as e:
        logging.info(f"Error loading model {model_name_or_path}: {e}")
        return f"Error loading model: {e}"

    res = model.transform(run)

    pt.io.write_results(res, out_file)

    logging.info(f"Results written to {out_file}")
    return f"Results written to {out_file}"


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    Fire(run_topics)
