from core.smooth_analyze import *

if __name__ == '__main__':
    plt.style.use('seaborn-dark')
    sns.set_theme(style="darkgrid")

    files = [
        '/home/orange/Main/Experiment/ICLR/exp/STD_100_1000_0.25_1.0',
        '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.25_-0.1_2',
        '/home/orange/Main/Experiment/ICLR/exp/STD_100_1000_0.25_0.0',
        '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.25_-0.05',
        '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.25_-0.1',
        '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.25_-0.25',
        #
        #
        # '/home/orange/Main/Experiment/ICLR/exp/STD_100_1000_0.5_1.0',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.5_-0.1_2',
        # '/home/orange/Main/Experiment/ICLR/exp/STD_100_1000_0.5_0.0',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.5_-0.05',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.5_-0.1',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_0.5_-0.25',
        #
        # '/home/orange/Main/Experiment/ICLR/exp/STD_100_1000_1.0_1.0',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_1.0_-0.1_2',
        # '/home/orange/Main/Experiment/ICLR/exp/STD_100_1000_1.0_0.0',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_1.0_-0.05',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_1.0_-0.1',
        # '/home/orange/Main/Experiment/ICLR/exp/SMRAP_100_1000_1.0_-0.25',
    ]
    line_styles = ['r-', 'r--', 'b-', '--', '--', '--', ]
    fig, ax = plt.subplots()
    fig.set_size_inches(16,9)

    ress = []
    for file, line_style in zip(files, line_styles):
        res = ApproximateAccuracy(file).at_radii(np.linspace(0, 3, 120))
        idx = np.where(res == 0)[0][0]
        res[idx - 3:idx] = res[idx - 3]
        x = np.linspace(0, 3, len(res))
        ax.plot(x, res, line_style)
        ress += [res]
        s = ''
        for i in range(0, 120, 10):
            s += '{0:.1f}'.format(res[i] * 100) + '&'
        print(s)
    ax.legend(['Cohen et al',
               'Cohen et al  + SCRFP ($\eta=0.10)$',
               'Salmon et al',
               'Salmon et al + SCRFP ($\eta=0.05$)',
               'Salmon et al + SCRFP ($\eta=0.10$)',
               'Salmon et al + SCRFP ($\eta=0.25$)'], fontsize=22)
    # ax.set_xlim([0, 2.5])
    # ax.set_ylim([0., 0.48])
    # ax.set_xlim([0, 1.4])
    # ax.set_ylim([0., 0.65])
    ax.set_xlim([0., 0.76])
    ax.set_ylim([0., 0.75])
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(20)

    ax.set_xlabel('Radius', fontsize=24)
    ax.set_ylabel('Accuracy', fontsize=24)
    ax.tick_params(axis='x', labelsize=24)
    ax.tick_params(axis='y', labelsize=24)
    # for tick in ax.xaxis.get_major_ticks():
    #     tick.label.set_fontsize(14)
    # plt.show()
    plt.savefig('imagenet-025', bbox_inches='tight')
    ress = np.array(ress)
    # plt.savefig('cert-025.png', bbox_inches='tight')
    print(1)
