"""Solve 2D incompressible Navier-Stokes (vorticity form) with Crank-Nicolson,
save HDF5 in a self-checkable format, and provide a constraint checker.

Adapted from:
https://github.com/zongyi-li/fourier_neural_operator/blob/master/data_generation/navier_stokes/ns_2d.py
"""

import os
import math
import argparse

import h5py
import numpy as np
import torch
from einops import rearrange, repeat
from tqdm import tqdm

from random_fields import GaussianRF


# =========================
#   Fourier-space solver
# =========================

def solve_navier_stokes_2d(w0, f, visc=1e-3, T=49, delta_t=1e-3, record_steps=50):
    """Solve Navier-Stokes equations in 2D using Crank-Nicolson method.

    Parameters
    ----------
    w0 : torch.Tensor, shape (B, N, N)
        Initial vorticity field batch.

    f : torch.Tensor, shape (B, N, N)
        Forcing term batch.

    visc : float or np.ndarray
        Viscosity (1/Re). If array of shape (B,), uses per-sample viscosity.

    T : float
        Final physical time.

    delta_t : float
        Internal time step.

    record_steps : int
        Number of snapshots to record (including t=0).
    """
    N = w0.shape[-1]
    k_max = math.floor(N / 2)
    steps = math.ceil(T / delta_t)

    # FFT of initial vorticity and forcing
    w_h = torch.fft.fftn(w0, dim=[1, 2], norm='backward')
    f_h = torch.fft.fftn(f,  dim=[-2, -1], norm='backward')
    if len(f_h.shape) < len(w_h.shape):
        f_h = rearrange(f_h, '... -> 1 ...')

    # recording stride
    record_time = math.floor(steps / (record_steps - 1))

    # wave numbers
    k_y = torch.cat((
        torch.arange(start=0, end=k_max, step=1, device=w0.device),
        torch.arange(start=-k_max, end=0,  step=1, device=w0.device)), 0
    ).repeat(N, 1)
    k_x = k_y.transpose(0, 1)

    # Laplacians
    lap_base = 4 * (math.pi ** 2) * (k_x ** 2 + k_y ** 2)

    # CHANGED: split Laplacian
    #   lap     : true Laplacian (zero mode = 0), used in time stepping
    #   lap_psi : safe Laplacian (zero mode set to 1), only for Poisson inversion
    lap = lap_base.clone()
    lap_psi = lap_base.clone()
    lap_psi[0, 0] = 1.0

    if isinstance(visc, np.ndarray):
        visc = torch.from_numpy(visc).to(w0.device)
        visc = repeat(visc, 'b -> b m n', m=N, n=N)
        lap = repeat(lap, 'm n -> b m n', b=w0.shape[0])
        lap_psi = repeat(lap_psi, 'm n -> b m n', b=w0.shape[0])

    # 2/3 dealias
    dealias = torch.unsqueeze(
        torch.logical_and(
            torch.abs(k_y) <= (2.0 / 3.0) * k_max,
            torch.abs(k_x) <= (2.0 / 3.0) * k_max
        ).float(), 0
    )

    # outputs
    sol = torch.zeros(*w0.size(), record_steps, device=w0.device)
    sol_t = torch.zeros(record_steps, device=w0.device)
    sol[..., 0] = w0
    sol_t[0] = 0.0

    t = 0.0
    c = 1

    for j in tqdm(range(steps)):
        # stream function: solve Poisson in Fourier
        psi_h = w_h / lap_psi  # CHANGED: only here we use lap_psi with (0,0)=1

        # velocity q = psi_y, v = -psi_x
        q = psi_h.clone()
        q_real_temp = q.real.clone()
        q.real = -2 * math.pi * k_y * q.imag
        q.imag =  2 * math.pi * k_y * q_real_temp
        q = torch.fft.ifftn(q, dim=[1, 2], norm='backward').real

        v = psi_h.clone()
        v_real_temp = v.real.clone()
        v.real =  2 * math.pi * k_x * v.imag
        v.imag = -2 * math.pi * k_x * v_real_temp
        v = torch.fft.ifftn(v, dim=[1, 2], norm='backward').real

        # grad w
        w_x = w_h.clone()
        w_x_temp = w_x.real.clone()
        w_x.real = -2 * math.pi * k_x * w_x.imag
        w_x.imag =  2 * math.pi * k_x * w_x_temp
        w_x = torch.fft.ifftn(w_x, dim=[1, 2], norm='backward').real

        w_y = w_h.clone()
        w_y_temp = w_y.real.clone()
        w_y.real = -2 * math.pi * k_y * w_y.imag
        w_y.imag =  2 * math.pi * k_y * w_y_temp
        w_y = torch.fft.ifftn(w_y, dim=[1, 2], norm='backward').real

        # nonlinear term in physical space, back to Fourier
        F_h = torch.fft.fftn(q * w_x + v * w_y, dim=[1, 2], norm='backward')
        F_h *= dealias

        # Crank-Nicolson for diffusion; explicit for advection + forcing
        factor = 0.5 * delta_t * visc * lap  # CHANGED: true lap, zero mode = 0
        # safety: enforce no diffusion of zero mode explicitly
        if factor.ndim == 2:
            factor[0, 0] = 0.0
        else:
            factor[..., 0, 0] = 0.0

        num = -delta_t * F_h + delta_t * f_h + (1.0 - factor) * w_h
        w_h = num / (1.0 + factor)

        t += delta_t
        if (j + 1) % record_time == 0:
            w = torch.fft.ifftn(w_h, dim=[1, 2], norm='backward').real
            if w.isnan().any().item():
                raise ValueError('NaN values found.')
            sol[..., c] = w
            sol_t[c] = t
            c += 1

    return sol.cpu().numpy(), sol_t.cpu().numpy()


