import numpy as np

from evaluation.evaluate_squadshift import get_bm25_context
from datasets import Dataset, load_dataset


class ContextDistractorDataset:
    def __init__(self, contexts):
        self.contexts = contexts
        self.len = len(self.contexts)
        self.indices = list(range(self.len))
        self.reset()

    def __len__(self):
        return self.len

    def reset(self):
        self.ptr = 0
        np.random.shuffle(self.indices)

    def _sample(self, paragraphs=2):
        context = []
        while paragraphs > 0:
            context.append(self.contexts[self.indices[self.ptr]])
            self.ptr += 1
            if self.ptr == self.len:
                self.reset()
            paragraphs -= 1
        return context

    def sample(self):
        return self._sample()


def build_distractor_dataset(dataset: str=None, n_distractor_dataset_items: int=1000):
    dataset = load_dataset("squadshifts", dataset, trust_remote_code=True)["test"]
    prev_context = None
    contexts = []
    for i, item in enumerate(dataset):
        if i >= n_distractor_dataset_items:
            break
        context = get_bm25_context(item, dataset_family)
        if context == prev_context:
            continue
        prev_context = context
        for paragraph in context:
            contexts.append(paragraph)
    return ContextDistractorDataset(contexts)
