import pickle
import numpy as np



def compare_with_bm_same_budget():
    def print_latex(budget='small'):
        uni_dir = './result/_test-TSP[20, 50, 100]-CVRP[20, 50, 100]-OP[20, 50, 100]-KP[50, 100, 200]-train12task_exp3_freq12'
        mtl_dir = './result/_test-TSP[20, 50, 100]-CVRP[20, 50, 100]-OP[20, 50, 100]-KP[50, 100, 200]-MTL_baseline'
        if budget == 'small':
            with open(uni_dir + '/result_gap_epoch{}-best.pkl'.format(
                        500), 'rb') as f:
                data = pickle.load(f)

            epoch_w = epoch_small_weigted
            epoch_e = epoch_small_equal
            with open(mtl_dir + '/result_gap_epoch{}-best.pkl'.format(
                        50), 'rb') as f:
                mtl_data = pickle.load(f)

        elif budget == 'median':
            with open(
                    uni_dir + '/result_gap_epoch{}-best.pkl'.format(
                        1000), 'rb') as f:
                data = pickle.load(f)
                epoch_w = epoch_median_weighted
                epoch_e = epoch_median_equal
            with open(mtl_dir + '/result_gap_epoch{}-best.pkl'.format(
                        100), 'rb') as f:
                mtl_data = pickle.load(f)
        elif budget == 'large':
            with open(
                    uni_dir
                    + '/result_gap_epoch{}-best.pkl'.format(
                        2000), 'rb') as f:
                data = pickle.load(f)
                epoch_w = epoch_large_weigted
                epoch_e = epoch_large_eqaul
            with open(mtl_dir + '/result_gap_epoch{}-best.pkl'.format(
                        200), 'rb') as f:
                mtl_data = pickle.load(f)

        uni_aug_res = data['aug_gap']

        mtl_aug_res = mtl_data['aug_gap']

        idx = [0,1,2] * 4

        no_aug_res_w = []
        aug_res_w = []
        file_epoch = zip((epoch_w.flatten().tolist())*(len(newpaths)//12), newpaths)
        for epoch, res_file in file_epoch:
            no_aug_res_w.append([])
            aug_res_w.append([])
            with open(res_file + '/result_gap_epoch{}.pkl'.format(epoch), 'rb') as f:
                data = pickle.load(f)
            no_aug_res_w[-1].append(data['no_aug_gap'])
            aug_res_w[-1].append(data['aug_gap'])
            no_aug_res_w[-1] = np.concatenate(no_aug_res_w[-1], axis=0)
            aug_res_w[-1] = np.concatenate(aug_res_w[-1], axis=0)

        aug_res_w = np.stack(aug_res_w, axis=0).reshape(-1,12,3).mean(0)[range(12),idx]


        no_aug_res_e = []
        aug_res_e = []
        file_epoch = zip((epoch_e.flatten().tolist())*(len(newpaths)//12), newpaths)
        for epoch, res_file in file_epoch:
            no_aug_res_e.append([])
            aug_res_e.append([])
            with open(res_file + '/result_gap_epoch{}.pkl'.format(epoch), 'rb') as f:
                data = pickle.load(f)
            no_aug_res_e[-1].append(data['no_aug_gap'])
            aug_res_e[-1].append(data['aug_gap'])
            no_aug_res_e[-1] = np.concatenate(no_aug_res_e[-1], axis=0)
            aug_res_e[-1] = np.concatenate(aug_res_e[-1], axis=0)

        aug_res_e = np.stack(aug_res_e, axis=0).reshape(-1,12,3).mean(0)[range(12),idx]

        ratio_w = uni_aug_res - aug_res_w
        ratio_e = uni_aug_res - aug_res_e

        mtl_ratio_w = mtl_aug_res - aug_res_w
        mtl_ratio_e = mtl_aug_res - aug_res_e

        return ratio_e, ratio_w, uni_aug_res, aug_res_e, aug_res_w, mtl_aug_res, mtl_ratio_e, mtl_ratio_w

    res_files = [
        './result/_test-TSP[20, 50, 100]-bm_tsp20',
        './result/_test-TSP[20, 50, 100]-bm_tsp50',
        './result/_test-TSP[20, 50, 100]-bm_tsp100',
        './result/_test-CVRP[20, 50, 100]-bm_cvrp20',
        './result/_test-CVRP[20, 50, 100]-bm_cvrp50',
        './result/_test-CVRP[20, 50, 100]-bm_cvrp100',
        './result/_test-OP[20, 50, 100]-bm_op20',
        './result/_test-OP[20, 50, 100]-bm_op50',
        './result/_test-OP[20, 50, 100]-bm_op100',
        './result/_test-KP[50, 100, 200]-bm_kp50',
        './result/_test-KP[50, 100, 200]-bm_kp100',
        './result/_test-KP[50, 100, 200]-bm_kp200',
    ]


    newpaths = []
    for repeat in ['_repeat{}_save'.format(i+1) for i in range(3)]:
        for file in res_files:
            newpaths.append(file+repeat)

    problem_list = ['TSP-20', 'TSP-50', 'TSP-100', 'CVRP-20', 'CVRP-50', 'CVRP-100', 'OP-20', 'OP-50', 'OP-100',
                    'KP-50', 'KP-100', 'KP-200']

    epoch_small_weigted = np.array([[61, 61, 47],
                          [44, 48, 39],
                         [59, 58, 59],
                          [35, 39, 32],])

    epoch_small_equal = np.array([
                                    [123, 61, 31],
                                    [89, 48, 26],
                                    [118, 58, 39],
                                    [70, 39, 21]])

    epoch_median_weighted = np.array([[127, 125, 98],
                                     [92, 99, 82],
                                      [122, 120, 122],
                                     [72, 81, 67]])

    epoch_median_equal = np.array([[254, 125, 65],
                                     [185, 99, 54],
                                    [244, 120, 81],
                                     [144, 81, 45]])


    epoch_large_weigted = np.array([[256, 254, 199],
                                   [187, 200, 165],
                                    [247, 243, 246],
                                   [146, 163, 136]])

    epoch_large_eqaul = np.array([[513, 254, 133],
                                   [374, 200, 110],
                                  [494, 243, 164],
                                   [292, 163,  90]])



    ratio_avg_s, ratio_bal_s, uni_aug_res_s, aug_res_e_s, aug_res_w_s, mtl_aug_res_s, mtl_ratio_avg_s, mtl_ratio_bal_s, = print_latex('small')
    ratio_avg_m, ratio_bal_m, uni_aug_res_m, aug_res_e_m, aug_res_w_m, mtl_aug_res_m, mtl_ratio_avg_m, mtl_ratio_bal_m = print_latex('median')
    ratio_avg_l, ratio_bal_l, uni_aug_res_l, aug_res_e_l, aug_res_w_l, mtl_aug_res_l, mtl_ratio_avg_l, mtl_ratio_bal_l = print_latex('large')


    res_gap = np.stack([aug_res_e_s,aug_res_w_s,mtl_aug_res_s, uni_aug_res_s,
                    aug_res_e_m,aug_res_w_m, mtl_aug_res_m, uni_aug_res_m,
                    aug_res_e_l,aug_res_w_l, mtl_aug_res_l, uni_aug_res_l],1).sum(0)
    res_diff = np.stack([ratio_avg_s,ratio_bal_s,
                            ratio_avg_m,ratio_bal_m,
                            ratio_avg_l,ratio_bal_l],1).sum(0)

    mtl_res_diff = np.stack([mtl_ratio_avg_s,mtl_ratio_bal_s,
                            mtl_ratio_avg_m,mtl_ratio_bal_m,
                            mtl_ratio_avg_l,mtl_ratio_bal_l],1).sum(0)

    for i in range(12):
        if i %3==0:
            print('\midrule')
            print('\midrule')
            print('\multirow{3}{*}{\\rotatebox[origin = c]{90}{}}')

        print('& {}'.format(problem_list[i]), end=' ')

        g1 = [aug_res_e_s[i],aug_res_w_s[i],mtl_aug_res_s[i],uni_aug_res_s[i]]
        idx1 = np.argmin(g1)
        for _ in range(4):
            if _==idx1:
                print('& $\mathbf{{{:.3f}}}\%$ '.format(g1[_]), end=' ')
            else:
                print('& ${:.3f}\%$ '.format(g1[_]), end=' ')

        g2 = [aug_res_e_m[i],aug_res_w_m[i],mtl_aug_res_m[i],uni_aug_res_m[i]]
        idx2 = np.argmin(g2)
        for _ in range(4):
            if _==idx2:
                print('& $\mathbf{{{:.3f}\%}}$ '.format(g2[_]), end=' ')
            else:
                print('& ${:.3f}\%$ '.format(g2[_]), end=' ')

        g3 = [aug_res_e_l[i],aug_res_w_l[i],mtl_aug_res_l[i],uni_aug_res_l[i]]
        idx3 = np.argmin(g3)
        for _ in range(4):
            if _==idx3:
                print('& $\mathbf{{{:.3f}}}\%$ '.format(g3[_]), end=' ')
            else:
                print('& ${:.3f}\%$'.format(g3[_]), end=' ')
        print('\\\\')


    print('\midrule')
    print('\midrule')
    print('& Total Gap', end=' ')
    for i in range(3):
        idx_gap = np.argmin(res_gap[i*4:(i+1)*4])
        for _ in range(4):
            if _==idx_gap:
                print('& $\mathbf{{{:.3f}}}\%$ '.format(res_gap[_+i*4]), end=' ')
            else:
                print('& ${:.3f}\%$ '.format(res_gap[_+i*4]), end=' ')
    print('\\\\')
    print('\midrule')
    print('& Gain by MTL & ${{{:.3f}}}\%$ & ${:.3f}\%$  & - & - & ${{{:.3f}}}\%$ & ${:.3f}\%$  & - & - & ${{{:.3f}}}\%$ & ${:.3f}\%$  & - & -  \\\\'.format(*mtl_res_diff))
    print('& Gain by Ours & $\mathbf{{{:.3f}}}\%$ & $\mathbf{{{:.3f}}}\%$  & - & - & $\mathbf{{{:.3f}}}\%$ & $\mathbf{{{:.3f}}}\%$  & -& - & $\mathbf{{{:.3f}}}\%$ & $\mathbf{{{:.3f}}}\%$  & - & -  \\\\'.format(*res_diff))
    print('\midrule')
    print('\midrule')


if __name__ == '__main__':
    compare_with_bm_same_budget()
