from xmlrpc.client import Boolean
from transformers import pipeline, AutoTokenizer
from transformers.pipelines.base import KeyDataset
from datasets import load_dataset, Dataset
from tqdm import tqdm
import argparse
from glob import glob
import pandas as pd
import os
import math
import json
import nltk


def get_token_length(article, tokenizer=None):
    tokenized = tokenizer(article, max_length=None, return_tensors='pt')
    length = tokenized.input_ids.shape[-1]
    return length


def write_summaries(paths, summaries, original_text, outpath, i, num_batches):
    print(f'writing {len(summaries)} summaries to {outpath}')
    df = pd.DataFrame(zip(paths, summaries, original_text),
                      columns=['file_name', 'summary', 'original'])
    df.to_csv(outpath, mode='a', header=not os.path.exists(outpath))
    print(f'\rwrote batch {i+1}/{num_batches} to {outpath}')


def chunk_article(article, chunk_size):
    chunks = []
    chunk = []
    length = 0

    for sent in article:
        length += get_token_length(sent)
        if length < chunk_size:
            chunk.append(sent)
        else:
            chunks.append('\n'.join(chunk))
            chunk = [sent]
            length = len(sent)

            if length > chunk_size:
                with open('./err.txt', 'w+') as f:
                    f.write(str(length) + ' ' + sent)
                print('length error')
                exit(0)

    chunks.append(' '.join(chunk))
    return chunks


def bin_articles_by_length(articles, model_name, max_range=100):
    df = pd.DataFrame(articles, columns=['article'])
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    lengths = df.article.apply(get_token_length, tokenizer=tokenizer)
    num_bins = math.ceil((lengths.max() - lengths.min()) / max_range)
    labels = pd.cut(lengths, num_bins, labels=[x for x in range(num_bins)])
    bins = [{'text': Dataset.from_pandas(df[labels == x]),
            'min_length': lengths[labels == x].min(),
             'max_length': lengths[labels == x].max()}
            for x in range(num_bins) if (labels == x).any()]
    return bins


