from functools import partial

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian

from phijax.data import UniformSampler
from phijax.equations.base import IVP#, LocalTimeNormalizedMixin, TimeMarchingRefSliceMixin
from phijax.equations.registry import register_pde
from jax.experimental.jet import jet


def get_chaotic_ks(ref_path: str):
    import numpy as np
    data = np.loadtxt(ref_path, comments="%")
    u_ref = data[:, 2].reshape(-1, 251)


@register_pde("ks_chaotic", aliases=["kuramoto_sivashinsky_chaotic", "kuramoto-sivashinsky-chaotic"])
class ChaoticKuramotoSivashinsky(IVP):
    loss_keys = ("ics",  "res")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        
        u_ref, t_star, x_star = get_dataset(pcfg.ref_path, fraction=fraction)


        self.x_star = x_star




        self.t_star = t_star
        self.u_ref = u_ref

        self.t0 = float(t_star[0])
        self.t1 = float(t_star[-1])
        self.u0 = self.u_ref[0, ...]


        x0, x1 = float(self.x_star[0]), float(self.x_star[-1])
        self.dom = jnp.array([[self.t0, self.t1], [x0, x1]])

        self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)

        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

        
    def neural_net(self, state, t, x, y):
        t_scaled = t / self.t_star[-1]
        z = jnp.stack([t_scaled, x, y])
        _, out = self.state.apply_fn(state.variables(), z)
        return out[0]

   

    def r_net(self, state, t, x):
        u = self.neural_net(state, t, x)

        u_t = grad(self.neural_net, argnums=1)(state, t, x)
        u_fn = lambda x: self.neural_net(state, t, x)
        _, (u_x, u_xx, u_xxx, u_xxxx) = jet(u_fn, (x, ), [[1.0, 0.0, 0.0, 0.0]])
        return (
            u_t
            + 100.0 / 16.0 * u * u_x
            + 100.0 / 16.0**2 * u_xx
            + 100.0 / 16.0**4 * u_xxxx
        )

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u_pred = vmap(self.neural_net, (None, None, 0))(state, self.t0, self.x_star)
        # residuals  loss
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        return {
            "ics": u_pred - self.u0,
            "res": r_pred
        }
    
    def _log_stats(self, state, batch, *args):
        return 
    

    @partial(jit, static_argnums=(0,))
    def compute_metrics(self, state):
        u_pred = self.u_pred_fn(state, self.t_star, self.x_star)
       
        rmse = jnp.linalg.norm(u_pred - self.u_ref) / jnp.linalg.norm(self.u_ref)
        rmae = jnp.sum(jnp.abs(u_pred - self.u_ref)) / jnp.sum(jnp.abs(self.u_ref))
        return rmse, rmae
    
    
    def log_errors(self, state):
        u_rmse_error, u_rmae_error = self.compute_metrics(state)
        self.log_dict["rmse_error"] = u_rmse_error
        self.log_dict["rmae_error"] = u_rmae_error

    



