from functools import partial

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian

from phijax.data import UniformSampler
from phijax.equations.base import IVP#, LocalTimeNormalizedMixin, TimeMarchingRefSliceMixin
from phijax.equations.registry import register_pde


def get_dataset(ref_path: str, fraction = [0, 1]):
    import scipy.io
    data = scipy.io.loadmat(ref_path)

    u_ref = data["usol"]
    v_ref = data["vsol"]
    t_star = data["t"].flatten()
    x_star = data["x"].flatten()
    y_star = data["y"].flatten()

    start_time_step = int(fraction[0] * len(t_star))
    end_time_step = int(fraction[1] * len(t_star))
    num_time_steps = end_time_step - start_time_step

    u_ref = u_ref[start_time_step:end_time_step, :, :]
    v_ref = v_ref[start_time_step:end_time_step, :, :]
    t_star = t_star[:num_time_steps]

    eps = data["eps"].flatten()[0]
    k = data["k"].flatten()[0]
    return u_ref, v_ref, t_star, x_star, y_star, eps, k


@register_pde("ginzburg_landau", aliases=["gl", "ginzburg", "ginzburglandau"])
class GinzburgLandau(IVP):
    loss_keys = ("uics", "vics", "ru", "rv")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        fraction = getattr(pcfg, "data_fraction", [0.0, 1.0])
        u_ref, v_ref, t_star, x_star, y_star, eps_data, k_data = get_dataset(pcfg.ref_path, fraction=fraction)


        self.x_star = x_star
        self.y_star = y_star
        self.eps = float(getattr(pcfg, "eps", eps_data if eps_data is not None else 1e-2))
        self.k = float(getattr(pcfg, "k", k_data if k_data is not None else 1.0))


        self.t_star = t_star
        self.u_ref = u_ref
        self.v_ref = v_ref
        self.t0 = float(t_star[0])
        self.t1 = float(t_star[-1])
        self.u0 = self.u_ref[0, ...]
        self.v0 = self.v_ref[0, ...]

        x0, x1 = float(self.x_star[0]), float(self.x_star[-1])
        y0, y1 = float(self.y_star[0]), float(self.y_star[-1])

        self.dom = jnp.array([[self.t0, self.t1], [x0, x1], [y0, y1]])
        self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)

        self.u0_pred_fn = vmap(vmap(self.u_net, (None, None, None, 0)), (None, None, 0, None))
        self.v0_pred_fn = vmap(vmap(self.v_net, (None, None, None, 0)), (None, None, 0, None))

        self.u_pred_fn = vmap(
            vmap(vmap(self.u_net, (None, None, None, 0)), (None, None, 0, None)),
            (None, 0, None, None),
        )
        self.v_pred_fn = vmap(
            vmap(vmap(self.v_net, (None, None, None, 0)), (None, None, 0, None)),
            (None, 0, None, None),
        )

    def neural_net(self, state, t, x, y):
        t_scaled = t / self.t_star[-1]
        z = jnp.stack([t_scaled, x, y])
        _, out = self.state.apply_fn(state.variables(), z)
        return out[0], out[1]

    def u_net(self, state, t, x, y):
        u, _ = self.neural_net(state, t, x, y)
        return u

    def v_net(self, state, t, x, y):
        _, v = self.neural_net(state, t, x, y)
        return v

    def r_net(self, state, t, x, y):
        u, v = self.neural_net(state, t, x, y)

        u_t = grad(self.u_net, argnums=1)(state, t, x, y)
        v_t = grad(self.v_net, argnums=1)(state, t, x, y)

        u_hess, v_hess = hessian(self.neural_net, argnums=(2, 3))(state, t, x, y)

        u_xx = u_hess[0][0]
        u_yy = u_hess[1][1]
        v_xx = v_hess[0][0]
        v_yy = v_hess[1][1]

        u_lap = u_xx + u_yy
        v_lap = v_xx + v_yy

        r2 = u * u + v * v

        ru = u_t - self.eps * u_lap - self.k * (u - u * r2 + 1.5 * v * r2)
        rv = v_t - self.eps * v_lap - self.k * (v - v * r2 - 1.5 * u * r2)
        return ru, rv

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        xx, yy = jnp.meshgrid(self.x_star, self.y_star, indexing="ij")
        X = xx.reshape(-1)
        Y = yy.reshape(-1)
        T0 = jnp.full_like(X, self.t0)

        u0_pred = vmap(self.u_net, (None, 0, 0, 0))(state, T0, X, Y).reshape(self.u0.shape)
        v0_pred = vmap(self.v_net, (None, 0, 0, 0))(state, T0, X, Y).reshape(self.v0.shape)

        t_b, x_b, y_b = batch[:, 0], batch[:, 1], batch[:, 2]
        ru, rv = vmap(self.r_net, (None, 0, 0, 0))(state, t_b, x_b, y_b)

        return {
            "uics": u0_pred - self.u0,
            "vics": v0_pred - self.v0,
            "ru": ru,
            "rv": rv,
        }
    
    def _log_stats(self, state, batch, *args):
        return 
    

    @partial(jit, static_argnums=(0,))
    def compute_l2_error(self, state):
        u_pred = self.u_pred_fn(state, self.t_star, self.x_star, self.y_star)
        v_pred = self.v_pred_fn(state, self.t_star, self.x_star, self.y_star)

        u_error = jnp.linalg.norm(u_pred - self.u_ref) / jnp.linalg.norm(self.u_ref)
        v_error = jnp.linalg.norm(v_pred - self.v_ref) / jnp.linalg.norm(self.v_ref)
        return u_error, v_error
    
    def log_errors(self, state):
        u_l2_error, v_l2_error = self.compute_l2_error(state)
        self.log_dict["u_rmse_error"] = u_l2_error
        self.log_dict["v_rmse_error"] = v_l2_error
    


