import pandas as pd
import numpy as np
from utils import get_summary_indices
from tqdm import trange
from nltk import sent_tokenize
import seaborn as sns
from paraphrase import is_bad
from tqdm import tqdm

summarizers = ['textrank', 'matchsum', 'presumm_ext', 'presumm_abs', 'azure', 'pegasus', 'bart']

def get_bad_sent(x):
    sents = sent_tokenize(x)
    
    for i, sent in enumerate(sents):
        if is_bad(sent):
            return True

    return False

def get_bad_sent_idx(x):    
    for i, sent in enumerate(x):
        if is_bad(sent):
            return i

    return -1

for summarizer in summarizers:
    original = pd.read_csv(f'~/user/{summarizer}_test.csv', index_col=0).dropna().reset_index(drop=True)
    paraphrased = pd.read_csv(f'~/user/paraphrasing/targeted/{summarizer}/articles_test_targeted_no_sample.csv', index_col=0)
    
    print(f'finding bad articles for {summarizer}')
    bad = []
    for a in tqdm(paraphrased.article):
        bad.append(get_bad_sent(a))
    
    bad_ids = paraphrased[bad].id.unique().tolist()
    original_idx = original.id.apply(lambda x: x in bad_ids)
    original_arts = original[original_idx]

    fixed_arts = []

    print(f'fixing {len(paraphrased[bad])} articles for {summarizer}')
    for i in trange(len(paraphrased[bad])):
        entry = paraphrased[bad].iloc[i]
        orig_art = sent_tokenize(original_arts[original_arts.id == entry.id].article.iloc[0])
        bad_art = sent_tokenize(entry.article)
        bad_idx = get_bad_sent_idx(bad_art)
        bad_art[bad_idx] = orig_art[bad_idx]

        if len(orig_art) > len(bad_art):
            bad_art.insert(bad_idx + 1, orig_art[bad_idx+1])
        
        fixed_arts.append(' '.join(bad_art))
    
    df = paraphrased[bad].reset_index(drop=True)
    df.article = fixed_arts
    print(f'writing {len(df)} articles to ~/user/paraphrasing/targeted/{summarizer}/fixed_articles.csv')
    df.to_csv(f'~/user/paraphrasing/targeted/{summarizer}/fixed_articles.csv')

