import argparse
import json

from .metrics import compute_metric

def postprocess_text(texts):
    return [text.lower() for text in texts]

def fix_tokenizer(texts):
    fixed_texts = []

    for text in texts:
        text = text.replace(" \u2019 ", "'")
        text = text.replace("\u2019 ", "'")
        text = text.replace("\u2019", "'")
        fixed_texts.append(text)
    return fixed_texts

def main(args):
    with open(args.pred_file, "r") as f:
        pred_data = json.load(f)

    predictions = fix_tokenizer(postprocess_text([pred["GEN"] for pred in pred_data]))
    references = fix_tokenizer(postprocess_text([pred["TRG"] for pred in pred_data]))
    sources = fix_tokenizer(postprocess_text([pred["SRC"] for pred in pred_data]))

    metrics_dict = dict()

    for metric_name in ["bleu_paradetox", "j_score"]:
        metrics_dict[metric_name] = compute_metric(
            metric_name, 
            predictions=predictions, 
            references=references, 
            sources=sources
        )

    for key, value in metrics_dict.items():
        print(f"{key}: {value}")
    
    save_dict = []
    for pred, ref, src in zip(predictions, references, sources):
        save_dict.append({
            "pred": pred,
            "ref": ref,
            "src": src
        })
    json.dump(save_dict, open(args.pred_file.replace(".json", "_postprocessed.json"), "w"), indent=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pred_file", type=str, required=True)
    args = parser.parse_args()

    main(args)