from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap

import optax

from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP
from phijax.equations.registry import register_pde
from jax.experimental.jet import jet



def get_dataset(ref_path):
    import scipy.io
    data = scipy.io.loadmat(ref_path)
    u_ref = data["usol"]
    t_star = data["t"].flatten()
    x_star = data["x"].flatten()

    return u_ref, t_star, x_star


@register_pde(
    "kdv", aliases=["kortewegdevries"], 
    #defaults={ "epsilon": 50.0 , "num_points_per_dim": 256}
)
class KdV(IVP):
    loss_keys = ("ics",  "res")
    def __init__(self, config):
        super().__init__(config)
        pcfg = self.config.pde_config
        num_pts = pcfg.num_points_per_dim or 256
        u_ref, t_star, x_star = get_dataset(pcfg.ref_path)
        self.u_ref = u_ref
        self.u0 = u_ref[0, :]
        self.t_star = t_star
        self.x_star = x_star
        self.t0, self.t1, self.x0, self.x1 = self.t_star[0], self.t_star[-1], self.x_star[0], self.x_star[-1]
        self.dom = jnp.array([[self.t0, self.t1], [self.x0, self.x1]])
        if  self.config.training.get('sampler', None) is None:
            self.config.training.sampler = "uniform"
        if self.config.training.sampler == "uniform":
            print("Using random/uniform sampler")
            self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)
        if self.config.training.sampler == "fixed":
            print("Using fixed mesh sampler", num_pts)
            self.sampler = MeshSampler(self.dom, res=[100, 200], batch_size=self.config.training.batch_size)


        

        # Predictions over a grid
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

    def u_net(self, state, t, x):
        z = jnp.stack([t, x])
        _, u = self.state.apply_fn(state.variables(), z)
        return u[0]
    
    def r_net(self, state, t, x):
        u = self.u_net(state, t, x)
        u_t = grad(self.u_net, argnums=1)(state, t, x)
        u_fn = lambda x: self.u_net(state, t, x)
        _, (u_x, u_xx, u_xxx) = jet(u_fn, (x,), [[1.0, 0.0, 0.0]])

        return u_t + u * u_x + 0.022**2 * u_xxx


    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u_pred = vmap(self.u_net, (None, None, 0))(state, self.t0, self.x_star)
        # residuals  loss
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        return {
            "ics": u_pred - self.u0,
            #"bcs": jnp.zeros_like(u_pred),
            "res": r_pred
        }
        
    



    