
""
### Trial data generation for 2D Navier-Stokes (NS) with turbulence ### 
from __future__ import annotations

import os
from pathlib import Path
import yaml
import numpy as np

import jax
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp

import logging
from typing import Any, Union, Sequence
from flax.core import FrozenDict 

import jax_cfd.base as cfd
import jax_cfd.base.grids as grids
import jax_cfd.spectral as spectral

Array = Union[np.ndarray, jax.Array]
IntOrSequence = Union[int, Sequence[int]]
PyTree = Any

logging.basicConfig(
    format="[%(asctime)s] %(levelname)s: %(message)s",
    level=logging.INFO) 


def generate_training_and_validation_grid(
    config: Union[dict, FrozenDict], 
) -> None:
    """
    Generate a 2D Navier–Stokes vorticity dataset for training and validation.

    Args:
        config: Dictionary or FrozenDict containing simulation parameters:
            - ns: Grid resolution (int).
            - domain: Domain bounds ((float, float), (float, float)).
            - max_velocity: Maximum velocity magnitude (float).
            - viscosity: Fluid viscosity (float).
            - anti_aliasing: Apply dealiasing filter (bool).
            - tf: Final simulation time (float).
            - outer_steps: Number of saved frames (int).
            - key: PRNGKey for initial conditions.
            - vorticity_temporal_slice: slice object for time axis selection.
            - return_real_space: If True, returns real-space vorticity.
            - mol: Time-stepping method (callable).
            - precision: jnp.dtype for storage.
            - data_save_dir: Directory path for saving results.

    Returns:
        None. Saves coordinate grid and vorticity snapshots to disk.
    """
    # Extract parameters with defaults
    ns = config.get("ns", 0)
    domain = config.get(
        "domain", ((0.0, 2.0 * jnp.pi), (0.0, 2.0 * jnp.pi))
    )
    max_velocity = config.get("max_velocity", 10.0)
    viscosity = config.get("viscosity", 0.0)
    anti_aliasing = config.get("anti_aliasing", True)
    tf = config.get("tf", 25.0)
    outer_steps = config.get("outer_steps", 10)
    snapshots = config.get('temporal_snapshots', outer_steps // 10)
    vorticity_temporal_slice = config.get("vorticity_temporal_slice", slice(0, outer_steps, max(1, snapshots)))
    return_real_space = config.get("return_real_space", True)
    mol = config.get("mol", spectral.time_stepping.crank_nicolson_rk4)
    precision = config.get("precision", jnp.float32)
    data_save_dir = Path(config.get("data_save_dir", "."))

    data_save_dir.mkdir(parents=True, exist_ok=True)

    # Create grid
    grid = grids.Grid((ns, ns), domain=domain)

    # Stable time step (CFL condition)
    dt = cfd.equations.stable_time_step(max_velocity, 0.5, viscosity, grid)
    step_fn = mol(
        spectral.equations.NavierStokes2D(viscosity, grid, smooth=anti_aliasing), dt
    )

    # Number of inner steps between saved frames
    inner_steps = int((tf // dt) // outer_steps)
    trajectory_fn = cfd.funcutils.trajectory(
        cfd.funcutils.repeated(step_fn, inner_steps), outer_steps
    )

    # Initial vorticity in Fourier space
    v0 = cfd.initial_conditions.filtered_velocity_field(
        jax.random.PRNGKey(config.get("seed")), grid, max_velocity, 4
    )
    vorticity0 = cfd.finite_differences.curl_2d(v0).data
    vorticity_hat0 = jnp.fft.rfftn(vorticity0)

    # Time integration
    _, trajectory = trajectory_fn(vorticity_hat0)  # (T, nx, ny//2+1)

    coords = jnp.stack((grid.rfft_mesh()[0], grid.rfft_mesh()[1]), axis=-1) # remains in spectral space with shape (ns, ns//2 + 1, 2)

    if return_real_space:
        coords = jnp.stack((jnp.fft.irfftn(grid.rfft_mesh()[0]), jnp.fft.irfftn(grid.rfft_mesh()[1])), axis=-1) # moves to real space with shape (ns, ns, 2)
        trajectory = jnp.fft.irfftn(trajectory[vorticity_temporal_slice], s=(ns, ns))

        jnp.save(
            data_save_dir / "coord.npy",
            jnp.array(coords, dtype=precision),
            allow_pickle=True,
        )
        jnp.save(
            data_save_dir / "vorticity_trajectory.npy",
            trajectory.astype(precision),
            allow_pickle=True,
        )

        logging.info(
            f"Saved 2D NS vorticity grid ({ns}x{ns}) on real space"
            f"to {data_save_dir}, {outer_steps} steps, slice {vorticity_temporal_slice}."
        )
    else:
        jnp.save(
            data_save_dir / "coord.npy",
            jnp.array(coords, dtype=precision),
            allow_pickle=True,
        )
         
        jnp.save(
            data_save_dir / "vorticity_trajectory.npy",
            trajectory.astype(precision),
            allow_pickle=True,
        )

        logging.info(
            f"Saved 2D NS vorticity grid ({ns}x{ns//2 +1}) on spectral space"
            f"to {data_save_dir}, {outer_steps} steps, slice {vorticity_temporal_slice}."
        )
    return

if __name__ == "__main__":
    os.environ["JAX_PLATFORMS"] = "gpu"
    jax.config.update("jax_enable_x64", False)

    ns = 256
    domain = ((0, 2 * jnp.pi), (0, 2 * jnp.pi))
    eta = 1e-3
    v = 5
    tf = 25.0
    out_steps = 100
    snapshots = 4

    config = {
        "ns": ns,
        "domain": domain,
        "viscosity": eta,
        "max_velocity": v,
        "tf": tf,
        "outer_steps": out_steps,
        "seed": 42,
        'temporal_snapshots': snapshots,
        "vorticity_temporal_slice": slice(0, out_steps, out_steps // snapshots),
        "anti_aliasing": True,
        "return_real_space": True,
        "mol": spectral.time_stepping.crank_nicolson_rk4,
        "precision": jnp.float32,
        "data_save_dir": "..",
    }

    generate_training_and_validation_grid(config)

    # Save config YAML
    config_path = Path(config["data_save_dir"]) / "config.yml"
    with open(config_path, "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)

    logging.info(f"Stored config file at {config_path}")
    logging.info("Simulation data and config storage complete.")



