# -*- coding: utf-8 -*-
"""
Created on Thu Aug  1 12:57:49 2024

"""

# Implement Greedy Poisson Rejection Sampling
# Algorithm 3

import scipy
import matplotlib.pyplot as plt
import cache_tools
import pandas as pd
from scipy.optimize import minimize_scalar
from scipy.integrate import quad
from scipy.stats import zipf
import pickle
# Input: Proposal distribution P
#        Density ratio r = dQ/dP
#        Stretch function sigma
# Output: Sample X~Q and its index N

import numpy as np
import scipy.integrate
import time

def gen_poisson(N, m):
    # Generate random exponential value
    exp_val = np.random.exponential(scale=1, size=m)
    T_n = np.cumsum(exp_val) # Poisson process

    # Generate random X_n bit array
    X_n = np.random.randint(2, size=(m,N))

    return T_n, X_n

def filter_repeats( T, X ):
    repeats = []
    X_new = []
    T_new = []
    for i in range( len(X) ):
        if str(X[i]) not in repeats:
            X_new.append( X[i] )
            T_new.append( T[i] )
            repeats.append( str(X[i]) )
    return np.asarray(T_new), np.asarray( X_new )

def W(X_n, Y_n, epsilon):
    N = len(Y_n)

    # Compute number of bits different between X_n and Y_n
    d = sum(x != y for x, y in zip(X_n, Y_n))

    channel_prob = epsilon**d * (1-epsilon)**(N-d)
    return channel_prob

# Sigma integral definition
def integrand(eta, N, epsilon):
    l = lambda x: (np.log2(x/(2**N)) - N * np.log2(1-epsilon)) / (np.log2(epsilon/(1-epsilon)))
    w_P = lambda x: scipy.stats.binom.cdf(l(x), N, 0.5)
    w_Q = lambda x: scipy.stats.binom.cdf(l(x), N, epsilon)
    return 1 / (w_Q(eta) - eta*w_P(eta))

@cache_tools.memoize
def calc_sigma(N, epsilon, h):
    max_arg = 2**N * ( 1 - epsilon )**N
    if h >= max_arg:
        return np.inf
    # Compute integral 0->h of 1/(w_Q(eta) - eta*w_P(eta))
    sigma_calc = quad( integrand, 0, h, args = (N, epsilon))
    sigma = sigma_calc[0]
    return sigma

def GPRS( X_n, epsilon ):
    N = len(X_n)
    m = 10000
    Y_new = []
    while len(Y_new) < 2**N:
        T, Y = gen_poisson(N,m)
        T_new, Y_new = filter_repeats( T, Y )
    for i in range( len( Y_new ) ):
        h_i = W(X_n, Y_new[i], epsilon) * 2**N
        sigma = calc_sigma(N, epsilon, h_i)
        if T_new[i] < sigma:
            return Y_new[i], i
    return

def histogram( N = 8, epsilon = 0.1, n_test = 100 ):
    dists = []
    X = np.random.randint(2, size=N )
    rates = []
    for i in range( n_test ):
        print(i)
        Y,ind = GPRS( X, epsilon )
        I = N * (1-binary_entropy(epsilon))
        s = 1 + 1/(I + 2*np.log2(np.exp(1)))
        Z = zipf.pmf(ind + 1,s)
        dists.append( sum(x != y for x, y in zip(X, Y)) )
        rates.append( -np.log2(Z)/N )
    return np.array(dists), rates

def binary_entropy(p):
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

def inverse_H(arr):
    vals = []
    for val in arr:
        # Define a function that returns the absolute difference between H(p) and the desired value
        func = lambda p: abs(binary_entropy(p) - val)
        # Use minimize_scalar to find the value of p that minimizes the absolute difference
        result = minimize_scalar(func, bounds=(1e-15, 1-1e-15), method='bounded')
        vals.append(result.x)
    vals = np.array(vals)
    vals[vals > 0.5] = 1 - vals[vals > 0.5]
    return vals

def data_store_filename( N, id_string ):
    if id_string == 'MI':
        return 'GPRS_BSC_New' + '_MI_' + str(N) + '.pickle'
    elif id_string == 'Rates':
        return 'GPRS_BSC_New' + '_Rates_' + str(N) + '.pickle'

def rate_sweep(N, n_test = 800):
    #R = np.arange(1e-2, 1, 1/50)
    H_inv = np.linspace(1e-2, 0.5-1e-2, 15)
    rates = []
    calc_sigma.cache_clear()
    exec_times = []
    for epsilon in H_inv:
        print(epsilon)
        eps_exec_times = []
        rate_GPRS_eps = []
        I = N * (1-binary_entropy(epsilon))
        s = 1 + 1/(I + 2*np.log2(np.exp(1)))
        for _ in range( n_test ):
            calc_sigma.cache_clear()
            t0 = time.time()
            X = np.random.randint(2, size=N )
            _, ind = GPRS( X, epsilon )
            Z = zipf.pmf(ind + 1,s)
            rate_GPRS_eps.append( -np.log2(Z)/N )
            t1 = time.time()
            eps_exec_times.append( t1 - t0 )
        with open('GPRS_Calcs/Rate_GPRS_BSC_' + str(epsilon) + '.pickle', 'wb') as f:
            pickle.dump(rate_GPRS_eps, f, pickle.HIGHEST_PROTOCOL)
        with open('GPRS_Calcs/Time_GPRS_BSC_' + str(epsilon) + '.pickle', 'wb') as f:
            pickle.dump(eps_exec_times, f, pickle.HIGHEST_PROTOCOL)
        rates.append( rate_GPRS_eps )
        exec_times.append( eps_exec_times )
    return np.array( rates ), np.array(exec_times)

N = 8
rates,exec_times = rate_sweep(N)

with open(data_store_filename(N, 'Rates'), 'wb') as f:
    # Pickle the 'data' dictionary using the highest protocol available.
    pickle.dump(rates, f, pickle.HIGHEST_PROTOCOL)

with open('GPRS_exec_times_N=8_50runs.pickle', 'wb') as f:
    # Pickle the 'data' dictionary using the highest protocol available.
    pickle.dump(exec_times, f, pickle.HIGHEST_PROTOCOL)
# hist, rate = histogram()

# plt.hist( hist )
# print('Rate Is:', rate)