question_guide = '''You are an export in the field of {}, please follow the guide.
'''

question_head = '''Choose the best answer of the question. Output A, B, C, etc.
'''
question_format = '''Question: {}
Options: {}
The answer is: {}
'''

import torch

from datasets import load_dataset
from datasets.utils.info_utils import VerificationMode

eva_task = "obqa"
subsets =  ['main', 'additional']

train_name = "train"
vali_name = "validation"
test_name = "test"

# csv_file_path = "./mmlu/data/{}/"
# csv_file = "{}_{}.csv"
# =============================================================
def get_dataloader(name: str, subset: str, split: str):
    # datafile = {split: csv_file.format(subset, split)}
    # dataset = load_dataset(csv_file_path.format(split), data_files=datafile, split=split)
    data_files = {'train': f'/data1/shilh/codes/RTD/dscc/openbookqa/{subset}/train-00000-of-00001.parquet', 
                  'test': f'/data1/shilh/codes/RTD/dscc/openbookqa/{subset}/test-00000-of-00001.parquet',
                  'validation': f"/data1/shilh/codes/RTD/dscc/openbookqa/{subset}/validation-00000-of-00001.parquet"}
    dataset = load_dataset("parquet", data_files=data_files, split=split)
    dataset.set_format(type="torch", columns=["question_stem", "choices", "answerKey"])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, )
    return dataloader

# =============================================================
answer_convert = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5,
                  '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5,
                   1 : 0,  2 : 1,  3 : 2,  4 : 3,  5 : 4,  6 : 5}
def get_data_answer(data):
    return answer_convert[data["answerKey"][0]]

def get_data_lines(data) -> tuple[str, list[str]]:
    question_prompt = data["question_stem"][0]
    choices_prompts = []
    for choice in data["choices"]["text"]:
        choices_prompts.append(choice[0])
    return question_prompt, choices_prompts

# =============================================================
def view_dataset():
    while True:
        dataset_name = input("Dataset name : ")
        dataset_split = input("Dataset split: ")

        dataloader = get_dataloader("mmlu/", dataset_name, dataset_split)

        for data in dataloader:
            print(get_data_lines(data), "\n===")
            print(get_data_answer(data))
            input()

if __name__ == "__main__":

    view_dataset()