from summa import summarizer
from argparse import ArgumentParser
from datasets import load_dataset
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm
import pandas as pd
import os


def summarize_article(data, max_words=None, max_percent=None):
    article = data['article']

    if article is None or pd.isna(article) or len(article) == 0:
        return ''

    assert max_words is not None or max_percent is not None
    if max_words is not None:
        summary = summarizer.summarize(article, words=max_words)
    else:
        summary = summarizer.summarize(article, ratio=max_percent)

    return summary


def summarize(dataset, save_path, nworkers, summary_length_percent, summary_length_tokens):
    print(
        f'generating summaries of {len(dataset)} and saving to {save_path}')

    if summary_length_percent is not None:
        summarize_art = partial(summarize_article, max_percent=summary_length_percent)
    else:
        summarize_art = partial(summarize_article, max_words=summary_length_tokens)

    with Pool(nworkers) as pool:
        summaries = list(pool.imap(summarize_art, tqdm(dataset, total=len(dataset)), chunksize=10))

    df = pd.DataFrame(zip([x['id'] for x in dataset], [x['article'] for x in dataset], summaries), columns=['id', 'article', 'summary'])
    df.to_csv(save_path)


def load_data(dataset_file, dataset_name, dataset_version, dataset_split, col_name='article'):
    if dataset_file is None:
        dataset = load_dataset(dataset_name, dataset_version, split=dataset_split)
    else:
        dataset = pd.read_csv(dataset_file)[['id', col_name]].rename(columns={col_name:'article'}).to_dict(orient='records')
    
    return dataset


def main(args):
    if not args.recursive:
        dataset = load_data(args.dataset_file, args.dataset, args.dataset_version, args.data_splits)
        summarize(dataset, args.save_path, args.nworkers, args.summary_length_percent, args.summary_length_tokens)
    else:
        dataset = load_data(args.dataset_file, None, None, None)
        file_name = args.dataset_file.split('/')[-1]

        for i in range(args.num_steps):
            print(f'-----starting step {i+1}/{args.num_steps}-----')
            save_dir = f'{args.save_path}/{i}_{args.summary_length_percent}'
            save_path = f'{save_dir}/{file_name}'
            
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)

            summarize(dataset, save_path, args.nworkers, args.summary_length_percent, None)
            dataset = load_data(save_path, None, None, None, col_name='summary')

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        '--dataset_file', help='path to csv file containing dataset', type=str)
    parser.add_argument(
        '--dataset', help='name of dataset on huggingface hub (Default: cnn_dailymail).', default='cnn_dailymail', type=str)
    parser.add_argument(
        '--dataset_version', help='version of dataset (default: 3.0.0)', default='3.0.0', type=str)
    parser.add_argument('--data_splits', help='splits of data to summarize (Default: train+validation+test)',
                        default='train+validation+test', type=str)
    parser.add_argument('--save_path', help='directory to save summaries dataframe (default: ./summaries/cnn_dailymail.csv)',
                        default='./summaries/cnn_dailymail.csv', type=str)
    parser.add_argument(
        '--nworkers', help='number of workers (default: 10)', default=10, type=int)
    parser.add_argument('--summary_length_tokens',
                        help='max summary length in tokens', type=int)
    parser.add_argument('--summary_length_percent',
                        help='Summary length as a percent length of the original article', type=float)
    parser.add_argument('--recursive', help='flag for recursive summarization. If true, save_path should be a directory', action='store_true')
    parser.add_argument('--num_steps', help='Number of steps to use in recursive summarization', type=int)
    args = parser.parse_args()

    main(args)
