import numpy as np
from PDE1d import PDE1d
from ODE1d import ODE1d
from simulation_plots import make_plots,plot_four_windows
import tqdm
from scipy.special import xlogy

# use for all simulations
dz = 0.1
N  = 60
dt = 0.01


def H1(rho_east_west,rho_bar):
    return xlogy(rho_east_west,rho_bar)/20.


def V1(x,z,Z_samp):
    epsilon = 1./2.
    prob = 1.-np.around((1.+np.exp(-3.*(z-x)))**-1,decimals=16)
    out = 10.*np.sum((Z_samp.reshape((1,len(Z_samp)))-z.reshape((len(z),1)))**2,axis=1)/len(Z_samp)/len(z) + epsilon * prob
    return out


def W_no_kernel(xj,xi,dx):
    return 0.


def W_consensus(xj,xi,dx):
    x_ = np.abs(xj-xi)
    return -0.05*(1. + x_**2)**-1*dx


def initial_dist(z_i,mu):
    '''initial condition for 0 labels'''
    sig = np.sqrt(0.1)
    return np.exp(-(z_i-mu)**2/2./sig)/np.sqrt(2*np.pi*sig)




####################### select experiments ###############################
experiments = [1] #,2,3,4,7]

################ experiment 1: slow x with initial condition at optimal location #####################
if 1 in experiments:
    print("Experiment 1")
    nSamples      = 4 # number of samples for computing the gradient of x
    x0            = 1.5
    T             = 30.
    nT            = int(T/dt)
    mu            = 3.
    x_conv_rate   = 1.e0 #"best response" #1e0 # fast convergence

    pde = PDE1d(dz,N,nT,H_prime_rho=H1,V=V1,W=W_consensus,save_data=False)
    pde.set_initial_distribution(lambda z: initial_dist(z,0),lambda z: initial_dist(z,mu))
    ode = ODE1d(pde.z_i,dt,nT,pde.g0,x0,nSamples,save_data=False,x_speed=x_conv_rate)
    rho = pde.rho0
    for t in tqdm.tqdm(range(1,nT)):
        x   = ode.update_x(rho,t)
        rho = pde.update_RK(x,t,dt,ode.negative_labels)

    plot_four_windows(ode,pde,"sampled gradient under equal rate, $n=4$","sampled_gradient_4")

    nSamples      = 40
    pde = PDE1d(dz,N,nT,H_prime_rho=H1,V=V1,W=W_consensus,save_data=False)
    pde.set_initial_distribution(lambda z: initial_dist(z,0),lambda z: initial_dist(z,mu))
    ode = ODE1d(pde.z_i,dt,nT,pde.g0,x0,nSamples,save_data=False,x_speed=x_conv_rate)
    rho = pde.rho0
    for t in tqdm.tqdm(range(1,nT)):
        x   = ode.update_x(rho,t)
        rho = pde.update_RK(x,t,dt,ode.negative_labels)

    plot_four_windows(ode,pde,"sampled gradient under equal rate, $n=40$","sampled_gradient_40")
