import spacy
import os
from difflib import ndiff
import json

def levenshtein_distance(str1, str2, threshold=5):
    counter = {"+": 0, "-": 0}
    distance = 0
    for edit_code, *_ in ndiff(str1, str2):
        if edit_code == " ":
            distance += max(counter.values())
            if distance > threshold:
                return distance
            counter = {"+": 0, "-": 0}
        else:
            counter[edit_code] += 1

    distance += max(counter.values())
    return distance

def find_same(string, string_ls, threshold=5):
    sim_string = (100000, '')
    string = string.strip()

    if string in string_ls:
        return string


    for cand_string in string_ls:
        dist = levenshtein_distance(string, cand_string, threshold)

        if dist < sim_string[0]:
            sim_string = (dist, cand_string)

    if  sim_string[0] > threshold:
        print("Failed to find the correspondance.")
        return ""

    else:
        return sim_string[1]

def match_caps(vllama_pred, vllama_gpt_pred, vllama_mapping_path):
    gpt_llama_keys = list(vllama_gpt_pred.keys())
    if os.path.exists(vllama_mapping_path):
        orig2pred = json.load(open(vllama_mapping_path, 'r'))
    else:
        orig2pred = {}

    for vid, llama_pred_info in vllama_pred.items():
        for orig, llama_pred in llama_pred_info.items():
            same_str = find_same(llama_pred, gpt_llama_keys, threshold=5)
            orig2pred[orig] = same_str
        json.dump(orig2pred, open(vllama_mapping_path, 'w'))
    print("here")

def collect_orig_2_gpt_res(orig2pred, vllama_gpt_pred, orig_2_gpt_res_path):

    orig_2_gpt_res = {}
    for orig_cap, pred_cap in orig2pred.items():
        if not pred_cap == "":
            orig_2_gpt_res[orig_cap] = vllama_gpt_pred[pred_cap]
            vllama_gpt_pred["vllama_caption"] = pred_cap

    json.dump(orig_2_gpt_res, open(orig_2_gpt_res_path, 'w'))

    print("here")


if __name__ == "__main__":
    nl_data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../data/open_pvsg/nl2spec'))

    vllama_pred_filename = "videollamav2_caption.json"
    vllama_gpt_filename = "open_pvsg_vllamav2_gpt4_cache.json"
    vllama_mapping_filename = "origcap2vllama.json"
    orig_2_gpt_res_filename = "videollamav2_origcap_2_gpt_cache.json"

    vllama_pred_path = os.path.join(nl_data_dir, vllama_pred_filename)
    vllama_gpt_path = os.path.join(nl_data_dir, vllama_gpt_filename)
    vllama_mapping_path = os.path.join(nl_data_dir, vllama_mapping_filename)
    orig_2_gpt_res_path = os.path.join(nl_data_dir, orig_2_gpt_res_filename)

    vllama_pred = json.load(open(vllama_pred_path, 'r'))
    vllama_gpt_pred = json.load(open(vllama_gpt_path, 'r'))

    # orig2pred = match_caps(vllama_pred, vllama_gpt_pred, vllama_mapping_path)
    if os.path.exists(vllama_mapping_path):
        orig2pred = json.load(open(vllama_mapping_path, 'r'))
        collect_orig_2_gpt_res(orig2pred, vllama_gpt_pred, orig_2_gpt_res_path)

    print("end")