import equinox as eqx
import jax
import jax.numpy as jnp

class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, key, input_dim, output_dim):
        hidden_dim = 16
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [eqx.nn.Linear(input_dim, hidden_dim, key=key1),
                       eqx.nn.Linear(hidden_dim, hidden_dim, key=key2),
                       eqx.nn.Linear(hidden_dim, output_dim, key=key3)]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)

def pairwise_l2_distances(x, y):
    diff = x[:, None, :] - y[None, :, :]
    dist = jnp.sqrt(jnp.sum(diff * diff, axis=-1))  
    return dist

def energy_distance(x, y):
    xx = pairwise_l2_distances(x, x) 
    yy = pairwise_l2_distances(y, y)
    xy = pairwise_l2_distances(x, y)

    term_xy = 2.0 * jnp.mean(xy)      
    term_xx = jnp.mean(xx)          
    term_yy = jnp.mean(yy)          

    ed_squared = term_xy - term_xx - term_yy
    ed_squared = jnp.maximum(ed_squared, 0.0)

    return jnp.sqrt(ed_squared) 

def gen_simulation_samples(
    key,
    theta,
    prop_sim_fn,
    simulator_fn,
    n_prop,
    n_sim_dst
):
    key, theta_keys, sims_keys = jax.random.split(key, 3)

    thetas_q = prop_sim_fn(theta_keys, theta, n_prop)
    
    keys_for_sims = jax.random.split(sims_keys, n_prop)

    sims_q = jax.vmap(simulator_fn, in_axes=(0, 0, None), out_axes=0)(
        keys_for_sims, thetas_q, n_sim_dst
    )
    return thetas_q, sims_q, key


def grad_log_shifted_exp(params, obs):
    theta = params[0]
    x_min = jnp.min(obs) 
    n = obs.shape[0]     
    gradient = jnp.where(theta <= x_min, n, 0)

    return jnp.array([gradient], dtype=obs.dtype)

def grad_log_normal(
    x,
    mu,
    cov
):
    if jnp.isscalar(cov) or (isinstance(cov, jnp.ndarray) and cov.ndim == 0):
        return -(x - mu) / cov
    else:
        cov_inv = jnp.linalg.inv(cov)
        return -cov_inv @ (x - mu)