from __future__ import annotations

import json
import logging
import os
from collections import defaultdict
from time import time
from typing import Any

import tqdm
from datasets import Dataset, Features, Value, load_dataset

from mteb.encoder_interface import Encoder

from ..evaluation.evaluators import utils
from ..evaluation.evaluators.InstructionRetrievalEvaluator import (
    InstructionRetrievalEvaluator,
)
from .AbsTask import AbsTask, DescriptiveStatistics
from .AbsTaskRetrieval import HFDataLoader

import random
import numpy as np
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import LuceneIndexer


logger = logging.getLogger(__name__)


# Adapted from https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/datasets/data_loader_hf.py#L10
class HFDataLoaderInstructions(HFDataLoader):
    def __init__(
        self,
        hf_repo: str | None = None,
        hf_repo_qrels: str | None = None,
        data_folder: str | None = None,
        prefix: str | None = None,
        corpus_file: str = "corpus.jsonl",
        query_file: str = "queries.jsonl",
        qrels_folder: str = "qrels",
        qrels_file: str = "",
        streaming: bool = False,
        keep_in_memory: bool = False,
    ):
        self.corpus = {}
        self.queries = {}
        self.qrels = {}
        self.og_instructions = {}
        self.changed_instructions = {}
        self.top_ranked = {}
        self.hf_repo = hf_repo
        if hf_repo:
            # By default fetch qrels from same repo not a second repo with "-qrels" like in original
            self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo
        else:
            # data folder would contain these files:
            # (1) fiqa/corpus.jsonl  (format: jsonlines)
            # (2) fiqa/queries.jsonl (format: jsonlines)
            # (3) fiqa/qrels/test.tsv (format: tsv ("\t"))
            if prefix:
                query_file = prefix + "-" + query_file
                qrels_folder = prefix + "-" + qrels_folder

            self.corpus_file = (
                os.path.join(data_folder, corpus_file) if data_folder else corpus_file
            )
            self.query_file = (
                os.path.join(data_folder, query_file) if data_folder else query_file
            )
            self.qrels_folder = (
                os.path.join(data_folder, qrels_folder) if data_folder else None
            )
            self.qrels_file = qrels_file
        self.streaming = streaming
        self.keep_in_memory = keep_in_memory

    def load(
        self, split="test"
    ) -> tuple[
        Dataset,
        Dataset,
        dict[str, dict[str, int]],
        dict[str, dict[str, int]],
        Dataset,
    ]:
        if not self.hf_repo:
            self.og_qrels_file = os.path.join(self.qrels_folder + "_og", split + ".tsv")
            self.changed_qrels_file = os.path.join(
                self.qrels_folder + "_changed", split + ".tsv"
            )
            self.check(fIn=self.corpus_file, ext="jsonl")
            self.check(fIn=self.query_file, ext="jsonl")
            self.check(fIn=self.og_qrels_file, ext="tsv")
            self.check(fIn=self.changed_qrels_file, ext="tsv")

        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
            logger.info("Doc Example: %s", self.corpus[0])

        if not len(self.queries):
            logger.info("Loading Queries...")
            self._load_queries()

        self._load_qrels(split, changed=False)
        self._load_qrels(split, changed=True)
        # filter queries with no qrels
        og_qrels_dict = defaultdict(dict)
        changed_qrels_dict = defaultdict(dict)

        def qrels_dict_init(row):
            og_qrels_dict[row["query-id"]][row["corpus-id"]] = int(row["score"])

        def qrels_changed_dict_init(row):
            changed_qrels_dict[row["query-id"]][row["corpus-id"]] = int(row["score"])

        self.changed_qrels.map(qrels_dict_init)
        self.og_qrels.map(qrels_changed_dict_init)
        self.og_qrels = og_qrels_dict
        self.changed_qrels = changed_qrels_dict
        self.queries = self.queries.filter(lambda x: x["id"] in self.og_qrels)
        logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
        logger.info("Query Example: %s", self.queries[0])

        # load top_ranked
        self.load_top_ranked()

        return (
            self.corpus,
            self.queries,
            self.og_qrels,
            self.changed_qrels,
            self.top_ranked,
        )

    def load_top_ranked(self) -> None:
        if self.hf_repo:
            top_ranked_ds = load_dataset(
                self.hf_repo,
                "top_ranked",
                keep_in_memory=self.keep_in_memory,
                streaming=self.streaming,
            )
        else:
            top_ranked_ds = load_dataset(
                "json",
                data_files=self.top_ranked_file,
                streaming=self.streaming,
                keep_in_memory=self.keep_in_memory,
            )
        top_ranked_ds = next(iter(top_ranked_ds.values()))  # get first split
        top_ranked_ds = top_ranked_ds.cast_column("qid", Value("string"))
        top_ranked_ds = top_ranked_ds.cast_column("pid", Value("string"))
        top_ranked_ds = top_ranked_ds.remove_columns(
            [col for col in top_ranked_ds.column_names if col not in ["qid", "pid"]]
        )
        self.top_ranked = top_ranked_ds

    def _load_queries(self):
        if self.hf_repo:
            queries_ds = load_dataset(
                self.hf_repo,
                "queries",
                keep_in_memory=self.keep_in_memory,
                streaming=self.streaming,
            )
        else:
            queries_ds = load_dataset(
                "json",
                data_files=self.query_file,
                streaming=self.streaming,
                keep_in_memory=self.keep_in_memory,
            )
        queries_ds = next(iter(queries_ds.values()))  # get first split
        queries_ds = queries_ds.cast_column("_id", Value("string"))
        queries_ds = queries_ds.rename_column("_id", "id")
        queries_ds = queries_ds.remove_columns(
            [
                col
                for col in queries_ds.column_names
                if col
                not in [
                    "id",
                    "text",
                    "instruction_og",
                    "instruction_changed",
                    "keywords",
                    "short_query",
                ]
            ]
        )
        self.queries = queries_ds

    def _load_qrels(self, split, changed=False):
        if self.hf_repo:
            qrels_ds = load_dataset(
                self.hf_repo_qrels,
                "qrels_og" if not changed else "qrels_changed",
                keep_in_memory=self.keep_in_memory,
                streaming=self.streaming,
            )[split]
        else:
            qrels_file = self.og_qrels_file if not changed else self.changed_qrels_file
            qrels_ds = load_dataset(
                "csv",
                data_files=qrels_file,
                delimiter="\t",
                keep_in_memory=self.keep_in_memory,
            )
        features = Features(
            {
                "query-id": Value("string"),
                "corpus-id": Value("string"),
                "score": Value("float"),
            }
        )
        qrels_ds = qrels_ds.cast(features)

        if changed:
            self.changed_qrels = qrels_ds
        else:
            self.og_qrels = qrels_ds


