from metric_method.bart_score import BARTScorer
import bert_score
import json
import os
from typing import List, Optional, Union, Dict, Tuple
import torch 
from metric_method.huggingface_bleu.bleu import Bleu
from rouge import Rouge
import jieba
import sys
from metric_method.perplexity import Perplexity_compute
sys.setrecursionlimit(1000000)


device = torch.device("cuda")



def eval_bertscore(
    hypotheses: List[List[str]],
    references: List[List[str]],
    model_type="bert-base-multilingual-cased",
    lang="en",
    ) -> List[float]:
    """
    Evaluate the hypothesis and reference using bertscore.

    Args:
        hypotheses: the hypotheses
        references: the references
    """
    print("Evaluating bertscore")
    assert len(hypotheses) == len(references)
    if lang=="en":
        P, R, F1 = bert_score.score(hypotheses, references, lang=lang, verbose=True, model_type="./bert-base-multilingual-cased",num_layers=9, batch_size=128)
    else:
        P, R, F1 = bert_score.score(hypotheses, references, lang=lang, verbose=True, model_type="./bert-base-chinese",num_layers=10, batch_size=128)
   

   
    return P.mean(),R.mean(),F1.mean()


def calculate_score(hypotheses,references):



    metric={}
    bart_scorer = BARTScorer(device=device, max_length=512,checkpoint='./bart-large-cnn')
    bart_scorer.load(path="./bart_score.pth")
    bart_scores = bart_scorer.score(hypotheses,references, batch_size=128)

    metric["bartscore"]=round(sum(bart_scores)/len(bart_scores), 4)


    P,R,F1 = eval_bertscore(hypotheses, references,lang="en")


    metric["bertscore P"]=round(P.item()*100, 4)
    metric["bertscore R"]=round(R.item()*100, 4)
    metric["bertscore F"]=round(F1.item()*100, 4)



    bleu_reference=[]
    for i in references:
        bleu_reference.append([i])

    bleu = Bleu()
    bleuscores=bleu.compute(predictions=hypotheses,references=bleu_reference)

    metric["scarebleu"]=round(bleuscores["bleu"]*100, 4)
    metric["bleu-1"]=round(bleuscores["precisions"][0]*100, 4)
    metric["bleu-2"]=round(bleuscores["precisions"][1]*100, 4)
    metric["bleu-3"]=round(bleuscores["precisions"][2]*100, 4)
    metric["bleu-4"]=round(bleuscores["precisions"][3]*100, 4)

   
    rouge = Rouge()

    rougescores=rouge.get_scores(hypotheses, references, avg = True)

    metric["rouge-1"]=round(rougescores["rouge-1"]["f"]*100, 4)
    metric["rouge-2"]=round(rougescores["rouge-2"]["f"]*100, 4)
    metric["rouge-l"]=round(rougescores["rouge-l"]["f"]*100, 4)

    print(metric)



def calculate_score_cn(hypotheses,references,use_jieba=False):

    print(len(hypotheses))

    metric={}
    bart_scorer = BARTScorer(device=device, max_length=512,checkpoint='./bart-base-chinese',use_cn=True)
    bart_scores = bart_scorer.score(hypotheses,references, batch_size=128)


    metric["bartscore"]=round(sum(bart_scores)/len(bart_scores), 4)


    P,R,F1 = eval_bertscore(hypotheses, references,lang="zh")


    metric["bertscore P"]=round(P.item()*100, 4)
    metric["bertscore R"]=round(R.item()*100, 4)
    metric["bertscore F"]=round(F1.item()*100, 4)




    bleu_reference=[]


    if use_jieba:
        bleu_reference = [[" ".join(jieba.cut(i))] for i in references]
        hypotheses_data = [ " ".join(jieba.cut(i)) for i in hypotheses]
        references_data = [ " ".join(jieba.cut(i)) for i in references]
    else:
        bleu_reference = [[" ".join(i)] for i in references]
        hypotheses_data = [ " ".join(i) for i in hypotheses]
        references_data = [ " ".join(i) for i in references]



    bleu = Bleu()
    bleuscores=bleu.compute(predictions=hypotheses_data,references=bleu_reference)

    metric["scarebleu"]=round(bleuscores["bleu"]*100, 4)
    metric["bleu-1"]=round(bleuscores["precisions"][0]*100, 4)
    metric["bleu-2"]=round(bleuscores["precisions"][1]*100, 4)
    metric["bleu-3"]=round(bleuscores["precisions"][2]*100, 4)
    metric["bleu-4"]=round(bleuscores["precisions"][3]*100, 4)

    

    rouge = Rouge()

    rougescores=rouge.get_scores(hypotheses_data, references_data, avg=True)

    metric["rouge-1"]=round(rougescores["rouge-1"]["f"]*100, 4)
    metric["rouge-2"]=round(rougescores["rouge-2"]["f"]*100, 4)
    metric["rouge-l"]=round(rougescores["rouge-l"]["f"]*100, 4)

    print(metric)


def read_file(input_file):
    with open(input_file, 'r') as f:
   
        data = json.load(f)

    hypotheses=[]
    references=[]
    for i in data:
        hypotheses.append(i["predict_answer"])
        references.append(i["truth_answer"])
    return hypotheses,references

if __name__ == '__main__':


    hypotheses3,references3 = read_file("")
    calculate_score(hypotheses3,references3)


