import string

from transformers import set_seed

from tqdm import tqdm

from detectors import RoBERTaAIDetector
from my_utils.data_utils import save_list_to_tsv
from baselines.parrot.parrot import Parrot




import numpy as np
from nltk.tokenize import sent_tokenize
from my_utils.test_utils import get_generated_text_saving_dir
from my_utils.my_dataloader import load_test_data



class ParaAttack:

    def __init__(self, device, adequacy_threshold=0.8, fluency_threshold=0.8):
        self.adequacy_threshold = adequacy_threshold
        self.fluency_threshold = fluency_threshold
        self.paraphraser = Parrot(device)

    def one_shot_attack(self, orig_text):
        max_length = 128

        # split sentence, parrot cannot preserve meaning of long sentences
        sent_list, len_list = self.split_sentence(orig_text, max_length)
        para_list = []


        for sent, sent_len in zip(sent_list, len_list):

            for i in range(10):
                paraphrased_text_tmp_list = self.paraphraser(sent, 5, max_length=sent_len,
                                                         adequacy_threshold=self.adequacy_threshold,
                                                         fluency_threshold=self.fluency_threshold,
                                                             do_diverse=False)

                if len(paraphrased_text_tmp_list) > 0:
                    find_valid = False
                    for para_text in paraphrased_text_tmp_list:
                        para_text = para_text.strip()
                        if para_text[-1] not in string.punctuation and sent[-1] in string.punctuation:
                            para_text += sent[-1]
                        if para_text.lower() != sent.lower():
                            para_list.append(para_text)
                            find_valid = True
                            break

                    if find_valid:
                        break
            else:
                para_list.append(sent)


        return ' '.join(para_list)


    def _get_length(self, text):
        return int(len(text.split()) * 1.5)

    def split_sentence(self, doc_str, max_length):
        sentence_list = sent_tokenize(doc_str)
        final_sentence_list = []
        for sent in sentence_list:
            sent_length = len(sent.split())

            if len(final_sentence_list) == 0:
                final_sentence_list.append(sent)
                continue


            added_length = len(final_sentence_list[-1].split()) + sent_length
            if added_length > max_length:
                final_sentence_list.append(sent)
            else:
                final_sentence_list[-1] = final_sentence_list[-1] + ' ' + sent

        len_list = [max(max_length, self._get_length(t)) for t in final_sentence_list]



        return final_sentence_list, len_list



if __name__ == '__main__':

    set_seed(42)

    device = 'cuda:0'
    detector = RoBERTaAIDetector(device)
    baseline_name = 'parrot_paraphrase'
    dataset = ''
    task_type = 'paraphrase'

    # load test
    test_data_list = load_test_data(dataset, task_type)

    para_attacker = ParaAttack(device)

    paraphrased_text_list = []

    for ai_text in tqdm(test_data_list): # , disable=True
        paraphrased_text = para_attacker.one_shot_attack(ai_text)
        paraphrased_text_list.append(paraphrased_text)


    final_text_list = list(zip(test_data_list, paraphrased_text_list))
    final_text_list.insert(0, ['ai', 'new'])
    save_list_to_tsv(final_text_list, get_generated_text_saving_dir(dataset, 'parrot_paraphrase'))


