from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap

import optax

from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP
from phijax.equations.registry import register_pde



def get_convection_data(batch_size: int = 256, epsilon: float = 50.0):
    x = jnp.linspace(0, 2 * jnp.pi, batch_size)
    t = jnp.linspace(0, 1.0, batch_size)
    tt, xx = jnp.meshgrid(t, x, indexing="ij")
    u_ref = jnp.sin(xx - epsilon * tt)
    u0 = jnp.sin(x)
    return u_ref, t, x, u0

@register_pde(
    "convection", aliases=["1d_convection", "conv"], 
    defaults={ "epsilon": 50.0 , "num_points_per_dim": 256}
)
class Convection(IVP):
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config

        num_pts = pcfg.num_points_per_dim or 256
        epsilon = pcfg.epsilon or 50.0
        ####data 
        x = jnp.linspace(0, 2 * jnp.pi, num_pts)
        t = jnp.linspace(0, 1.0, num_pts)
        tt, xx = jnp.meshgrid(t, x, indexing="ij")
        u_ref = jnp.sin(xx - epsilon * tt)
        u0 = jnp.sin(x)
        self.u0 = u0
        self.t_star = t
        self.x_star = x
        self.t0, self.t1, self.x0, self.x1 = self.t_star[0], self.t_star[-1], self.x_star[0], self.x_star[-1]

        self.dom = jnp.array([[self.t0, self.t1], [self.x0, self.x1]])
        self.u_ref = u_ref
        if  self.config.training.get('sampler', None) is None:
            self.config.training.sampler = "uniform"
        if self.config.training.sampler == "uniform":
            print("Using random/uniform sampler")
            self.sampler = UniformSampler(self.dom, batch_size=self.config.training.batch_size)
        if self.config.training.sampler == "fixed":
            print("Using fixed mesh sampler", num_pts)
            self.sampler = MeshSampler(self.dom, res=[51, 51], batch_size=self.config.training.batch_size)

        # Predictions over a grid
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

        #print if the model uses rotational layers
        print(self.state.variables().keys())
        if hasattr(self.state, "use_rot"):
            print("Using Rotational Layers in the model", self.state.use_rot)

    def u_net(self, state, t, x):
        z = jnp.stack([t, x])
        _, u = self.state.apply_fn(state.variables(), z)
        return u[0]
    
    def r_net(self, state, t, x):
        #u = self.u_net(params, t, x)
        u_t = grad(self.u_net, argnums=1)(state, t, x)
        u_x = grad(self.u_net, argnums=2)(state, t, x)
        #u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(params, t, x)

        return u_t + 50 * u_x

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        # Initial condition loss
        u_pred = vmap(self.u_net, (None, None, 0))(state, self.t0, self.x_star)
        # Boundary condition loss
        u_bc1_pred = vmap(self.u_net, (None, 0, None))(state, self.t_star, self.x_star[0])
        u_bc2_pred = vmap(self.u_net, (None, 0, None))(state, self.t_star, self.x_star[-1])
        # PDE residuals
        r_pred = vmap(self.r_net, (None, 0, 0))(state, batch[:, 0], batch[:, 1])
        residuals = {
            "ics": u_pred - self.u0,
            "bcs": u_bc1_pred - u_bc2_pred,
            "res": r_pred

        }
        return residuals
    

@register_pde(
    "spectral_convection", aliases=["sconv", "spectral_conv"],
    defaults={"epsilon": 50.0, "num_points_per_dim": 256}
)
class SpectralConvection(Convection):
    loss_keys = ("ics", "bcs", "res", "spec")

    def __init__(self, config):
        super().__init__(config)
        #self.state = self.state.replace(weights={"ics": 1.0, "bcs": 1.0, "res": 1.0, "spec": 0.01})

    def _features(self, state, z):
        feat, _ = self.state.apply_fn(state.variables(), z)
        return feat
    def spectral_reg(self, state, Z, eps=1e-12):
        F = vmap(lambda zi: self._features(state, zi))(Z)
        F = F / (jnp.linalg.norm(F, axis=1, keepdims=True) + eps)

        n = F.shape[0]
        C = (F.T @ F) / n

        trC = jnp.trace(C)
        trC2 = jnp.sum(C * C.T)   # = trace(C @ C) for symmetric C

        return trC2 / (trC * trC + eps)


    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        residuals = super().residuals(state, batch)
        #print(self.spectral_reg(state, batch))
        R_spec = self.spectral_reg(state, batch)
        residuals["spec"] = jnp.array([(R_spec)])
        return residuals

    