# -*- coding: utf-8 -*-
"""
Created on Thu Oct 24 23:40:23 2024

@author: shara
"""
from polar_modified import *
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import math
from scipy.optimize import minimize_scalar
import pandas as pd
import psutil, os
p = psutil.Process(os.getpid())
p.nice(psutil.HIGH_PRIORITY_CLASS)
# Set experiment seed
np.random.seed(101)

# HELPER FUNCTIONS

# Convert a n to binary representation with num_bits bits
def binary(n, num_bits):
    '''
    Encode number in binary in np array

          Arguments:
                  n (int64): Number
                  num_bits (int64): Number of bits

          Returns:
                  x (int64[:]): Numpy array of encoded bits
    '''
    binary_str = format(n, 'b').zfill(num_bits)
    return np.array([int(bit) for bit in binary_str], dtype=np.int64)

# Compute the binary entropy of p
def binary_entropy(p):
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

# Given an array of H, use the inverse H function to compute p table
def inverse_H(arr):
    # Accepts an array of channel capacities and uses it to find p parameter?
    vals = []
    for val in arr:
        # Define a function that returns the absolute difference between H(p) and the desired value
        func = lambda p: abs(binary_entropy(p) - val)
        # Use minimize_scalar to find the value of p that minimizes the absolute difference
        result = minimize_scalar(func, bounds=(1e-15, 1-1e-15), method='bounded')
        vals.append(result.x)
    vals = np.array(vals)
    vals[vals > 0.5] = 1 - vals[vals > 0.5]
    return vals

# Recursively compute the capacity of BEC subchannels
def I_bec(N, i, eps):
  if N == 1:                # Base Case: capacity is 1 - eps
    return (1.0 - eps)
  elif (i + 1) % 2 == 1:                               # Case: Index is odd
    return (I_bec(int(N/2), int((i + 1)/2), eps)) ** 2
  else:                                                # Case: Index is even
    recursive = I_bec(int(N/2), int(i/2), eps)
    return (2 * recursive) - (recursive ** 2)



# Simulate trials for the BEC with size N, erasure probability eps
# Returns an array of size num_trials containing rates
def simulate_trials_for_arg(chan, N, p, num_trials):
    # Set up rates array to collect results
    results_arr = np.zeros(num_trials, dtype=np.float64)
    eps_avg_e = 0
    # Obtain probability table for size N
    p_n = polar_channel_mc(int(np.log2(N)),chan,p,2001)
    p_n = np.array( [ min(a,1-a) for a in p_n ] )
    for t in range(num_trials):
        #Generate P1 domain input realizations for channel simulation
        x_n_true = np.random.randint(2, size=N).astype(np.float64)
        y_n = chan( x_n_true, p )
        #Generate Common randomness
        zn = np.random.uniform(0, 1, N)
        # Run Henry Pfister decoder with modified random choice
        uhat, xhat, delta = polar_decode_with_cr(y_n, zn, np.full(N, 0.5, dtype=np.float64))
        rate_sum = 0
        for i in range(N):
            if delta[i] == 1:
                rate_sum += -1 * np.log2(1/2 - p_n[i])
            else:
                rate_sum += -1 * np.log2(1/2 + p_n[i])
        result_rate = (rate_sum + 1)/N
        results_arr[t]= result_rate
        eps_avg_e += result_rate / num_trials
    print('Average Rate Is:', eps_avg_e)
    return results_arr

def arg_list_biawgn()

#For very large N, we perform the rate calculation for a very sparse set of points due to computation time constraints
def experiment_size(N):
    if N<=2**17:
        num_trials = 200
        num_points = 100
    else:
        num_trials = 100
        num_points = 20
    return num_trials, num_points

#BSC
for N in [2**i for i in np.arange(1, 18)]:
    num_trials, n_points = experiment_size(N)
    for p in np.linspace(0,0.5,n_points):
        print("p:", p)
        filename = "./Data_Old/BSC_trials_N_" + str(N) + "_p_" + str(round(p, 4)) + ".csv"
        if os.path.isfile(filename) == False:
            results_for_p = simulate_trials_for_arg(bsc_p1, N, p, num_trials)
            np.savetxt(filename, results_for_p, delimiter=",")

#BIAWGN
for N in [2**i for i in np.arange(1, 18)]:
    num_trials, n_points = experiment_size(N)
    for sig in np.linspace(0,np.sqrt(3),n_points):
        print("sig:", sig)
        filename = "./Data_Old/BIAWGN_trials_N_" + str(N) + "_sig_" + str(round(sig, 4)) + ".csv"
        if os.path.isfile(filename) == False:
            results_for_sig = simulate_trials_for_arg(awgn_p1_new, N, sig, num_trials)
            np.savetxt(filename, results_for_sig, delimiter=",")

#BEC
for N in [2**i for i in np.arange(1, 18)]:
    num_trials=200
    for eps in np.linspace(0,1,100):
        print("eps:", eps)
        filename = "./Data_Old/BEC_trials_N_" + str(N) + "_eps_" + str(round(eps, 4)) + ".csv"
        if os.path.isfile(filename) == False:
            results_for_eps = simulate_trials_for_arg(bec_p1, N, eps, num_trials)
            np.savetxt(filename, results_for_eps, delimiter=",")