# =========================
#   Data generator + I/O
# =========================

@torch.no_grad()
def navier_stokes(
        root: str,
        nw: int = 100,
        nf: int = 100,
        s: int = 64,
        t: float = 49.0,
        steps: int = 50,
        mu: float = 1e-3,
        batch_size: int = 1024,
        seed: int = 42,
        delta: float = 1e-3,
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.manual_seed(seed)
    np.random.seed(seed + 1234)

    if os.path.dirname(root):
        os.makedirs(os.path.dirname(root), exist_ok=True)
    path = os.path.join(root, f'ns2d_nw{nw}_nf{nf}_s{s}_steps{steps}_mu{mu}.h5')

    # domain and spacings for checks
    Lx = 1.0
    Ly = 1.0
    dx = Lx / s
    dy = Ly / s

    # Gaussian random ICs (vorticity)
    GRF = GaussianRF(2, s, alpha=2.5, tau=7, device=device)

    # prepare IC and forcing template on grid
    w0 = GRF.sample(nw)                         # (nw, s, s)
    ft = torch.linspace(0, 1, s + 1, device=device)[:-1]
    X, Y = torch.meshgrid(ft, ft, indexing='ij')

    # CHANGED: phi ~ U(0, pi/2), amplitude = 0.1 / sqrt(2)
    # phi = (math.pi / 2) * torch.rand(nf, dtype=torch.float, device=device)  # (nf,)
    phi = np.pi / 2 * torch.linspace(0, 1, nf, dtype=torch.float, device=device)
    fs = (0.1 / math.sqrt(2)) * torch.sin(2 * math.pi * (X + Y).unsqueeze(0) + phi.view(-1, 1, 1))  # (nf, s, s)

    # expand to nsim = nw*nf
    w0_sim = w0.unsqueeze(1).repeat(1, nf, 1, 1).view(-1, s, s)             # (nsim, s, s)
    f_sim  = fs.unsqueeze(0).repeat(nw, 1, 1, 1).view(-1, s, s)             # (nsim, s, s)
    phi_sim = phi.unsqueeze(0).repeat(nw, 1).reshape(-1)                     # (nsim,)
    nsim = w0_sim.shape[0]

    # allocate HDF5 with self-checkable structure
    with h5py.File(path, 'w') as f:
        f.create_dataset('solution',           shape=(nsim, s, s, steps), dtype=np.float32)
        f.create_dataset('initial_condition',  shape=(nsim, s, s),       dtype=np.float32)
        f.create_dataset('force',              shape=(nsim, s, s),       dtype=np.float32)
        f.create_dataset('phi',                shape=(nsim,),            dtype=np.float32)
        f.create_dataset('nu',                 shape=(nsim,),            dtype=np.float32)
        f.create_dataset('time',               shape=(steps,),           dtype=np.float32)

        # also store spacings for integral checks
        f.create_dataset('dx', data=np.float32(dx))
        f.create_dataset('dy', data=np.float32(dy))
        f.create_dataset('Lx', data=np.float32(Lx))
        f.create_dataset('Ly', data=np.float32(Ly))

        # write IC/force/nu first
        f['initial_condition'][:] = w0_sim.cpu().numpy()
        f['force'][:]             = f_sim.cpu().numpy()
        f['phi'][:]               = phi_sim.cpu().numpy()
        f['nu'][:]                = np.full((nsim,), np.float32(mu), dtype=np.float32)

        # batched solve
        sols = []
        time_ref = None
        n_batch = math.ceil(nsim / batch_size)
        i = 0
        for wb, fb in zip(torch.split(w0_sim, batch_size, dim=0),
                          torch.split(f_sim,  batch_size, dim=0)):
            i += 1
            print(f'Batch {i} / {n_batch}')
            sol_b, time_b = solve_navier_stokes_2d(wb.to(device), fb.to(device),
                                                   visc=mu, T=t, delta_t=delta, record_steps=steps)
            sols.append(sol_b)
            if time_ref is None:
                time_ref = time_b.astype(np.float32)

        sols = np.concatenate(sols, axis=0).reshape(nsim, s, s, steps).astype(np.float32)
        f['solution'][:] = sols
        f['time'][:] = time_ref

    print(f'Done. Dataset saved to {path}')
    return path


# =========================
#   Constraint checker
# =========================

def check_navier_stokes(h5_path, atol_ic=1e-7, atol_mass=1e-9):
    """Check IC consistency and global mass (total vorticity) conservation."""
    with h5py.File(h5_path, 'r') as f:
        sol = f['solution'][:]              # (N, s, s, steps)
        ic  = f['initial_condition'][:]     # (N, s, s)
        time = f['time'][:]                 # (steps,)
        # read spacings; fallback if absent
        dx = float(f['dx'][()]) if 'dx' in f else 1.0 / sol.shape[1]
        dy = float(f['dy'][()]) if 'dy' in f else 1.0 / sol.shape[2]

    N, s1, s2, steps = sol.shape
    assert s1 == s2, "Only square grids are assumed."

    # 1) initial condition check: sol[..., 0] == ic
    ic_ok = np.allclose(sol[..., 0], ic, atol=atol_ic, rtol=0.0)
    if ic_ok:
        pass
    else:
        diff = np.max(np.abs(sol[..., 0] - ic))


    # 2) global mass (total vorticity) conservation over space for all time
    area_elem = dx * dy
    mass_t = sol.reshape(N, s1 * s2, steps).sum(axis=1) * area_elem   # (N, steps)
    drift = np.max(np.abs(mass_t - mass_t[:, [0]]))                   # max over all samples and times





    return dict(ic_ok=ic_ok, max_mass_drift=float(drift), steps=int(steps), times=time.tolist())


# =========================
#   CLI
# =========================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="2D incompressible Navier-Stokes data generation and constraint check."
    )
    parser.add_argument('--root', type=str, default='.',
                        help='root to save the data (HDF5 filename is auto-generated)')
    parser.add_argument('--nw', type=int, default=90,
                        help='number of initial vorticity fields')
    parser.add_argument('--nf', type=int, default=90,
                        help='number of forcing phases')
    parser.add_argument('--s', type=int, default=64,
                        help='grid size (s x s)')
    parser.add_argument('--t', type=float, default=49.0,
                        help='final physical time')
    parser.add_argument('--steps', type=int, default=50,
                        help='number of recorded snapshots (including t=0)')
    parser.add_argument('--mu', type=float, default=1e-3,
                        help='viscosity (1/Re)')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='batch size for solver')
                        # 42
    parser.add_argument('--seed', type=int, default=3407,
                        help='seed value for reproducibility')
    parser.add_argument('--delta', type=float, default=1e-3,
                        help='internal time step for solver')
    parser.add_argument('--check_only', action='store_true',
                        help='only run the checker on the most recent file (path via --h5)')
    parser.add_argument('--h5', type=str, default='',
                        help='path to an existing HDF5 for --check_only')

    args = parser.parse_args()

    if args.check_only:
        if not args.h5:
            raise ValueError("Please provide --h5 to check.")
        check_navier_stokes(args.h5)
    else:
        out_path = navier_stokes(
            args.root, args.nw, args.nf, args.s, args.t,
            args.steps, args.mu, args.batch_size, args.seed, args.delta
        )
  
        check_navier_stokes(out_path, atol_ic=1e-7, atol_mass=1e-6)
        
