import pickle
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science', 'grid'])
import numpy as np
import os

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']

N = 8

def GPRS_data_store_filename( N, id_string ):
    filename_prefix = './../Misc_Data/'
    if id_string == 'MI':
        return filename_prefix + 'GPRS_BSC_New' + '_MI_' + str(N) + '.pickle'
    elif id_string == 'Rates':
        return filename_prefix + 'GPRS_BSC_New' + '_Rates_' + str(N) + '.pickle'



with open( GPRS_data_store_filename( N, 'Rates' ), 'rb' ) as f:
    gprs_rates = pickle.load(f)

def safe_xlogx(x):
    if x == 0:
        return 0
    else:
        return x * np.log2(x)

def binary_entropy(p):
    return -1 * safe_xlogx(p) - 1 * safe_xlogx(1 - p)

gprs_points_prange = np.linspace(1e-2, 0.5-1e-2, 15)
gprs_points = np.array( [ 1 - binary_entropy(p) for p in gprs_points_prange ] )


def experiment_size(N):
    if N<=2**17:
        num_trials = 200
        num_points = 100
    else:
        num_trials = 100
        num_points = 10
    return num_trials, num_points


def PolarSim_plotting(N):
    _, n_points = experiment_size(N)
    channel_args = np.linspace(0,0.5,n_points)
    MI = np.array( [ 1 - binary_entropy(p) for p in channel_args ] )
    rates = []
    rates_5 = []
    rates_95 = []
    for p in channel_args:
        rates_p = np.genfromtxt("./../PolarSim_Rate_Data/BSC_trials_N_" + str(N) + "_p_" + str(round(p, 4)) + ".csv", delimiter=",")
        rates.append( np.median(rates_p) )
        rates_5.append( np.percentile( rates_p, 5 ) )
        rates_95.append( np.percentile( rates_p, 95 ) )
    return MI, rates, rates_5, rates_95

N = 2**12
MI_right, rate_right, rate_right_5, rate_right_95  = PolarSim_plotting(N)

N_high = 2**17
MI_right_highN, rate_right_highN, rate_right_highN_5, rate_right_highN_95 = PolarSim_plotting(N_high)


fig = plt.figure()

new_plot = fig.add_subplot(111)

#Low N
new_plot.fill_between(MI_right, rate_right_5, rate_right_95, color = '#fb6a4a', alpha=0.35 )
new_plot.plot(MI_right, rate_right, label = 'PolarSim, n=' + str(N), color = '#fb6a4a' )

#High N
new_plot.fill_between(MI_right_highN, rate_right_highN_5, rate_right_highN_95, color = '#cb181d', alpha=0.35 )
new_plot.plot(MI_right_highN, rate_right_highN, label = 'PolarSim, n=' + str(N_high), color = '#cb181d' )

#GPRS
new_plot.plot( gprs_points, np.median(gprs_rates, axis = 1), color = 'black' )

new_plot.tick_params(axis='both', which='major', color='0', labelsize=7)
new_plot.tick_params(axis='both', which='minor', color='0.3')
new_plot.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)

new_plot.plot([0, 1], [0, 1], color = '#2CA25F', linewidth = 0.5, label = 'Lower bound')
fig.supxlabel('$I(X;Y)$', size=9)
fig.supylabel('Rate', size=9)
# plt.legend()
fig.savefig("Figure5.pdf")
#fig.savefig("UpdatedComparePolarSimGPRS.pdf")