import json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

with open("origianlBenchmark.json", 'r') as f:
    original_data = json.load(f)

wait_img_filename = []
for ep_dict in original_data:
    if ep_dict["action"]["action_type"] == "wait":
        wait_img_filename.append(ep_dict["img_filename"])

print(f"The ep with action 'wait': {len(wait_img_filename)}")

dup = set()

with open("exploreBenchmark_width_correct.json", 'r') as f:
    wdith_data = json.load(f)

new_wdith_data = []
wait_num = 0
repeat_num = 0
instruction = []
for ep_dict in wdith_data:
    instruction.append(ep_dict["high_level_instruction"])


# Vectorize the text using TF-IDF
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(instruction)

# Calculate cosine similarity matrix
similarity_matrix = cosine_similarity(tfidf_matrix)

# Set similarity threshold
threshold = 0.9
filtered_texts = []

# Filter out similar texts
for i, text in enumerate(instruction):
    if not any(similarity_matrix[i, :i] > threshold):  # Check similarity with previous texts
        filtered_texts.append(text)

rule_str = ["emoji", 'GIF', 'voice']
for i, text in enumerate(instruction):
    pass_flag = False
    if wdith_data[i]["img_filename"] in wait_img_filename:
        wait_num += 1
        continue
    for rule in rule_str:
        if rule in wdith_data[i]["high_level_instruction"]:
            pass_flag = True
    if pass_flag:
        continue

    if not any(similarity_matrix[i, :i] > threshold):
        new_wdith_data.append(wdith_data[i])
    else:
        repeat_num += 1


cnt = 0
for ep_dict in new_wdith_data:
    ep_dict["idx"] = cnt
    if ep_dict["action"]["action_type"] == "scroll":
        high_level_instruction = ep_dict["high_level_instruction"].replace(" by scrolling up", "")
        high_level_instruction = high_level_instruction.replace(" by scrolling down", "")
        high_level_instruction = high_level_instruction.replace(" by scrolling left", "")
        high_level_instruction = high_level_instruction.replace(" by scrolling right", "")

        high_level_instruction = high_level_instruction.replace("Scroll down ", "")
        high_level_instruction = high_level_instruction.replace("Scroll up ", "")
        high_level_instruction = high_level_instruction.replace("Scroll left ", "")
        high_level_instruction = high_level_instruction.replace("Scroll right ", "")
        if high_level_instruction and not high_level_instruction[0].isupper():
            high_level_instruction = high_level_instruction[0].upper() + high_level_instruction[1:]
        ep_dict["high_level_instruction"] = high_level_instruction
        if ep_dict["action"]["direction"] == "up":
            ep_dict["action"]["direction"] = "down"
        if ep_dict["action"]["direction"] == "down":
            ep_dict["action"]["direction"] = "up"
        if ep_dict["action"]["direction"] == "left":
            ep_dict["action"]["direction"] = "right"
        if ep_dict["action"]["direction"] == "right":
            ep_dict["action"]["direction"] = "left"
    cnt += 1
print(f"wait num: {wait_num}, repeat num: {repeat_num}")
print(f"Original ExploreBenchmark: {len(wdith_data)}, New ExploreBenchmark: {len(new_wdith_data)}")
with open("exploreBenchmark_wdith_w_high_quality.json", 'w') as f:
    json.dump(new_wdith_data, f, indent=4)