from typing import List
import numpy as np
from .base import Dataset, Request


class RandomToken(Dataset):

    def __init__(self, tokenizer, input_len, num_samples=20, **kwargs):
        self.data: List[Request] = [] #list of list of questions.
        self.num_samples = num_samples
        self.input_len = input_len
        self._preprocess(tokenizer)

    def _preprocess(self, tokenizer):
        np.random.seed(0)
        num_prompts = self.num_samples
        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
        for i in range(num_prompts):
            prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
                                    for j in range(int(self.input_len*1.5))])
            re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[:(self.input_len)]
            prompt = tokenizer.decode(re_encoded_sequence)
            self.data.append(Request(system_prompt=None, turns=[prompt]))
        self.data = self.data[:self.num_samples]