# def process_in_ds(ds):
#     ds = ds.rename_column("article", "input")
#     ds = ds.rename_column("highlights", "reference")
#     return ds


def get_in_ds():
    raise NotImplementedError
    from datasets import load_dataset

    cnn_daily = load_dataset("cnn_dailymail", "3.0.0").shuffle(seed=42)
    ds = cnn_daily["test"]

    import os

    if os.environ.get("EXP_DEBUG", None) == "1":
        ds = ds.select(range(0, 2))
    if os.environ.get("EXP_DEBUG", None) == "2":
        ds = ds.shard(num_shards=100, index=0)

    # ds = process_in_ds(ds)
    ds = ds.sort("id")

    return ds

import tqdm
from datasets import Dataset
import pandas as pd
import json
def get_in_ds_undetectable_exp(prompt_num=2000,repeat_num=1,truncate_num=100):
    from datasets import load_dataset

    # cnn_daily = load_dataset("cnn_dailymail", "3.0.0").shuffle(seed=42)
    # ds = cnn_daily["test"]
    
    #truncate text based on word number
    def truncate_text(text,word_num):
        assert len(text.split(' '))>word_num
        return ' '.join(text.split(' ')[:word_num]), ' '.join(text.split(' ')[word_num:])
        
    
    print('generating text generation dataset...')
    dataset_path='dataset/c4_subset.json'
    with open(dataset_path,'r') as f:
        c4_subset=json.load(f)
    
    
    import random
    random.seed(43)
    random.shuffle(c4_subset)
    
    
    # need 
    
    # assert prompt_num*repeat_num<=len(ds) #for generating id list
    # id_list=ds['id'][:prompt_num*repeat_num]
    
    id_list=range(prompt_num*repeat_num)
    
    ds_subset=[]
    for repeat_idx in tqdm.tqdm(range(0,repeat_num)):
        for prompt_idx in range(prompt_num):
            id_idx=repeat_idx*prompt_num+prompt_idx
            new_item={}
            # new_item['article']=ds[prompt_idx]['article']
            # new_item['highlights']=ds[prompt_idx]['highlights']
            new_input,new_refernce=truncate_text(c4_subset[prompt_idx],word_num=truncate_num)
            new_item['input']=new_input
            new_item['reference']=new_refernce
            new_item['id']=id_list[id_idx]
            new_item['reference_id']=id_list[prompt_idx]
            ds_subset.append(new_item)
    ds_subset=pd.DataFrame(ds_subset)
    ds_subset=Dataset.from_pandas(ds_subset,preserve_index=False)
    
    # ds_subset = process_in_ds(ds_subset)
    ds_subset = ds_subset.sort("id")

    return ds_subset

# def get_merged_ds(path):
#     in_ds = get_in_ds()

#     from datasets import load_dataset

#     out_ds = load_dataset("json", data_files={"test": path})["test"]
#     out_ds = out_ds.sort("id")

#     from experiments.common import add_reference, group_batch

#     ds = add_reference(in_ds, out_ds)
#     return ds


from . import get_output
from . import evaluate_ppl
from . import evaluate_beta_score
