from azure.ai.textanalytics import TextAnalyticsClient
from azure.core.credentials import AzureKeyCredential
import pandas as pd
import time
from tqdm import trange
from argparse import ArgumentParser
from datasets import load_dataset, Dataset
from multiprocessing import Pool
from azure.core.credentials import AzureKeyCredential
from azure.ai.textanalytics.aio import TextAnalyticsClient
from azure.ai.textanalytics import ExtractSummaryAction
from azure.core.exceptions import HttpResponseError
import asyncio
import os
import math
from nltk.tokenize import sent_tokenize


key = os.getenv('AZURE_KEY')
endpoint = os.getenv('AZURE_ENDPOINT')
requests = 0

# Example method for summarizing text


async def summarize(batch, ids, max_sents=3):
    client = TextAnalyticsClient(
        endpoint=endpoint,
        credential=AzureKeyCredential(key),
    )

    async with client:
        global requests
        requests += 1

        if requests == 1000:
            time.sleep(60)
            requests = 0

        try:
            poller = await client.begin_analyze_actions(
                batch,
                actions=[
                    ExtractSummaryAction(max_sentence_count=max_sents)
                ],
            )
        except HttpResponseError as e:
            if e.reason == 'Too Many Requests':
                print('sleeping')
                time.sleep(60)
                print('awake')
                poller = await client.begin_analyze_actions(
                    batch,
                    actions=[
                        ExtractSummaryAction(max_sentence_count=max_sents)
                    ],
                )
            else:
                print(e.message)
                exit(-1)

        document_results = await poller.result()
        summaries = []

        async for result in document_results:
            extract_summary_result = result[0]  # first document, first result
            i = int(extract_summary_result.id)
            id = ids[i]
            article = batch[i]
            if not extract_summary_result.is_error:
                sentences = [x.text for x in extract_summary_result.sentences]
                summary = ' '.join(sentences)
                summaries.append((id, article, summary))

    return summaries

def get_bins(dataset, max_range=10):
    df = pd.DataFrame(dataset)
    sentences = df.article.apply(sent_tokenize)
    num_sentences = sentences.apply(len)
    num_bins = math.ceil((num_sentences.max() - num_sentences.min()) / max_range)
    labels = pd.cut(num_sentences, num_bins, labels=range(num_bins))
    bins = [{
                #'id':sentences[labels == x].index.to_list(), 
                'text': Dataset.from_dict(df[labels == x]), 
                'min_length': num_sentences[labels == x].min(), 
                'max_length': num_sentences[labels == x].max(),
            } for x in range(num_bins) if (labels == x).any()]

    return bins

def get_batch(dataset, max_requests, use_bins=False, max_bin_range=10):
    if not use_bins:
        for i in trange(0, len(dataset), max_requests):
            yield dataset['article'][i:i+max_requests], dataset['id'][i:i+max_requests], None, None
    else:
        bins = get_bins(dataset, max_bin_range)
        for bin in bins:
            for i in trange(0, len(bin['text']), max_requests):
                yield bin['text']['article'][i:i+max_requests], bin['text']['id'][i:i+max_requests], bin['min_length'], bin['max_length']


async def main(args):
    print('loading dataset')
    if args.dataset_file is None:
        dataset = load_dataset(
            args.dataset, args.dataset_version, split=args.dataset_split)
    else:
        dataset = pd.read_csv(args.dataset_file, lineterminator='\n').dropna()
        dataset = Dataset.from_dict(dataset)
        
    assert dataset is not None

    summaries = []

    print(f'starting generation of {len(dataset)} summaries')
    for batch, ids, min_length, max_length in get_batch(dataset, 1000, use_bins=args.percent_length is not None, max_bin_range=10):
        if args.percent_length is None:
            results = await asyncio.gather(*(summarize(batch[i:i+args.batch_size], ids[i:i+args.batch_size]) for i in range(0, len(batch), args.batch_size)))
        else:
            max_sents = min(math.floor(args.percent_length * (max_length + min_length)/2), 20)
            if max_sents > 0:
                results = await asyncio.gather(*(summarize(batch[i:i+args.batch_size], ids[i:i+args.batch_size], max_sents=max_sents) for i in range(0, len(batch), args.batch_size)))
            else:
                results = [[(ids[i], batch[i], '') for i in range(len(batch))]]
                
        for result in results:
            summaries += result
        #print('sleeping')
        #time.sleep(60)
        summary_df = pd.DataFrame(
            summaries, columns=['id', 'article', 'summary'])
        summary_df.to_csv(args.save_path)  


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        '--dataset_file', help='path to dataset csv file', type=str)
    parser.add_argument(
        '--dataset', help='name of dataset on huggingface hub', type=str)
    parser.add_argument('--dataset_version',
                        help='version of dataset on huggingface hub', type=str)
    parser.add_argument(
        '--dataset_split', help='dataset split to use from huggingface hub', type=str)
    parser.add_argument(
        '--batch_size', help='batch size (default: 20)', default=20, type=int)
    parser.add_argument(
        '--save_path', help='path to store csv with summaries', type=str)
    parser.add_argument('--percent_length', help='percent be of original article that summary should be', type=float)
    args = parser.parse_args()
    asyncio.run(main(args))
