import re

import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from .filters import Adequacy
from .filters import Fluency
from .filters import Diversity

class Parrot():

    def __init__(self, device):


        model_tag = 'parrot-paraphraser-on-T5'
        self.tokenizer = AutoTokenizer.from_pretrained(model_tag, use_auth_token=False)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_tag, use_auth_token=False).to(device)
        self.adequacy_score = Adequacy(device)
        self.fluency_score = Fluency(device)
        # self.diversity_score = Diversity()

        self.device = device

    def __call__(self, input_text, max_return_num, max_length=32, adequacy_threshold=0.90, fluency_threshold=0.90, do_diverse=True):
        # input_text = 'a test of paraphrasing model in New York.'
        # print('Orig Text', input_text)

        # tmp = input_text
        # input_text = re.sub('[^a-zA-Z0-9 \?\'\-\/\:\.]', '', input_text)
        input_ids = self.tokenizer.encode('paraphrase: ' + input_text, return_tensors='pt')
        # print('Tokenized Text', self.tokenizer.decode(input_ids[0]))

        input_ids = input_ids.to(self.device)

        if do_diverse:
            for n in range(2, 9):
                if max_return_num % n == 0:
                    break
                    # print("max_return_phrases - ", max_return_phrases , " and beam groups -", n)
            with torch.no_grad():
                preds = self.model.generate(
                    input_ids,
                    do_sample=False,
                    max_length=max_length,
                    num_beams=max_return_num,
                    num_beam_groups=n,
                    diversity_penalty=2.0,
                    early_stopping=True,
                    num_return_sequences=max_return_num)
        else:
            with torch.no_grad():
                preds = self.model.generate(
                    input_ids,
                    do_sample=True,
                    max_length=max_length,
                    # top_k=40,
                    # top_p=0.9,
                    early_stopping=True,
                    num_return_sequences=max_return_num)

        paraphrase_list = []

        for pred in preds:
            gen_pp = self.tokenizer.decode(pred, skip_special_tokens=True)
            # gen_pp = re.sub('[^a-zA-Z0-9 \?\!\'\-]', '', gen_pp)
            if gen_pp not in paraphrase_list and gen_pp != input_text:
                paraphrase_list.append(gen_pp)

        return self.filter_output_list(input_text, paraphrase_list, adequacy_threshold, fluency_threshold)

    def filter_output_list(self, input_text, orig_paraphrase_list, adequacy_threshold, fluency_threshold):

        final_output_list = []
        for text in orig_paraphrase_list:
            # text = re.sub('[^a-zA-Z0-9 \?\'\-]', '', text)
            final_output_list.append(text)

        final_output_list = self.adequacy_score.filter(input_text, final_output_list, adequacy_threshold)
        final_output_list = self.fluency_score.filter(final_output_list, fluency_threshold)

        return final_output_list



    def augment(self, input_phrase, use_gpu=False, diversity_ranker="levenshtein", do_diverse=False,
                max_return_phrases=10, max_length=32, adequacy_threshold=0.90, fluency_threshold=0.90):



        save_phrase = input_phrase
        if len(input_phrase) >= max_length:
            max_length += 32

        input_phrase = re.sub('[^a-zA-Z0-9 \?\'\-\/\:\.]', '', input_phrase)
        input_phrase = "paraphrase: " + input_phrase
        input_ids = self.tokenizer.encode(input_phrase, return_tensors='pt')
        input_ids = input_ids.to(self.device)

        # if do_diverse:
        #     for n in range(2, 9):
        #         if max_return_phrases % n == 0:
        #             break
        #             # print("max_return_phrases - ", max_return_phrases , " and beam groups -", n)
        #     preds = self.model.generate(
        #         input_ids,
        #         do_sample=False,
        #         max_length=max_length,
        #         num_beams=max_return_phrases,
        #         num_beam_groups=n,
        #         diversity_penalty=2.0,
        #         early_stopping=True,
        #         num_return_sequences=max_return_phrases)
        # else:
        #     preds = self.model.generate(
        #         input_ids,
        #         do_sample=True,
        #         max_length=max_length,
        #         top_k=50,
        #         top_p=0.95,
        #         early_stopping=True,
        #         num_return_sequences=max_return_phrases)
        #
        # paraphrases = set()
        #
        # for pred in preds:
        #     gen_pp = self.tokenizer.decode(pred, skip_special_tokens=True).lower()
        #     gen_pp = re.sub('[^a-zA-Z0-9 \?\'\-]', '', gen_pp)
        #     paraphrases.add(gen_pp)
        #
        # adequacy_filtered_phrases = self.adequacy_score.filter(input_phrase, paraphrases, adequacy_threshold, device)
        # if len(adequacy_filtered_phrases) > 0:
        #     fluency_filtered_phrases = self.fluency_score.filter(adequacy_filtered_phrases, fluency_threshold, device)
        #     if len(fluency_filtered_phrases) > 0:
        #         diversity_scored_phrases = self.diversity_score.rank(input_phrase, fluency_filtered_phrases,
        #                                                              diversity_ranker)
        #         para_phrases = []
        #         for para_phrase, diversity_score in diversity_scored_phrases.items():
        #             para_phrases.append((para_phrase, diversity_score))
        #         para_phrases.sort(key=lambda x: x[1], reverse=True)
        #         return para_phrases
        #     else:
        #         return [(save_phrase, 0)]





