import spacy, random
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np
# Code from https://github.com/HKUNLP/ProGen/blob/main/scripts/self_bleu.py
def bleu_i(weights, all_sentences, smoothing_function, i):
    # noinspection PyTypeChecker
    return sentence_bleu(
        references=all_sentences[:i] + all_sentences[i + 1:],
        hypothesis=all_sentences[i],
        weights=weights,
        smoothing_function=smoothing_function)

def sample_sentences(all_sentences, n_sample):
    nlp = spacy.load('en_core_web_sm', disable=['parser', 'tagger', 'ner'])
    if n_sample is not None and n_sample < len(all_sentences):
        all_sentences = random.sample(all_sentences, n_sample)
    all_sentences = [[tok.text for tok in nlp(s)] for s in all_sentences]
    return all_sentences

def cal_self_bleu(all_sentences, n_sample, n_gram=3):
    random.seed(0)
    all_sentences = sample_sentences(all_sentences, n_sample)
    smoothing_function = SmoothingFunction().method1
    bleu_scores = []

    if n_gram == 1:
        weights = (1.0, 0, 0, 0)
    elif n_gram == 2:
        weights = (0.5, 0.5, 0, 0)
    elif n_gram == 3:
        weights = (1.0 / 3, 1.0 / 3, 1.0 / 3, 0)
    elif n_gram == 4:
        weights = (0.25, 0.25, 0.25, 0.25)
    elif n_gram == 5:
        weights = (0.2, 0.2, 0.2, 0.2, 0.2)
    else:
        raise ValueError
    
    bleu_scores = [
        bleu_i(weights, all_sentences, smoothing_function, i) for i in range(len(all_sentences))
    ]

    return np.mean(bleu_scores), n_sample