import nltk
import sys
import re
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from argparse import ArgumentParser
from utils.json_reader import jsonl_loader, json_loader
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer

def rouge_evaluation(jsonl_reference):
    json_reference_list= json_loader(jsonl_reference)
            
    Total_ROUGE_1_precision = 0
    Total_ROUGE_2_precision = 0
    Total_ROUGE_3_precision = 0
    Total_ROUGE_L_precision = 0

    Total_ROUGE_1_f1 = 0
    Total_ROUGE_2_f1 = 0
    Total_ROUGE_3_f1 = 0
    Total_ROUGE_L_f1 = 0

    Total_ROUGE_1_recall = 0
    Total_ROUGE_2_recall = 0
    Total_ROUGE_3_recall = 0
    Total_ROUGE_L_recall = 0
    
    for i, line in enumerate(json_reference_list):
        original_answer_ = line["original_answer"]
        rewrite_answer_ = line["answer"]

        original_answer = re.sub(r'[^A-Za-z0-9 ]+', '', original_answer_).lower()
        rewrite_answer = re.sub(r'[^A-Za-z0-9 ]+', '', rewrite_answer_).lower()

        # rouge Score Calculation using rouge-score
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rougeL'], use_stemmer=True)
        scores = scorer.score(original_answer, rewrite_answer)

        Total_ROUGE_1_precision += scores['rouge1'].precision
        Total_ROUGE_2_precision += scores['rouge2'].precision
        Total_ROUGE_3_precision += scores['rouge3'].precision
        Total_ROUGE_L_precision += scores['rougeL'].precision

        Total_ROUGE_1_f1 += scores['rouge1'].fmeasure
        Total_ROUGE_2_f1 += scores['rouge2'].fmeasure
        Total_ROUGE_3_f1 += scores['rouge3'].fmeasure
        Total_ROUGE_L_f1 += scores['rougeL'].fmeasure

        Total_ROUGE_1_recall += scores['rouge1'].recall
        Total_ROUGE_2_recall += scores['rouge2'].recall
        Total_ROUGE_3_recall += scores['rouge3'].recall
        Total_ROUGE_L_recall += scores['rougeL'].recall

        
    # Calculate average scores
    num_samples = len(json_reference_list)
    average_ROUGE1_precision = Total_ROUGE_1_precision / num_samples
    average_ROUGE2_precision = Total_ROUGE_2_precision / num_samples
    average_ROUGE3_precision = Total_ROUGE_3_precision / num_samples
    average_ROUGEL_precision = Total_ROUGE_L_precision / num_samples
    average_ROUGE1_f1 = Total_ROUGE_1_f1 / num_samples
    average_ROUGE2_f1 = Total_ROUGE_2_f1 / num_samples
    average_ROUGE3_f1 = Total_ROUGE_3_f1 / num_samples
    average_ROUGEL_f1 = Total_ROUGE_L_f1 / num_samples
    average_ROUGE1_recall = Total_ROUGE_1_recall / num_samples
    average_ROUGE2_recall = Total_ROUGE_2_recall / num_samples
    average_ROUGE3_recall = Total_ROUGE_3_recall / num_samples
    average_ROUGEL_recall = Total_ROUGE_L_recall / num_samples
    print(f"Average ROUGE-1 Precision: {average_ROUGE1_precision:.2f}")
    print(f"Average ROUGE-1 Recall: {average_ROUGE1_recall:.2f}")
    print(f"Average ROUGE-1 F1: {average_ROUGE1_f1:.2f}")
    print("\n")

    print(f"Average ROUGE-2 Precision: {average_ROUGE2_precision:.2f}")
    print(f"Average ROUGE-2 Recall: {average_ROUGE2_recall:.2f}")
    print(f"Average ROUGE-2 F1: {average_ROUGE2_f1:.2f}")
    print("\n")

    print(f"Average ROUGE-3 Precision: {average_ROUGE3_precision:.2f}")     
    print(f"Average ROUGE-3 Recall: {average_ROUGE3_recall:.2f}")
    print(f"Average ROUGE-3 F1: {average_ROUGE3_f1:.2f}")
    
    print("\n")
    print(f"Average ROUGE-L Precision: {average_ROUGEL_precision:.2f}")
    print(f"Average ROUGE-L Recall: {average_ROUGEL_recall:.2f}")
    print(f"Average ROUGE-L F1: {average_ROUGEL_f1:.2f}")

if __name__ == "__main__":
    parser= ArgumentParser()
    parser.add_argument("--jsonl_input", type=str, help = "Please enter the path to the dataset_04.jsonl (dataset generated with module 4).")
    args = parser.parse_args()

    rouge_evaluation(args.jsonl_input)
    