from typing import Dict
from dataclasses import dataclass
from beir.datasets.data_loader import GenericDataLoader
from beir import util as beirUtil
import os
from loguru import logger
@dataclass
class DatasetInfo:
    corpus: Dict[str, str]
    queries: Dict[str, str]
    qrels: Dict[str, Dict[str, int]]
    corpus_ids2index: Dict[str, int]
class BEIRDataLoader:
    def __init__(self, dataset_path: str):
        self.dataset_path = dataset_path
    def load_dataset(self, dataset_name: str) -> DatasetInfo:
        logger.info(f"Loading dataset: {dataset_name} from path: {self.dataset_path}")
        if dataset_name in ["scifact", "nfcorpus", "nq", "cqadupstack", "arguana", "scidocs", "fiqa"]:
            url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
            if not os.path.exists(self.dataset_path):
                raise FileNotFoundError(f"Dataset path {self.dataset_path} does not exist.")
            data_path = beirUtil.download_and_unzip(url, self.dataset_path)
            if dataset_name == "cqadupstack":
                data_path += "/english"
            corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
            corpus_ids2index = {c_id: i for i, c_id in enumerate(corpus.keys())}
            dataset = DatasetInfo(
                corpus=corpus,
                queries=queries,
                qrels=qrels,
                corpus_ids2index=corpus_ids2index
            )
        else:
            raise ValueError(f"Dataset {dataset_name} not supported.")
        logger.info(f"Dataset loaded successfully: {len(dataset.corpus)} corpus items, {len(dataset.queries)} queries")
        return dataset
