import torch
import spacy
from torch.nn.functional import cross_entropy
from typing import List, Set
from evaluate import load
from nltk.util import ngrams
from collections import defaultdict
from transformers import AutoTokenizer
import spacy
import numpy as np

from paradetox import compute_bleu_paradetox, compute_j_score


def compute_metric(metric_name, predictions, references, **kwargs):
    if metric_name == "mauve":
        return compute_mauve(predictions=predictions, references=references)
    elif metric_name == "div":
        return compute_diversity(all_texts_list=predictions)['diversity']
    elif metric_name == "mem":
        return compute_memorization(all_texts_list=predictions, human_references=references)
    elif metric_name.startswith("rouge"):
        return compute_rouge(predictions=predictions, references=references)[metric_name]
    elif metric_name == "bert-score":
        return compute_bert_score(predictions=predictions, references=references)
    elif metric_name == "bleu":
        return compute_bleu(predictions=predictions, references=references)["bleu"]
    elif metric_name == "meteor":
        return compute_meteor(predictions=predictions, references=references)
    elif metric_name == "bleu-paradetox":
        return compute_bleu_paradetox(predictions=predictions, references=references)
    elif metric_name == "j-score":
        return compute_j_score(predictions=predictions, sources=kwargs["sources"])
    elif metric_name == "len":
        return {
            "predictions_len": compute_len(sequences=predictions),
            "references_len": compute_len(sequences=references),
        }
    elif metric_name == "sari":
        return compute_sari(predictions=predictions, references=references, sources=kwargs["sources"])
    elif metric_name == "ppl":
        return compute_ppl(predictions=predictions)
    else:
        raise Exception(f"Unknown metric: {metric_name}")


def filter_empty_texts(predictions, references):
    pred_list = []
    ref_list = []
    for i in range(len(predictions)):
        if predictions[i] and references[i]:
            pred_list.append(predictions[i])
            ref_list.append(references[i])
    return pred_list, ref_list


def compute_ppl(predictions, model_id='gpt2-large'):
    torch.cuda.empty_cache()

    predictions = [p for p in predictions if p]

    perplexity = load("perplexity", module_type="metric", model_id=model_id)
    ppl_list = perplexity.compute(
        predictions=predictions, 
        model_id=model_id, 
        device='cuda', 
        add_start_token=True,
    )["perplexities"]
    ppl_list = np.sort(ppl_list)
    quantile = 0.05
    a_min, a_max = int(quantile * len(ppl_list)), int((1 - quantile) * len(ppl_list))
    ppl_list = ppl_list[a_min: a_max]
    ppl = np.mean(ppl_list)
    return ppl


def compute_mauve(predictions, references, model_id='gpt2-large'):
    torch.cuda.empty_cache() 

    mauve = load("mauve")
    assert len(predictions) == len(references)

    predictions, references = filter_empty_texts(predictions, references)

    results = mauve.compute(
        predictions=predictions, references=references,
        featurize_model_name=model_id, device_id=0, verbose=False
    )

    return results.mauve


def compute_wordcount(all_texts_list):
    wordcount = load("word_count")
    wordcount = wordcount.compute(data=all_texts_list)
    return wordcount['unique_words']


def compute_diversity(all_texts_list):
    ngram_range = [2, 3, 4]

    tokenizer = spacy.load("en_core_web_sm").tokenizer
    token_list = []
    for sentence in all_texts_list:
        token_list.append([str(token) for token in tokenizer(sentence)])
    ngram_sets = {}
    ngram_counts = defaultdict(int)

    metrics = {}
    for n in ngram_range:
        ngram_sets[n] = set()
        for tokens in token_list:
            ngram_sets[n].update(ngrams(tokens, n))
            ngram_counts[n] += len(list(ngrams(tokens, n)))
        metrics[f'{n}gram_repitition'] = (1 - len(ngram_sets[n])/ngram_counts[n])
    diversity = 1
    for val in metrics.values():
        diversity *= (1 - val)
    metrics['diversity'] = diversity
    return metrics


def compute_memorization(all_texts_list, train_unique_four_grams: Set[tuple[str]], n=4):
    tokenizer = spacy.load("en_core_web_sm").tokenizer

    total = 0
    duplicate = 0
    for sentence in all_texts_list:
        four_grams = list(ngrams([str(token) for token in tokenizer(sentence)], n))
        total += len(four_grams)
        for four_gram in four_grams:
            if four_gram in train_unique_four_grams:
                duplicate += 1

    return duplicate / total


def compute_rouge(predictions, references):
    torch.cuda.empty_cache() 

    rouge = load('rouge')
    result = rouge.compute(predictions=predictions, references=references)
    return result


def compute_bert_score(predictions, references):
    torch.cuda.empty_cache()

    bertscore = load("bertscore", module_type="metric")
    results = bertscore.compute(predictions=predictions, references=references, model_type='microsoft/deberta-xlarge-mnli', lang='en', verbose=True)
    # https://github.com/Shark-NLP/DiffuSeq/blob/f78945d79de5783a4329695c0adb1e11adde31bf/scripts/eval_seq2seq.py#L128C48-L128C115
    return np.mean(results["f1"])


def compute_bleu(predictions, references, max_order=4, smooth=False):
    torch.cuda.empty_cache()

    bleu = load("bleu")

    results = bleu.compute(references=references, predictions=predictions, max_order=max_order, smooth=smooth)
    return results


def compute_meteor(predictions, references):
    torch.cuda.empty_cache()

    meteor = load('meteor')
    return meteor.compute(predictions=predictions, references=references)["meteor"]


def compute_len(sequences):
    if isinstance(sequences[0], list):
        sequences = [seq for seq_list in sequences for seq in seq_list]

    wordcount = load("word_count")
    wordcount = wordcount.compute(data=sequences)
    return wordcount["total_word_count"] / len(sequences)


def compute_sari(predictions, references, sources):
    # pip install sacremoses
    torch.cuda.empty_cache()

    sari = load('sari')
    references = [[r] for r in references]
    return sari.compute(sources=sources, predictions=predictions, references=references)["sari"]
