from functools import partial

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian

from phijax.data import UniformSampler
from phijax.equations.base import IVP#, LocalTimeNormalizedMixin, TimeMarchingRefSliceMixin
from phijax.equations.registry import register_pde
from jax.experimental.jet import jet


def get_dataset(ref_path: str, fraction = [0, 1]):
    import scipy.io
    data = scipy.io.loadmat(ref_path)
    u_ref = data["usol"]
    t_star = data["t"].flatten()
    x_star = data["x"].flatten()

    # Only use a fraction of the data
    num_time_steps = int(fraction * len(t_star))
    t_star = t_star[:num_time_steps]
    u_ref = u_ref[:num_time_steps, :]

    return u_ref, t_star, x_star


@register_pde("kstm", aliases=["kuramoto_sivashinsky_tm", "kuramoto-sivashinsky_tm"])
class KuramotoSivashinskyTM(IVP):
    loss_keys = ("ics",  "res")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        fraction = getattr(pcfg, "data_fraction", 0.1)
        u_ref, t_star, x_star = get_dataset(pcfg.ref_path, fraction=fraction)


        self.x_star = x_star




        self.t_star = t_star
        self.u_ref = u_ref

        self.t0 = float(t_star[0])
        self.t1 = float(t_star[-1])
        self.u0 = self.u_ref[0, ...]


        x0, x1 = float(self.x_star[0]), float(self.x_star[-1])
        self.dom = jnp.array([[self.t0, self.t1], [x0, x1]])

        self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)

        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

        
    def neural_net(self, state, t, x, y):
        t_scaled = t / self.t_star[-1]
        z = jnp.stack([t_scaled, x, y])
        _, out = self.state.apply_fn(state.variables(), z)
        return out[0]

   

    def r_net(self, state, t, x):
        u = self.neural_net(state, t, x)

        u_t = grad(self.neural_net, argnums=1)(state, t, x)
        u_fn = lambda x: self.neural_net(state, t, x)
        _, (u_x, u_xx, u_xxx, u_xxxx) = jet(u_fn, (x, ), [[1.0, 0.0, 0.0, 0.0]])
        return (
            u_t
            + 100.0 / 16.0 * u * u_x
            + 100.0 / 16.0**2 * u_xx
            + 100.0 / 16.0**4 * u_xxxx
        )

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u_pred = vmap(self.neural_net, (None, None, 0))(state, self.t0, self.x_star)
        # residuals  loss
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        return {
            "ics": u_pred - self.u0,
            "res": r_pred
        }
    
    def _log_stats(self, state, batch, *args):
        return 
    

    @partial(jit, static_argnums=(0,))
    def compute_metrics(self, state):
        u_pred = self.u_pred_fn(state, self.t_star, self.x_star)
       
        rmse = jnp.linalg.norm(u_pred - self.u_ref) / jnp.linalg.norm(self.u_ref)
        rmae = jnp.sum(jnp.abs(u_pred - self.u_ref)) / jnp.sum(jnp.abs(self.u_ref))
        return rmse, rmae
    
    
    def log_errors(self, state):
        u_rmse_error, u_rmae_error = self.compute_metrics(state)
        self.log_dict["rmse_error"] = u_rmse_error
        self.log_dict["rmae_error"] = u_rmae_error


@register_pde("ks", aliases=["kuramoto_sivashinsky", "kuramoto-sivashinsky"])
class KuramotoSivashinsky(IVP):
    loss_keys = ("ics",  "res")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        import scipy.io
        data = scipy.io.loadmat(pcfg.ref_path)
        u_ref = data["usol"]
        t_star = data["t"].flatten()
        x_star = data["x"].flatten()



        self.x_star = x_star




        self.t_star = t_star
        self.u_ref = u_ref

        self.t0 = float(t_star[0])
        self.t1 = float(t_star[-1])
        self.u0 = self.u_ref[0, ...]


        x0, x1 = float(self.x_star[0]), float(self.x_star[-1])
        self.dom = jnp.array([[self.t0, self.t1], [x0, x1]])

        self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)

        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

        
    def u_net(self, state, t, x):
        #t_scaled = t / self.t_star[-1]
        z = jnp.stack([t, x])
        _, out = self.state.apply_fn(state.variables(), z)
        return out[0]

   

    def r_net(self, state, t, x):
        u = self.u_net(state, t, x)

        u_t = grad(self.u_net, argnums=1)(state, t, x)
        u_fn = lambda x: self.u_net(state, t, x)
        _, (u_x, u_xx, u_xxx, u_xxxx) = jet(u_fn, (x, ), [[1.0, 0.0, 0.0, 0.0]])
        return (
            u_t
            + 100.0 / 16.0 * u * u_x
            + 100.0 / 16.0**2 * u_xx
            + 100.0 / 16.0**4 * u_xxxx
        )

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u_pred = vmap(self.u_net, (None, None, 0))(state, self.t0, self.x_star)
        # residuals  loss
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        return {
            "ics": u_pred - self.u0,
            "res": r_pred
        }
    
    
    

    #@partial(jit, static_argnums=(0,))
    #def compute_metrics(self, state):
    #    u_pred = self.u_pred_fn(state, self.t_star, self.x_star)
    #   
    #    rmse = jnp.linalg.norm(u_pred - self.u_ref) / jnp.linalg.norm(self.u_ref)
    #    rmae = jnp.sum(jnp.abs(u_pred - self.u_ref)) / jnp.sum(jnp.abs(self.u_ref))
    #    return rmse, rmae
    
    
    #def log_errors(self, state, batch):
    #    u_rmse_error, u_rmae_error = self.compute_metrics(state)
    #    self.log_dict["rmse_error"] = u_rmse_error
    #    self.log_dict["rmae_error"] = u_rmae_error

    



