import numpy as np
from PDE1d import PDE1d
from ODE1d_two_moving_distributions import ODE1d
from simulation_plots import make_plots_both_moving
import tqdm
from scipy.special import xlogy

# use for all simulations
dz = 0.1
N  = 60
dt = 0.01


def f1(x,z):
    '''Evaluate the cost function f1 at x and z (can be vector-valued)'''
    return 1.-np.around((1.+np.exp(-3.*(z-x)))**-1,decimals=16)


def grad_f1_x(x,z):
    '''Gradient of f1 with respect to x, evaluted at inputs x and z'''
    exponent_ = np.exp(-3.*(z-x))
    return 3*exponent_ / (exponent_+1)**2


def f2(x,z):
    '''Evaluate the cost function f2 at x and z (can be vector-valued)'''
    return np.around((1.+np.exp(-3.*(z-x)))**-1,decimals=16)


def grad_f2_x(x,z):
    '''Gradient of f2 with respect to x, evaluted at inputs x and z'''
    exponent_ = np.exp(-3.*(z-x))
    return -3*exponent_ / (exponent_+1)**2


def H1(rho_east_west,rho_bar,rho_tilde):
    '''KL divergence term
    rho_tilde is the initial condition
    rho_bar is the current distribution'''
    return ( xlogy(rho_east_west,rho_bar)-xlogy(rho_east_west,rho_tilde) ) /10.


def V1(x,z):
    return f1(x,z)


def W_no_kernel(xj,xi,dx):
    return 0.

def normal_dist(z_i,mu):
    '''initial condition for 0 labels'''
    sig = np.sqrt(0.1)
    # print("area",np.trapz(np.exp(-(z_i-mu)**2/2./sig)/np.sqrt(2*np.pi*sig),z_i))
    return np.exp(-(z_i-mu)**2/2./sig)/np.sqrt(2*np.pi*sig)


####################### select experiments ###############################
experiments = [1]

################ experiment 1: two distributions, no kernel, optimal IC for x #####################
if 1 in experiments:
    print("Experiment 1: two distributions, no kernel, optimal IC for x")
    x0            = 1.5
    T             = 20.
    nT            = int(T/dt)
    mu            = 3.
    x_conv_rate   = "best response" # slow convergence

    pde0 = PDE1d(dz,N,nT,H_prime_rho=H1,V=V1,W=W_no_kernel,save_data=False)
    pde1 = PDE1d(dz,N,nT,H_prime_rho=H1,V=V1,W=W_no_kernel,save_data=False)
    pde0.set_initial_distribution(lambda z: normal_dist(z,0))
    pde1.set_initial_distribution(lambda z: normal_dist(z,mu))
    ode = ODE1d(pde0.z_i,dt,nT,x0,f1,f2,grad_f1_x,grad_f2_x,save_data=False,x_speed=x_conv_rate)
    rho0 = pde0.rho0
    rho1 = pde1.rho0
    for t in tqdm.tqdm(range(1,nT)):
        x    = ode.update_x(rho0,rho1,t)
        rho0 = pde0.update_RK(x,t,dt)
        rho1 = pde1.update_RK(x,t,dt)

    make_plots_both_moving(pde0,pde1,ode,"plots/experiment1_both_moving",make_gif=False)



