from rouge import Rouge
import json
import os
import numpy as np
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

def remove_stopwords(text):


    # Set up stop words
    stop_words = set(stopwords.words('english'))

    # Example text

    # Tokenize the text
    tokens = nltk.word_tokenize(text)

    # Remove stop words from the text
    filtered_text = [word for word in tokens if not word.lower() in stop_words]

    # Join the filtered words back into a sentence
    filtered_sentence = ' '.join(filtered_text)
    return filtered_sentence



# batch = [1, 2, 4, 16, 32, 64, 128]
batch = [1, 2, 4, 8]

rouge = Rouge()

text1_base = 'wikitext-103/'

text2_base = 'wikitext-103-res-byte/'
res_base = 'wikitext-103-res-byte/score/'

# text2_base = 'wikitext-103-res-subword/'
# res_base = 'wikitext-103-res-subword/score/'
for b in batch:
    scores_l = []
    scores_1 = []
    scores_2 = []
    s_l = float("-inf")
    s_1 = float("-inf")
    s_2 = float("-inf")
    # s_l = float("inf")
    # s_1 = float("inf")
    # s_2 = float("inf")
    for i in range(1, 6):
        text1 = text1_base + str(b) +  '/' + str(i)
        text2 = text2_base + str(b) +  '/' + str(i)

        with open(text2, 'r') as f:
        # Load the contents of the file into a Python dictionary
            data = json.load(f)

        original = data["original"]
        # original = [''.join(original)]

        # print(original)
        # exit()

        ress = data["results"]
        # print(ress)
        # exit()

        for o in original:
            for r in ress:
                # print(r)
                # exit()
                oo = remove_stopwords(o)
                rr = remove_stopwords(r["result"][0])
                sco = rouge.get_scores(oo, rr)
                f_l = sco[0]["rouge-l"]["f"]
                f_1 = sco[0]["rouge-1"]["f"]
                f_2 = sco[0]["rouge-2"]["f"]
                # print("fl is ", f_l)
                # exit()
                if f_l > s_l:
                    ori = o
                    res = r
                    s_l = f_l

                if f_1 > s_1:
                    ori = o
                    res = r
                    s_1 = f_1

                if f_2 > s_2:
                    ori = o
                    res = r
                    s_2 = f_2
        scores_l.append(s_l)
        scores_1.append(s_1)
        scores_2.append(s_2)
    
    output_path = res_base + str(b) +  '/'
    out = {}
    out["ori"] = ori
    out["res"] = res
    out["scores_l"] = scores_l
    out["average rouge-L score"] = np.mean(scores_l)
    out["scores_1"] = scores_1
    out["average rouge-1 score"] = np.mean(scores_1)
    out["scores_2"] = scores_2
    out["average rouge-2 score"] = np.mean(scores_2)



    # Specify the path and filename for the JSON file
    json_path = output_path

    # Create the directory if it doesn't exist
    if not os.path.exists(os.path.dirname(json_path)):
        os.makedirs(os.path.dirname(json_path))

    # Write the JSON data to the file
    with open(json_path + '/' + str(b), "w") as json_file:
        json.dump(out, json_file, indent=4)


