import argparse
import pandas as pd
from nltk.tokenize import sent_tokenize
from datasets import load_dataset, Dataset
from tqdm import tqdm

def main(args):
    if args.dataset_file is not None:
        if '.json' in args.dataset_file:
            df = pd.read_json(args.dataset_file, lines=True)
        else:
            df = pd.read_csv(args.dataset_file)
        dataset = Dataset.from_dict(df)
    else:
        dataset = load_dataset(args.dataset, args.dataset_version, split=args.data_splits)

    with open(args.save_path, 'w+') as f_save:
        for obj in tqdm(dataset):
            text = obj[args.column_name]
        
            if len(text) == 0:
                continue
            
            if args.summarization_mode == 'ext':
                sents = sent_tokenize(text)
                text = ' [CLS] [SEP] '.join(sents)
            
            f_save.write(text + '\n')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_file', help='Path to local file to use as dataset', type=str)
    parser.add_argument('--dataset', help='Name of dataset on huggingface hub (Default: cnn_dailymail)', default='cnn_dailymail', type=str)
    parser.add_argument('--dataset_version', help='Version of dataset in huggingface hub (Default: 3.0.0)', default='3.0.0', type=str)
    parser.add_argument('--data_splits', help='splits of dataset to use (Default: train+validation+test).', default='train+validation+test', type=str)
    parser.add_argument('--save_path', help='Path to output file (Default: ../../user/presumm_cnn_dailymail_ext.txt)', default='../../user/presumm_cnn_dailymail_ext.txt', type=str)
    parser.add_argument('--summarization_mode', help='mode for summarizer (defines output file format). Default: ext', default='ext', type=str)
    parser.add_argument('--column_name', help='column to use in dataset', default='article', type=str)
    args = parser.parse_args()
    main(args)