import jax
import jax.numpy as jnp
import optax
import numpy as np
import matplotlib.pyplot as plt

def generate_dataset(N=100, seed=0):
    rng = np.random.RandomState(seed)
    y_vals  = rng.uniform(-2, 2, size=N)
    yp_vals = rng.uniform(-2, 2, size=N)
    I_vals  = (np.abs(y_vals) < np.abs(yp_vals)).astype(float)
    return y_vals, yp_vals, I_vals

y_data, yp_data, I_data = generate_dataset(N=10000)

def log_ratio(y, mu, log_sigma):
    log_sigma = jnp.maximum(log_sigma, -5.0)  
    sigma = jnp.exp(log_sigma)
    return (0.5 * jnp.log(2.0)
            - jnp.log(sigma)
            - ((y - mu)**2) / (2.0 * sigma**2)
            + (y**2) / 4.0)

def ratio(y, mu, log_sigma):
    return jnp.exp(log_ratio(y, mu, log_sigma))

def objective(params, y_data, yp_data, I_data, tau=0.01):
    mu, log_sigma = params
    r_y  = ratio(y_data,  mu, log_sigma)
    lr_y = log_ratio(y_data, mu, log_sigma)
    r_yp = ratio(yp_data, mu, log_sigma)
    lr_yp= log_ratio(yp_data, mu, log_sigma)
    I = jnp.array(I_data)
    term_y  = r_y  * I        - tau * r_y  * lr_y
    term_yp = r_yp * (1 - I)  - tau * r_yp * lr_yp
    return jnp.sum(term_y + term_yp)

def train_adam(y_data, yp_data, I_data, tau=0.05, lr=1e-2, num_steps=3000):
    params = jnp.array([0.0, jnp.log(2.0)]) 
    opt = optax.adam(learning_rate=lr)
    opt_state = opt.init(params)
    
    def loss_fn(params):
        return -objective(params, y_data, yp_data, I_data, tau)
    
    @jax.jit
    def step_fn(params, opt_state):
        val, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, val
    
    for step in range(num_steps):
        params, opt_state, val = step_fn(params, opt_state)
        if step % 10000 == 0:
            obj_val = -val
            print(f"Step {step}, objective={obj_val:.3f}, mu={params[0]:.3f}, log_sigma={params[1]:.3f}")
            
    return params

final_params = train_adam(y_data, yp_data, I_data, tau=0.05, lr=1e-2, num_steps=15000)

mu_opt, log_sigma_opt = final_params
sigma_opt = np.exp(log_sigma_opt)


def pdf_ref(y):
    return (1.0 / np.sqrt(2.0 * np.pi * 2.0)) * np.exp(-y**2 / 4.0)

def pdf_new(y, mu, sigma):
    return (1.0 / (np.sqrt(2.0 * np.pi) * sigma)) * np.exp(-(y - mu)**2 / (2.0 * sigma**2))


y_grid = np.linspace(-10, 10, 400)

pdf_values_ref = [pdf_ref(y) for y in y_grid]
pdf_values_new = [pdf_new(y, mu_opt, sigma_opt) for y in y_grid]

plt.plot(y_grid, pdf_values_ref, label="Reference: N(0, var=2)")
plt.plot(y_grid, pdf_values_new, label=f"New Policy: N({mu_opt:.2f}, {sigma_opt**2:.2f})")

plt.title("Comparison of Reference vs Learned Distribution")
plt.legend()
plt.savefig('./images/first-test.png')
