from src.task_utils.task_origented_dataset_class import SingleInferenceValidityAssessmentDataset, \
    MultipleChoiceInferenceValidityAssessmentDataset, ConsequentGenerationDataset, \
    NegativeDefaultJustificationGenerationDataset, NegativeDefaultJustificationSelectionDataset, \
    DefaultJustificationGenerationDataset, TaskOrigentedDataset

type2dataset = {
    'SingleInferenceValidityAssessment': SingleInferenceValidityAssessmentDataset,
    'MultipleChoiceInferenceValidityAssessment': MultipleChoiceInferenceValidityAssessmentDataset,
    'ConsequentGeneration': ConsequentGenerationDataset,
    'NegativeDefaultJustificationGeneration': NegativeDefaultJustificationGenerationDataset,
    'NegativeDefaultJustificationSelection': NegativeDefaultJustificationSelectionDataset,
    'DefaultJustificationGeneration': DefaultJustificationGenerationDataset
}

def generate_dataset_according_to_task(task, data_path) -> TaskOrigentedDataset:
    # TODO: generate dataset according to task
    assert task in ['SingleInferenceValidityAssessment',
                    'MultipleChoiceInferenceValidityAssessment',
                    'ConsequentGeneration',
                    'NegativeDefaultJustificationGeneration',
                    'NegativeDefaultJustificationSelection',
                    'DefaultJustificationGeneration'
                    ]
    return type2dataset[task](data_path)