from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap
from jax.tree_util import tree_map, tree_reduce, tree_leaves

import optax

from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP
from phijax.equations.registry import register_pde





@register_pde(
    "reaction", aliases=["reaction1d",], 
    defaults={ "epsilon": 5.0 , "num_points_per_dim": 256}
)
class Reaction1D(IVP):
    loss_keys = ("ics",  "bcs", "res")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config

        num_pts = pcfg.num_points_per_dim or 101
        self.epsilon = pcfg.epsilon or 5.0
        ####data 
        x = jnp.linspace(0, 2 * jnp.pi, num_pts)
        t = jnp.linspace(0, 1.0, num_pts)

        
        tt, xx = jnp.meshgrid(t, x, indexing="ij")

        h = lambda x: jnp.exp(-1.0 * (x - jnp.pi) ** 2 / (2 * jnp.square((jnp.pi / 4.0))))


        u_ref = h(xx) * jnp.exp(self.epsilon * tt) / (h(xx) * (jnp.exp(self.epsilon * tt) - 1.0) + 1.0)
        u0 =  jnp.exp(-1.0 * (x - jnp.pi) ** 2 / (2 * jnp.square((jnp.pi / 4.0))))
        self.u0 = u0
        self.t_star = t
        self.x_star = x
        self.t0, self.t1, self.x0, self.x1 = self.t_star[0], self.t_star[-1], self.x_star[0], self.x_star[-1]

        self.dom = jnp.array([[self.t0, self.t1], [self.x0, self.x1]])
        self.u_ref = u_ref
        if  self.config.training.get('sampler', None) is None:
            self.config.training.sampler = "uniform"
        if self.config.training.sampler == "uniform":
            print("Using random/uniform sampler")
            self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)
        if self.config.training.sampler == "fixed":
            print("Using fixed mesh sampler", num_pts)
            self.sampler = MeshSampler(self.dom, res=[num_pts, num_pts], batch_size=self.config.training.batch_size)

        # Predictions over a grid
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

    def u_net(self, state, t, x):
        z = jnp.stack([t, x])
        _, u = self.state.apply_fn(state.variables(), z)
        return u[0]
    
    def r_net(self, state, t, x):
        u = self.u_net(state, t, x)
        u_t = grad(self.u_net, argnums=1)(state, t, x)
        return u_t - self.epsilon * u * (1 - u)

    
    def u_t(self, state, t, x):
        return grad(self.u_net, argnums=1)(state, t, x)
    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u = vmap(self.u_net, (None, None, 0))(state, self.t0, self.x_star)
        # Boundary condition loss
        u_bc1_pred = vmap(self.u_net, (None, 0, None))(state, self.t_star, self.x_star[0])
        u_bc2_pred = vmap(self.u_net, (None, 0, None))(state, self.t_star, self.x_star[-1])
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        residuals = {
            "ics": u - self.u0,
            "bcs": u_bc1_pred - u_bc2_pred,
            "res": r_pred,
        }
        return residuals
    


    


    