import datasets
from datasets import load_from_disk
import sys
from functools import partial

def usage():
    print("Usage: python reduce_seq.py <dataset_path> <output_path> <new_len>")
    print("Example: python reduce_seq.py ./dataset ./dataset_reduced 256")

def truncate_input_ids(example, new_len):
    # Truncate "input_ids" to a maximum length of 256
    for k in example.keys():
        example[k] = example[k][:new_len]
    return example

def reduce_seq(dataset, new_len):
    # Apply the function to the entire dataset
    truncated_dataset = dataset.map(partial(truncate_input_ids, new_len=new_len))
    return truncated_dataset

if len(sys.argv) != 4:
    usage()
    sys.exit(1)

dataset_path = sys.argv[1]
output_path = sys.argv[2]
new_len = int(sys.argv[3])
dataset_dict = load_from_disk(dataset_path)

train_dataset = dataset_dict['train']
eval_dataset = dataset_dict['validation']

train_dataset = reduce_seq(train_dataset, new_len)
eval_dataset = reduce_seq(eval_dataset, new_len)

# new dataset dict
dataset_dict['train'] = train_dataset
dataset_dict['validation'] = eval_dataset

# Save the new dataset_dict
dataset_dict.save_to_disk(output_path)
