import jax.numpy as np

def time_derivative_euler(
    zeta, t, f_poisson_bracket, f_phi, denominator, f_forcing=None, f_diffusion = None, energyconserving = False, REDUC = 0.0,
):
    H = f_phi(zeta, t)
    if f_forcing is not None:
        forcing_term = f_forcing(zeta)
    else:
        forcing_term = 0.0
    if f_diffusion is not None:
        diffusion_term = f_diffusion(zeta)
    else:
        diffusion_term = 0.0
    
    pb_term = f_poisson_bracket(zeta, H) / denominator[None, None, :]

    if energyconserving: 
        M = pb_term # mass is already conserved, M = N
        U = zeta - np.mean(zeta)
        psi_bar = np.mean(H, axis=-1)[...,None]
        psi_bar = psi_bar - np.mean(psi_bar)
        P = M - np.sum(M * psi_bar) / np.sum(psi_bar**2) * psi_bar
        W = U - np.sum(U * psi_bar) / np.sum(psi_bar**2) * psi_bar
        dSdt = -np.sum(pb_term * zeta)
        #dSdt = -np.sum(W * P)
        pb_term = P - np.sum(W * P) / np.sum(W**2) * W - (1.0 - REDUC) * dSdt * P / np.sum(W * P)



    return pb_term + ((forcing_term + diffusion_term) / denominator[None, None, :])