from datasets import load_dataset ,DatasetDict 
from .base import BaseDataModule 


def cleanup_cd (example ):
    """
    {'input': '15,44,79,50', 'output': '44-15=29,79-29=50'}
    """
    digits =[int (i )for i in example ["input"].split (",")]

    target =digits .pop ()
    example ["input"]=digits 
    example ["answer"]=target 
    example ["trajectory"]=example ["output"].split (",")
    return example 


class CountDownDataModule (BaseDataModule ):
    def __init__ (
    self ,
    data_path :str ="reasoning-datasets",
    data_name :str ="countdown",
    train_split :str ="cd4_train",
    test_split :str ="cd4_test",
    **kwargs ,
    ):
        super ().__init__ (data_path =data_path ,data_name =data_name ,**kwargs )
        self .train_split =train_split 
        self .test_split =test_split 

    def prepare_data (self ):

        load_dataset (self .data_path ,self .data_name )

    def preprocess (self ,dataset ):
        new_dataset =DatasetDict (
        train =dataset [self .train_split ],
        test =dataset [self .test_split ],
        )
        new_dataset =new_dataset .map (cleanup_cd )
        return new_dataset 
