question_guide = ""

question_head = '''Choose the best answer of the question. Output A, B, C, etc.
'''
question_format = '''Question: {}
Options: {}
The answer is: {}
'''


import torch
from datasets import load_dataset

eva_task = "piqa"
subsets = ["choice"]

train_name = "train"
vali_name = "test"
test_name = "validation"

# =============================================================
def get_dataloader(name: str, subset: str, split: str):
    dataset = load_dataset("/data/share3/RTD/dscc/piqa/piqa.py", split=split)
    dataset.set_format(type="torch", columns=["goal", "sol1", "sol2", "label"])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
    return dataloader

# =============================================================
answer_convert = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5,
                  '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5,
                   1 : 0,  2 : 1,  3 : 2,  4 : 3,  5 : 4,  6 : 5}
def get_data_answer(data) -> int:
    return data["label"][0].item()

def get_data_lines(data) -> tuple[str, list[str]]:
    question_prompt = data["goal"][0]
    choices_prompts = [data["sol1"][0], data["sol2"][0]]
    return question_prompt, choices_prompts

# =============================================================
def view_dataset():
    while True:
        dataset_name = input("Dataset name : ")
        dataset_split = input("Dataset split: ")

        dataloader = get_dataloader(eva_task, dataset_name, dataset_split)

        for data in dataloader:
            print(get_data_lines(data), "\n===")
            print(get_data_answer(data))
            input()

if __name__ == "__main__":

    view_dataset()