from typing import Optional, Dict
from datasets import load_dataset
from dataclasses import dataclass

from utils.string_utils import truncate
from _datasets.utils import prepare_dataset


@dataclass
class DatasetConfig:
    name: str
    num_examples: Optional[int] = None
    split: Optional[str] = None

    def get_dataset(self, to_truncate=False, max_length=0):
        hf_datasets: Dict[str, Dict] = {
            "scientific_papers": {
                "path": "scientific_papers",
                "name": "pubmed",
                "trust_remote_code": True,
            },
            "wikipedia": {  # This dataset seems too long for our purposes
                "path": "wikipedia",
                "name": "20220301.en",
            },
            "paul_graham": {"path": "sgoel9/paul_graham_essays"},
            "amazon_polarity": {
                "path": "amazon_polarity",
            },
            "arxiv-clustering-p2p": {
                "path": "mteb/arxiv-clustering-p2p",
                "split": "test",
            },
            "arguana": {
                "path": "BeIR/arguana",
                "name": "corpus",
                "split": "corpus",
            },
            "sts22": {
                "path": "mteb/sts22-crosslingual-sts",
            },
            "reddit": {
                "path": "mteb/reddit-clustering-p2p",
                "split": "test",
            },
        }

        if self.name in hf_datasets:
            kwargs = hf_datasets[self.name].copy()
            split = self.split or kwargs.pop("split", None) or "train"

            ds = load_dataset(**kwargs, split=split).to_pandas()
            ds = prepare_dataset(self.name, ds)
        else:
            raise ValueError(f"Unrecognized Dataset {self.name}")

        if self.num_examples:
            ds = ds[: self.num_examples]
        if to_truncate:
            assert max_length != 0, f"Specify the max length {max_length}"
            ds["original"] = ds["original"].apply(lambda x: truncate(x, max_length))
        return ds
