import numpy as np
import matplotlib.pyplot as plt
from numba import njit

@njit
def pressure_poisson(p, u, v, dx, dz, dt):
    """Solve pressure Poisson equation using Jacobi iteration."""
    nx, nz = p.shape
    p_new = p.copy()
    div = np.zeros_like(p)
    
    # Compute divergence
    for i in range(1, nx-1):
        for j in range(1, nz-1):
            div[i,j] = ((u[i+1,j] - u[i-1,j])/(2*dx) + 
                        (v[i,j+1] - v[i,j-1])/(2*dz))
    
    # Jacobi iteration for pressure
    for _ in range(100):
        p_old = p_new.copy()
        for i in range(1, nx-1):
            for j in range(1, nz-1):
                p_new[i,j] = 0.25 * (
                    p_old[i+1,j] + p_old[i-1,j] + 
                    p_old[i,j+1] + p_old[i,j-1] - 
                    dx*dz*div[i,j]/dt
                )
                
        # Periodic BC for pressure
        p_new[0,:] = p_new[-2,:]
        p_new[-1,:] = p_new[1,:]
        p_new[:,0] = p_new[:,1]
        p_new[:,-1] = p_new[:,-2]
        
        # Check convergence
        if np.max(np.abs(p_new - p_old)) < 1e-6:
            break
    
    return p_new

@njit
def update_velocity(u, v, p, b, nu, dx, dz, dt, Ra, Pr):
    """Update velocity using momentum equation."""
    nx, nz = u.shape
    u_new = u.copy()
    v_new = v.copy()
    
    for i in range(1, nx-1):
        for j in range(1, nz-1):
            # u-momentum
            u_adv_x = u[i,j] * (u[i+1,j] - u[i-1,j])/(2*dx)
            u_adv_z = v[i,j] * (u[i,j+1] - u[i,j-1])/(2*dz)
            
            u_diff_x = nu * (u[i+1,j] - 2*u[i,j] + u[i-1,j]) / (dx**2)
            u_diff_z = nu * (u[i,j+1] - 2*u[i,j] + u[i,j-1]) / (dz**2)
            
            u_press = -(p[i+1,j] - p[i-1,j]) / (2*dx)
            
            u_new[i,j] = u[i,j] + dt * (
                -u_adv_x - u_adv_z + 
                u_diff_x + u_diff_z + 
                u_press
            )
            
            # v-momentum 
            v_adv_x = u[i,j] * (v[i+1,j] - v[i-1,j])/(2*dx)
            v_adv_z = v[i,j] * (v[i,j+1] - v[i,j-1])/(2*dz)
            
            v_diff_x = nu * (v[i+1,j] - 2*v[i,j] + v[i-1,j]) / (dx**2)
            v_diff_z = nu * (v[i,j+1] - 2*v[i,j] + v[i,j-1]) / (dz**2)
            
            v_press = -(p[i,j+1] - p[i,j-1]) / (2*dz)
            
            # Buoyancy term
            v_buoy = b[i,j]
            
            v_new[i,j] = v[i,j] + dt * (
                -v_adv_x - v_adv_z + 
                v_diff_x + v_diff_z + 
                v_press + v_buoy
            )
    
    # Periodic BC for velocities
    u_new[0,:] = u_new[-2,:]
    u_new[-1,:] = u_new[1,:]
    v_new[0,:] = v_new[-2,:]
    v_new[-1,:] = v_new[1,:]
    
    # No-slip BC at top and bottom
    u_new[:,0] = 0
    u_new[:,-1] = 0
    v_new[:,0] = 0 
    v_new[:,-1] = 0
    
    return u_new, v_new

@njit
def update_buoyancy(b, u, v, kappa, dx, dz, dt):
    """Update buoyancy using transport equation."""
    nx, nz = b.shape
    b_new = b.copy()
    
    for i in range(1, nx-1):
        for j in range(1, nz-1):
            # Advection terms
            b_adv_x = u[i,j] * (b[i+1,j] - b[i-1,j])/(2*dx)
            b_adv_z = v[i,j] * (b[i,j+1] - b[i,j-1])/(2*dz)
            
            # Diffusion terms
            b_diff_x = kappa * (b[i+1,j] - 2*b[i,j] + b[i-1,j]) / (dx**2)
            b_diff_z = kappa * (b[i,j+1] - 2*b[i,j] + b[i,j-1]) / (dz**2)
            
            b_new[i,j] = b[i,j] + dt * (
                -b_adv_x - b_adv_z + 
                b_diff_x + b_diff_z
            )
    
    # Periodic BC for buoyancy
    b_new[0,:] = b_new[-2,:]
    b_new[-1,:] = b_new[1,:]
    
    # Boundary conditions
    b_new[:,0] = 0  # Top boundary 
    b_new[:,-1] = 1  # Bottom boundary
    
    return b_new

def main():
    # Parameters
    Lx, Lz = 4, 1
    nx, nz = 128, 32
    Ra, Pr = 2e6, 1
    
    # Derived parameters
    nu = (Ra/Pr)**(-0.5)
    kappa = (Ra*Pr)**(-0.5)
    
    dx, dz = Lx/(nx-1), Lz/(nz-1)
    dt = 0.001
    t_end = 50
    
    # Initialize fields
    x = np.linspace(0, Lx, nx)
    z = np.linspace(0, Lz, nz)
    
    # Initial conditions with small random perturbation
    np.random.seed(0)
    b = np.zeros((nx, nz))
    for i in range(nx):
        for j in range(nz):
            b[i,j] = Lz - z[j] + 0.01 * np.random.rand()
    
    u = np.zeros((nx, nz))
    v = np.zeros((nx, nz))
    p = np.zeros((nx, nz))
    
    # Time stepping
    t = 0
    while t < t_end:
        # Pressure correction
        p = pressure_poisson(p, u, v, dx, dz, dt)
        
        # Update velocities
        u, v = update_velocity(u, v, p, b, nu, dx, dz, dt, Ra, Pr)
        
        # Update buoyancy
        b = update_buoyancy(b, u, v, kappa, dx, dz, dt)
        
        t += dt
    
    # Save final solutions
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/u_2D_Rayleigh_Benard_Convection.npy', u)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/v_2D_Rayleigh_Benard_Convection.npy', v)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/b_2D_Rayleigh_Benard_Convection.npy', b)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/p_2D_Rayleigh_Benard_Convection.npy', p)

if __name__ == "__main__":
    main()