from .base import SamplingAlgorithmBase


class Sampling(SamplingAlgorithmBase):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.sample_num = self.cfg.get('sample_num', 32)

    def sample(self, data, **kwargs):
        request_id_list = [
            self.scheduler.generator_submit_request(
                # add few-shot prompts
                self._preprocess_data(data),
            ) for _ in range(self.sample_num)
        ]
        for _idx, request_id in enumerate(request_id_list):
            outputs = self.scheduler.generator_get_request_outputs(request_id)
            # print(outputs)
            yield outputs, self._post_sample_info(cost=_idx+1)
            if _idx + 1 < self.sample_num and (_idx + 1) % self.log_interval == 0:
                self.process_print('Progress: {} / {}'.format(
                    _idx + 1, self.sample_num
                ))
