from transformers import AutoTokenizer
from tiktoken import encoding_for_model
import numpy as np


def edit_distance(s1, s2):
    dp = np.full((len(s1) + 5, len(s2) + 5), 10000000, dtype=int)
    dp[0][0] = 0
    for i in range(1, len(s1) + 1):
        dp[i][0] = i
    for i in range(1, len(s2) + 1):
        dp[0][i] = i
    for i in range(1, len(s1) + 1):
        for j in range(1, len(s2) + 1):
            dp[i][j] = min(dp[i - 1][j - 1] + int(s1[i - 1] != s2[j - 1]), dp[i - 1][j] + 1, dp[i][j - 1] + 1)
    return dp[len(s1)][len(s2)]

def tokens_match(tokens1, tokens2):
    dp = np.full((len(tokens1) + 5, len(tokens2) + 5), 10000000, dtype=int)
    record = np.full((len(tokens1) + 5, len(tokens2) + 5), 0, dtype=int)
    dp[0][0] = 0
    for i in range(1, len(tokens1) + 1):
        dp[i][0] = dp[i - 1][0] + len(tokens1[i - 1])
        record[i][0] = 2
    for i in range(1, len(tokens2) + 1):
        dp[0][i] = dp[0][i - 1] + len(tokens2[i - 1])
        record[0][i] = 3
    for i in range(1, len(tokens1) + 1):
        for j in range(1, len(tokens2) + 1):
            # print("++", i, j)
            c_ij = edit_distance(tokens1[i - 1], tokens2[j - 1])
            tmp1 = dp[i - 1][j - 1] + c_ij
            tmp2 = dp[i - 1][j] + len(tokens1[i - 1])
            tmp3 = dp[i][j - 1] + len(tokens2[j - 1])

            if tmp1 == min(tmp1, tmp2, tmp3):
                dp[i][j] = tmp1
                record[i][j] = 1
            elif tmp2 == min(tmp1, tmp2, tmp3):
                dp[i][j] = tmp2
                record[i][j] = 2
            else:
                dp[i][j] = tmp3
                record[i][j] = 3

    def dfs_retrieve(i, j):
        # print(i, j, record[i][j])
        if i == 0 and j == 0:
            return
        if record[i][j] == 1:
            matches.append((i, j, 1))
            dfs_retrieve(i - 1, j - 1)
        elif record[i][j] == 2:
            matches.append((i, j, 2))
            dfs_retrieve(i - 1, j)
        elif record[i][j] == 3:
            matches.append((i, j, 3))
            dfs_retrieve(i, j - 1)
        else:
            raise RuntimeError

    matches = []
    dfs_retrieve(len(tokens1), len(tokens2))
    print(dp[len(tokens1)][len(tokens2)])
    matches.reverse()
    print(dp)
    print(matches)



if __name__ == '__main__':
    # T5_tokenizer = AutoTokenizer.from_pretrained('/data1/cyr/resources/t5-base')
    # GPT_tokenizer = encoding_for_model('text-davinci-003')
    #
    # s1 = "select Name FROM member WHERE Country  =  \"United States\" OR Country  =  \"Canada\""
    # s2 = "SELECT T2.balance FROM accounts AS T1 JOIN checking AS T2 ON T1.custid > T2.custid WHERE T1.name LIKE '%ee%'"
    # s3 = "`1213234234543254325"
    #
    #
    # tokens_1 = T5_tokenizer.tokenize(s2)
    # tokens_2 = [GPT_tokenizer.decode_single_token_bytes(tok) for tok in GPT_tokenizer.encode(s2)]

    # pad = '▁'
    #
    # print(tokens_1)
    # print(tokens_2)
    #
    # pos = 0
    # string_1 = ''
    # pos_list_1 = []
    # if tokens_1[0] != pad:
    #     tokens_1[0] = tokens_1[0].replace(pad, '')
    # else:
    #     tokens_1 = tokens_1[1:]
    #
    # for i in range(0, len(tokens_1)):
    #
    #     # print(tokens_1[i])
    #     string_1 += tokens_1[i].replace(pad, ' ')
    #     pos_list_1 += [len(string_1)]
    #
    # string_2 = ''
    # pos_list_2 = []
    # for i in range(0, len(tokens_2)):
    #     # print(tokens_1[i])
    #     string_2 += str(tokens_2[i], 'utf-8')
    #     pos_list_2 += [len(string_2)]
    #
    # print(string_1)
    # print(pos_list_1)
    #
    # print(string_2)
    # print(pos_list_2)

    # str1 = 'abc'
    # str2 = 'efasgbsdcsb'
    # print(edit_distance(str1, str2))

    tokens1 = ['▁cu', 'sti', 'd']
    # tokens2 = ['c', 'ust', 'id']
    tokens2 = ['▁', 'c', 'u', 'st', 'id']


    tokens_match(tokens1, tokens2)




