from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap
from jax.tree_util import tree_map, tree_reduce, tree_leaves

import optax

from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP
from phijax.equations.registry import register_pde





@register_pde(
    "wave", aliases=["wave1d", "1d_wave"], 
    defaults={ "epsilon": 3.0 , "num_points_per_dim": 256}
)
class Wave1D(IVP):
    loss_keys = ("ics_u", "ics_v", "bcs_r", "bcs_l", "res")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config

        num_pts = pcfg.num_points_per_dim or 256
        self.epsilon = pcfg.epsilon or 3.0
        ####data 
        x = jnp.linspace(0, 1.0, num_pts)
        t = jnp.linspace(0, 1.0, num_pts)
        tt, xx = jnp.meshgrid(t, x, indexing="ij")

        u_ref = jnp.sin(jnp.pi * xx) * jnp.cos(2.0 * jnp.pi  * tt) + 0.5 *jnp.sin(
        self.epsilon * jnp.pi * xx) * jnp.cos(2.0 * jnp.pi * self.epsilon * tt
        )
        u0 = jnp.sin(jnp.pi * x) + 0.5 * jnp.sin(self.epsilon * jnp.pi * x)
        self.u0 = u0
        self.t_star = t
        self.x_star = x
        self.t0, self.t1, self.x0, self.x1 = self.t_star[0], self.t_star[-1], self.x_star[0], self.x_star[-1]

        self.dom = jnp.array([[self.t0, self.t1], [self.x0, self.x1]])
        self.u_ref = u_ref
        if  self.config.training.get('sampler', None) is None:
            self.config.training.sampler = "uniform"
        if self.config.training.sampler == "uniform":
            print("Using random/uniform sampler")
            self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)
        if self.config.training.sampler == "fixed":
            print("Using fixed mesh sampler", num_pts)
            self.sampler = MeshSampler(self.dom, res=[51, 51], batch_size=self.config.training.batch_size)

        # Predictions over a grid
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

    def u_net(self, state, t, x):
        z = jnp.stack([t, x])
        _, u = self.state.apply_fn(state.variables(), z)
        return u[0]
    
    def r_net(self, state, t, x):
        u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(state, t, x)
        u_tt = grad(grad(self.u_net, argnums=1), argnums=1)(state, t, x)
        return u_tt - 4.0 * u_xx
    
    def u_t(self, state, t, x):
        return grad(self.u_net, argnums=1)(state, t, x)
    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u = vmap(self.u_net, (None, None, 0))(state, self.t0, self.x_star)
        v = vmap(self.u_t, (None, None, 0))(state, self.t0, self.x_star)
        # Boundary condition loss
        u_bc1_pred = vmap(self.u_net, (None, 0, None))(state, self.t_star, self.x_star[0])
        u_bc2_pred = vmap(self.u_net, (None, 0, None))(state, self.t_star, self.x_star[-1])
        # PDE residuals
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        residuals = {
            "ics_u": u - self.u0,
            "ics_v": v,
            "bcs_r": u_bc1_pred,
            "bcs_l": u_bc2_pred,
            "res": r_pred

        }
        return residuals
    


    


    