from tqdm import tqdm
import json
import argparse
from nltk import word_tokenize
import os
from simcse import SimCSE
import numpy as np
import nltk
nltk.download('punkt')

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file_name", default="/home/yewen001/CODE/zhiyi/New_score_Mistral_ICLR/Final_ICLR_2025/Experiment/LLama3_DPO/our_method/Decoding_results/overall_score_Sample_Prompt_num_2_Sample_Node_num_1_Sample_Original_Prompt_0_Max_Node_length_200.jsonl", type=str)
    parser.add_argument('--output_path', type=str, default='/home/yewen001/CODE/zhiyi/New_score_Mistral_ICLR/Final_ICLR_2025/Experiment/LLama3_DPO/our_method/metric/overall_score_Sample_Prompt_num_2_Sample_Node_num_1_Sample_Original_Prompt_0_Max_Node_length_200.jsonl')
    parser.set_defaults(bottleneck=True)
    parser.set_defaults(augment=True)
    args = parser.parse_args()
    return args


def compute_rep_n(text, n):
    tokens = word_tokenize(text)
    ngrams = [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]
    rep_n = 100 * (1.0 - len(set(ngrams)) / (len(ngrams) + 1))
    return rep_n


def compute_diversity(text):
    diversity = 1.0
    for n in range(2, 5):
        rep_n_val = compute_rep_n(text, n)
        diversity *= 1.0 - rep_n_val / 100
    return diversity


def clean(text, sep="###"):
    return text.split(sep)[0]


def average(entries):
    return sum(entries) / len(entries)



if __name__ == "__main__":
    args = get_args()

    
    entries = []
    output = json.load(open(args.file_name, "r"))

    for line in tqdm(output):
        try:
            generation = line
            prompt = generation["prompt"]
            response = generation["result"]

            if len(response) == 0:
                response = " "
            rep_2 = compute_rep_n(response, 2)
            rep_3 = compute_rep_n(response, 3)
            rep_4 = compute_rep_n(response, 4)
            diversity = compute_diversity(response)
            entries.append(
                {
                    "prompt": prompt,
                    "response": response,
                    "original_response": generation["response"][len(prompt) :],
                    "rep_2": rep_2,
                    "rep_3": rep_3,
                    "rep_4": rep_4,
                    "diversity": diversity,
                    "response_length": len(response),
                }
            )
        except:
            continue

    evaluations = {
        "rep_2": average([entry["rep_2"] for entry in entries]),
        "rep_3": average([entry["rep_3"] for entry in entries]),
        "rep_4": average([entry["rep_4"] for entry in entries]),
        "diversity": average([entry["diversity"] for entry in entries]),
        "response_length": average([entry["response_length"] for entry in entries]),
        "entries": entries,
    }

    json.dump(evaluations, open(args.output_path, "w"), indent=2)