# -*- coding: utf-8 -*-
"""
Created on Tue Oct 29 11:41:31 2024

@author: shara
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from numpy import genfromtxt
from scipy.optimize import minimize_scalar
import montecarlo_biawgn_capacity
import scienceplots
plt.style.use(['science', 'grid'])
import psutil, os
import Levenshtein

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']

filename_prefix = './../'

def legend_without_duplicate_labels(fig):
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
    handles, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    by_label = dict(zip(labels, handles))
    fig.legend(by_label.values(), by_label.keys(), bbox_to_anchor=(1.5, 0.4))

def safe_xlogx(x):
    if x == 0:
        return 0
    else:
        return x * np.log2(x)

def binary_entropy(p):
    return -1 * safe_xlogx(p) - 1 * safe_xlogx(1 - p)

def experiment_size(N):
    if N<=2**17:
        num_trials = 200
        num_points = 100
    else:
        num_trials = 100
        num_points = 10
    return num_trials, num_points

def get_arg_list(chan, n_points):
    if chan == 'BSC':
        arg_list = np.linspace(0,0.5,n_points)

    if chan == 'BIAWGN':
        arg_list_1 = np.linspace(0,np.sqrt(3),n_points)
        arg_list_2 = np.linspace(np.sqrt(3),3 ,10)
        arg_list = np.concatenate( [ arg_list_1, arg_list_2 ] )
    if chan == 'BEC':
        arg_list = np.linspace(0,1,100)

    return arg_list

def get_filename( chan, arg, N ):
    if chan == 'BSC':
        filename = "PolarSim_Rate_Data/BSC_trials_N_" + str(N) + "_p_" + str(round(arg, 4)) + ".csv"

    elif chan == 'BIAWGN':
        filename = "PolarSim_Rate_Data/BIAWGN_trials_N_" + str(N) + "_sig_" + str(round(arg, 4)) + ".csv"

    elif chan == 'BEC':
        filename = "PolarSim_Rate_Data/BEC_trials_N_" + str(N) + "_eps_" + str(round(arg, 4)) + ".csv"

    return filename_prefix + filename

def plot_for_channel( chan, N_range, colors, ax, y_label_visible = True, x_label_visible = True ):
    if chan == 'BSC':
        plot_y_label = 'p'
        lb_calc = lambda x: 1 - binary_entropy(x)
    elif chan == 'BIAWGN':
        plot_y_label = '$\sigma$'
        lb_calc = lambda x: montecarlo_biawgn_capacity.biawgn_capacity(x)
    elif chan == 'BEC':
        plot_y_label = '$\epsilon$'
        lb_calc = lambda x: 1 - x

    arg_list_lb = get_arg_list(chan, 100)
    rates_95 = {}
    rates_5 = {}
    rates = {}
    rates_lb = []
    for arg in arg_list_lb:
        rates_lb.append( lb_calc(arg) )
    #Plot LB
    ax.plot(rates_lb, arg_list_lb, label = 'LB', color = '#2ca25f', linewidth = 0.4)
    for i in range(len(N_range)):
        N = N_range[i]
        color = colors[i]
        rates[str(N)] = []
        rates_95[str(N)] = []
        rates_5[str(N)] = []
        _, n_points = experiment_size(N)
        arg_list = get_arg_list(chan, n_points)
        for arg in arg_list:
            filename = get_filename( chan, arg, N )
            rates_p = genfromtxt(filename, delimiter=",")
            rates[str(N)].append(np.median(rates_p))
            rates_95[str(N)].append( np.percentile(rates_p,95) )
            rates_5[str(N)].append( np.percentile(rates_p,5) )
        ax.plot(rates[str(N)], arg_list, label = 'n='+str(N), color = color, linewidth = 0.4)
        ax.fill_betweenx(arg_list, rates_5[str(N)], rates_95[str(N)], where=(np.array(rates_95[str(N)]) >= np.array(rates_5[str(N)])), color=color, alpha=0.3)
    if y_label_visible == True:
        ax.set_ylabel( plot_y_label, size=6 )
    else:
        ax.yaxis.label.set_visible(False)
    if x_label_visible == False:
        ax.xaxis.label.set_visible(False)
        ax.set_xticks(np.arange(0, 1.2, 0.2))
    ax.tick_params(axis='both', which='major', color='0', labelsize=4)
    ax.tick_params(axis='both', which='minor', color='0.3')
    ax.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)
    return

def plot_legend( N_range, colors, ax ):
    dummy = np.linspace(0,1,10)
    ax.plot(dummy, dummy, label = 'LB', color = '#2ca25f', linewidth = 0.7)
    for i in range(len(N_range)):
        N = N_range[i]
        color = colors[i]
        ax.plot(dummy, dummy, label = 'n='+str(N), color = color, linewidth = 0.7)

    legend = ax.legend()
    for line in ax.lines:
        line.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    for spine in ax.spines.values():
        spine.set_visible(False)

    return

fig, axs = plt.subplots( 3, 2, sharex='col', sharey = 'row' )
n_low = [2**12]
n_high = [2**17]
color_low = ['#fb6a4a']
color_high = ['#cb181d']
plot_for_channel( 'BSC', n_low, color_low, axs[0,0], y_label_visible=True, x_label_visible= False )
plot_for_channel( 'BSC', n_high, color_high, axs[0,1], y_label_visible=False , x_label_visible= False )
plot_for_channel( 'BIAWGN', n_low, color_low, axs[1,0], y_label_visible=True , x_label_visible= False )
plot_for_channel( 'BIAWGN', n_high, color_high, axs[1,1], y_label_visible=False , x_label_visible= False )

n_high_bec = [2**14]
color_high_bec = ['#67000d']
plot_for_channel( 'BEC', n_low, color_low, axs[2,0], y_label_visible=True , x_label_visible= True )
plot_for_channel( 'BEC', n_high_bec, color_high_bec, axs[2,1], y_label_visible=False , x_label_visible= True )
#legend_without_duplicate_labels( fig )
fig.supxlabel('Rate', size = 6, x = 0.515)
# plot_legend( N_range, colors, axs[1,1] )
# plt.tight_layout()
fig.savefig("Figure3.pdf")