import json

from datasets import load_dataset


def get_dataset(dataset="sciq",load_from_local=True):
    if(load_from_local == False):
        if dataset == "sciq":
            ds_train = load_dataset("allenai/sciq", split="train")
            ds_test = load_dataset("allenai/sciq", split="test")
            def get_data_dict(ds):
                res = []
                for d in ds:
                    res.append({"support": d["support"], "question": d["question"]})
                return res
            return get_data_dict(ds_train), ds_train["correct_answer"], get_data_dict(ds_test), ds_test["correct_answer"]
        else:
            print(dataset)
            raise ValueError("Invalid dataset.")
    elif(load_from_local == True):
        if "squad" in dataset or "sciq" in dataset:
            train_data = []
            train_answers = []
            test_data = []
            test_answers = []
            def process_data(file_path, data, answers):
                with open(file_path, 'r', encoding='utf-8') as file:
                    for line in file:
                        json_obj = json.loads(line.strip())
                        data.append({
                            "support": json_obj["support"],
                            "question": json_obj["question"]
                        })
                        answer = json_obj["answer"]
                        if isinstance(answer, list):
                            answers.append(answer)  
                        else:
                            answers.append(answer)
                return data, answers
    
            train_data, train_answers = process_data(f"data/{dataset}/{dataset}_train.jsonl", train_data, train_answers)
            test_data, test_answers = process_data(f"data/{dataset}/{dataset}_test.jsonl", test_data, test_answers)
            return train_data, train_answers, test_data, test_answers
