import torch
import spacy
from typing import Set
from evaluate import load
from nltk.util import ngrams
from collections import defaultdict
from .paradetox_metrics import compute_bleu_paradetox, compute_j_score
from rouge_score import rouge_scorer
import spacy
import numpy as np
import collections
import math
from transformers import AutoTokenizer
from torch.cuda import empty_cache


def compute_metric(metric_name, predictions, references, **kwargs):
    if metric_name == "mauve":
        return compute_mauve(predictions=predictions, references=references)
    elif metric_name == "div":
        return compute_diversity(all_texts_list=predictions)['diversity']
    elif metric_name == "mem":
        return compute_memorization(all_texts_list=predictions, human_references=references)
    elif metric_name == "ppl":
        return compute_ppl(predictions=predictions)
    elif metric_name == "j_score":
        return compute_j_score(predictions=predictions, sources=kwargs["sources"])
    elif metric_name == "bleu_paradetox":
        return compute_bleu_paradetox(predictions=predictions, references=references)
    elif metric_name == "bertscore":
        return compute_bert_score(predictions=predictions, references=references)
    elif metric_name == "rouge":
        return compute_rouge(predictions=predictions, references=references)
    elif metric_name == "bleu":
        return compute_bleu(predictions=predictions, references=references)
    else:
        raise Exception(f"Unknown metric: {metric_name}")



def compute_bert_score(predictions, references):
    torch.cuda.empty_cache()

    bertscore = load("bertscore")
    results = bertscore.compute(predictions=predictions, references=references, model_type="microsoft/deberta-xlarge-mnli")
    return np.mean(results["f1"])


def compute_rouge(predictions, references):
    torch.cuda.empty_cache() 

    rouge = load('rouge')
    assert len(predictions) == len(references)

    metrics = rouge.compute(predictions=predictions, references=references)
    return metrics


def filter_empty_texts(predictions, references):
    pred_list = []
    ref_list = []
    for i in range(len(predictions)):
        if predictions[i] and references[i]:
            pred_list.append(predictions[i])
            ref_list.append(references[i])
    return pred_list, ref_list


def compute_ppl(predictions, model_id='gpt2-large'):
    torch.cuda.empty_cache()

    predictions = [p for p in predictions if p]

    perplexity = load("perplexity", module_type="metric", model_id=model_id)
    ppl_list = perplexity.compute(
        predictions=predictions, 
        model_id=model_id, 
        device='cuda', 
        add_start_token=True,
    )["perplexities"]
    ppl_list = np.sort(ppl_list)
    quantile = 0.05
    a_min, a_max = int(quantile * len(ppl_list)), int((1 - quantile) * len(ppl_list))
    ppl_list = ppl_list[a_min: a_max]
    ppl = np.mean(ppl_list)
    return ppl


def compute_mauve(predictions, references, model_id='gpt2-large'):
    torch.cuda.empty_cache() 

    mauve = load("mauve")
    assert len(predictions) == len(references)

    predictions, references = filter_empty_texts(predictions, references)

    results = mauve.compute(
        predictions=predictions, references=references,
        featurize_model_name=model_id, device_id=0, verbose=False
    )

    return results.mauve


def compute_wordcount(all_texts_list):
    wordcount = load("word_count")
    wordcount = wordcount.compute(data=all_texts_list)
    return wordcount['unique_words']


def compute_diversity(all_texts_list):
    ngram_range = [2, 3, 4]

    tokenizer = spacy.load("en_core_web_sm").tokenizer
    token_list = []
    for sentence in all_texts_list:
        token_list.append([str(token) for token in tokenizer(sentence)])
    ngram_sets = {}
    ngram_counts = defaultdict(int)

    metrics = {}
    for n in ngram_range:
        ngram_sets[n] = set()
        for tokens in token_list:
            ngram_sets[n].update(ngrams(tokens, n))
            ngram_counts[n] += len(list(ngrams(tokens, n)))
        metrics[f'{n}gram_repitition'] = (1 - len(ngram_sets[n])/ngram_counts[n])
    diversity = 1
    for val in metrics.values():
        diversity *= (1 - val)
    metrics['diversity'] = diversity
    return metrics


def compute_memorization(all_texts_list, train_unique_four_grams: Set[tuple[str]], n=4):
    tokenizer = spacy.load("en_core_web_sm").tokenizer

    total = 0
    duplicate = 0
    for sentence in all_texts_list:
        four_grams = list(ngrams([str(token) for token in tokenizer(sentence)], n))
        total += len(four_grams)
        for four_gram in four_grams:
            if four_gram in train_unique_four_grams:
                duplicate += 1

    return duplicate / total

def _get_ngrams(segment, max_order):
    """Extracts all n-grams upto a given maximum order from an input segment.

    Args:
      segment: text segment from which n-grams will be extracted.
      max_order: maximum length in tokens of the n-grams returned by this
          methods.

    Returns:
      The Counter containing all n-grams upto max_order in segment
      with a count of how many times each n-gram occurred.
    """
    ngram_counts = collections.Counter()
    for order in range(1, max_order + 1):
        for i in range(0, len(segment) - order + 1):
            ngram = tuple(segment[i:i+order])
            ngram_counts[ngram] += 1
    return ngram_counts


def bleu(reference_corpus, 
         translation_corpus, 
         max_order=4,
         smooth=False):
    """Computes BLEU score of translated segments against one or more references.

    Args:
    reference_corpus: list of lists of references for each translation. Each
        reference should be tokenized into a list of tokens.
    translation_corpus: list of translations to score. Each translation
        should be tokenized into a list of tokens.
    max_order: Maximum n-gram order to use when computing BLEU score.
    smooth: Whether or not to apply Lin et al. 2004 smoothing.

    Returns:
    3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
    precisions and brevity penalty.
    """
    matches_by_order = [0] * max_order
    possible_matches_by_order = [0] * max_order
    reference_length = 0
    translation_length = 0
    for (references, translation) in zip(reference_corpus,
                                        translation_corpus):
        reference_length += min(len(r) for r in references)
        translation_length += len(translation)

    merged_ref_ngram_counts = collections.Counter()
    for reference in references:
        merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
    translation_ngram_counts = _get_ngrams(translation, max_order)
    overlap = translation_ngram_counts & merged_ref_ngram_counts
    for ngram in overlap:
        matches_by_order[len(ngram)-1] += overlap[ngram]
    for order in range(1, max_order+1):
        possible_matches = len(translation) - order + 1
        if possible_matches > 0:
            possible_matches_by_order[order-1] += possible_matches

    precisions = [0] * max_order
    for i in range(0, max_order):
        if smooth:
            precisions[i] = ((matches_by_order[i] + 1.) /
                            (possible_matches_by_order[i] + 1.))
        else:
            if possible_matches_by_order[i] > 0:
                precisions[i] = (float(matches_by_order[i]) /
                                    possible_matches_by_order[i])
            else:
                precisions[i] = 0.0

    if min(precisions) > 0:
        p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
        geo_mean = math.exp(p_log_sum)
    else:
        geo_mean = 0

    ratio = float(translation_length) / reference_length

    if ratio > 1.0:
        bp = 1.
    else:
        bp = math.exp(1 - 1. / ratio)

    bleu = geo_mean * bp

    return (bleu, precisions, bp, ratio, translation_length, reference_length)

def compute_bleu(predictions, references, max_order=4, smooth=False):
    empty_cache()
    
    tokenizer_mbert = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
    references = [[tokenizer_mbert.tokenize(item)] for item in references]
    predictions = [tokenizer_mbert.tokenize(item) for item in predictions]

    results = bleu(reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth)
    return results[0]