import numpy as np
import matplotlib.pyplot as plt
if __name__ == '__main__':
    a = np.load('avg_time_n_eqls_100.npy')
    n_params = a[1,:]
    est_1_data = a[2,:]
    est_2_data = a[3,:]
    
    plt.rcParams["figure.figsize"] = (6.5,4.5)
    plt.rcParams.update({'font.size': 18})
    plt.rcParams['font.family'] = 'serif'
    plt.tick_params(axis='both', which='major', labelsize=16)
    plt.rc('legend',fontsize=13)
    plt.xscale('log', base=10)
    #plt.yscale('log', base=2)
    plt.grid(True, which="both", linestyle='--', linewidth=0.5)
    plt.scatter(n_params,est_1_data, c='tab:blue')
    plt.scatter(n_params,est_2_data, c='tab:orange')
    plt.plot(n_params,est_1_data, c='tab:blue',label = 'Generator Gradient')
    plt.plot(n_params,est_2_data, c='tab:orange',label = 'Pathwise Differentiation')
    
    plt.xlabel("Dimension of θ")
    plt.ylabel("Avg runtime")
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.savefig('compare_runtime', dpi=800)