import datasets
import json

import os
def preprocess_arc_sample(sample):
    label_transfer_map = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
    if sample["answerKey"] in label_transfer_map:
        sample["answerKey"] = label_transfer_map[sample["answerKey"]]
    sample["answerKey"] = ["A", "B", "C", "D", "E"].index(sample["answerKey"])
    return sample

def load_jsonl(file_path):
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]
def get_dataset(dataset="arc-easy",load_from_local=True):
    dataset_train = load_jsonl("data/"+dataset+"/"+dataset+"_train.jsonl")
    dataset_test = load_jsonl("data/"+dataset+"/"+dataset+"_test.jsonl")
    if(dataset in ["amazon_polarity", "yelp_polarity", "glue-sst2", "customer_reviews", "imdb", "sst2-1000", "sst2-all"]):
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 2]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 2]
    # return dataset_train, dataset_test
    elif("mnli" in dataset):
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 3]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 3]
    elif("qnli" in dataset or "winogrande" in dataset):
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 2]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 2]
    return dataset_train, dataset_test
def get_dataset_question_answer(dataset="arc-easy",load_from_local=True):
    
    dataset_train = load_jsonl("data/"+dataset+"/"+dataset+"_train.jsonl")
    dataset_test = load_jsonl("data/"+dataset+"/"+dataset+"_test.jsonl")
    
    if(dataset in ["amazon_polarity", "yelp_polarity", "glue-sst2", "customer_reviews", "imdb", "sst2-1000", "sst2-all"]):
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 2]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 2]    
            
        train_question = [sample["question"] for sample in dataset_train]
        train_answer = [sample["answerKey"] for sample in dataset_train]
        test_question = [sample["question"] for sample in dataset_test]
        test_answer = [sample["answerKey"] for sample in dataset_test]
        
        return train_question, train_answer, test_question, test_answer
    elif("mnli" in dataset):
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 3]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 3]    
            
        train_question = [sample["question"] for sample in dataset_train]
        train_answer = [sample["answerKey"] for sample in dataset_train]
        test_question = [sample["question"] for sample in dataset_test]
        test_answer = [sample["answerKey"] for sample in dataset_test]
        
        return train_question, train_answer, test_question, test_answer
    elif("qnli" in dataset or "winogrande" in dataset):
        dataset_train = [example for example in dataset_train if len(example["choices"]["label"]) == 2]
        dataset_test = [example for example in dataset_test if len(example["choices"]["label"]) == 2]    
            
        train_question = [sample["question"] for sample in dataset_train]
        train_answer = [sample["answerKey"] for sample in dataset_train]
        test_question = [sample["question"] for sample in dataset_test]
        test_answer = [sample["answerKey"] for sample in dataset_test]
        
        return train_question, train_answer, test_question, test_answer