import numpy as np

def solve_cfd():
    # Parameters
    Lx = 4.0
    Lz = 1.0
    Ra = 2e6
    Pr = 1.0
    nu = (Ra/Pr)**(-0.5)
    kappa = (Ra*Pr)**(-0.5)
    nx = 64
    nz = 64
    dt = 0.001
    t_final = 50.0

    # Grid
    x = np.linspace(0, Lx, nx)
    z = np.linspace(0, Lz, nz)
    dx = x[1] - x[0]
    dz = z[1] - z[0]
    X, Z = np.meshgrid(x, z)

    # Initial conditions
    u = np.zeros((nz, nx))
    w = np.zeros((nz, nx))
    b = Lz - Z + 0.01 * np.random.rand(nz, nx)

    # Time loop
    t = 0.0
    while t < t_final:
        # Spatial derivatives (central difference)
        u_x = (np.roll(u, -1, axis=1) - np.roll(u, 1, axis=1)) / (2 * dx)
        u_z = (np.roll(u, -1, axis=0) - np.roll(u, 1, axis=0)) / (2 * dz)
        w_x = (np.roll(w, -1, axis=1) - np.roll(w, 1, axis=1)) / (2 * dx)
        w_z = (np.roll(w, -1, axis=0) - np.roll(w, 1, axis=0)) / (2 * dz)
        b_x = (np.roll(b, -1, axis=1) - np.roll(b, 1, axis=1)) / (2 * dx)
        b_z = (np.roll(b, -1, axis=0) - np.roll(b, 1, axis=0)) / (2 * dz)

        u_xx = (np.roll(u, -1, axis=1) - 2 * u + np.roll(u, 1, axis=1)) / (dx**2)
        u_zz = (np.roll(u, -1, axis=0) - 2 * u + np.roll(u, 1, axis=0)) / (dz**2)
        w_xx = (np.roll(w, -1, axis=1) - 2 * w + np.roll(w, 1, axis=1)) / (dx**2)
        w_zz = (np.roll(w, -1, axis=0) - 2 * w + np.roll(w, 1, axis=0)) / (dz**2)
        b_xx = (np.roll(b, -1, axis=1) - 2 * b + np.roll(b, 1, axis=1)) / (dx**2)
        b_zz = (np.roll(b, -1, axis=0) - 2 * b + np.roll(b, 1, axis=0)) / (dz**2)

        # Momentum equation (explicit)
        u_new = u - dt * (u * u_x + w * u_z) - dt * 0 + dt * nu * (u_xx + u_zz)
        w_new = w - dt * (u * w_x + w * w_z) - dt * 0 + dt * nu * (w_xx + w_zz) + dt * b

        # Buoyancy transport (explicit)
        b_new = b - dt * (u * b_x + w * b_z) + dt * kappa * (b_xx + b_zz)

        # Boundary conditions
        u_new[0, :] = 0.0
        u_new[-1, :] = 0.0
        w_new[0, :] = 0.0
        w_new[-1, :] = 0.0
        b_new[0, :] = Lz
        b_new[-1, :] = 0.0

        # Update variables
        u = u_new
        w = w_new
        b = b_new

        t += dt

    # Save the final solution
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/gemini/prompts/u_2D_Rayleigh_Benard_Convection.npy', u)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/gemini/prompts/w_2D_Rayleigh_Benard_Convection.npy', w)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/gemini/prompts/b_2D_Rayleigh_Benard_Convection.npy', b)

if __name__ == "__main__":
    solve_cfd()