class InstructionRetrievalDescriptiveStatistics(DescriptiveStatistics):
    """Descriptive statistics for Instruction Retrieval tasks

    Attributes:
        num_queries: Number of queries
        num_docs: Number of documents
        average_document_length: Average length of documents
        average_query_length: Average length of queries
        average_instruction_length: Average length of instructions
        average_changed_instruction_length: Average length of changed instructions
        average_relevant_docs_per_query: Average number of relevant docs per query
        average_top_ranked_per_query: Average number of top ranked docs per query
    """

    num_queries: int
    num_docs: int
    average_document_length: float
    average_query_length: float
    average_instruction_length: float
    average_changed_instruction_length: float
    average_relevant_docs_per_query: float
    average_top_ranked_per_query: float


class AbsTaskInstructionRetrieval(AbsTask):
    """Abstract class for retrieval tasks that use instructions. An example from Core17 would be
        query: What is the ongoing status of The Three Gorges Project?
        instruction: A relevant document will provide the projected or actual date of completion of the project, its estimated or actual total cost, or the estimated or ongoing electrical output of the finished project. Discussions of the social, political, or ecological impact of the project are not relevant.

    Child-classes must implement the following properties:
    self.corpus = Dict[corpus_id, Dict[str, str]] #id => dict with document datas like title and text
    self.queries = Dict[query_id, str] #id => query
    self.relevant_docs = Dict[query_id, Dict[corpus_id, int]]
    self.og_instructions = Dict[str, str] query => original instruction
    self.changed_instructions = Dict[str, str] query => changed instruction
    self.top_ranked = Dict[query_id, List[corpus_id]] #id => list of top ranked document ids

    See https://arxiv.org/abs/2403.15246 for more details
    """

    def __init__(
        self, index_data_path=None, custom_index_dir=None, random_ic_prompt=False, doc_only=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.do_length_ablation = kwargs.get("do_length_ablation", False)
        if self.do_length_ablation:
            logger.info("Running length ablation also...")
        self.index_data_path = index_data_path
        self.custom_index_dir = custom_index_dir
        self.random_ic_prompt = random_ic_prompt
        self.doc_only = doc_only

    def load_data(self, **kwargs):
        if self.data_loaded:
            return
        self.corpus, self.queries, self.og_relevant_docs, self.changed_relevant_docs = (
            {},
            {},
            {},
            {},
        )
        self.og_instructions, self.changed_instructions = {}, {}
        self.top_ranked = {}
        if self.do_length_ablation:
            self.keywords, self.short_instructions = {}, {}

        dataset_path = self.metadata_dict["dataset"]["path"]
        hf_repo_qrels = (
            dataset_path + "-qrels" if "clarin-knext" in dataset_path else None
        )
        for split in kwargs.get("eval_splits", self.metadata_dict["eval_splits"]):
            (
                corpus,
                queries,
                og_relevant_docs,
                changed_relevant_docs,
                top_ranked_init,
            ) = HFDataLoaderInstructions(
                hf_repo=dataset_path,
                hf_repo_qrels=hf_repo_qrels,
                streaming=False,
                keep_in_memory=False,
            ).load(split=split)

            # Conversion from DataSet
            top_ranked = defaultdict(list)
            [
                top_ranked[cur_inst["qid"]].append(cur_inst["pid"])
                for cur_inst in top_ranked_init
            ]
            og_instructions = {
                query["text"]: query["instruction_og"] for query in queries
            }
            changed_instructions = {
                query["text"]: query["instruction_changed"] for query in queries
            }
            if self.do_length_ablation:
                keywords = {query["text"]: query["keywords"] for query in queries}
                short_instructions = {
                    query["text"]: query["short_query"] for query in queries
                }
            
            new_queries = {query["id"]: query["text"] for query in queries}
            if self.n_ic_examples > 0:
                index_dir = os.path.join(os.environ.get("TRANSFORMERS_CACHE", "temp_index_eval"), "temp_index_eval", dataset_path.split("/")[-1])
                new_queries = self.encode_queries_with_ic_data(new_queries, index_dir, **kwargs)
            
            og_instructions = {
                new_queries[query["id"]]: query["instruction_og"] for query in queries
            }
            changed_instructions = {
                new_queries[query["id"]]: query["instruction_changed"] for query in queries
            }
            queries = new_queries

            corpus = {
                doc["id"]: {"title": doc["title"], "text": doc["text"]}
                for doc in corpus
            }
            assert (
                len(top_ranked) == len(queries)
            ), f"Top ranked not loaded properly! Expected {len(self.queries)} but got {len(self.top_ranked)}."

            (
                self.corpus[split],
                self.queries[split],
                self.og_relevant_docs[split],
                self.changed_relevant_docs[split],
            ) = corpus, queries, og_relevant_docs, changed_relevant_docs
            self.changed_instructions[split], self.og_instructions[split] = (
                changed_instructions,
                og_instructions,
            )
            self.top_ranked[split] = top_ranked

            if self.do_length_ablation:
                self.keywords[split], self.short_instructions[split] = (
                    keywords,
                    short_instructions,
                )

        self.data_loaded = True

    def _evaluate_subset_lang(
        self,
        retriever: InstructionRetrievalEvaluator,
        corpus: dict,
        queries: dict,
        og_relevant_docs: dict,
        changed_relevant_docs: dict,
        og_instructions: dict,
        changed_instructions: dict,
        top_ranked: dict,
        lang: str,
        split: str,
        keywords: dict | None = None,
        short_instructions: dict | None = None,
        **kwargs,
    ) -> dict[str, dict[str, float] | float]:
        corpus, queries = corpus[split], queries[split]
        og_relevant_docs, changed_relevant_docs = (
            og_relevant_docs[split],
            changed_relevant_docs[split],
        )
        og_instructions, changed_instructions = (
            og_instructions[split],
            changed_instructions[split],
        )

        top_ranked = top_ranked[split]
        scores_og, results_og = self._evaluate_subset(
            retriever,
            corpus,
            queries,
            og_relevant_docs,
            og_instructions,
            top_ranked,
            lang,
            **kwargs,
        )
        scores_changed, results_changed = self._evaluate_subset(
            retriever,
            corpus,
            queries,
            changed_relevant_docs,
            changed_instructions,
            top_ranked,
            lang,
            **kwargs,
        )

        newly_irrelevant_qrels = self.create_qrel_diff(
            og_relevant_docs,
            changed_relevant_docs,
        )
        overall_changed_scores = utils.evaluate_change(
            results_og, results_changed, newly_irrelevant_qrels
        )

        overall_changed_scores["individual"] = {
            "original": scores_og,
            "changed": scores_changed,
        }

        if self.do_length_ablation:
            keywords, short_instructions = (
                keywords[split],
                short_instructions[split],
            )
            scores_base, results_base = self._evaluate_subset(
                retriever,
                corpus,
                queries,
                og_relevant_docs,
                defaultdict(str),
                top_ranked,
                lang,
                **kwargs,
            )
            scores_w_keywords_scores, scores_w_keywords_results = self._evaluate_subset(
                retriever,
                corpus,
                queries,
                og_relevant_docs,
                keywords,
                top_ranked,
                lang,
                **kwargs,
            )
            (
                scores_w_short_instr_scores,
                scores_w_short_instr_result,
            ) = self._evaluate_subset(
                retriever,
                corpus,
                queries,
                og_relevant_docs,
                short_instructions,
                top_ranked,
                lang,
                **kwargs,
            )
            overall_changed_scores["length_ablation"] = {
                "keywords": scores_w_keywords_scores,
                "short_instructions": scores_w_short_instr_scores,
                "base": scores_base,
            }

        return overall_changed_scores

    def evaluate(
        self,
        model: Encoder,
        split: str = "test",
        *,
        encode_kwargs: dict[str, Any] = {},
        **kwargs,
    ) -> dict[str, dict[str, Any]]:
        retriever = InstructionRetrievalEvaluator(
            retriever=model,
            task_name=self.metadata.name,
            encode_kwargs=encode_kwargs,
            **kwargs,
        )

        scores = {}
        if self.is_multilingual:
            for lang in self.hf_subsets:
                logger.info(f"Language: {lang}")
                scores[lang] = self._evaluate_subset_lang(
                    retriever,
                    corpus=self.corpus[lang],
                    queries=self.queries[lang],
                    og_relevant_docs=self.og_relevant_docs[lang],
                    changed_relevant_docs=self.changed_relevant_docs[lang],
                    og_instructions=self.og_instructions[lang],
                    changed_instructions=self.changed_instructions[lang],
                    top_ranked=self.top_ranked[lang],
                    lang=lang,
                    split=split,
                    keywords=self.keywords[lang] if self.do_length_ablation else None,
                    short_instructions=self.short_instructions[lang]
                    if self.do_length_ablation
                    else None,
                    **kwargs,
                )
                self._add_main_score(scores[lang])
        else:
            lang = "default"
            scores[lang] = self._evaluate_subset_lang(
                retriever,
                corpus=self.corpus,
                queries=self.queries,
                og_relevant_docs=self.og_relevant_docs,
                changed_relevant_docs=self.changed_relevant_docs,
                og_instructions=self.og_instructions,
                changed_instructions=self.changed_instructions,
                top_ranked=self.top_ranked,
                lang=lang,
                split=split,
                keywords=self.keywords if self.do_length_ablation else None,
                short_instructions=self.short_instructions
                if self.do_length_ablation
                else None,
                **kwargs,
            )
            self._add_main_score(scores[lang])

        return scores

    def _add_main_score(self, scores: dict[str, dict[str, float]]) -> None:
        scores["main_score"] = scores[self.metadata.main_score]

    def _evaluate_subset(
        self,
        retriever: InstructionRetrievalEvaluator,
        corpus: dict[str, dict[str, str]],
        queries: dict[str, str],
        relevant_docs: dict[str, dict[str, int]],
        instructions: dict[str, str],
        top_ranked: dict[str, list[str]],
        lang=None,
        **kwargs,
    ) -> tuple[dict[str, float], dict[str, dict[str, float]]]:
        start_time = time()

        # do the results by query and relevant docs only
        all_results = []
        for query_id in tqdm.tqdm(list(queries.keys()), leave=False, desc="Retrieving"):
            cur_queries = {query_id: queries[query_id]}
            cur_instructions = {queries[query_id]: instructions[queries[query_id]]}
            cur_docs = {
                key: value
                for (key, value) in corpus.items()
                if key in top_ranked[query_id]
            }
            all_results.append(
                retriever(
                    cur_docs, cur_queries, instructions=cur_instructions, qid=query_id
                )
            )

        # combine all the results (which are {'qid' -> {'doc_id' -> score} mappings)
        # we know all are unique qids, so we can smash together
        results = {k: v for d in all_results for k, v in d.items()}

        end_time = time()
        logger.info(f"Time taken to retrieve: {end_time - start_time:.2f} seconds")

        if kwargs.get("save_predictions", False):
            output_folder = kwargs.get("output_folder", "results")
            if not os.path.isdir(output_folder):
                os.makedirs(output_folder)
            top_k = kwargs.get("top_k", None)
            if top_k is not None:
                for qid in list(results.keys()):
                    doc_ids = set(
                        sorted(
                            results[qid], key=lambda x: results[qid][x], reverse=True
                        )[:top_k]
                    )
                    results[qid] = {
                        k: v for k, v in results[qid].items() if k in doc_ids
                    }
            if lang is None:
                qrels_save_path = (
                    f"{output_folder}/{self.metadata_dict['name']}_predictions.json"
                )
            else:
                qrels_save_path = f"{output_folder}/{self.metadata_dict['name']}_{lang}_predictions.json"

            with open(qrels_save_path, "w") as f:
                json.dump(results, f)

        ndcg, _map, recall, precision, naucs = retriever.evaluate(
            relevant_docs,
            results,
            retriever.k_values,
            ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
        )
        mrr, naucs = retriever.evaluate_custom(
            relevant_docs, results, retriever.k_values, "mrr"
        )
        scores = {
            **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
            **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
            **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
            **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
            **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
            **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
        }
        return scores, results

    def create_qrel_diff(self, og_qrels, changed_qrels):
        newly_irrelevant_qrels = {}
        for qid in og_qrels:
            newly_irrelevant_qrels[qid] = []
            for doc_id in og_qrels[qid]:
                if changed_qrels[qid][doc_id] != og_qrels[qid][doc_id]:
                    newly_irrelevant_qrels[qid].append(doc_id)

        return newly_irrelevant_qrels

    def _calculate_metrics_from_split(
        self, split: str, hf_subset: str | None = None, compute_overall: bool = False
    ) -> InstructionRetrievalDescriptiveStatistics:
        if hf_subset:
            corpus = self.corpus[hf_subset][split]
            queries = self.queries[hf_subset][split]
            relevant_docs = self.og_relevant_docs[hf_subset][split]
            og_instructions = self.og_instructions[hf_subset][split]
            changed_instructions = self.changed_instructions[hf_subset][split]
            top_ranked = self.top_ranked[hf_subset][split]
        elif compute_overall:
            corpus = {}
            queries = {}
            relevant_docs = {}
            og_instructions = {}
            changed_instructions = {}
            top_ranked = {}
            for hf_subset in self.metadata.eval_langs:
                corpus.update(process_docs(self.corpus, hf_subset, split))
                queries.update(process_docs(self.queries, hf_subset, split))
                relevant_docs.update(
                    process_relevant_docs(self.og_relevant_docs, hf_subset, split)
                )
                og_instructions.update(
                    process_docs(
                        self.og_instructions,
                        hf_subset,
                        split,
                    )
                )
                changed_instructions.update(
                    process_docs(self.changed_instructions, hf_subset, split)
                )
                top_ranked.update(process_top_ranked(self.top_ranked, hf_subset, split))
        else:
            corpus = self.corpus[split]
            queries = self.queries[split]
            relevant_docs = self.og_relevant_docs[split]
            og_instructions = self.og_instructions[split]
            changed_instructions = self.changed_instructions[split]
            top_ranked = self.top_ranked[split]

        total_corpus_len = sum(
            [len(doc.get("title", "")) + len(doc["text"]) for doc in corpus.values()]
        )
        total_queries_len = sum([len(query) for query in queries.values()])
        total_instructions_len = sum(
            [len(instruction) for instruction in og_instructions.values()]
        )
        total_changed_instructions_len = sum(
            [len(instruction) for instruction in changed_instructions.values()]
        )
        num_qrels_non_zero = sum(
            sum(1 for doc_id in docs if docs[doc_id] != 0)
            for docs in relevant_docs.values()
        )
        qrels_per_doc = num_qrels_non_zero / len(relevant_docs) if len(queries) else 0
        top_ranked_per_query = (
            sum(len(docs) for docs in top_ranked.values()) / len(queries)
            if len(queries)
            else 0
        )
        return InstructionRetrievalDescriptiveStatistics(
            num_docs=len(corpus),
            num_queries=len(queries),
            average_document_length=(
                total_corpus_len / len(corpus) if len(corpus) else 0
            ),
            average_query_length=(
                total_queries_len / len(queries) if len(queries) else 0
            ),
            average_instruction_length=(
                total_instructions_len / len(queries) if len(queries) else 0
            ),
            average_changed_instruction_length=(
                total_changed_instructions_len / len(queries) if len(queries) else 0
            ),
            average_relevant_docs_per_query=qrels_per_doc,
            average_top_ranked_per_query=top_ranked_per_query,
        )

    def encode_queries_with_ic_data(self, queries, index_dir, **kwargs):
        """
        Encode queries with in-context data.

        Args:
        queries: list of queries [str]
        ic_data: list of in-context data [list of dicts]
        searcher: SimpleSearcher instance with pre-encoded index
        top_n: number of similar in-context examples to select

        Returns:
        encoded_queries: list of encoded queries [np.ndarray]
        """
        ic_corpus, ic_queries, relevant_docs = {}, {}, {}
        dataset_path = self.metadata_dict["dataset"]["path"]
        hf_repo_qrels = (
            dataset_path + "-qrels" if "clarin-knext" in dataset_path else None
        )
        
        dataset = load_dataset("jhu-clsp/FollowIR-train", split="train")

        ic_data = self.construct_ic_dataset(dataset)
        
        if self.random_ic_prompt:
            print("Using random/fixed in-context prompts...")
            encoded_queries = self.process_ic_data(queries, ic_data, searcher=None, top_n=self.n_ic_examples)
        else:
            searcher = self.encode_ic_data(ic_data, index_dir=index_dir)
            encoded_queries = self.process_ic_data(queries, ic_data, searcher=searcher, top_n=self.n_ic_examples)

        return encoded_queries
    
    def construct_ic_dataset(self, dataset):
        ic_data = {}
        for row in dataset:
            query = row['query']
            instruction = row['instruction']
            ic_example = {}
            ic_example["id"] = str(row["id"])
            ic_example["positive_ctxs"] = []
            max_rel_doc = {}
            max_rel_doc["title"] = ''
            max_rel_doc["text"] = row["document"]
            ic_example["positive_ctxs"].append(max_rel_doc)
            ic_example["negative_ctxs"] = []

            min_rel_row = random.choice(dataset)
            min_rel_doc = {}
            min_rel_doc["title"] = ''
            min_rel_doc["text"] = min_rel_row['document']
            ic_example["negative_ctxs"].append(min_rel_doc)
            ic_example["question"] = "Query: " + query + ' ' + instruction
            ic_data[str(row["id"])] = ic_example 
        
        return ic_data
    
    def encode_ic_data(self, ic_data, index_dir='temp_index'):
        """
        Encode the in-context data using pyserini and keep the index in memory.

        Args:
        ic_data: list of in-context data [list of dicts]

        Returns:
        searcher: SimpleSearcher instance with in-memory index
        ic_questions: List of questions from the in-context data
        """
        #ic_questions = [item["question"] for item in ic_data]
        
        documents = [{"id": f"doc{idx}", "contents": item["question"]} for idx, item in ic_data.items()]
        if os.path.exists(index_dir):
            searcher = SimpleSearcher(index_dir)
        else:
            os.makedirs(index_dir, exist_ok=True)
            max_wait_time = 60  # Maximum wait time in seconds

            start_time = time()

            # Loop to handle write lock with timeout
            while True:
                try:
                    # Attempt to create an indexer
                    indexer = LuceneIndexer(index_dir=index_dir, threads=1)
                    
                    # Add documents to the indexer
                    indexer.add_batch_dict(documents)
                    
                    # Close the indexer after adding documents
                    indexer.close()
                    
                    break  # Exit loop if successful
                except Exception as e:
                    # Check if the error is due to the write lock
                    if "Lock obtain timed out" in str(e):
                        elapsed_time = time() - start_time
                        if elapsed_time >= max_wait_time:
                            raise TimeoutError(f"Exceeded maximum wait time of {max_wait_time} seconds for lock release.")
                        print("Write lock detected, waiting for release...")
                        time_module.sleep(1)  # Wait for 1 second before retrying
                    else:
                        raise e  # Raise other exceptions

            # Proceed to search after the index is updated
            searcher = SimpleSearcher(index_dir)
        
        return searcher

    def process_ic_data(self, queries, ic_data, searcher=None, top_n=5, qrels=None):
        """
        Add in-context data to the queries based on similarity using BM25.

        Args:
        queries: dict of queries [dict]
        ic_data: dict of in-context data [dict of dicts]
        searcher: SimpleSearcher instance with pre-encoded index
        top_n: number of similar in-context examples to select

        Returns:
        new_queries: list of queries with context data added
        """
        new_queries = {}
        prefix = ""
        for key, query in queries.items():
            new_query = prefix
            if top_n > 0:
                if searcher is None:
                    keys = list(ic_data.keys())
                    random.shuffle(keys)
                    top_idxs = keys[:top_n]
                else:
                    hits = searcher.search(query, k=top_n+100)
                    hits = [hit for hit in hits if hit.docid.split('doc')[-1]]
                    top_idxs = [hit.docid.split('doc')[-1] for hit in hits]
                    if qrels is not None:
                        relevant_docs = qrels[key]
                        relevant_docs = [doc for doc in relevant_docs.keys()]
                        doc_ids = {idx: ic_data[idx]["id"] for idx in top_idxs}
                        top_idxs_filtered = [idx for idx in top_idxs if doc_ids[idx] not in relevant_docs]
                        print(f"Initial: {len(top_idxs)}")
                        print(f"Diff: {len(top_idxs) - len(top_idxs_filtered)}")
                        top_idxs = top_idxs_filtered
                
                top_idxs = top_idxs[:top_n]
                top_idxs = top_idxs[::-1]
                for k, idx in enumerate(top_idxs):
                    idx = str(idx)
                    question = ic_data[idx]["question"]
                    positive_ctx = ic_data[idx]["positive_ctxs"][0]["title"] + ". " + ic_data[idx]["positive_ctxs"][0]["text"]
                    if k == 0:
                        new_query += f"{question}\nPositive Document: {positive_ctx}\n"
                    else:
                        if self.doc_only:
                            new_query += f"Positive Document: {positive_ctx}\n"
                        else:
                            new_query += f"Query: {question}\nPositive Document: {positive_ctx}\n"
            new_query += "Query: " + query
            new_queries[key] = new_query
        return new_queries


def process_docs(
    collection: dict[str, dict[str, dict[str, str]]], hf_subset: str, split: str
) -> dict[str, str]:
    """Collections can contain overlapping ids in different splits. Prepend split to avoid this"""
    return {
        f"{split}_{hf_subset}_{k}": v for k, v in collection[hf_subset][split].items()
    }


def process_relevant_docs(
    collection: dict[str, dict[str, dict[str, dict[str, int]]]],
    hf_subset: str,
    split: str,
) -> dict[str, dict[str, int]]:
    """Collections can contain overlapping ids in different splits. Prepend split to avoid this"""
    return_collection = {}
    for query_id, relevant in collection[hf_subset][split].items():
        return_collection[f"{split}_{hf_subset}_{query_id}"] = {
            f"{split}_{hf_subset}_{doc_id}": value for doc_id, value in relevant.items()
        }
    return return_collection


def process_top_ranked(
    collection: dict[str, dict[str, dict[str, list[str]]]], hf_subset: str, split: str
) -> dict[str, list[str]]:
    return_collection = {}
    for query_id, docs_id in collection[hf_subset][split].items():
        return_collection[f"{split}_{hf_subset}_{query_id}"] = [
            f"{split}_{hf_subset}_{doc_id}" for doc_id in docs_id
        ]
    return return_collection
