














"""
This script is used to decontaminate a dataset by checking for n-gram overlap with other datasets.
It uses the same approach presented in https://arxiv.org/abs/2501.19393,
as found in: https://github.com/simplescaling/s1/blob/main/data/decontaminate_util.py

Usage:

python scripts/decontaminate.py \
    --dataset open-r1/verifiable-coding-problems-python \
    --split train \
    --ngram_size 8 \
    --problem_column problem \
    --cleanup
"""

import collections 

from tqdm import tqdm 


def normalize_string (text :str )->str :
    """Basic string normalization."""

    text =text .lower ().strip ()

    text =" ".join (text .split ())
    return text 


def word_ngrams (text :str ,n :int )->list :
    """Generate word-level n-grams from text."""
    words =text .split ()
    return [" ".join (words [i :i +n ])for i in range (len (words )-n +1 )]


def build_ngram_lookup (
documents :list [str ],ngram_size :int =8 
)->dict [str ,set [int ]]:
    """Build ngram lookup for documents."""
    lookup =collections .defaultdict (set )

    for doc_id ,document in enumerate (tqdm (documents )):
        normalized_text =normalize_string (document )
        ngrams =word_ngrams (normalized_text ,ngram_size )
        for ngram in ngrams :
            lookup [ngram ].add (doc_id )

    return lookup 


def build_ngram_single (document :str ,ngram_size :int =8 )->set [str ]:
    normalized_text =normalize_string (document )
    ngrams =word_ngrams (normalized_text ,ngram_size )

    return set (ngrams )


if __name__ =="__main__":
    import argparse 

    parser =argparse .ArgumentParser ()
    parser .add_argument (
    "--dataset",
    type =str ,
    required =True ,
    help ="Name of the dataset to check for contamination.",
    )
    parser .add_argument (
    "--config",type =str ,default =None ,help ="Name of the dataset config to load."
    )
    parser .add_argument (
    "--split",
    type =str ,
    default ="train",
    help ="Split to check for contamination, defaults to `train`.",
    )
    parser .add_argument (
    "--ngram_size",
    type =int ,
    default =8 ,
    help ="Size of n-grams to build, defaults to 8.",
    )
    parser .add_argument (
    "--problem_column",
    type =str ,
    default ="problem",
    help ="Name of the column containing the problem (prompt).",
    )
    parser .add_argument (
    "--cleanup",
    action ="store_true",
    help ="Whether to remove the contaminated rows before pushing the dataset.",
    )
    parser .add_argument (
    "--new_dataset_name",
    type =str ,
    default =None ,
    help ="New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.",
    )
    parser .add_argument (
    "--output_dir",
    type =str ,
    default =None ,
    help ="Output directory to save the decontaminated dataset. If not provided, will save the dataset in the current directory.",
    )
    args =parser .parse_args ()

    from datasets import load_dataset ,Dataset 


    ds =load_dataset (args .dataset ,name =args .config ,split =args .split )

    eval_datasets ={










    "lcb":(
    load_dataset (
    "livecodebench/code_generation_lite",
    split ="test",
    version_tag ="v1",
    trust_remote_code =True ,
    ),
    "question_content",
    ),
    }
    ngram_lookups ={}
    for ds_name ,(eval_dataset ,problem_col )in eval_datasets .items ():
        ngram_lookups [ds_name ]=build_ngram_lookup (
        eval_dataset [problem_col ],ngram_size =args .ngram_size 
        )

    for eval_name ,ngram_lookup in ngram_lookups .items ():

        def find_contaminated (row ):

            ngrams =build_ngram_single (
            row [args .problem_column ],ngram_size =args .ngram_size 
            )
            row [f"contaminated_{eval_name }"]=any (
            set (ngram in ngram_lookup for ngram in ngrams )
            )
            return row 

        ds =ds .map (find_contaminated ,num_proc =8 )


    def cleanup (dataset :Dataset )->Dataset :
        initial_size =len (dataset )
        contamination_cols =[
        col for col in dataset .column_names if col .startswith ("contaminated_")
        ]
        for col in contamination_cols :
            if col .startswith ("contaminated_"):
                size_prior =len (dataset )
                dataset =dataset .filter (lambda x :not x [col ],num_proc =8 )
                if len (dataset )<size_prior :
                    print (
                    f"Removed {size_prior -len (dataset )} samples from '{col .replace ('contaminated_','')}'"
                    )
        dataset =dataset .remove_columns (contamination_cols )
        print (f"Initial size: {initial_size }, Final size: {len (dataset )}")
        return dataset 

    if args .cleanup :
        ds =cleanup (ds )








    output_dir =args .output_dir or "./"

    ds .save_to_disk (output_dir )

    print (f"Decontaminated dataset saved locally at: {output_dir }")
