import numpy as np
import matplotlib.pyplot as plt

def compute_visit_prob(P,
                       prob_policy,
                       init_dist,
                       gamma):

    S_size = P.shape[1]
    A_size = prob_policy.shape[1]
    
    P_pi = np.zeros_like(P[0])          # Size: [S_size, S_size]
    for s in range(S_size):
        P_pi[s, :] = np.dot(prob_policy[s, :], P[:, s, :]).reshape((-1,))     
    
    D = np.linalg.inv(np.eye(S_size) - gamma * P_pi)
    
    d = np.real(np.dot(init_dist, D)) * (1 - gamma)
    
    d = d / d.sum()
    
    return d

def add_noise(x: np.ndarray,
              noise=.1):
    if noise is None: return x
    noise = np.random.normal(scale=noise, size=x.shape)
    return x + noise
        
        
def phi_exp_factory(p, q=1):
    assert p % 2 == 1 & q % 2 == 1
    return lambda x: np.exp(np.power(np.abs(np.array(x)), p/q) * ((np.array(x) > 0).astype(np.float32) * 2 - 1))

def phi_exp_refined_factory(p, q=1, delta=1):
    assert p % 2 == 1 & q % 2 == 1
    phi = phi_exp_factory(p, q)
    # k = (phi(delta) - phi(-delta)) / (2 * delta)
    # b = phi(delta) - k * delta
    k = np.power(delta, p/q - 1)
    def exp_combined(x):
        x = np.array(x)
        flag = np.logical_and((x < delta), (x > -delta)).astype(np.float32)
        return phi(x) * (1-flag) + np.exp(k*x) * flag
    return exp_combined

def phi_exp_inv_factory(p, q=1):
    assert p % 2 == 1 & q % 2 == 1
    def phi_exp_inv(x: np.ndarray):
        x = np.array(x)
        sign = (x > 0).astype(np.float32) * 2 - 1
        x = np.abs(x)
        x = np.exp(sign * np.power(x, q/p))
        return x
    return phi_exp_inv

def phi_exp_combined_factory(p_1, q_1=1, p_2=1, q_2=1):
    assert p_1 % 2 == 1 & q_1 % 2 == 1
    assert p_2 % 2 == 1 & q_2 % 2 == 1
    phi_exp = phi_exp_factory(p_1, q_1)
    phi_exp_inv = phi_exp_inv_factory(p_2, q_2)
    
    def phi_exp_combined(x: np.ndarray):
        flag = np.float32(np.abs(np.array(x)) > 1)
        return phi_exp(x) * flag + phi_exp_inv(x) * (1-flag)
    
    return phi_exp_combined

def phi_poly_factory(p=1):
    def phi_poly(x: np.ndarray):
        x = np.array(x)
        return np.power(1+p*x, p)
    return phi_poly

def phi_sigmoid_factory():
    def phi_sigmoid(x: np.ndarray):
        x = np.array(x)
        return 1 / (1 + np.exp(-x))
    return phi_sigmoid

def phi_tan_factory():
    def phi_tan(x: np.ndarray):
        x = np.array(x)
        return np.tan(x) + 1
    return phi_tan

def phi_log_factory():
    def phi_log(x: np.ndarray):
        x = np.array(x)
        return np.log(x+2)+3
    return phi_log


def visualize_func(phi):
    
    x = np.linspace(-2, 2, 1000)
    y = phi(x)
    
    plt.plot(x, y)
    plt.show()
    
    
if __name__ == '__main__':
    
    phi = phi_exp_refined_factory(3, 5, 1)
    visualize_func(phi)