from transformers import set_seed,AutoTokenizer, AutoModelForCausalLM

class InferenceDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer,coh_prefix):
        super(SupervisedDataset, self).__init__()
        self.tokenizer = tokenizer

        self.coh_prefix = coh_prefix
        list_data_dict=[]
        with open(data_path, 'r', encoding = 'utf-8') as f:
            for line in f.readlines():
                dic = json.loads(line)
                coh_prefix_tokens = self.tokenizer.encode(coh_prefix)
                chosen_prefix_tokens = self.tokenizer.encode(dic['prefix'])
                dic['input_ids'] = coh_prefix_tokens + chosen_prefix_tokens
                dic['attention_masks'] = [1] * len(dic['input_ids'])
                list_data_dict.append(dic)


        self.input_ids = data_dict["input_ids"]
        self.generation_ids = data_dict['generation_ids']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i],generation_ids=self.generation_ids[i])