#!/usr/bin/env python3
import numpy as np
import matplotlib.pyplot as plt

# Domain and grid parameters
Nx = 128
Nz = 128
Lx = 1.0
Lz = 2.0  # z from -1 to 1
x = np.linspace(0, Lx, Nx, endpoint=False)
z = np.linspace(-1, 1, Nz, endpoint=False)
dx = Lx / Nx
dz = Lz / Nz
X, Z = np.meshgrid(x, z)

# Physical parameters
nu = 1/(5e4)       # kinematic viscosity
D = nu/1.0         # tracer diffusivity

# Time parameters
T = 20.0
dt = 0.005
nt = int(T/dt)

# Initial conditions
# u(x,z,0) = 0.5*(1 + tanh((z-0.5)/0.1) - tanh((z+0.5)/0.1))
u = 0.5 * (1.0 + np.tanh((Z - 0.5)/0.1) - np.tanh((Z + 0.5)/0.1))
# w: small sinusoidal perturbations localized around z = ±0.5
w = 1e-3 * np.sin(2*np.pi*X) * (np.exp(-((Z-0.5)/0.1)**2) + np.exp(-((Z+0.5)/0.1)**2))
# tracer initially equal to u
s = u.copy()
# pressure initial condition
p = np.zeros_like(u)

# Precompute wavenumbers for FFT-based Poisson solver (periodic in both x and z)
kx = 2 * np.pi * np.fft.fftfreq(Nx, d=dx)
kz = 2 * np.pi * np.fft.fftfreq(Nz, d=dz)
kx, kz = np.meshgrid(kx, kz)
k2 = kx**2 + kz**2
k2[0,0] = 1.0  # avoid division by zero in the 0 mode

def ddx(f):
    # periodic central difference in x direction
    return (np.roll(f, -1, axis=1) - np.roll(f, 1, axis=1)) / (2*dx)

def ddz(f):
    # periodic central difference in z direction
    return (np.roll(f, -1, axis=0) - np.roll(f, 1, axis=0)) / (2*dz)

def laplacian(f):
    return (np.roll(f, -1, axis=1) - 2*f + np.roll(f, 1, axis=1)) / dx**2 + \
           (np.roll(f, -1, axis=0) - 2*f + np.roll(f, 1, axis=0)) / dz**2

# Time stepping loop
for it in range(nt):
    # Compute derivatives for convection terms (central differences)
    u_x = ddx(u)
    u_z = ddz(u)
    w_x = ddx(w)
    w_z = ddz(w)
    
    adv_u = u * u_x + w * u_z
    adv_w = u * w_x + w * w_z
    
    diff_u = nu * laplacian(u)
    diff_w = nu * laplacian(w)
    
    # Compute intermediate velocities (without pressure gradient)
    u_star = u + dt * (- adv_u + diff_u)
    w_star = w + dt * (- adv_w + diff_w)
    
    # Projection: enforce incompressibility
    # Compute divergence of intermediate velocity
    div_u_star = ddx(u_star) + ddz(w_star)
    
    # Solve Poisson equation: Laplacian(phi) = (1/dt) div(u_star)
    div_hat = np.fft.fftn(div_u_star)
    phi_hat = -div_hat / (dt * k2)
    phi_hat[0,0] = 0.0  # set the mean to zero
    # Inverse FFT to get phi in physical space
    phi = np.real(np.fft.ifftn(phi_hat))
    
    # Compute pressure gradient (using spectral derivative)
    phi_x = np.real(np.fft.ifftn(1j * kx * phi_hat))
    phi_z = np.real(np.fft.ifftn(1j * kz * phi_hat))
    
    # Update velocity fields using projection
    u = u_star - dt * phi_x
    w = w_star - dt * phi_z
    p = phi.copy()  # update pressure field (could be accumulated if needed)
    
    # Update tracer field s
    s_x = ddx(s)
    s_z = ddz(s)
    adv_s = u * s_x + w * s_z
    diff_s = D * laplacian(s)
    s = s + dt * (- adv_s + diff_s)
    
    # (Optional) You can print progress every 500 steps
    if (it+1) % 500 == 0:
        print(f"Time step {it+1}/{nt}")

# Save final results as .npy files (2D arrays)
np.save('u.npy', u)
np.save('w.npy', w)
np.save('p.npy', p)
np.save('s.npy', s)