def process_in_ds(ds):
    ds = ds.rename_column("article", "input")
    ds = ds.rename_column("highlights", "reference")
    return ds


def get_in_ds():
    raise NotImplementedError
    from datasets import load_dataset

    cnn_daily = load_dataset("cnn_dailymail", "3.0.0").shuffle(seed=42)
    ds = cnn_daily["test"]

    import os

    if os.environ.get("EXP_DEBUG", None) == "1":
        ds = ds.select(range(0, 2))
    if os.environ.get("EXP_DEBUG", None) == "2":
        ds = ds.shard(num_shards=100, index=0)

    ds = process_in_ds(ds)
    ds = ds.sort("id")

    return ds

import tqdm
from datasets import Dataset
import pandas as pd
def get_in_ds_undetectable_exp(prompt_num=10,repeat_num=10000):
    from datasets import load_dataset

    cnn_daily = load_dataset("cnn_dailymail", "3.0.0").shuffle(seed=42)
    ds = cnn_daily["test"]
    
    # assert prompt_num*repeat_num<=len(ds) #for generating id list
    # id_list=ds['id'][:prompt_num*repeat_num]
    
    if prompt_num==-1:
        prompt_num=len(ds)
        # print(prompt_num)
        # raise NotImplementedError
    
    id_list=range(prompt_num*repeat_num)
    
    ds_subset=[]
    for repeat_idx in tqdm.tqdm(range(0,repeat_num)):
        for prompt_idx in range(prompt_num):
            id_idx=repeat_idx*prompt_num+prompt_idx
            new_item={}
            new_item['article']=ds[prompt_idx]['article']
            new_item['highlights']=ds[prompt_idx]['highlights']
            new_item['id']=id_list[id_idx]
            new_item['reference_id']=id_list[prompt_idx]
            ds_subset.append(new_item)
    ds_subset=pd.DataFrame(ds_subset)
    ds_subset=Dataset.from_pandas(ds_subset,preserve_index=False)
    
    ds_subset = process_in_ds(ds_subset)
    ds_subset = ds_subset.sort("id")

    return ds_subset

def get_merged_ds(path):
    raise NotImplementedError
    in_ds = get_in_ds()

    from datasets import load_dataset

    out_ds = load_dataset("json", data_files={"test": path})["test"]
    out_ds = out_ds.sort("id")

    from experiments.common import add_reference, group_batch

    ds = add_reference(in_ds, out_ds)
    return ds


from . import get_output
from . import evaluate
from . import evaluate_ppl

