#!/usr/bin/env python3
import numpy as np

# Domain and simulation parameters
Lx = 4.0
Lz = 1.0
Nx = 81        # number of grid points in x-direction
Nz = 21        # number of grid points in z-direction
dx = Lx / (Nx - 1)
dz = Lz / (Nz - 1)

T = 50.0       # final time
dt = 0.01      # time step
nt = int(T/dt)

# Physical parameters
Ra = 2e6
Pr = 1.0
nu = 1.0 / np.sqrt(Ra / Pr)       # kinematic viscosity
kappa = 1.0 / np.sqrt(Ra * Pr)      # thermal diffusivity

# Create grid
x = np.linspace(0, Lx, Nx)
z = np.linspace(0, Lz, Nz)

# Initialize variables (2D arrays with shape (Nz, Nx))
u = np.zeros((Nz, Nx))    # horizontal velocity
w = np.zeros((Nz, Nx))    # vertical velocity
p = np.zeros((Nz, Nx))    # pressure
b = np.zeros((Nz, Nx))    # buoyancy (temperature deviation)

# Initial conditions
# u and w are initially zero.
# b = Lz - z + a small random perturbation; here we add the perturbation row-wise.
np.random.seed(42)
for j in range(Nz):
    b[j, :] = Lz - z[j] + 1e-3 * (np.random.rand(Nx) - 0.5)

# Function: compute Laplacian of a 2D field with periodic BC in x and Dirichlet in z.
def laplacian(f):
    lap = (np.roll(f, -1, axis=1) - 2*f + np.roll(f, 1, axis=1)) / dx**2
    # For the z derivatives, use central differences (no periodicity in z)
    lap[1:-1, :] += (f[2:, :] - 2*f[1:-1, :] + f[:-2, :]) / dz**2
    return lap

# Function: solve Poisson equation for the pressure correction phi with Dirichlet BC (phi=0 at top and bottom) 
# and periodic in x using a simple iterative Jacobi solver.
def poisson_solver(rhs, tol=1e-6, max_iter=1000):
    phi = np.zeros_like(rhs)
    # Coefficients for non-uniform grid in x and z
    dx2 = dx**2
    dz2 = dz**2
    coeff = 2*(dx2 + dz2)
    for it in range(max_iter):
        phi_old = phi.copy()
        # update interior points (j=1 to Nz-2, all i with periodic wrap in x)
        phi[1:-1, :] = (
            dz2 * (np.roll(phi, -1, axis=1)[1:-1, :] + np.roll(phi, 1, axis=1)[1:-1, :]) +
            dx2 * (phi[2:, :] + phi[:-2, :]) -
            dx2 * dz2 * rhs[1:-1, :]
        ) / coeff
        # Enforce Dirichlet conditions in z
        phi[0, :] = 0.0
        phi[-1, :] = 0.0
        # Check convergence
        if np.linalg.norm(phi - phi_old, ord=2) < tol:
            break
    return phi

# Main time-stepping loop
for n in range(nt):
    # --- Step 1: Compute intermediate velocities (u*, w*) using explicit Euler for convection and diffusion ---
    u_star = u.copy()
    w_star = w.copy()
    
    # Update interior points only (j=1 to Nz-2)
    # Use central differences in x with periodic BC and in z (not using rolled wrap in z to avoid ghost effects)
    # Horizontal convection terms for u
    u_int = u[1:-1, :]
    w_int = w[1:-1, :]
    
    u_x = (np.roll(u, -1, axis=1)[1:-1, :] - np.roll(u, 1, axis=1)[1:-1, :]) / (2*dx)
    u_z = (u[2:, :] - u[:-2, :]) / (2*dz)
    conv_u = u_int * u_x + w_int * u_z
    
    # Horizontal convection terms for w
    w_x = (np.roll(w, -1, axis=1)[1:-1, :] - np.roll(w, 1, axis=1)[1:-1, :]) / (2*dx)
    w_z = (w[2:, :] - w[:-2, :]) / (2*dz)
    conv_w = u_int * w_x + w_int * w_z
    
    # Diffusion terms for u and w
    lap_u = laplacian(u)[1:-1, :]
    lap_w = laplacian(w)[1:-1, :]
    
    # Update predicted velocities for interior points
    u_star[1:-1, :] = u[1:-1, :] + dt * (-conv_u + nu * lap_u)
    w_star[1:-1, :] = w[1:-1, :] + dt * (-conv_w + nu * lap_w + b[1:-1, :])
    
    # Enforce velocity BC: top and bottom (Dirichlet: u = 0, w = 0)
    u_star[0, :] = 0.0
    u_star[-1, :] = 0.0
    w_star[0, :] = 0.0
    w_star[-1, :] = 0.0

    # --- Step 2: Pressure projection to enforce incompressibility ---
    # Compute divergence of u_star at interior points using central differences.
    div = np.zeros((Nz, Nx))
    div[1:-1, :] = ((np.roll(u_star, -1, axis=1) - np.roll(u_star, 1, axis=1)) / (2*dx))[1:-1, :] \
                   + (w_star[2:, :] - w_star[:-2, :]) / (2*dz)
    # Right-hand side for pressure Poisson: (divergence)/dt
    rhs = div / dt
    # Solve Poisson equation for pressure correction phi with phi=0 at z=0 and z=Lz, periodic in x.
    phi = poisson_solver(rhs, tol=1e-6, max_iter=500)
    
    # Correct velocities with pressure gradient (central differences)
    grad_phi_x = (np.roll(phi, -1, axis=1) - np.roll(phi, 1, axis=1)) / (2*dx)
    grad_phi_z = np.zeros_like(phi)
    grad_phi_z[1:-1, :] = (phi[2:, :] - phi[:-2, :]) / (2*dz)
    
    u_new = u_star - dt * grad_phi_x
    w_new = w_star - dt * grad_phi_z
    
    # Enforce no-slip BC for velocity at top and bottom
    u_new[0, :] = 0.0
    u_new[-1, :] = 0.0
    w_new[0, :] = 0.0
    w_new[-1, :] = 0.0

    # Update pressure (here we simply set p = phi, as correction)
    p = phi.copy()

    # --- Step 3: Update buoyancy b using explicit Euler, convection and diffusion ---
    b_new = b.copy()
    b_int = b[1:-1, :]
    b_x = (np.roll(b, -1, axis=1)[1:-1, :] - np.roll(b, 1, axis=1)[1:-1, :]) / (2*dx)
    b_z = (b[2:, :] - b[:-2, :]) / (2*dz)
    conv_b = u_int * b_x + w_int * b_z
    lap_b = laplacian(b)[1:-1, :]
    b_new[1:-1, :] = b[1:-1, :] + dt * (-conv_b + kappa * lap_b)
    # Enforce buoyancy BC: top: b=0, bottom: b=Lz
    b_new[0, :] = 0.0
    b_new[-1, :] = Lz

    # Set updated fields for next iteration
    u = u_new.copy()
    w = w_new.copy()
    b = b_new.copy()

# Save the final time step solution fields as .npy files (2D arrays)
np.save("u.npy", u)
np.save("w.npy", w)
np.save("p.npy", p)
np.save("b.npy", b)