import argparse
import os
import json
from torchmetrics.text.bert import BERTScore
import re


os.environ["TOKENIZERS_PARALLELISM"] = "false"


def calculate_bertscore(
    preds_file,
    target_file,
    device,
    batch_size,
    lang = "en",
):
    preds = open(preds_file, "r").read()
    targets = open(target_file, "r").read()
    preds = re.split(r"\n\d+\t", preds)
    targets = re.split(r"\n\d+\t", targets)
    assert len(preds) == len(targets)
    
    bert_score = BERTScore(
        max_length=512,
        device=device,
        verbose=True,
        batch_size=batch_size,
        lang=lang, 
    )
    results = bert_score(preds, targets)
    return results


def run_bertscore(
    results_dir,
    device,
    batch_size,
    lang = "en",
):
    if not os.path.exists(results_dir):
        raise ValueError("results_dir does not exist")
    
    all_files = os.listdir(results_dir)
    preds_files = [f for f in all_files if "output" in f]
    target_files = [f for f in all_files if "golden" in f]
    assert len(preds_files) == len(target_files)
    
    scores = []
    
    if len(preds_files) == 1:
        preds_file = os.path.join(results_dir, "output.txt")
        target_file = os.path.join(results_dir, "golden.txt")
        scores.append(calculate_bertscore(
            preds_file,
            target_file,
            device,
            batch_size,
            lang = lang,
        ))
    else:
        for i in range(len(preds_files)):
            preds_file = os.path.join(results_dir, "output_{}.txt".format(i))
            target_file = os.path.join(results_dir, "golden_{}.txt".format(i))
            scores.append(calculate_bertscore(
                preds_file,
                target_file,
                device,
                batch_size,
            ))
        
    result = {}
    for key in scores[0].keys():
        result[key] = []
    
    for score in scores:
        for key, values in score.items():
            result[key].extend(values)
    
    result = {k: sum(v)/len(v) for k, v in result.items()}
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results_dir", type=str, default="./results/Alpaca-13B_0S/MediaSum/Alpaca_Observed/2")
    
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--batch_size", type=int, default=256)
    args = parser.parse_args()
    result = run_bertscore(**vars(args))
    json.dump(result, open(os.path.join(args.results_dir, "bertscore.json"), "w"))
    print(result)
    

if __name__ == "__main__":
    main()
    