import logging
from pathlib import Path

from hydra_zen import zen, store, builds, load_from_yaml, to_yaml
from hydra.core.hydra_config import HydraConfig

from tqdm import tqdm
from time import time
import numpy as np
import xarray as xr

from dataclasses import dataclass

from phi.torch.flow import (  # SoftGeometryMask,; Sphere,; batch,; tensor,
    Box,
    CenteredGrid,
    Noise,
    StaggeredGrid,
    advect,
    diffuse,
    extrapolation,
    fluid,
    jit_compile,
)
from phi.math import reshaped_native

# initialize logger
log = logging.getLogger(__name__)

@dataclass
class NavierStokes2D:
    num_samples: int
    tmin: float # start time
    tmax: float # end time
    noise_scale: float
    noise_smoothness: float
    nt: int # number of timesteps
    nx: int # grid size in x dimension
    ny: int # grid size in y dimension
    Lx: float # domain size in x dimension (real units)
    Ly: float # domain size in y dimension (real units)
    buoyancy_factor: float # buoyancy in the y dimension using the Boussinesq approximation
    diffusion_coef: float
    # skip_nt: int # number of steps to skip at the beginning of the trajectory
    output_path: Path

@jit_compile
def step(smoke, velocity, pressure, dt, buoyancy_factor, diffusion_coef):
    smoke = advect.semi_lagrangian(smoke, velocity, dt) # default dt is 1.5
    buoyancy_force = (smoke * (0, buoyancy_factor)).at(velocity)  # resamples smoke to velocity sample points
    velocity = advect.semi_lagrangian(velocity, velocity, dt) + dt * buoyancy_force
    velocity = diffuse.explicit(velocity, diffusion_coef, dt)
    velocity, pressure = fluid.make_incompressible(velocity)
    return smoke, velocity, pressure

