
import os

from phijax.equations.registry import get_pde
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from types import SimpleNamespace
import jax
import jax.numpy as jnp
from flax import linen as nn


from phijax.data import *
from phijax.equations import *
from phijax.models import *
from phijax.utils_ import Logger, Collection, tree_map
import time

import yaml




config = {
    "pde": "convection",
    "exp_name": "mlp-run1",
    "seed": 0,
    "input_dim": 2,
    "init_batch_size": 4,
    #"num_layers": 10, 
    "epsilon": 50.0,
    "model_config": {
        "hidden_dim": 256, 
        "num_layers": 3, 
        },

    "optim": dict(
        optimizer="Adam",
        learning_rate=1e-3,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        decay_steps=5000,
        decay_rate=0.9,
        staircase=False,
        warmup_steps=0,
        schedule_free=False,
        #clip_norm=0.0,
        grad_accum_steps= 0,
        scheduler=None
    ),

    "weighting": dict(
        init_weights={"ics": 1.0, "bcs": 1.0,  "res": 1.0},
        momentum=0.9,
        use_causal=False,
        scheme=None,
    ), 
    "training": dict(
        batch_size=4096,
        dom = [[0.0, 1.0], [0.0, 2 * jnp.pi]],
        num_points_per_dim=256,
        num_epochs=300000
    ), 
    "logging": dict(
        log_losses=True,
        log_weights=False,
        log_grads=False,
        log_ntk=False,
        log_every=1000,
        use_wandb=False,
        wandb_online=False,
        save_every=5000,
        num_keep_ckpts=5
    ),

}

with open("./config/con.yaml", "r") as f:
    config = yaml.safe_load(f)

config = Collection.from_dict(config)


def get_dataset():
    # 1d convection dataset
    x = jnp.linspace(0, 2*jnp.pi, 256)
    t = jnp.linspace(0, 1.0, 256)
    tt, xx = jnp.meshgrid(t, x, indexing='ij')
    u = jnp.sin(xx - 50 * tt)

    return u, t, x

logger = Logger()




model = get_pde(config)#(config)
res_sampler = iter(model.sampler)
u_ref = model.u_ref

train_time_start = time.time()

for step in range(config.training.num_epochs):
    start_time  = time.time()
    batch = next(res_sampler)
    model.state = model.step(model.state, batch)

    
    if jax.process_index() == 0 and step % 1000 == 0:
        state = jax.device_get(tree_map(lambda x: x[0], model.state))
        batch = jax.device_get(tree_map(lambda x: x[0], batch))
        log_dict = model.log(state, batch, u_ref)
        end_time = time.time()
        logger.log_iter(step, start_time, end_time, log_dict)

train_time_end = time.time()
print(f"Training time: {train_time_end - train_time_start} seconds")
        

