# -*- coding: utf-8 -*-
"""
Created on Wed Oct 30 01:12:01 2024

@author: shara
"""

import matplotlib.pyplot as plt
import numpy as np
import random
import time
import math
from scipy.optimize import minimize_scalar
import pandas as pd
import psutil, os



import scienceplots
plt.style.use(['science', 'grid'])

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']

## 3 - Plotting polarization
from polar_modified import *

filename_prefix = './../Misc_Data/'
def safe_xlogx(x):
    if x == 0:
        return 0
    else:
        return x * np.log2(x)

def binary_entropy(p):
    return -1 * safe_xlogx(p) - 1 * safe_xlogx(1 - p)

def get_error_prob( channel, N, epsilon ):
    filename = filename_prefix + channel + str(N) + '_' + str(round(epsilon, 4)) + 'MCErrorProbs.npy'
    try:
        error_prob = np.load(filename)
        return error_prob
    except:
        if channel == 'BSC':
            chan = bsc_p1
        error_prob = polar_channel_mc( int(np.log2(N)), chan, epsilon, 2001 )
        error_prob = np.array( [ min(a,1-a) for a in error_prob ] )
        np.save(filename, error_prob)
        return error_prob

def plotting_polarization(ax_scattered, ax_sorted, N, epsilon, color, show_x = False):
    # Calculate capacity and error probability

    error_prob = get_error_prob('BSC', N, epsilon)
    capacity = np.array( [ 1 - binary_entropy(p) for p in error_prob ] )


    polarsim_rate = np.array( [ binary_entropy( 0.5 - p ) for p in error_prob ] )

    i_vector = np.arange(N)

    # To save to file if needed
    # df = pd.DataFrame({'Index': i_vector, 'Capacity': capacity})
    # csv_name = "polarization_8192_0.1"
    # df.to_csv(csv_name, index=False)

    # To read from file if needed
    # df = pd.read_csv(csv_name)
    # inds = df['Index'].values
    # capacity = df['Capacity'].values

    # Plot capacity
    # plt.rcParams['font.size'] = 20
    keep_indices = np.random.choice(len(i_vector), size=int(0.1 * len(i_vector)), replace=False)  # Keep 20% of the data
    i_vector_thinned = i_vector[keep_indices]
    capacity_thinned = capacity[keep_indices]

    ax_scattered.scatter(i_vector_thinned,capacity_thinned,s=0.025, color = '#2CA25F')
    #plt.title('(a) BSC Polarization - Channel Capacity For Each Subchannel, Epsilon=' + str(epsilon), fontdict = {'fontsize' : 30})
    if show_x == True:
        ax_scattered.set_xlabel('Index', size = 6)
    ax_scattered.tick_params(axis='both', which='major', color='0', labelsize=4)
    ax_scattered.tick_params(axis='both', which='minor', color='0.3')
    ax_scattered.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)
    #ax_scattered.set_ylabel('Capacity')

    sorted_inds = np.argsort(capacity)

    # Plot sorted capacity
    # plt.rcParams['font.size'] = 20
    ax_sorted.fill_between(i_vector, capacity[sorted_inds], polarsim_rate[sorted_inds], color='yellow', alpha=0.3)
    ax_sorted.scatter(i_vector,capacity[sorted_inds],s=0.025, label="Capacity", color = '#2CA25F')
    ax_sorted.scatter(i_vector,polarsim_rate[sorted_inds],s=0.025, label="PolarSim", color = color)

    #plt.title('(d) BSC Polarization - Sorted  Channel Capacity, Epsilon='+str(epsilon), fontdict = {'fontsize' : 30})
    if show_x == True:
        ax_sorted.set_xlabel('Sorted Indices', size = 5)
    #ax_sorted.set_ylabel('Capacity')
    MI = (-epsilon*np.log2(epsilon) - (1-epsilon)*np.log2(1-epsilon))
    ax_sorted.vlines(x = N*(MI), colors="r", linestyles='dotted', ymax=1, ymin=0, label="Mutual Information", color = 'black')
    ax_sorted.tick_params(axis='both', which='major', color='0', labelsize=4)
    ax_sorted.tick_params(axis='both', which='minor', color='0.3')
    ax_sorted.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)

N_low = 4096
N_high = 2**15
fig, axs = plt.subplots( 2, 2, sharey = 'row' )
plotting_polarization(axs[0,0], axs[0,1], N_low, 0.2, '#fb6a4a', show_x = False)
plotting_polarization(axs[1,0], axs[1,1], N_high, 0.2, '#67000d', show_x = True)
fig.supylabel('Bits', size = 5)
fig.savefig("Figure1.pdf")