import numpy as np
import math
import itertools
from scipy.special import lambertw as W
from sampling import U
import warnings
warnings.filterwarnings("ignore")

# def Wbar(x):
#     return float(-W(-np.exp(-x),k=-1))
# def x1(delta,t):
#     return np.max(np.array([1+np.log(4/delta),1+np.log((2+2*np.log(t))/delta),Wbar(1+np.log(4*np.log(t)/delta))]))


# def x(delta,n,m): # Threshold function for KL concentration from Johansson et al 2020
#     return np.log(1/delta) + (m-1) + (m-1)*np.log(1+n/(m-1))


#useful numerical functions 
def hInv(x): # see Proposition 15 in (Kaufmann and Koolen 2018)
    return float(-W(-np.exp(-x),k=-1))
def hTilde(x):
    if x >= hInv(1/np.log(1.5)):
        a = hInv(x)
        return a*np.exp(1/a)
    else:
        return 1.5*(x-np.log(np.log(1.5)))
def phi(x):  # Threshold function for KL concentration of rewards, taken from Theorem 14 in (Kaufmann and Koolen 2018)
    z = hInv(1+x)+2*np.log((math.pi**2)/3)
    return 2*hTilde(z/2)


def y(delta,n): # Threshold function for KL concentration of rewards, taken from Theorem 14 in (Kaufmann and Koolen 2018)
    d = 4
    c = 3
    x = np.log(1/delta)
    return c*np.log(d+np.log(n)) + x + np.log(1+x+np.sqrt(2*x)) # x + np.log(1+x+np.sqrt(2*x)) is an approximation of tau(x) in their paper

#Threshold for transitions
def beta_p(visits, delta):
    Ns = visits.shape[0]
    return np.log(1/delta) + (Ns-1)+ (Ns-1)*np.sum(np.log(1+(visits/(Ns-1))))
#Threshold for rewards
def beta_r(visits, delta):
    Ns = visits.shape[0]
    Na = visits.shape[0]
    Nsa = Ns*Na
    return Nsa*phi(np.log(1/delta)/Nsa) + 3*np.sum(np.log(1+np.log(1+visits)))

def relaxed_beta(visits,delta,t):
    Ns = visits.shape[0]
    Na = visits.shape[0]
    Nsa = Ns*Na
    return np.log(1/delta) + 4*np.log(1+np.log(1+t)) 

def stop(pi,H,H_star, visits, delta, t):
    Ns = visits.shape[0]
    Na = visits.shape[0]
    Nsa = Ns*Na
    beta =  relaxed_beta(visits,delta,t) #beta_p(visits, delta) + beta_r(visits, delta)
    U_hat = U(pi,H,H_star, visits/t)[0]
    Z = t/U_hat
    print("Iteration = {}, Z = {} , threshold = {}".format(t, Z, beta ))#+ Nsa*np.log(1+ t/Nsa)) #beta)
    return Z >  relaxed_beta(visits,delta,t), U_hat, beta

             
    