import datasets
import json

import os
def preprocess_arc_sample(sample):
    label_transfer_map = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
    if sample["answerKey"] in label_transfer_map:
        sample["answerKey"] = label_transfer_map[sample["answerKey"]]
    sample["answerKey"] = ["A", "B", "C", "D", "E"].index(sample["answerKey"])
    return sample

def load_jsonl(file_path):
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]
def get_dataset(dataset="arc-easy"):
    if(dataset in ["arc-easy", "arc-challenge", "hellaswag", "commonsense_qa", "sociali_qa", "openbook_qa", "commonsense_qa-all"] or "winogrande" in dataset or "hellaswag" in dataset):
        dataset_train = load_jsonl("data/"+dataset+"/"+dataset+"_train.jsonl")
        dataset_test = load_jsonl("data/"+dataset+"/"+dataset+"_test.jsonl")
        # make sure only 4 choices
        if(dataset in ["arc-easy", "arc-challenge", "openbook_qa"] or "hellaswag" in dataset):
            # make sure only 4 choices
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 4]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 4]
        elif("commonsense_qa" in dataset):
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 5]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 5]
        elif(dataset == "sociali_qa"):
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 3]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 3]
        elif("winogrande" in dataset):
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 2]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 2]
        """
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 4]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 4]
        """
        return dataset_train, dataset_test
    else:
        raise ValueError("Invalid dataset.")


def get_dataset_question_answer(dataset="arc-easy"):
    print("Hello", dataset)
    if(dataset in ["arc-easy", "arc-challenge", "hellaswag", "openbook_qa", "commonsense_qa", "sociali_qa", "commonsense_qa-all"] or "winogrande" in dataset or "hellaswag" in dataset):
        dataset_train = load_jsonl("data/"+dataset+"/"+dataset+"_train.jsonl")
        dataset_test = load_jsonl("data/"+dataset+"/"+dataset+"_test.jsonl")
        if(dataset in ["arc-easy", "arc-challenge", "openbook_qa"] or "hellaswag" in dataset):
            # make sure only 4 choices
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 4]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 4]
        elif("commonsense_qa" in dataset):
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 5]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 5]
        elif(dataset == "sociali_qa"):
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 3]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 3]
        elif("winogrande" in dataset):
            dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 2]
            dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 2]
        train_question = [sample["question"] for sample in dataset_train]
        train_answer = [sample["answerKey"] for sample in dataset_train]
        test_question = [sample["question"] for sample in dataset_test]
        test_answer = [sample["answerKey"] for sample in dataset_test]
        
        return train_question, train_answer, test_question, test_answer
def get_datase_origin(dataset="arc-e"):
    if dataset == "arc-e":
        dataset_train = datasets.load_dataset("ai2_arc", "ARC-Easy", split="train")


        dataset_test = datasets.load_dataset("ai2_arc", "ARC-Easy", split="test")

        dataset_train = dataset_train.map(preprocess_arc_sample)
        dataset_test = dataset_test.map(preprocess_arc_sample)

        return dataset_train, dataset_test
    elif dataset == "arc-c":
        dataset_train = datasets.load_dataset("ai2_arc", "ARC-Challenge", split="train")
        # make sure only 4 choices
        dataset_train = dataset_train.filter(lambda example: len(example["choices"]["label"]) == 4)

        dataset_test = datasets.load_dataset("ai2_arc", "ARC-Challenge", split="test")
        # make sure only 4 choices
        dataset_test = dataset_test.filter(lambda example: len(example["choices"]["label"]) == 4)

        for i, d_train in enumerate(dataset_train):
            dataset_train[i] = preprocess_arc_sample(d_train)
        for i, d_test in enumerate(dataset_test):
            dataset_test[i] = preprocess_arc_sample(d_test)

        return dataset_train, dataset_test
    else:
        raise ValueError("Invalid dataset.")