import numpy as np
import random
import math
import degraded_send
import cache_tools
import time




# Check that the input is a power of 2 and return the power
def isPowerOfTwo(n):
    pow = math.log2(n)
    if 2**pow != n:
        return False
    return True, pow

# Calculate and return the n-fold F matrix
def calc_F(n,F): # count backwards from n
    if n == 0:
        return np.ones(1)
    F_N = np.kron(F, calc_F(n-1,F))    # kronecker product
    return F_N

# Form the bit reversal permutation matrix
def permutation_mat(L):
    mat = np.zeros((L,L))
    desired_length = len(bin(L-1)[2:])
    for N_indx in range(L):
        new_index = int('{0:0{1}b}'.format(N_indx, desired_length)[::-1], 2)
        mat[new_index][N_indx] = 1
    return mat

# Helper function for the BSC, flip an element with probability epsion
def flip(elem, epsilon):
    # random.random creates a uniformly distributed random floating point number in [0,1)
    return elem if random.random() >= epsilon else int(not elem)

# Encoder u -> x using bit reversal permutation matrix and n-fold F matrix
def encoding(u):
    N = len(u)
    
    # Check if N is a power of 2, alters the value of pow
    _,pow = isPowerOfTwo(N)
    n = pow # 2^n = N

    # Form u_N array
    u_N = np.array([int(bit) for bit in u], dtype=int)

    # Let B_N = bit-reversal permutation matrix
    B_N = permutation_mat(N)
    F = np.array([[1,0], [1,1]])

    F_N = calc_F(n,F) # Kronecker product

    G_N = B_N @ F_N   # G_N = B_N F^(Ox)n
    x_N = u_N @ G_N

    # Convert matrix to string
    x = ''.join(map(str, x_N.astype(int) % 2))

    return G_N, x

# BSC to flip all elements x independently with probability epsilon to form elements y
def bsc(x, epsilon):
    N = len(x)

    # Form x_N array
    x_N = np.array([int(bit) for bit in x], dtype=int)

    y_N = np.zeros(len(x_N))
    for i in range(N):
        y_N[i] = flip(x_N[i], epsilon)
        
    # Convert matrix to string
    y = ''.join(map(str, y_N.astype(int)))

    return y

# Decoder code to take all y elements and predict u_hat
# This uses the code from the Tal & Vardy paper to get the indices of good channels
def decoding(u, y, epsilon, mu, thresh):
    N = len(y)

    # Convert bit string to numpy array of ints
    u_N = np.array([int(bit) for bit in u], dtype=int)
    y_N = np.array([int(bit) for bit in y], dtype=int)

    u_hat_N = np.zeros(N)

    good_indices = degraded_send.get_good_ind(N, epsilon, mu, thresh)
    #print("Number of good indices is: " + str(good_indices.sum()))

    for i in range(N):
        if good_indices[i]: # This is an information bit
            L_i = likelihood(y_N, u_hat_N[:i], epsilon)
            if L_i >= 0: # likelihood >= 1, order of mag of likelihood >= 0
                u_hat_N[i] = 0
            else:
                u_hat_N[i] = 1
        else: # This is a frozen bit
            u_hat_N[i] = u_N[i]

    # Convert matrix to string
    u_hat = ''.join(map(str, u_hat_N.astype(int)))
    print( likelihood.cache_info() )
    likelihood.cache_clear()
    return u_hat, good_indices.sum()