def generate_navier_stokes(ns_cfg: NavierStokes2D):
    log.info(store)
    # get the hydra config output directory
    output_dir = Path(HydraConfig.get().run.dir)
    # get config object
    cfg = load_from_yaml(output_dir / ".hydra" / "config.yaml")
    # log configuration
    log.info(to_yaml(cfg))
    dt = (ns_cfg.tmax - ns_cfg.tmin) / ns_cfg.nt
    sim_id_coords = np.arange(1, ns_cfg.num_samples+1)
    time_coords = np.linspace(ns_cfg.tmin + dt, ns_cfg.tmax, ns_cfg.nt)
    x_coords = np.linspace(0., ns_cfg.Lx, ns_cfg.nx)
    y_coords = np.linspace(0., ns_cfg.Ly, ns_cfg.ny)
    smoke_da = xr.DataArray(np.zeros((ns_cfg.num_samples, ns_cfg.nt, ns_cfg.nx, ns_cfg.ny)), dims=("sim_id", "time", "y", "x"), coords={"sim_id": sim_id_coords, "time": time_coords, "y": y_coords, "x": x_coords})
    velocity_x_da = xr.DataArray(np.zeros((ns_cfg.num_samples, ns_cfg.nt, ns_cfg.nx, ns_cfg.ny)), dims=("sim_id", "time", "y", "x"), coords={"sim_id": sim_id_coords, "time": time_coords, "y": y_coords, "x": x_coords})
    velocity_y_da = xr.DataArray(np.zeros((ns_cfg.num_samples, ns_cfg.nt, ns_cfg.nx, ns_cfg.ny)), dims=("sim_id", "time", "y", "x"), coords={"sim_id": sim_id_coords, "time": time_coords, "y": y_coords, "x": x_coords})
    avg_sim_time = 0.0
    for i in tqdm(range(ns_cfg.num_samples)):
        log.info(f"Generating sample {i}")
        smoke = abs(
                CenteredGrid(
                    Noise(scale=ns_cfg.noise_scale, smoothness=ns_cfg.noise_smoothness),
                    # extrapolation.BOUNDARY, # this is the same a zero_gradient
                    extrapolation.ZERO_GRADIENT,
                    x=ns_cfg.nx,
                    y=ns_cfg.ny,
                    bounds=Box['x,y', 0 : ns_cfg.Lx, 0 : ns_cfg.Ly],
                )
                )  # sampled at cell centers
        velocity = StaggeredGrid(
                0, extrapolation.ZERO, x=ns_cfg.nx, y=ns_cfg.ny, bounds=Box['x,y', 0 : ns_cfg.Lx, 0 : ns_cfg.Ly]
            )  # sampled in staggered form at face centers
        fluid_field_ = []
        velocity_ = []
        step_time = 0.0
        start_time = time()
        for t in range(0, ns_cfg.nt):
            # log.info(f"simulating step {t}")
            smoke, velocity, _ = step(smoke, velocity, None, dt, ns_cfg.buoyancy_factor, ns_cfg.diffusion_coef)
            end_step_time = time()
            fluid_field_.append(reshaped_native(smoke.values, groups=("x", "y", "vector"), to_numpy=True))
            velocity_.append(
                reshaped_native(
                    velocity.staggered_tensor(),
                    groups=("x", "y", "vector"),
                    to_numpy=True,
                )
            )
        # fluid_field_ = np.asarray(fluid_field_[8 :]).squeeze()
        fluid_field_ = np.asarray(fluid_field_).squeeze()
        smoke_da[i] = fluid_field_.transpose(0, 2, 1) # transpose dims to get y and x in the right place
        # velocity has the shape [nt, nx+1, ny+2, 2]
        # velocity_corrected_ = np.asarray(velocity_[8 :]).squeeze()[:, :-1, :-1, :]
        velocity_corrected_ = np.asarray(velocity_).squeeze()[:, :-1, :-1, :]
        velocity_x_da[i] = velocity_corrected_[:, :, :, 0].transpose(0, 2, 1)
        velocity_y_da[i] = velocity_corrected_[:, :, :, 1].transpose(0, 2, 1)
        end_time = time()
        avg_sim_time += end_time - start_time
        # log.info(f"Time taken: {end_time - start_time}")
    avg_sim_time = avg_sim_time / ns_cfg.num_samples
    log.info(f"Simulations complete, average sim time: {avg_sim_time}")
    # create dataset
    ds = xr.Dataset(dict(smoke=smoke_da, velocity_x=velocity_x_da, velocity_y = velocity_y_da))
    # save dataset to hdf5 file
    log.info(f"Saving dataset to hdf5 file {ns_cfg.output_path}")
    ds.to_netcdf(ns_cfg.output_path)

ns_store = store(group="ns_cfg")
# configuration for 128x128 dataset with 64 timesteps and dt=1.5
ns_store(NavierStokes2D, num_samples=1000, output_path=Path("../data/ns_data_128_1000.h5"), tmin=0.0, tmax=96.0, noise_scale=11.0, noise_smoothness=6.0, nt=64, nx=128, ny=128, Lx=32.0, Ly=32.0, buoyancy_factor=0.5, diffusion_coef=0.01, name="default_128")
# configuration for 64x64 dataset with 64 timesteps and dt=1.5
ns_store(NavierStokes2D, num_samples=1000, output_path=Path("../data/ns_data_64_1000_test.h5"), tmin=0.0, tmax=96.0, noise_scale=11.0, noise_smoothness=6.0, nt=64, nx=64, ny=64, Lx=32.0, Ly=32.0, buoyancy_factor=0.5, diffusion_coef=0.01, name="default_64")
# test/debug configuration
ns_store(NavierStokes2D, num_samples=5, output_path=Path("../data/ns_data_128_5_test.h5"), tmin=0.0, tmax=96.0, noise_scale=11.0, noise_smoothness=6.0, nt=64, nx=128, ny=128, Lx=32.0, Ly=32.0, buoyancy_factor=0.5, diffusion_coef=0.01, name="test_128")

# default config
store(
    generate_navier_stokes,
    hydra_defaults=[
        "_self_",
        # default config:
        {"ns_cfg": "default_64"},
    ],
)

if __name__ == "__main__":
    store.add_to_hydra_store()
    # Generate the CLI
    zen(generate_navier_stokes).hydra_main(
        config_name="generate_navier_stokes",
        config_path=None,
        version_base="1.3",
    )



