#### Example usecase ######
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = ""  # force CPU usage for JAX
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # prevent JAX from preallocating all GPU memory
import jax
import jax.numpy as jnp
from jax import lax
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# -------------------------------
# Global constants
# -------------------------------
N = 512 # 384
PI = jnp.pi
dtime = 1.e-5 #2.5e-5
systemsize = 10.0 * PI
nonlinparameter = 1.0
x = jnp.linspace(0.0, systemsize, N, endpoint=False)
k_indices = jnp.arange(N // 2 + 1)
k_vec = 2.0 * PI / systemsize * k_indices

# -------------------------------
# FFT helpers (using JAX)
# -------------------------------
def to_fourier(r):
    return jnp.fft.rfft(r) / N

def to_real(c):
    return jnp.fft.irfft(c * N, n=N)

# -------------------------------
# PDE RHS in Fourier space
# -------------------------------
def pde_rhs(c_input):
    r = to_real(c_input)
    r2 = r * r
    c_squared = to_fourier(r2)
    return 1j * ((k_vec ** 3) * c_input - 0.5 * nonlinparameter * k_vec * c_squared)

# -------------------------------
# Single RK4 step for one soliton
# -------------------------------
@jax.jit
def rk4_single_step(c):
    k1 = pde_rhs(c)
    k2 = pde_rhs(c + 0.5 * dtime * k1)
    k3 = pde_rhs(c + 0.5 * dtime * k2)
    k4 = pde_rhs(c + dtime * k3)
    return c + (dtime / 6.0) * (k1 + 2*k2 + 2*k3 + k4)

# @jax.jit(static_argnums=(1,))
def rk4_step(c, nsteps):
    def body_fun(carry, _):
        c_next = rk4_single_step(carry)
        return c_next, None
    c_final, _ = lax.scan(body_fun, c, None, length=nsteps)
    return c_final

# -------------------------------
# Soliton initializer
# -------------------------------
def initialize_soliton(velocity, position):
    arg_base = x - position
    amp_prefactor = 3.0 * velocity / nonlinparameter
    coeff = 0.5 * jnp.sqrt(jnp.abs(velocity))
    sol = amp_prefactor * (
        1.0 / jnp.cosh(coeff * arg_base)**2
        + 1.0 / jnp.cosh(coeff * (arg_base - systemsize))**2
        + 1.0 / jnp.cosh(coeff * (arg_base + systemsize))**2
    )
    return sol

# -------------------------------
# Main simulation
# -------------------------------
def run_simulation():
    velocities = jnp.array([50.0, 20.0, 4.0, 0.5])
    positions  = jnp.array([2.0*PI, 6.0*PI, 12.0*PI, 18.0*PI])
 
    r_components = jax.vmap(initialize_soliton)(velocities, positions)
    r_total = jnp.sum(r_components, axis=0)

    c_components = jax.vmap(to_fourier)(r_components)
    c_total = to_fourier(r_total)

    n_frames = 600
    steps_per_frame = 200

    def frame_step(state, _):
        c_arr, c_tot = state
        # vmap only over solitons, not over k_vec
          # ensure Python int
        c_new = jax.vmap(lambda c: rk4_step(c, steps_per_frame))(c_arr)
        c_total_new = rk4_step(c_tot, steps_per_frame)
 
        frames = jax.vmap(to_real)(c_new)
        total_field = to_real(c_total_new)
        return (c_new, c_total_new), (frames, total_field)

    (_, _), (frames_components, frames_total) = lax.scan(
        frame_step, (c_components, c_total), None, length=n_frames
    )

    # frames_components shape: (n_frames, 4, N)
    # frames_total shape: (n_frames, N)
    frames_components = jnp.array(frames_components)
    frames_total = jnp.array(frames_total)

    # Move to host for animation (JAX → numpy)
    frames_components_np = jax.device_get(frames_components)
    frames_total_np = jax.device_get(frames_total)
    x_np = jax.device_get(x)

    # --- Matplotlib animation ---
    fig, ax = plt.subplots(figsize=(10, 5))
    colors = ["tab:red", "tab:orange", "tab:green", "tab:blue"]
    lines_individual = [
        ax.plot(x_np, frames_components_np[0, i], color=colors[i], lw=1.2, alpha=0.8, label=f"Soliton {i+1}")[0]
        for i in range(4)
    ]
    line_total, = ax.plot(x_np, frames_total_np[0], color="k", lw=2.0, label="Total field")

    ax.legend(loc="upper right", fontsize=9)
    ax.set_xlim(0, systemsize)
    ax.set_ylim(-1, 200)
    title = ax.text(0.5, 1.03, "", transform=ax.transAxes, ha="center")

    def init():
        for i in range(4):
            lines_individual[i].set_ydata(frames_components_np[0, i])
        line_total.set_ydata(frames_total_np[0])
        title.set_text("t = 0.0")
        return lines_individual + [line_total, title]

    def update(i):
        for j in range(4):
            lines_individual[j].set_ydata(frames_components_np[i, j])
        line_total.set_ydata(frames_total_np[i])
        title.set_text(f"t = {i * steps_per_frame * dtime:.2f}")
        return lines_individual + [line_total, title]

    anim = FuncAnimation(fig, update, frames=n_frames, init_func=init, blit=True, interval=40)
    plt.show()
    return anim

if __name__ == "__main__":
    anim = run_simulation()
    anim.save("/kdv_soliton_superposition_jax.mp4", fps=25, dpi=150)



