from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap

import optax

from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP
from phijax.equations.registry import register_pde



def get_dataset(ref_path):
    import scipy.io
    data = scipy.io.loadmat(ref_path)
    u_ref = data["usol"]
    t_star = data["t"].flatten()
    x_star = data["x"].flatten()

    return u_ref, t_star, x_star

@register_pde(
    "burgers", aliases=["burgers1d"], 
)
class Burgers(IVP):
    loss_keys = ("ics", "bcs_l", "bcs_r",  "res")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        num_pts = pcfg.num_points_per_dim or 256

        u_ref, t_star, x_star = get_dataset(pcfg.ref_path)
        self.u_ref = u_ref
        self.u0 = u_ref[0, :]
        self.t_star = t_star
        self.x_star = x_star
        self.t0, self.t1, self.x0, self.x1 = self.t_star[0], self.t_star[-1], self.x_star[0], self.x_star[-1]

        self.dom = jnp.array([[self.t0, self.t1], [self.x0, self.x1]])
        if  self.config.training.get('sampler', None) is None:
            self.config.training.sampler = "uniform"
        if self.config.training.sampler == "uniform":
            print("Using random/uniform sampler")
            self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)
        if self.config.training.sampler == "fixed":
            print("Using fixed mesh sampler", num_pts)
            self.sampler = MeshSampler(self.dom, res=[100, 200], batch_size=self.config.training.batch_size)
        # Predictions over a grid
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

    def u_net(self, state, t, x):
        z = jnp.stack([t, x])
        _, u = self.state.apply_fn(state.variables(), z)
        return u[0]
    
    def r_net(self, state, t, x):
        u = self.u_net(state, t, x)
        u_t = grad(self.u_net, argnums=1)(state, t, x)
        u_x = grad(self.u_net, argnums=2)(state, t, x)
        u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(state, t, x)
        return u_t + u * u_x - (0.01 / jnp.pi) * u_xx
    
    def ux_net(self, state, t, x):
        return grad(self.u_net, argnums=2)(state, t, x)

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        u_pred = vmap(self.u_net, (None, None, 0))(state, self.t0, self.x_star)
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])

        u_bc_l = vmap(self.u_net,  (None, 0, None))(state, self.t_star, self.x0)
        u_bc_r = vmap(self.u_net,  (None, 0, None))(state, self.t_star, self.x1)
        residuals = {
            "ics": u_pred - self.u0,
            "bcs_l": u_bc_l,
            "bcs_r": u_bc_r,
            "res": r_pred
        }
        return residuals
