import numpy as np
import scipy.fftpack as fftpack

def periodic_bc(u):
    u_periodic = u.copy()
    u_periodic[0, :] = u[-2, :]
    u_periodic[-1, :] = u[1, :]
    u_periodic[:, 0] = u[:, -2]
    u_periodic[:, -1] = u[:, 1]
    return u_periodic

def initial_condition_u(x, z):
    return 0.5 * (1 + np.tanh((z - 0.5)/0.1) - np.tanh((z + 0.5)/0.1))

def initial_condition_w(x, z):
    return 0.01 * np.sin(np.pi * z) * np.sin(np.pi * x)

def solve_navier_stokes_2d():
    # Domain parameters
    Lx, Lz = 1.0, 2.0
    nx, nz = 128, 256
    x = np.linspace(0, Lx, nx)
    z = np.linspace(-1, 1, nz)
    dx, dz = x[1] - x[0], z[1] - z[0]

    # Time parameters
    t_start, t_end = 0, 20
    dt = 0.01
    nt = int((t_end - t_start) / dt)

    # Physical parameters
    nu = 1/5e4
    D = nu

    # Initialize fields
    u = np.zeros((nx, nz))
    w = np.zeros((nx, nz))
    s = np.zeros((nx, nz))

    # Initial conditions
    for i in range(nx):
        for j in range(nz):
            u[i, j] = initial_condition_u(x[i], z[j])
            w[i, j] = initial_condition_w(x[i], z[j])
            s[i, j] = u[i, j]

    # Time stepping using pseudo-spectral method
    for n in range(nt):
        # Compute derivatives in spectral space
        u_hat = fftpack.fft2(u)
        w_hat = fftpack.fft2(w)
        s_hat = fftpack.fft2(s)

        # Wavenumbers
        kx = fftpack.fftfreq(nx, dx) * 2 * np.pi
        kz = fftpack.fftfreq(nz, dz) * 2 * np.pi

        # Nonlinear terms
        u_adv_hat = -1j * (kx[:, np.newaxis] * u_hat + kz[np.newaxis, :] * w_hat) * u_hat
        w_adv_hat = -1j * (kx[:, np.newaxis] * u_hat + kz[np.newaxis, :] * w_hat) * w_hat
        s_adv_hat = -1j * (kx[:, np.newaxis] * u_hat + kz[np.newaxis, :] * w_hat) * s_hat

        # Diffusive terms
        u_diff_hat = -nu * (kx[:, np.newaxis]**2 + kz[np.newaxis, :]**2) * u_hat
        w_diff_hat = -nu * (kx[:, np.newaxis]**2 + kz[np.newaxis, :]**2) * w_hat
        s_diff_hat = -D * (kx[:, np.newaxis]**2 + kz[np.newaxis, :]**2) * s_hat

        # Update in spectral space
        u_hat += dt * (u_diff_hat + u_adv_hat)
        w_hat += dt * (w_diff_hat + w_adv_hat)
        s_hat += dt * (s_diff_hat + s_adv_hat)

        # Transform back to physical space
        u = np.real(fftpack.ifft2(u_hat))
        w = np.real(fftpack.ifft2(w_hat))
        s = np.real(fftpack.ifft2(s_hat))

        # Apply periodic boundary conditions
        u = periodic_bc(u)
        w = periodic_bc(w)
        s = periodic_bc(s)

    # Save final solutions
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/u_2D_Shear_Flow_With_Tracer.npy', u)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/w_2D_Shear_Flow_With_Tracer.npy', w)
    np.save('/opt/CFD-Benchmark/PDE_Benchmark/results/prediction/sonnet-35/prompts/s_2D_Shear_Flow_With_Tracer.npy', s)

solve_navier_stokes_2d()