import numpy as np
import matplotlib.pyplot as plt

# Parameters
Lx, Lz = 4.0, 1.0
Nx, Nz = 128, 32
dx, dz = Lx / Nx, Lz / Nz
Ra, Pr = 2e6, 1.0
nu = (Ra / Pr) ** -0.5
kappa = (Ra * Pr) ** -0.5
dt = 0.001
T = 50.0
Nt = int(T / dt)

# Grids
x = np.linspace(0, Lx, Nx, endpoint=False)
z = np.linspace(0, Lz, Nz, endpoint=False)
X, Z = np.meshgrid(x, z, indexing='ij')

# Initial conditions
u = np.zeros((Nx, Nz))
w = np.zeros((Nx, Nz))
b = Lz - Z + 0.01 * np.random.rand(Nx, Nz)

# Helper functions for finite difference
def periodic_bc(arr):
    arr[0, :] = arr[-2, :]
    arr[-1, :] = arr[1, :]
    return arr

def laplacian(arr, dx, dz):
    lap = np.zeros_like(arr)
    lap[1:-1, 1:-1] = (arr[2:, 1:-1] - 2 * arr[1:-1, 1:-1] + arr[:-2, 1:-1]) / dx**2 + \
                      (arr[1:-1, 2:] - 2 * arr[1:-1, 1:-1] + arr[1:-1, :-2]) / dz**2
    return lap

def gradient(arr, dx, dz):
    grad_x = np.zeros_like(arr)
    grad_z = np.zeros_like(arr)
    grad_x[1:-1, :] = (arr[2:, :] - arr[:-2, :]) / (2 * dx)
    grad_z[:, 1:-1] = (arr[:, 2:] - arr[:, :-2]) / (2 * dz)
    return grad_x, grad_z

# Time-stepping loop
for n in range(Nt):
    # Compute derivatives
    u_x, u_z = gradient(u, dx, dz)
    w_x, w_z = gradient(w, dx, dz)
    b_x, b_z = gradient(b, dx, dz)
    
    # Nonlinear terms
    u_adv = u * u_x + w * u_z
    w_adv = u * w_x + w * w_z
    b_adv = u * b_x + w * b_z
    
    # Laplacians
    u_lap = laplacian(u, dx, dz)
    w_lap = laplacian(w, dx, dz)
    b_lap = laplacian(b, dx, dz)
    
    # Update equations
    u[1:-1, 1:-1] += dt * (-u_adv[1:-1, 1:-1] + nu * u_lap[1:-1, 1:-1])
    w[1:-1, 1:-1] += dt * (-w_adv[1:-1, 1:-1] + nu * w_lap[1:-1, 1:-1] + b[1:-1, 1:-1])
    b[1:-1, 1:-1] += dt * (-b_adv[1:-1, 1:-1] + kappa * b_lap[1:-1, 1:-1])
    
    # Apply boundary conditions
    u[:, 0] = 0
    u[:, -1] = 0
    w[:, 0] = 0
    w[:, -1] = 0
    b[:, 0] = Lz
    b[:, -1] = 0
    
    # Apply periodic boundary conditions
    u = periodic_bc(u)
    w = periodic_bc(w)
    b = periodic_bc(b)

# Save final results
np.save('/home/weichao/Downloads/Code_Generation_Benchmark/PDE_Benchmark/results/prediction/gpt-4o/prompts/u_2D_Rayleigh_Benard_Convection.npy', u)
np.save('/home/weichao/Downloads/Code_Generation_Benchmark/PDE_Benchmark/results/prediction/gpt-4o/prompts/w_2D_Rayleigh_Benard_Convection.npy', w)
np.save('/home/weichao/Downloads/Code_Generation_Benchmark/PDE_Benchmark/results/prediction/gpt-4o/prompts/b_2D_Rayleigh_Benard_Convection.npy', b)