# class Parrot():
#
#     def __init__(self, model_tag="prithivida/parrot_paraphraser_on_T5",):
#
#         if model_tag == "prithivida/parrot_paraphraser_on_T5":
#             model_tag = '/data/data/hf_model_hub/parrot-paraphraser-on-T5'
#
#         from transformers import AutoTokenizer
#         from transformers import AutoModelForSeq2SeqLM
#         import pandas as pd
#         from .filters import Adequacy
#         from .filters import Fluency
#         from .filters import Diversity
#         self.tokenizer = AutoTokenizer.from_pretrained(model_tag, use_auth_token=False)
#         self.model = AutoModelForSeq2SeqLM.from_pretrained(model_tag, use_auth_token=False)
#         self.adequacy_score = Adequacy()
#         self.fluency_score = Fluency()
#         self.diversity_score = Diversity()
#
#     # def rephrase(self, input_phrase, use_gpu=False, diversity_ranker="levenshtein", do_diverse=False, style=1,
#     #              max_length=32, adequacy_threshold=0.90, fluency_threshold=0.90):
#     #     if use_gpu:
#     #         device = "cuda:0"
#     #     else:
#     #         device = "cpu"
#     #
#     #     self.model = self.model.to(device)
#     #     import re
#     #     save_phrase = input_phrase
#     #     if len(input_phrase) >= max_length:
#     #         max_length += 32
#     #     input_phrase = re.sub('[^a-zA-Z0-9 \?\'\-\/\:\.]', '', input_phrase)
#     #     input_phrase = "paraphrase: " + input_phrase
#     #     input_ids = self.tokenizer.encode(input_phrase, return_tensors='pt')
#     #     input_ids = input_ids.to(device)
#     #     max_return_phrases = 10
#     #     if do_diverse:
#     #         for n in range(2, 9):
#     #             if max_return_phrases % n == 0:
#     #                 break
#     #                 # print("max_return_phrases - ", max_return_phrases , " and beam groups -", n)
#     #         preds = self.model.generate(
#     #             input_ids,
#     #             do_sample=False,
#     #             max_length=max_length,
#     #             num_beams=max_return_phrases,
#     #             num_beam_groups=n,
#     #             diversity_penalty=2.0,
#     #             early_stopping=True,
#     #             num_return_sequences=max_return_phrases)
#     #     else:
#     #         preds = self.model.generate(
#     #             input_ids,
#     #             do_sample=True,
#     #             max_length=max_length,
#     #             top_k=50,
#     #             top_p=0.95,
#     #             early_stopping=True,
#     #             num_return_sequences=max_return_phrases)
#     #
#     #     paraphrases = set()
#     #
#     #     for pred in preds:
#     #         gen_pp = self.tokenizer.decode(pred, skip_special_tokens=True).lower()
#     #         gen_pp = re.sub('[^a-zA-Z0-9 \?\'\-]', '', gen_pp)
#     #         paraphrases.add(gen_pp)
#     #
#     #     adequacy_filtered_phrases = self.adequacy_score.filter(input_phrase, paraphrases, adequacy_threshold, device)
#     #     if len(adequacy_filtered_phrases) > 0:
#     #         fluency_filtered_phrases = self.fluency_score.filter(adequacy_filtered_phrases, fluency_threshold, device)
#     #         if len(fluency_filtered_phrases) > 0:
#     #             diversity_scored_phrases = self.diversity_score.rank(input_phrase, fluency_filtered_phrases,
#     #                                                                  diversity_ranker)
#     #             para_phrases = []
#     #             for para_phrase, diversity_score in diversity_scored_phrases.items():
#     #                 para_phrases.append((para_phrase, diversity_score))
#     #             para_phrases.sort(key=lambda x: x[1], reverse=True)
#     #             return para_phrases[0]
#     #         else:
#     #             return [(save_phrase, 0)]
#
#     def augment(self, input_phrase, use_gpu=False, diversity_ranker="levenshtein", do_diverse=False,
#                 max_return_phrases=10, max_length=32, adequacy_threshold=0.90, fluency_threshold=0.90):
#         if use_gpu:
#             device = "cuda:0"
#         else:
#             device = "cpu"
#
#         self.model = self.model.to(device)
#
#         import re
#
#         save_phrase = input_phrase
#         if len(input_phrase) >= max_length:
#             max_length += 32
#
#         input_phrase = re.sub('[^a-zA-Z0-9 \?\'\-\/\:\.]', '', input_phrase)
#         input_phrase = "paraphrase: " + input_phrase
#         input_ids = self.tokenizer.encode(input_phrase, return_tensors='pt')
#         input_ids = input_ids.to(device)
#
#         if do_diverse:
#             for n in range(2, 9):
#                 if max_return_phrases % n == 0:
#                     break
#                     # print("max_return_phrases - ", max_return_phrases , " and beam groups -", n)
#             preds = self.model.generate(
#                 input_ids,
#                 do_sample=False,
#                 max_length=max_length,
#                 num_beams=max_return_phrases,
#                 num_beam_groups=n,
#                 diversity_penalty=2.0,
#                 early_stopping=True,
#                 num_return_sequences=max_return_phrases)
#         else:
#             preds = self.model.generate(
#                 input_ids,
#                 do_sample=True,
#                 max_length=max_length,
#                 top_k=50,
#                 top_p=0.95,
#                 early_stopping=True,
#                 num_return_sequences=max_return_phrases)
#
#         paraphrases = set()
#
#         for pred in preds:
#             gen_pp = self.tokenizer.decode(pred, skip_special_tokens=True).lower()
#             gen_pp = re.sub('[^a-zA-Z0-9 \?\'\-]', '', gen_pp)
#             paraphrases.add(gen_pp)
#
#         adequacy_filtered_phrases = self.adequacy_score.filter(input_phrase, paraphrases, adequacy_threshold, device)
#         if len(adequacy_filtered_phrases) > 0:
#             fluency_filtered_phrases = self.fluency_score.filter(adequacy_filtered_phrases, fluency_threshold, device)
#             if len(fluency_filtered_phrases) > 0:
#                 diversity_scored_phrases = self.diversity_score.rank(input_phrase, fluency_filtered_phrases,
#                                                                      diversity_ranker)
#                 para_phrases = []
#                 for para_phrase, diversity_score in diversity_scored_phrases.items():
#                     para_phrases.append((para_phrase, diversity_score))
#                 para_phrases.sort(key=lambda x: x[1], reverse=True)
#                 return para_phrases
#             else:
#                 return [(save_phrase, 0)]