@register_pde("ginzburg_landau_tm", aliases=["gl_tm", "tmgl", "curriculum_ginzburg", "cgl", "cginzburg", "cginzburglandau"])
class GinzburgLandauTM(GinzburgLandau):

    def neural_net(self, state, t, x, y):
        t = t / self.wt_star[-1]
        inputs = jnp.stack([t, x, y])
        _, outputs = self.state.apply_fn(state.variables(), inputs)

        u = outputs[0]
        v = outputs[1]
        return u, v

    def u_net(self, state, t, x, y):
        u, _ = self.neural_net(state, t, x, y)
        return u

    def v_net(self, state, t, x, y):
        _, v = self.neural_net(state, t, x, y)
        return v

    def r_net(self, state, t, x, y):
        u, v = self.neural_net(state, t, x, y)
        u_t = grad(self.u_net, argnums=1)(state, t, x, y)
        v_t = grad(self.v_net, argnums=1)(state, t, x, y)

        u_hessian, v_hessian = hessian(self.neural_net, argnums=(2, 3))(state, t, x, y)
        u_xx = u_hessian[0][0]
        u_yy = u_hessian[1][1]

        v_xx = v_hessian[0][0]
        v_yy = v_hessian[1][1]

        u_laplace = u_xx + u_yy
        v_laplace = v_xx + v_yy

        ru = (
            u_t
            - self.eps * u_laplace
            - self.k * (u - u * (u**2 + v**2) + 1.5 * v * (u**2 + v**2))
        )
        rv = (
            v_t
            - self.eps * v_laplace
            - self.k * (v - v * (u**2 + v**2) - 1.5 * u * (u**2 + v**2))
        )

        return ru, rv

    def ru_net(self, state, t, x, y):
        ru, _ = self.r_net(state, t, x, y)
        return ru

    def rv_net(self, state, t, x, y):
        _, rv = self.r_net(state, t, x, y)
        return rv

    @partial(jit, static_argnums=(0,))
    def compute_l2_error(self, state):
        u_pred = self.u_pred_fn(state, self.wt_star, self.x_star, self.y_star)
        v_pred = self.v_pred_fn(state, self.wt_star, self.x_star, self.y_star)

        u_error = jnp.linalg.norm(u_pred - self.u_ref_window) / jnp.linalg.norm(self.u_ref_window)
        v_error = jnp.linalg.norm(v_pred - self.v_ref_window) / jnp.linalg.norm(self.v_ref_window)
        return u_error, v_error
    
    def set_initial_condition(self, u0, v0, u_star, v_star, window_t_star, *args):
        self.u0 = u0
        self.v0 = v0
        self.u_ref_window = u_star
        self.v_ref_window = v_star
        self.wt_star = window_t_star


