import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# class Adequacy():
#
#   def __init__(self, model_tag='/data/data/hf_model_hub/parrot-adequacy-model'):
#     from transformers import AutoModelForSequenceClassification, AutoTokenizer
#     self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag)
#     self.tokenizer = AutoTokenizer.from_pretrained(model_tag)
#
#   def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
#       top_adequacy_phrases = []
#       for para_phrase in para_phrases:
#         x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
#         x = x.to(device)
#         self.adequacy_model = self.adequacy_model.to(device)
#         logits = self.adequacy_model(**x).logits
#         probs = logits.softmax(dim=1)
#         prob_label_is_true = probs[:,1]
#         adequacy_score = prob_label_is_true.item()
#         if adequacy_score >= adequacy_threshold:
#             top_adequacy_phrases.append(para_phrase)
#       return top_adequacy_phrases
#
#
#   def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
#       adequacy_scores = {}
#       for para_phrase in para_phrases:
#         x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
#         x = x.to(device)
#         self.adequacy_model = self.adequacy_model.to(device)
#         logits = self.adequacy_model(**x).logits
#         probs = logits.softmax(dim=1)
#         prob_label_is_true = probs[:,1]
#         adequacy_score = prob_label_is_true.item()
#         if adequacy_score >= adequacy_threshold:
#           adequacy_scores[para_phrase] = adequacy_score
#       return adequacy_scores

def _get_filter_score(tokenizer_, model_, device_, max_length_, input_text_list):
    batch_input_dicts = tokenizer_(input_text_list, return_tensors='pt', max_length=max_length_, truncation=True, padding=True)
    batch_input_dicts = batch_input_dicts.to(device_)
    with torch.no_grad():
        predictions = model_(**batch_input_dicts)
    probs = predictions.logits.softmax(dim=-1).cpu().numpy()
    scores = probs[:, 1]

    return scores

class Adequacy():
    def __init__(self, device):
        ad_model_dir = 'parrot-adequacy-model'
        self.adequacy_tokenizer = AutoTokenizer.from_pretrained(ad_model_dir)
        self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(ad_model_dir).to(device)
        self.device = device

    def filter(self, input_phrase, para_phrases, adequacy_threshold):

        if len(para_phrases) == 0:
            return para_phrases

        filtered_adequacy_phrases = []
        input_text_pairs = [[input_phrase, para_p] for para_p in para_phrases]

        # input_dicts = self.adequacy_tokenizer(input_text_pairs, return_tensors='pt', max_length=128, truncation=True, padding=True)
        # input_dicts = input_dicts.to(self.device)
        # with torch.no_grad():
        #     logits = self.adequacy_model(**input_dicts).logits
        # probs = logits.softmax(dim=1)
        # prob_label_is_true = probs[:, 1]
        adequacy_score_list = _get_filter_score(self.adequacy_tokenizer, self.adequacy_model, self.device, 128, input_text_pairs)

        for i in range(len(para_phrases)):
            adequacy_score = adequacy_score_list[i]
            # print(adequacy_score)
            if adequacy_score > adequacy_threshold:
                filtered_adequacy_phrases.append(para_phrases[i])

        return filtered_adequacy_phrases





class Fluency():
    def __init__(self, device):
        model_dir = 'parrot-fluency-model'
        self.fluency_model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=2).to(device)
        self.fluency_tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.device = device

    def filter(self, para_phrases, fluency_threshold):
        if len(para_phrases) == 0:
            return para_phrases

        top_fluent_phrases = []


        batch_text_input = ['Sentence: ' + para_p for para_p in para_phrases]
        fluency_scores = _get_filter_score(self.fluency_tokenizer, self.fluency_model, self.device, self.fluency_tokenizer.model_max_length, batch_text_input)

        for i in range(len(para_phrases)):
            para_phrase = para_phrases[i]
            fluency_s = fluency_scores[i]
            # print(fluency_s)
            if fluency_s >= fluency_threshold:
                top_fluent_phrases.append(para_phrase)

        return top_fluent_phrases
      


class Diversity():

  def __init__(self, model_tag='paraphrase-distilroberta-base-v2'):
    from sentence_transformers import SentenceTransformer
    self.diversity_model = SentenceTransformer(model_tag)

  def rank(self, input_phrase, para_phrases, diversity_ranker='levenshtein'):
      if diversity_ranker == "levenshtein":
        return self.levenshtein_ranker(input_phrase, para_phrases)
      elif diversity_ranker == "euclidean":
        return self.euclidean_ranker(input_phrase, para_phrases)
      elif diversity_ranker == "diff":
        return self.diff_ranker(input_phrase, para_phrases)

  def euclidean_ranker(self, input_phrase, para_phrases):
      import pandas as pd
      from sklearn_pandas import DataFrameMapper
      from sklearn.preprocessing import MinMaxScaler
      from scipy import spatial

      diversity_scores = {}
      outputs = []
      input_enc = self.diversity_model.encode(input_phrase.lower())
      for para_phrase in para_phrases:              
          paraphrase_enc = self.diversity_model.encode(para_phrase.lower())
          euclidean_distance = (spatial.distance.euclidean(input_enc, paraphrase_enc))
          outputs.append((para_phrase,  euclidean_distance))
      df = pd.DataFrame(outputs, columns=['paraphrase', 'scores'])
      fields = []
      for col in df.columns:
          if col == "scores":
              tup = ([col], MinMaxScaler())
          else:  
              tup = ([col], None)
          fields.append(tup) 

      mapper = DataFrameMapper(fields, df_out=True)
      for index, row in mapper.fit_transform(df.copy()).iterrows():
          diversity_scores[row['paraphrase']] = row['scores']
      return  diversity_scores

  def levenshtein_ranker(self, input_phrase, para_phrases):
      import Levenshtein
      diversity_scores = {}
      for para_phrase in para_phrases:              
          distance = Levenshtein.distance(input_phrase.lower(), para_phrase)
          diversity_scores[para_phrase] =  distance
      return diversity_scores
  
  def diff_ranker(self, input_phrase, para_phrases):
    import difflib
    differ = difflib.Differ()
    diversity_scores ={}
    for para_phrase in para_phrases:
        diff = differ.compare(input_phrase.split(), para_phrase.split())
        count = 0
        for d in diff:
          if "+" in d or "-" in d:
            count += 1
        diversity_scores[para_phrase] = count
    return diversity_scores