# Helpful function for the decoder to recursively get the likelihood ration of a prediction u_hat being 0 to 1
@cache_tools.memoize
def likelihood(y_N, u_hat, epsilon):
    N = len(y_N)
    if N == 1:
        if y_N == 0: # y_N is a single value here
            val = (1-epsilon) / epsilon
            return np.log(val)
        else: # y_1 = 1
            val = epsilon / (1-epsilon)
            return np.log(val)
    else:
        # y's needed as part of definition of new recursive likelihoods 
        firsthalf_y = y_N[:N//2]
        lasthalf_y = y_N[N//2:]
        
        u_hato = (u_hat[::2]).astype(int) # get only odd rows 1,3,...
        u_hate = (u_hat[1::2]).astype(int) # get only even rows 2,4,...

        if len(u_hato) != len(u_hate):
            u_hato = u_hato[:len(u_hato)-1]

        # Kronecker sum - add componentwise
        new_uhat = np.bitwise_xor(u_hato, u_hate)

        like1 = likelihood(firsthalf_y, new_uhat, epsilon)
        like2 = likelihood(lasthalf_y, u_hate, epsilon)
        # like1 and like2 will represent the order of magnitude of the likelihood calculations

        if len(u_hat) % 2 == 0: # Equation 75
            return safe_compute_even(like1, like2)
        else: # i is even, Equation 76
            power = 1 - 2*u_hat[len(u_hat)-1] # either 1 or -1
            return safe_compute_odd(like1, like2, power)
        


## Calculate order of magnitude of a number - only used in base case of recursion
def order_mag(value):
    order = math.log10(value)
    return order
def log_one_plus_x( log_x ):
    if log_x < -745:
        ans = 0
    elif log_x < -37:
        ans = np.exp(log_x)
    elif log_x > 37:
        ans = log_x + np.exp(-1*log_x)
    else:
        ans = np.log( 1 + np.exp( log_x ) )
    return ans
def safe_compute_even( log_like_1, log_like_2 ):
    # (like1*like2 + 1) / (like1 + like2)
    #Compute Numerator
    log_like = log_like_1 + log_like_2
    numerator = log_one_plus_x( log_like )
    #Compute Denominator
    max_term = max( log_like_1, log_like_2 )
    min_term = min( log_like_1, log_like_2 )
    denominator = max_term + log_one_plus_x( min_term - max_term )
    
    return numerator - denominator
# Likelihood ratio for odd indices used in the likelihood function
# Avoids overflow using a log representation
def safe_compute_odd(like1, like2, power):
    # (like1)**power * like2
    if power == -1:
        order1 = -like1
    elif power == 1:
        order1 = like1
    order2 = like2

    # Multiplication of two terms
    return order1 + order2

# Block to run input -> encoder -> bsc -> decoder -> output
def full_block_in(u, epsilon, mu, thresh):
    G_N, x = encoding(u)
    y = bsc(x, epsilon)
    u_hat,num_ind = decoding(u, y, epsilon, mu, thresh)
    return u_hat,G_N,x,y,num_ind

# Execution function
def execute(N, thresh, mu, times = 2):
    epsilon = 0.2
    n_ind = []
    n_err = []
    for i in range(times):
        u_in = ''.join(random.choice('01') for _ in range(N))
    
        u_hat, _, _, _, num_ind = full_block_in(u_in, epsilon, mu, thresh)
    
        u_N = np.array(list(map(int, u_in)), dtype=int) # string to array conversion
        u_N_hat = np.array(list(map(int, u_hat)), dtype=int)
        n_ind.append( num_ind )
        num_diff = 0
        for i in range(len(u_N)):
            if u_N[i] != u_N_hat[i]:
                num_diff += 1
        n_err.append( num_diff )

    print("# of different bits between u and u_hat = " + str(np.mean(num_diff)))
    return np.mean(num_diff), np.mean(num_ind)


# Example run of the full block
t0 = time.time()

thresholds = [-1]
new_num_good_indices = []
new_num_errors = []
for log_thresh in thresholds:
    err,ind = execute(N = 2048*4, thresh = 10**log_thresh, mu=4) # uncomment when running this file
    new_num_good_indices.append(ind)
    new_num_errors.append(err)



t1 = time.time()

total = t1-t0
# print(likelihood.cache_info())
# print(likelihood.calls)

print("Total Execution time is " + str(total) + " seconds")