import json
import datasets
import numpy as np

def load_data(dataset_name):
    if dataset_name == 'ambigqa':
        data_path = 'logs/dataset/ambigqa/ambigqa_dev_balance.json'
        with open(data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    elif dataset_name == 'ambig_inst':
        data_path = 'logs/dataset/ambiginst.json'
        with open(data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data

