
import os
import sys

if not __package__:
    # Make CLI runnable from source tree with
    #    python src/package
    package_source_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.insert(0, package_source_path)

from pathlib import Path
import hydra
import jax
import jax.numpy as jnp
from omegaconf import DictConfig
from hydra.utils import instantiate
from jax import device_put
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import pi_lr

OmegaConf.register_new_resolver("mul_pi", lambda x: jnp.pi * x)
@hydra.main(version_base=None, config_path="../config", config_name="generate_data.yaml")
def main(cfg: DictConfig) -> None:
    # basic parameters
    equation = instantiate(cfg.data.equation)
    
    key = jax.random.PRNGKey(cfg.random_seed) if cfg.random_seed is not None else jax.random.PRNGKey(0)
    if cfg.nx > 0:
        xs, dx = equation.xs(cfg.nx, retstep=True)
        y0 = equation.init_dist(numbers=cfg.numbers, key=key, xs=xs)
    else:
        xs, dx = None, None
        y0 = equation.init_dist(numbers=cfg.numbers, key=key)
    y0 = device_put(y0)  # putting variables in GPU (not necessary??)
    
    ts, dt = equation.ts(cfg.nt, retstep=True)
    print(f"dt = {dt}, dx = {dx}")
    vm_evolve = jax.pmap(jax.vmap(lambda y0: equation.simulate(y0, ts, xs), axis_name="j"), axis_name="i")
    local_devices = 1
    y = vm_evolve(y0.reshape([local_devices, cfg.numbers // local_devices, -1]))
    
    # reshape before saving
    y = y.reshape((-1, *y.shape[2:]))
    print("nan check: ", jnp.isnan(y).any())
    
    print("data saving...")
    cwd = hydra.utils.get_original_cwd() + "/"
    out_dir = Path(cwd + cfg.out_dir + f"/{repr(equation)}-nt:{cfg.nt}-nx:{cfg.nx}")
    out_dir.mkdir(parents=True, exist_ok=True)
    
    if xs is None:
        X = ts
    else:
        X = jnp.stack(jnp.meshgrid(*([ts] + [xs] * (len(equation.domain) - 1)), indexing='ij'), axis=-1)
    jnp.save(out_dir / "X.npy", X)
    jnp.save(out_dir / "y.npy", y)
    jnp.save(out_dir / "y0.npy", y0)
    OmegaConf.save(cfg, out_dir / "config.yaml")
    
    if cfg.visualize:
        print("visualizing...")
        save_path = out_dir / f"vis.png"
        equation.visualize(X, y, save_path, n_samples=5)
        
        
if __name__ == "__main__":
    main()