
"""
Evaluate model correctness
"""
import numpy as np
import argparse
import tqdm
import os
import evaluate
import pickle
from sentence_transformers import CrossEncoder

class CrossEncoderSimilarity:
    def __init__(
            self,
            model_id="cross-encoder/stsb-distilroberta-base",
            device="cuda",
            weight=1
        ):
        self.model = CrossEncoder(model_id, device=device)
        self.weight = weight

    def __call__(self, sources=None, summaries=None):
        scores = self.model.predict([[src, sum] for src, sum in zip(sources, summaries)])
        return scores.tolist()

def eval_sentsim(run_name):
    criterion = CrossEncoderSimilarity()
    with open(run_name, 'rb') as f:
        generations = pickle.load(f)
    scores = []
    for gen in tqdm.tqdm(generations):
        prediction = gen['most_likeli_generated_text'].lstrip()
        answers = gen['answer'] + gen['additional_answers'] if 'additional_answers' in gen.keys() and gen['additional_answers'] is not None else gen['answer']
        max_score = 0.0
        for answer in answers:
            if isinstance(answer, str) and isinstance(prediction, str) and len(answer) > 0 and len(prediction) > 0:
                results = criterion(sources=answer, summaries=prediction)
            else:
                results = [0]
            max_score = max(results[0], max_score)
        scores.append(max_score)

    for i in np.arange(0.1, 1.1, 0.1):
        print(i, (np.asarray(scores) > i).sum() / len(scores))

    with open(run_name.split('.')[0] + '_sentsim_for_correctness.pkl', 'wb') as f:
        pickle.dump(scores, f)


def eval_rouge_L(run_name):
    criterion = evaluate.load('rouge')
    with open(run_name, 'rb') as f:
        generations = pickle.load(f)
    scores = []
    for gen in tqdm.tqdm(generations):
        prediction = gen['most_likeli_generated_text'].lstrip()
        answers = gen['answer'] + gen['additional_answers'] if 'additional_answers' in gen.keys() and gen[
            'additional_answers'] is not None else gen['answer']
        max_score = 0.0
        for answer in answers:
            if isinstance(answer, str) and isinstance(prediction, str) and len(answer) > 0 and len(prediction) > 0:
                results = criterion.compute(references=[answer], predictions=[prediction])
            else:
                results = {'rougeL': 0}
            max_score = max(results['rougeL'], max_score)
        scores.append(max_score)

    for i in np.arange(0.1, 1.1, 0.1):
        print(i, (np.asarray(scores) > i).sum() / len(scores))

    with open(run_name.split('.')[0] + '_rougel_for_correctness.pkl', 'wb') as f:
        pickle.dump(scores, f)

def cmdline_args():
    # Make parser object
    p = argparse.ArgumentParser(description=__doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    p.add_argument('--generation-path', default='')

    return (p.parse_args())

if __name__ == '__main__':
    args = cmdline_args()
    eval_sentsim(args.generation_path)
    eval_rouge_L(args.generation_path)