from pickle import FALSE
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from datasets import load_dataset
import pandas as pd
import numpy as np
from nltk import sent_tokenize, FreqDist
from utils import get_summary_indices, get_overlap_scores, visualize_summary
from tqdm import trange
from argparse import ArgumentParser


def paraphrase_sents(input_text, model, tokenizer, torch_device='cpu', repeats=1, do_sample=False, temperature=1.5, num_beams=10):
  batch = tokenizer(input_text,truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
  translated = model.generate(**batch,max_length=60, do_sample=do_sample, num_beams=num_beams, num_return_sequences=repeats, temperature=temperature)
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
  return tgt_text

def get_sent(sents):
    for sent in sents:
        yield sent

# sometimes when using beam search, the model outputs a bunch of identical junk tokens
# This checks if the sentence is one of these bad sentences and disregards it if so
def is_bad(x):
    # we don't need to bother with true tokenizing to detect a bunch of identical words
    tokens = x.split(' ')
    types = set(tokens)

    if len(tokens) > 10 and len(types)/len(tokens) <= 0.5:
        return True
    if len(tokens[-1]) > 40:
        return True

    return False

def save_sents(articles, sents, sent_idxs, ids, num_repeats, repeat_indices, save_path, append=True):
    new_articles = []
    sent_idx = 0
    sents_used = 0

    if len(sent_idxs) == 0 or len(articles) == 0 or len(sents) == 0 or len(ids) == 0:
        return 0, 0

    for i in range(len(articles)):
        article = articles[i]
        if (repeat_indices and sent_idx * num_repeats + num_repeats * len(sent_idxs[i // num_repeats]) <= len(sents)):
            for j in sent_idxs[i // num_repeats]:
                # we only ever write whole repeat batches here, so using i as the offset is fine
                sent = sents[sent_idx * num_repeats + i % num_repeats]
                # protects against sentences that are just the same token over and over
                # Found that this occasionally happens with phone numbers, but it may be doing it with other tokens as well
                if not is_bad(sent):
                    article[j] = sent
                sent_idx += 1

            sents_used += len(sent_idxs[i // num_repeats])
            if i + 1 < len(ids) and ids[i + 1] == ids[i]:
                sent_idx -= len(sent_idxs[i // num_repeats])
            else:
                sent_idx = sents_used // num_repeats
            new_articles.append(' '.join(article))
        elif not repeat_indices and sent_idx + len(sent_idxs[i]) <= len(sents):
            for j in sent_idxs[i]:
                sent = sents[sent_idx]
                # protects against sentences that are just the same token over and over
                # Found that this occasionally happens with phone numbers, but it may be doing it with other tokens as well
                if not is_bad(sent):
                    article[j] = sent
                sent_idx += 1
            sents_used += len(sent_idxs[i])
            new_articles.append(' '.join(article))
        else:
            break

    df = pd.DataFrame(zip(ids, new_articles), columns=['id', 'article'])
    df.to_csv(save_path, mode='a+' if append else 'w+', header=not append, index=None)
    return sent_idx, i


def main(args):
    model_name = 'tuner007/pegasus_paraphrase'
    torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('loading tokenizer and model')
    tokenizer = PegasusTokenizer.from_pretrained(model_name)
    model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
    rng = np.random.default_rng(args.seed)

    print('loading summaries and articles')
    df = pd.read_csv(args.datapath, index_col=0).dropna().reset_index(drop=True)
    new_sents = []
    sents = []
    sent_idxs = []
    sent_idx = 0
    article_idx = 0

    print(f'starting paraphrasing with the following arguments: {args}')

    articles = df.article.apply(sent_tokenize)
    summaries = df.summary.apply(sent_tokenize)
    
    if args.targeted:
        for i in range(len(summaries)):
            sent_idxs.append([j for j in get_summary_indices(articles[i], summaries[i], top_k=args.top_k, tolerance=args.tolerance)])
            sents += [articles[i][j] for j in sent_idxs[-1]]
    elif args.repeat_indices:
        for art in articles:
            sent_idxs.append([y for y in rng.integers(0, high=len(art), size=int(np.ceil(args.percent * len(art))))])
            sents += [art[j] for j in sent_idxs[-1]]
    else:
        for i in range(args.num_repeats*len(articles)):
            art = articles[i//args.num_repeats]
            sent_idxs.append([y for y in rng.integers(0, high=len(art), size=int(np.ceil(args.percent*len(art))))])
            sents += [art[j] for j in sent_idxs[-1]]
    
    arts = []
    ids = []

    for i in range(len(articles)):
        arts += [[x for x in articles[i]] for _ in range(args.num_repeats)]
        ids += [df.id[i] for x in range(args.num_repeats)]

    for i in trange(0, len(sents), args.batch_size):
        batch = sents[i:i+args.batch_size]
        if args.repeat_indices or args.targeted:
            paraphrases = paraphrase_sents(batch, model, tokenizer, torch_device=torch_device, repeats=args.num_repeats, do_sample=args.do_sample, num_beams=args.num_beams)
        else:
            paraphrases = paraphrase_sents(batch, model, tokenizer, torch_device=torch_device, repeats=1, do_sample=args.do_sample, num_beams=args.num_beams)
        
        new_sents += paraphrases

        if i % args.batch_size * 100 == 0:
            if i == 0:
                append = False
            else:
                append = True
            
            if args.repeat_indices or args.targeted:
                new_sent_idx, new_article_idx = save_sents(arts[article_idx:], new_sents[sent_idx * args.num_repeats:], sent_idxs[article_idx//args.num_repeats:], ids[article_idx:], args.num_repeats, True, args.savepath, append=append)
            else:
                new_sent_idx, new_article_idx = save_sents(arts[article_idx:], new_sents[sent_idx:], sent_idxs[article_idx:], ids[article_idx:], args.num_repeats, False, args.savepath, append=append)
            
            article_idx += new_article_idx
            sent_idx += new_sent_idx
            
    
    save_sents(arts[article_idx:], new_sents[sent_idx:], sent_idxs[article_idx:], ids[article_idx:], args.num_repeats, args.repeat_indices or args.targeted, args.savepath, append=True)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--datapath', help='path to csv file containing summary and corresponding article', type=str)
    parser.add_argument('--savepath', help='Path to save csv with paraphrased articles', type=str)
    parser.add_argument('--targeted', help='Don\'t target the sentences from the summary to rephrase', action='store_true')
    parser.add_argument('--top_k', help='The number of similar sentences to consider for each sentence in the summary. Only meaningful if targeted=True. (Default: 2)', default=2, type=int)
    parser.add_argument('--tolerance', help='How far similarity scores can deviate from 1 while still considering two sentences identical. Only meaningful if targeted=True (Default: 0.1).', default=0.1, type=float)
    parser.add_argument('--percent', help='Percent of article to paraphrase if targeted=False (default:0.1)', default=0.1, type=float)
    parser.add_argument('--num_repeats', help='Number of paraphrases to generate (default:1)', default=1, type=int)
    parser.add_argument('--repeat_indices', help='If num_repeats > 1 and targeted=false, this controls whether different paraphrases are generated for the same indices or different indices (default: True).', action='store_true')
    parser.add_argument('--do_sample', help='Whether or not paraphrases should be generated through sampling. (default: False)', action='store_true')
    parser.add_argument('--num_beams', help='Number of beams to use in beamsearch if do_sample=False (default: 10)', default=10, type=int)
    parser.add_argument('--seed', help='seed for reproducibility in random indices', type=int)
    parser.add_argument('--batch_size', help='max number of sentences to paraphrase at once', default=15, type=int)
    parser.set_defaults(targeted=False, repeat_indices=False, do_sample=False)
    args = parser.parse_args()

    if args.percent > 1:
        args.percent /= 100

    main(args)