from dataset.data_wrapper import ScanFamilyDatasetWrapper
from dataset.threed_sav import ThreeDSAVDataset
from pipeline.registry import registry


@registry.register_dataset("3dsav")
def get_3dsav_dataset(split='train', **args):
    return ThreeDSAVDataset(split=split, **args)


@registry.register_dataset("3dsav_task")
def get_3dsav_task_dataset(split='train', tokenizer=None, txt_seq_length=50, pc_seq_length=80, **args):
    tokenizer = registry.get_language_model(tokenizer)()
    dataset = ThreeDSAVDataset(split=split, max_obj_len=pc_seq_length, **args)
    return ScanFamilyDatasetWrapper(dataset=dataset, tokenizer=tokenizer, max_seq_length=txt_seq_length, max_obj_len=pc_seq_length)
    
if __name__ == "__main__":
    pass
    
