import numpy as np; import pylab as plt

N = 1000
eps = 1e-5
unsafe_noise = np.linspace(eps, 1-eps, N)
safe_noise_range = np.linspace(eps, 1-eps, N)
II = np.zeros(N)
MI = np.zeros(N)
II_x = np.zeros((N,2))
II_XOR = np.zeros(N)
MI_XOR = np.zeros(N)
II_x_XOR = np.zeros((N,2))

for i, safe_noise in enumerate(safe_noise_range): 
    P_x__s_theta = np.zeros((2, N, 2))
    P_x__s_theta[1, :, 0] = safe_noise  
    P_x__s_theta[1, :, 1] = 1 - safe_noise  
    P_x__s_theta[0, :, 0] = unsafe_noise
    P_x__s_theta[0, :, 1] = 1 - unsafe_noise
    P_s_theta_x = P_x__s_theta * .5 * 1/N 
    
    
    P_s_x = P_s_theta_x.sum(1, keepdims = True)
    P_x_theta = P_s_theta_x.sum(0, keepdims = True)
    P_s_theta = P_s_theta_x.sum(-1, keepdims = True)
    p_x = P_s_theta_x.sum(0, keepdims = True).sum(1, keepdims = True)
    P_s = P_s_theta_x.sum(1, keepdims = True).sum(2, keepdims = True)
    P_theta = P_s_theta_x.sum(0, keepdims = True).sum(2, keepdims = True)    
    
    MI[i] = (P_s_theta * np.log(P_s_theta/ (P_s * P_theta))).sum()
    II[i] = MI[i] - (P_s_theta_x * np.log(p_x*P_s_theta_x/(P_s_x*P_x_theta))).sum()
    II_x[i] = (P_s_theta_x * np.log(p_x*P_s_theta_x/(P_s_x*P_x_theta))).sum(0).sum(0)
    
    assert np.isclose(P_s_theta_x.sum(),1)

    P_x__s_theta = np.zeros((2, 2, 2))
    P_x__s_theta[1, :, 0] = [safe_noise, 1-safe_noise]  
    P_x__s_theta[1, :, 1] = [1 - safe_noise, safe_noise]  
    P_x__s_theta[0, :, 0] = [1 - safe_noise, safe_noise]  
    P_x__s_theta[0, :, 1] = [safe_noise, 1-safe_noise]  
    P_s_theta_x = P_x__s_theta * .5 * 1/2
    P_s_x = P_s_theta_x.sum(1, keepdims = True)
    P_x_theta = P_s_theta_x.sum(0, keepdims = True)
    P_s_theta = P_s_theta_x.sum(-1, keepdims = True)
    p_x = P_s_theta_x.sum(0, keepdims = True).sum(1, keepdims = True)
    P_s = P_s_theta_x.sum(1, keepdims = True).sum(2, keepdims = True)
    P_theta = P_s_theta_x.sum(0, keepdims = True).sum(2, keepdims = True)

    assert np.isclose(P_s_theta_x.sum(),1)

    MI_XOR[i] = (P_s_theta * np.log(P_s_theta/ (P_s * P_theta))).sum()
    II_XOR[i] = MI[i] - (P_s_theta_x * np.log(p_x*P_s_theta_x/(P_s_x*P_x_theta))).sum()
    II_x_XOR[i] = (P_s_theta_x * np.log(p_x*P_s_theta_x/(P_s_x*P_x_theta))).sum(0).sum(0)


plt.plot(safe_noise_range, II); plt.xlabel("safe theta range"); plt.title("negative interaction information (MOUSE TASK)"); plt.show()
plt.plot(safe_noise_range, II_x); plt.xlabel("safe theta range");
plt.legend(["NOGO", "GO"]); plt.title("information gain from each stimuli (MOUSE TASK)"); plt.show()
plt.plot(safe_noise_range, II_x.sum(-1)); plt.title("information gain from both stimuli (MOUSE TASK)"); plt.xlabel("safe theta range"); plt.show()

plt.plot(safe_noise_range, II_XOR); plt.xlabel("theta range"); plt.title("negative interaction information (XOR)"); plt.show()
plt.plot(safe_noise_range, II_x_XOR); plt.xlabel("theta range");
plt.legend(["NOGO", "GO"]); plt.title("information gain from each stimuli (XOR)"); plt.show()
plt.plot(safe_noise_range, II_x_XOR.sum(-1)); plt.title("information gain from both stimuli (XOR)"); plt.xlabel("theta range")