import os.path

from src.dataset_utils.dataset_class import DefaultReasoningDataset, LabelClassification, LabelClassificationSubset, \
    LabelNMClassification, AnswerSetGeneration


def get_LabelClassification(dtype, base_path='logicalDatasets/checked_dataset/', **kwargs):
    fnames = {
        'random_str_symbolic': 'random_str_symbolic_dataset.jsonl',
        'random_word_symbolic': 'random_word_symbolic_dataset.jsonl',
        'related_word_symbolic': 'related_word_symbolic_dataset.jsonl',
        'random_str_language': 'random_str_natural_language_dataset.jsonl',
        'random_word_language': 'random_word_natural_language_dataset.jsonl',
        'related_word_language': 'related_word_natural_language_dataset.jsonl',
    }
    fpath = os.path.join(base_path, fnames[dtype])

    return LabelClassification(fpath, dtype, **kwargs)

def get_LabelNMClassification(dtype, base_path='datasets/nm_dataset', **kwargs):
    fnames = {
        'random_str_symbolic': 'random_str_symbolic.jsonl',
        'random_word_symbolic': 'random_word_symbolic.jsonl',
        'related_word_symbolic': 'related_word_symbolic.jsonl',
        'random_str_language': 'random_str_language.jsonl',
        'random_word_language': 'random_word_language.jsonl',
        'related_word_language': 'related_word_language.jsonl',
    }
    fpath = os.path.join(base_path, fnames[dtype])

    return LabelNMClassification(fpath, dtype, **kwargs)

def get_LabelClassificationSubset(dtype, **kwargs):
    return LabelClassificationSubset(dtype, dtype, **kwargs)

def get_AnswerSetGeneration(dtype, **kwargs):
    return AnswerSetGeneration(dtype, dtype, **kwargs)

func_dict = {
    'LabelClassification': get_LabelClassification,
    'LabelNMClassification': get_LabelNMClassification,
    'LabelClassificationSubset': get_LabelClassificationSubset,
    'AnswerSetGeneration': get_AnswerSetGeneration,
}

def get_dataset(dname, dtype, **kwargs) -> DefaultReasoningDataset:
    assert dname in ['CommonsenseQA', 'ProtoQA', 'OpenBookQA',
                     'CosmosQA', 'GenericsExemplars', 'RationaleGen', 'ARC', 'SocialIqa', 'LabelClassification', 'LabelClassificationSubset', 'LabelNMClassification', 'AnswerSetGeneration'], dname + 'is not supported yet.'

    print('Loading {} dataset, {} set...'.format(dname, dtype))
    return func_dict[dname](dtype, **kwargs)
