import numpy as np
import matplotlib.pyplot as plt
if __name__ == '__main__':
    data = np.load('compare_data.npy')
    jin_data = np.array([data[i] for i in range(len(data)) if data[i][1]==1])
    wang_data = np.array([data[i] for i in range(len(data)) if data[i][1]==0])
    wang_eps_list = np.linspace(0.02, 0.005,8)
    jin_eps_list = np.linspace(0.1, 0.04,8)
    jin_plot = []
    wang_plot = []
    for eps in jin_eps_list:
        jin_temp = np.array([jin_data[i] for i in range(len(jin_data)) if jin_data[i][0]==eps])
        jin_to_append = [jin_temp[:,2]@jin_temp[:,4]/300,jin_temp[:,3]@jin_temp[:,4]/300]
        jin_plot.append(jin_to_append)
    for eps in wang_eps_list:
        wang_temp = np.array([wang_data[i] for i in range(len(wang_data)) if wang_data[i][0]==eps])
        wang_to_append = [wang_temp[:,2]@wang_temp[:,4]/300,wang_temp[:,3]@wang_temp[:,4]/300]
        wang_plot.append(wang_to_append)
    
    jin_plot = np.array(jin_plot)
    wang_plot = np.array(wang_plot)
    jin_plot[:,1] = jin_plot[:,1]-0.5
    wang_plot[:,1] = wang_plot[:,1]-0.5
    
    plt.rcParams["figure.figsize"] = (6.5,4.5)
    plt.rcParams.update({'font.size': 18})
    plt.rcParams['font.family'] = 'serif'
    plt.tick_params(axis='both', which='major', labelsize=16)
    plt.rc('legend',fontsize=13)
    plt.ylim(0.0003, 0.01)
    plt.xscale('log', base=10)
    plt.yscale('log', base=10)
    plt.grid(True, which="both", linestyle='--', linewidth=0.5)
    
    #c = ['tab:blue','tab:red','tab:green']
    
    slope, intercept = np.polyfit(np.log10(jin_plot[:,0]), np.log10(jin_plot[:,1]), 1)
    regression_line = 10**(slope * np.log10(jin_plot[:,0]) + intercept)
    plt.scatter(jin_plot[:,0],jin_plot[:,1],c='tab:blue',label =f'[J&S21], slope = {slope:.2f}')
    plt.plot(jin_plot[:,0],regression_line, linestyle='--',c='tab:blue')
    
    slope, intercept = np.polyfit(np.log10(wang_plot[:,0]), np.log10(wang_plot[:,1]), 1)
    regression_line = 10**(slope * np.log10(wang_plot[:,0]) + intercept)
    plt.scatter(wang_plot[:,0],wang_plot[:,1],c='tab:red',label =f'This work, slope = {slope:.2f}')
    plt.plot(wang_plot[:,0],regression_line, linestyle='--',c='tab:red')


    
    plt.xlabel("number of samples")
    plt.ylabel("average error")
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.savefig('compare_jin_wang_plot', dpi=800)