import os, argparse
from codebleu import calc_codebleu


def code_bleu_val(reference, prediction, th=0.9, lang='c'):
    result = calc_codebleu([reference], [prediction], lang=lang, weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
    code_bleu = result['codebleu']
    if code_bleu > th:
        return True
    else:
        return False


def get_type(part, types):
    for line in types:
        line = line.replace('\n', '')
        part_name = line[:line.index(':')]
        if part == part_name:
            return line[line.index(':')+1:]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Semantic match with CodeBLEU')
    parser.add_argument('--model', type=str, help='model', default='gpt-4')
    parser.add_argument('--method', type=str, help='vanilla or miggpt', choices=['vanilla', 'miggpt'])
    parser.add_argument('--data', type=str, help='data path', default='./data_16k')
    args = parser.parse_args()
    model = args.model
    method = args.method
    gt_path = args.data

    if method == 'vanilla':
        result_path = f'./results_gpt/{model}'
        save_dir = f'./metric_codebleu/vanilla/{model}'
    elif method == 'miggpt':
        result_path = f'./results_miggpt/{model}'
        save_dir = f'./metric_codebleu/miggpt/{model}'
    else:
        raise ValueError

    res_items = os.listdir(result_path)
    res_folders = [item for item in res_items if os.path.isdir(os.path.join(result_path, item))]
    res_folders.sort()

    cnt = 0
    best_match_cnt = 0
    best_match_cnt_patched = 0
    failed_new_list = []
    failed_new_patched_list = []
    t2_cnt = 0
    t1_cnt = 0
    t2_cnt_p = 0
    t1_cnt_p = 0

    for ii, item in enumerate(res_folders):
        print(f'================================={ii}/{len(res_folders)-1}=====================================')
        res_folder_path = os.path.join(result_path, item)
        gt_folders_path = os.path.join(gt_path, item)
        print(res_folder_path)
        if item[-2:] == '_h':
            sa = '.h'
        elif item[-2:] == '_c':
            sa = '.c'
        else:
            raise ValueError

        res_folder_items = os.listdir(res_folder_path)
        res_parts = [part for part in res_folder_items if os.path.isdir(os.path.join(res_folder_path, part))]
        res_parts.sort()
        cnt += len(res_parts)

        with open(f'{gt_folders_path}/type.txt', 'r') as f:
            types = f.readlines()

        for jj, part in enumerate(res_parts):
            res_part_path = os.path.join(res_folder_path, part)
            gt_part_path = os.path.join(gt_folders_path, part)

            print('-----------------------------------')
            print(f'{jj}\{len(res_parts)-1}: {res_part_path}')

            if not os.path.exists(f'{res_part_path}/part_new{sa}'):
                continue
            else:
                with open(f'{res_part_path}/part_new{sa}', 'r') as f:
                    res_code_new = f.read()

                with open(f'{gt_part_path}/part_new_gt{sa}', 'r') as f:
                    gt_code_new = f.read()
                code_bleu_match = code_bleu_val(gt_code_new, res_code_new, th, lang='c')
                if code_bleu_match:
                    best_match_cnt += 1
                    if get_type(part=part, types=types) == 't2':
                        t2_cnt += 1
                    elif get_type(part=part, types=types) == 't1':
                        t1_cnt += 1
                    else:
                        raise ValueError
                else:
                    if os.path.exists(f'{gt_part_path}/part_new_gt2{sa}'):
                        with open(f'{gt_part_path}/part_new_gt2{sa}', 'r') as f:
                            gt_code_new2 = f.read()
                        code_bleu_match = code_bleu_val(gt_code_new2, res_code_new, th, lang='c')
                        if code_bleu_match:
                            best_match_cnt += 1
                            if get_type(part=part, types=types) == 't2':
                                t2_cnt += 1
                            elif get_type(part=part, types=types) == 't1':
                                t1_cnt += 1
                            else:
                                raise ValueError
                        else:
                            failed_new_list.append(res_part_path)
                    else:
                        failed_new_list.append(res_part_path)

            if not os.path.exists(f'{res_part_path}/part_new_patched{sa}'):
                continue
            else:
                with open(f'{res_part_path}/part_new_patched{sa}', 'r') as f:
                    res_code_new_patched = f.read()

                with open(f'{gt_part_path}/part_new_patched_gt{sa}', 'r') as f:
                    gt_code_new_patched = f.read()

                code_bleu_match = code_bleu_val(gt_code_new_patched, res_code_new_patched, th, lang='c')
                if code_bleu_match:
                    best_match_cnt_patched += 1
                    if get_type(part=part, types=types) == 't2':
                        t2_cnt_p += 1
                    elif get_type(part=part, types=types) == 't1':
                        t1_cnt_p += 1
                    else:
                        raise ValueError
                else:
                    if os.path.exists(f'{gt_part_path}/part_new_patched_gt2{sa}'):
                        with open(f'{gt_part_path}/part_new_patched_gt2{sa}', 'r') as f:
                            gt_code_new_patched2 = f.read()
                        code_bleu_match = code_bleu_val(gt_code_new_patched2, res_code_new_patched, th, lang='c')
                        if code_bleu_match:
                            best_match_cnt_patched += 1
                            print(get_type(part=part, types=types))
                            if get_type(part=part, types=types) == 't2':
                                t2_cnt_p += 1
                            elif get_type(part=part, types=types) == 't1':
                                t1_cnt_p += 1
                            else:
                                raise ValueError
                        else:
                            failed_new_patched_list.append(res_part_path)
                    else:
                        failed_new_patched_list.append(res_part_path)

            print(f'Best match part_new code: {best_match_cnt}/{cnt}')
            print(f'Best match part_new_patched code: {best_match_cnt_patched}/{cnt}')
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print('============================================================================')
    print(f'{model} {method}')
    print(f'Codebleu match part_new code: {best_match_cnt}/{cnt}')
    print('t1: ', t1_cnt, t1_cnt/cnt)
    print('t2', t2_cnt, t2_cnt/cnt)
    print('============================================================================')
    print(f'Codebleu match part_new_patched code: {best_match_cnt_patched}/{cnt}')
    print('t1: ', t1_cnt_p, t1_cnt_p/cnt)
    print('t2', t2_cnt_p, t2_cnt_p/cnt)

    with open(f'{save_dir}/failed_files.txt', 'w') as f:
        for name in failed_new_list:
            f.write(f'{name}\n')

    with open(f'{save_dir}/failed_files_patched.txt', 'w') as f:
        for name in failed_new_patched_list:
            f.write(f'{name}\n')

    with open(f'{save_dir}/result.txt', 'w') as f:
        f.write('============================================================================\n')
        f.write(f'Codebleu match part_new code: {best_match_cnt}/{cnt}\n')
        f.write(f't1: {t1_cnt}, {t1_cnt/cnt}\n')
        f.write(f't2: {t2_cnt}, {t2_cnt/cnt}\n')
        f.write('============================================================================\n')
        f.write(f'Codebleu match part_new_patched code: {best_match_cnt_patched}/{cnt}\n')
        f.write(f't1: {t1_cnt_p}, {t1_cnt_p/cnt}\n')
        f.write(f't2: {t2_cnt_p}, {t2_cnt_p/cnt}\n')