def summarize(summarizer, article, max_length, min_length, do_sample):
    chunks = chunk_article(article, 1024)
    summaries = summarizer(chunks, max_length=max_length//len(chunks),
                           min_length=min_length//len(chunks), do_sample=do_sample)
    summaries = [x['summary_text'] for x in summaries]
    return ' '.join(summaries)


def summarize_articles(summarizer, model_name, dataset, save_path, length_percent=None, min_length=60, max_length=75, batch_size=5, do_sample=False):


    # num_batches = math.ceil(len(text)/batch_size)

    print('generating summaries')
    # if args.recursive:
    # summaries = [summarize_recursive(summarizer, x, args.max_length, args.min_length, args.do_sample) for x in articles]
    # elif not args.truncate:
    # summaries = [summarize(summarizer, x, args.max_length, args.min_length, args.do_sample) for x in articles]
    # else:
    tokenizer_kwargs = {'truncation': True}

    if length_percent is not None:
        summaries = ['' for x in range(len(dataset))]
        bins = bin_articles_by_length(dataset['article'], model_name)
        print(f'-----Starting summarization over {len(bins)} bins-----')
        for bin in bins:
            i = 0
            max_length = min(math.ceil(length_percent * bin['max_length']), 1024)
            min_length = min(math.ceil(length_percent * bin['min_length']), 1024)

            # if the longest article in the bin has less than 5 tokens (realistically this number could be larger), we're definitely not getting a summary out of the model
            # Set the summary to empty to indicate this loss of information
            if max_length < 5:
                for i in range(len(bin['text'])):
                    idx = bin['text'][i]['__index_level_0__']
                    summaries[idx] = ''
                continue

            for summ in tqdm(summarizer(KeyDataset(bin['text'], 'article'), max_length=max_length, min_length=min_length, batch_size=batch_size, do_sample=do_sample, **tokenizer_kwargs), total=len(bin['text']), unit='ba'):
                idx = bin['text'][i]['__index_level_0__']
                summaries[idx] = (summ[0]['summary_text'].replace('<n>', ' '))
                i += 1
    else:
        summaries = []
        for summ in tqdm(summarizer(KeyDataset(dataset, 'article'), max_length=max_length, min_length=min_length, batch_size=batch_size, do_sample=do_sample, **tokenizer_kwargs), total=len(dataset), unit='ba'):
            summaries.append(summ[0]['summary_text'].replace('<n>', ' '))

    df = pd.DataFrame(zip(dataset['id'], dataset['article'], summaries), columns=[
        'id', 'article', 'summary'])
    df.to_csv(save_path)


def summarize_recursive(summarizer, article, max_length, min_length, do_sample):
    chunks = chunk_article(article, 1024)
    summaries = []

    summaries = summarizer(chunks, max_length=max_length,
                           min_length=min_length, do_sample=do_sample)
    summaries = [x['summary_text'] for x in summaries]

    if len(chunks) == 1:
        return summaries[0]
    else:
        summarize_recursive(summarizer, summaries, max_length, min_length, do_sample)

def load_data(dataset_file, dataset_name, dataset_version, dataset_split, col_name='article'):
    if dataset_file is not None:
        dataset = pd.read_csv(dataset_file)[['id', col_name]].rename(columns={col_name:'article'}).dropna()
        dataset = Dataset.from_dict(dataset)
    else:
        dataset = load_dataset(dataset_name, dataset_version, dataset_split)

    return dataset

def main(args):
    summarizer = pipeline('summarization', model=args.model_name, tokenizer=args.model_name, device=args.device)

    if not args.recursive:
        dataset = load_data(args.dataset_file, args.dataset, args.data_version, args.data_split)
    
        assert dataset is not None

        summarize_articles(summarizer, args.model_name, dataset, args.save_path, args.length_percent, args.min_length, args.max_length, args.batch_size, args.do_sample)
    else:
        dataset = load_data(args.dataset_file, None, None, None)
        file_name = args.dataset_file.split('/')[-1]

        for i in range(args.num_steps):
            print(f'-----starting step {i+1}/{args.num_steps}-----')
            save_dir = f'{args.save_path}/{i}_{args.length_percent}'
            save_path = f'{save_dir}/{file_name}'

            if not os.path.exists(save_dir):
                os.mkdir(save_dir)

            summarize_articles(summarizer, args.model_name, dataset, save_path, args.length_percent, None, None, args.batch_size, args.do_sample)
            dataset = load_data(save_path, None, None, None, 'summary')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_name', help='name of model on hugging face hub in org/model format')
    parser.add_argument(
        '--dataset_file', help='file path to dataset if using local version', type=str)
    parser.add_argument(
        '--dataset', help='name of dataset to use', type=str)
    parser.add_argument(
        '--data_version', help='version of dataset to use', type=str)
    parser.add_argument(
        '--data_split', help='split of data to use', type=str)
    parser.add_argument(
        '--suffix', help='pattern that article files end in', default='*')
    parser.add_argument(
        '--max_length', help='max length for the summary', default=75, type=int)
    parser.add_argument(
        '--min_length', help='min length for the summary', default=60, type=int)
    parser.add_argument(
        '--length_percent', help='percent of the original article length that each generated summary should be', type=float)
    parser.add_argument(
        '--do_sample', help='do_sample value for summarizer', action='store_true')
    parser.add_argument('--save_path', help='path to csv to write summaries')
    parser.add_argument(
        '--batch_size', help='number of articles to process at once', default=5, type=int)
    parser.add_argument('--recursive', help='flag to indicate whether final summaries should be generated by recursively summarizing intermediate summaries. If true, both length_percent and num_steps must be specified', action='store_true')
    parser.add_argument(
        '--num_steps', help='total number of summaries that should be generated in recursive summarization', type=int)
    parser.add_argument(
        '--truncate', help='whether long articles should be truncated', action='store_true')
    parser.add_argument('--device', help='device to run on',
                        default=-1, type=int)

    args = parser.parse_args()
    print(args)

    main(args)
