import distance
import os
import argparse


def code_match(code1, code2):
    clean_code1 = ''.join(code1.split())
    clean_code2 = ''.join(code2.split())
    edit_distance = distance.levenshtein(clean_code1, clean_code2)

    return edit_distance


def get_type(part, types):
    for line in types:
        line = line.replace('\n', '')
        part_name = line[:line.index(':')]
        if part == part_name:
            return line[line.index(':')+1:]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Best match')
    parser.add_argument('--model', type=str, help='model', default='gpt-4')
    parser.add_argument('--method', type=str, help='vanilla or miggpt', choices=['vanilla', 'miggpt'])
    parser.add_argument('--data', type=str, help='data path', default='./data_16k')
    args = parser.parse_args()
    model = args.model
    method = args.method
    gt_path = args.data

    if method == 'vanilla':
        result_path = f'./results_vanilla/{model}'
        save_dir = f'./metric/vanilla/{model}'
    elif method == 'miggpt':
        result_path = f'./results_miggpt/{model}'
        save_dir = f'./metric/miggpt/{model}'
    else:
        raise ValueError

    res_items = os.listdir(result_path)
    res_folders = [item for item in res_items if os.path.isdir(os.path.join(result_path, item))]
    res_folders.sort()

    cnt = 0
    best_match_cnt = 0
    best_match_cnt_patched = 0
    failed_new_list = []
    failed_new_patched_list = []
    t2_cnt = 0
    t1_cnt = 0
    t2_cnt_p = 0
    t1_cnt_p = 0

    for ii, item in enumerate(res_folders):
        print(f'================================={ii}/{len(res_folders)-1}=====================================')
        res_folder_path = os.path.join(result_path, item)
        gt_folders_path = os.path.join(gt_path, item)
        print(res_folder_path)
        if item[-2:] == '_h':
            sa = '.h'
        elif item[-2:] == '_c':
            sa = '.c'
        else:
            raise ValueError

        res_folder_items = os.listdir(res_folder_path)
        res_parts = [part for part in res_folder_items if os.path.isdir(os.path.join(res_folder_path, part))]
        res_parts.sort()
        cnt += len(res_parts)

        with open(f'{gt_folders_path}/type.txt', 'r') as f:
            types = f.readlines()

        for jj, part in enumerate(res_parts):
            res_part_path = os.path.join(res_folder_path, part)
            gt_part_path = os.path.join(gt_folders_path, part)

            print('-----------------------------------')
            print(f'{jj}\{len(res_parts)-1}: {res_part_path}')

            if not os.path.exists(f'{res_part_path}/part_new{sa}'):
                continue
            else:
                with open(f'{res_part_path}/part_new{sa}', 'r') as f:
                    res_code_new = f.read()

                with open(f'{gt_part_path}/part_new_gt{sa}', 'r') as f:
                    gt_code_new = f.read()
                
                distance_new = code_match(res_code_new, gt_code_new)
                if distance_new == 0:
                    best_match_cnt += 1
                    if get_type(part=part, types=types) == 't2':
                        t2_cnt += 1
                    elif get_type(part=part, types=types) == 't1':
                        t1_cnt += 1
                    else:
                        raise ValueError
                else:
                    if os.path.exists(f'{gt_part_path}/part_new_gt2{sa}'):
                        with open(f'{gt_part_path}/part_new_gt2{sa}', 'r') as f:
                            gt_code_new2 = f.read()
                        distance_new2 = code_match(res_code_new, gt_code_new2)
                        if distance_new2 == 0:
                            best_match_cnt += 1
                            if get_type(part=part, types=types) == 't2':
                                t2_cnt += 1
                            elif get_type(part=part, types=types) == 't1':
                                t1_cnt += 1
                            else:
                                raise ValueError
                        else:
                            failed_new_list.append(res_part_path)
                    else:
                        failed_new_list.append(res_part_path)

            if not os.path.exists(f'{res_part_path}/part_new_patched{sa}'):
                continue
            else:
                with open(f'{res_part_path}/part_new_patched{sa}', 'r') as f:
                    res_code_new_patched = f.read()

                with open(f'{gt_part_path}/part_new_patched_gt{sa}', 'r') as f:
                    gt_code_new_patched = f.read()

                distance_new_patched = code_match(res_code_new_patched, gt_code_new_patched)
                if distance_new_patched == 0:
                    best_match_cnt_patched += 1
                    if get_type(part=part, types=types) == 't2':
                        t2_cnt_p += 1
                    elif get_type(part=part, types=types) == 't1':
                        t1_cnt_p += 1
                    else:
                        raise ValueError
                else:
                    if os.path.exists(f'{gt_part_path}/part_new_patched_gt2{sa}'):
                        with open(f'{gt_part_path}/part_new_patched_gt2{sa}', 'r') as f:
                            gt_code_new_patched2 = f.read()
                        distance_new_patched2 = code_match(res_code_new_patched, gt_code_new_patched2)
                        if distance_new_patched2 == 0:
                            best_match_cnt_patched += 1
                            print(get_type(part=part, types=types))
                            if get_type(part=part, types=types) == 't2':
                                t2_cnt_p += 1
                            elif get_type(part=part, types=types) == 't1':
                                t1_cnt_p += 1
                            else:
                                raise ValueError
                        else:
                            failed_new_patched_list.append(res_part_path)
                    else:
                        failed_new_patched_list.append(res_part_path)

            print(f'Best match part_new code: {best_match_cnt}/{cnt}')
            print(f'Best match part_new_patched code: {best_match_cnt_patched}/{cnt}')
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print('============================================================================')
    print(f'Best match part_new code: {best_match_cnt}/{cnt}')
    print('t1: ', t1_cnt, t1_cnt/cnt)
    print('t2', t2_cnt, t2_cnt/cnt)
    print('============================================================================')
    print(f'Best match part_new_patched code: {best_match_cnt_patched}/{cnt}')
    print('t1: ', t1_cnt_p, t1_cnt_p/cnt)
    print('t2', t2_cnt_p, t2_cnt_p/cnt)
    with open(f'{save_dir}/failed_files.txt', 'w') as f:
        for name in failed_new_list:
            f.write(f'{name}\n')

    with open(f'{save_dir}/failed_files_patched.txt', 'w') as f:
        for name in failed_new_patched_list:
            f.write(f'{name}\n')

    with open(f'{save_dir}/result.txt', 'w') as f:
        f.write('============================================================================\n')
        f.write(f'Best match part_new code: {best_match_cnt}/{cnt}\n')
        f.write(f't1: {t1_cnt}, {t1_cnt/cnt}\n')
        f.write(f't2: {t2_cnt}, {t2_cnt/cnt}\n')
        f.write('============================================================================\n')
        f.write(f'Best match part_new_patched code: {best_match_cnt_patched}/{cnt}\n')
        f.write(f't1: {t1_cnt_p}, {t1_cnt_p/cnt}\n')
        f.write(f't2: {t2_cnt_p}, {t2_cnt_p/cnt}\n')