from utils_proof.file import *

class GraphDataset:
    def __init__(self, dataset_name, llm_name=None, draw=False) -> None:
        self.draw = draw
        self.dataset_name = dataset_name
        self.dataset_path = f"processed_dataset/{dataset_name}"
        self.trains, self.tests = load_json(self.dataset_path + '/train.json'), load_json(self.dataset_path + '/test.json')

    def construct_prompt(self, train_ids, test_id):
        in_context_examples = [self.trains[i] for i in train_ids]
        test_item = self.tests[test_id]
        in_context_examples.append(test_item)
        
        templates = load_json(f'prompts/{self.dataset_name}.json')
        instruction = templates['instruction']
        q_template = templates['q_template']
        a_template = templates['a_template']
        prompt = ''
        for i, ice in enumerate(in_context_examples):
            ice['i'] = i + 1
            prompt += q_template.format(**ice) + '\n\n'
            if i < len(in_context_examples) - 1:
                prompt += a_template.format(**ice) + '\n\n'
            else:
                prompt += templates['a_test'].format(**ice)
        test_item['prompt'] = (instruction + prompt).strip()
        test_item['ctxs'] = train_ids
        return test_item

if __name__ == '__main__':
    pass