import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from datasets import load_dataset
import pandas as pd
import numpy as np
from nltk import sent_tokenize
from utils import get_summary_indices, get_overlap_scores, visualize_summary
from summa.summarizer import summarize
from tqdm import trange

model_name = 'tuner007/pegasus_paraphrase'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)


def get_response(input_text,num_return_sequences,num_beams):
    batch = tokenizer(input_text,truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
    translated = model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
    tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    return tgt_text


def main():
    df = pd.read_csv('~/user/textrank_test.csv').dropna().reset_index(drop=True)

    paraphrases = []
    idxs = []

    for i in trange(len(df)):
        summ = df.summary[i]
        summ = sent_tokenize(summ)
        article = df.article[i]
        article = sent_tokenize(article)
        idx = get_summary_indices(article, summ, top_k=2, tolerance=0.1)
        sentences = [article[x] for x in idx]
        paraphrase = get_response(sentences, 1, 10)
        paraphrased_article = article
        
        for j, k in enumerate(idx):
            paraphrased_article[k] = paraphrase[j]

        new_summ = summarize(' '.join(paraphrased_article), words=75)
        new_summ = sent_tokenize(new_summ)
        new_idx = get_summary_indices(paraphrased_article, new_summ, top_k=2, tolerance=0.1)

        paraphrases.append((' '.join(paraphrased_article), ' '.join(new_summ)))
        idxs.append((idx, new_idx))

    paraphrase_df = pd.DataFrame(paraphrases, columns=['article', 'summary'])
    paraphrase_df.to_csv('~/user/textrank_test_paraphrase.csv')

if __name__ == '__main__':
    main()

