import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from typing import List, Optional
import matplotlib.pyplot as plt
from jax import config
from jax.experimental.sparse import BCOO
from ..utils import NDArray, rank, roll, effective_rank, remove_zero_sparse, mul
from jax import jacrev
config.update("jax_enable_x64", True)

class DE:
    def __init__(self, domain: NDArray, params: dict):
        self.domain = domain
        self.params = params
    
    def __call__(self, y: NDArray, t: float) -> NDArray:
        raise NotImplementedError("The __call__ method should be implemented by subclasses")
    
    def __repr__(self) -> str:
        raise NotImplementedError("The __repr__ method should be implemented by subclasses")

    def init_dist(self, numbers:int, key: int):
        raise NotImplementedError("The init_dist method should be implemented by subclasses")
    
    def ts(self, nt: int, retstep=False):
        T = self.domain[0]
        return jnp.linspace(T[0], T[1], nt, retstep=retstep)
    
    def xs(self, nx: int, retstep=False):
        if len(self.domain) == 1:
            return None
        X = self.domain[1]
        ret = jnp.linspace(0, 2*X[1], nx, endpoint=False, retstep=retstep)
        if retstep:
            xs, dx = ret
            return xs - X[1], dx
        else:
            xs = ret
            dx = None
            return xs - X[1]
        #xs = jnp.linspace(X[0], X[1], nx, retstep=retstep)
        #return xs
    
    
class LinearDE(DE):
    def __init__(self, domain: NDArray, params: dict):
        super().__init__(domain, params)
        self.D = None
        self.G = None
        self.M = None
        
    def get_matrix(self, basis_function, test_function, sparse):
        raise NotImplementedError("The get_matrix method should be implemented by subclasses")

    def set_matrix(self, D, G):
        self.D = D
        self.G = G
        if isinstance(self.G, BCOO):
            D_ = mul(D.T, self.G)
            M = mul(D_, D)
            self.M = remove_zero_sparse(M) # remove_zero_sparse(M)
        else:
            self.M = D.T @ self.G @ D
        return self.M
        
    def pi_loss(self, basis_function, test_function, weights, sparse=False):
        if self.M is None:
            D, G = self.get_matrix(basis_function, test_function, sparse=sparse)
            self.set_matrix(D, G)
        return jnp.dot(weights, self.M @ weights).squeeze()
    
    def affine_variety_dimension(self, basis_function, test_function, weights=None, sparse=False):
        if self.M is None:
            D, G = self.get_matrix(basis_function, test_function, sparse=sparse)
            self.set_matrix(D, G)
        r = rank(self.D)
        return self.D.shape[-1] - r
        
    def effective_affine_variety_dimension(self, basis_function, test_function, weights=None, sparse=False):
        if self.M is None:
            D, G = self.get_matrix(basis_function, test_function, sparse=sparse)
            self.set_matrix(D, G)
        r = effective_rank(self.D)
        return self.D.shape[-1] - r
    
class NonlinearDE(DE):
    def __init__(self, ts: NDArray, params: dict):
        super().__init__(ts, params)
        
    def get_matrix(self, basis_function, test_function, weights):
        # jacobian of polynomials system on the weights
        raise NotImplementedError("The tangent vector method should be implemented by subclasses")
    
    def pi_loss(self, basis_function, test_function, weights):
        raise NotImplementedError("The pi_loss method should be implemented by subclasses")
    
    def affine_variety_dimension(self, basis_function, test_function, weights, sparse=False, iterate=False):
        if iterate:
            n = weights.shape[0]
            ranks = []
            for i in range(n):
                w = weights[i:i+1]
                M, _ = self.get_matrix(basis_function, test_function, w)
                r = rank(M)
                d_V = M.shape[-1] - r
                ranks.append(d_V.tolist()[0])
            return max(ranks)
        else:
            M, _ = self.get_matrix(basis_function, test_function, weights, sparse) # M : (n, d, d)
            r = rank(M)
            return jnp.max(M.shape[-1] - r, axis=0)
        
    def effective_affine_variety_dimension(self, basis_function, test_function, weights, sparse=False):
        M, _ = self.get_matrix(basis_function, test_function, weights, sparse)
        r = effective_rank(M)
        return jnp.max(M.shape[-1] - r, axis=0)
    
class HarmonicOscillator(LinearDE):
    def __init__(self, 
        t_min: float = 0.,
        t_max: float = 2 * jnp.pi,
        k: float = 1.0,
        m: float=1.0,
        init_mean: List[float] = [0.0, 0.0],
        init_std: List[float] = [1.0, 1.0],
    ):
        params = dict(m=m, k=k, init_mean=init_mean, init_std=init_std)
        domain = np.array([[t_min, t_max]])
        super().__init__(domain, params)
        self.m, self.k = m, k
        
    def __call__(self, y: NDArray, t: float) -> NDArray:
        x, v = y
        dx = jnp.array([v, -(self.k/self.m) * x])
        return dx
    
    def __repr__(self) -> str:
        return f"Oscillator-k:{self.k}-m:{self.m}"
    
    def init_dist(self, numbers:int, key: jax.random.PRNGKey, xs: Optional[NDArray] = None):
        mean = jnp.asarray(self.params['init_mean']) * jnp.ones(2)
        std = jnp.diag(jnp.asarray(self.params['init_std']))
        y0 = jax.random.multivariate_normal(key, mean, std, shape=(numbers,))
        return y0
        
    def simulate(self, y0: NDArray, ts: NDArray, xs: NDArray):
        x0, v0 = y0
        omega = jnp.sqrt(self.k/self.m)
        theta = x0 * jnp.cos(omega * ts) + v0 / omega * jnp.sin(omega * ts)
        print(f"ground truth: {x0} * cos({omega}t) + {v0/omega} * sin({omega}t)")
        return theta
    
    def get_matrix(self, basis_function, test_function, sparse=False):
        # test_function
        phipsi = basis_function.phipsi(test_function)
        ddphipsi = basis_function.ddphipsi(test_function)
        D = ddphipsi + (self.k/self.m) * phipsi
        psipsi = test_function.phipsi(test_function)
        return D, psipsi
    
    def visualize(self, X, y, save_path, n_samples=5):
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        n = min(n_samples, y.shape[0])
        for i in range(n):
            ax.plot(X, y[i])
        ax.set_xlabel("t")
        ax.set_ylabel("y(t)")
        ax.set_title("Harmonic Oscillator")
        plt.savefig(save_path, dpi=300)
    
class PeriodicHeatEquation(LinearDE):
    def __init__(self,
        t_min: float = 0.,
        t_max: float = 1.0,
        L: float = 1,
        K: float = 0.1,
        init_mean: List[float] = [0.0, 0.0],
        init_std: List[float] = [1.0, 1.0],
    ):
        domain = np.array([[t_min, t_max], [-L, L]])
        params = dict(t_min=t_min, t_max=t_max, L=L, K=K)
        super().__init__(domain, params)
        
        self.L = L
        self.T = t_max - t_min
        self.K = K
        self.init_mean = init_mean
        self.init_std = init_std
        
    def __repr__(self) -> str:
        return f"PeriodicHeatEquation-K:{self.K}"
    
    def __call__(self, u: NDArray, t: float) -> NDArray:
        return self.velocity(u, t)
    
    def velocity(self, u_t: NDArray, t: float) -> NDArray:
        du = self.K * jnp.gradient(u_t, axis=0)
        return du
    
    def init_dist(self, numbers:int, key: jax.random.PRNGKey, xs: int, max_k: int=1):
        key1, key2 = jax.random.split(key)
        A_n = jax.random.normal(key1, shape=(numbers, max_k + 1, )) * self.init_std[0] + self.init_mean[0]
        B_n = jax.random.normal(key2, shape=(numbers, max_k,)) * self.init_std[1] + self.init_mean[1]
        freq = jnp.concatenate([A_n, B_n], axis=-1)
        return freq
    
    def simulate(self, freq: NDArray, ts: NDArray, xs: NDArray):
        max_k = freq.shape[-1] // 2
        A_n, B_n = freq[:max_k+1], freq[max_k+1:] # max_k
        
        k = jnp.arange(1, max_k + 1)
        sin_term = jnp.sin(k[None, :] * jnp.pi * xs[:, None] / self.L) # (nx, max_k)
        cos_term = jnp.cos(k[None, :] * jnp.pi * xs[:, None] / self.L)
        exp_term = jnp.exp(-self.K * (jnp.pi * k[None, :] / self.T)**2 * ts[:, None]) # (nt, max_k)
        
        sin_cos_term = A_n[None, 1:] * cos_term + B_n[None, :] * sin_term # (nx, max_k)
        u = A_n[0] + jnp.sum(sin_cos_term[None] * exp_term[:, None], axis=-1)
        return u
    
    def get_matrix(self, basis_function, test_function, sparse=False):
        dphipsi = basis_function.dphipsi(test_function, keepshape=False)
        ddphipsi = basis_function.ddphipsi(test_function, keepshape=False)
        
        D = dphipsi[..., 0] - self.K * ddphipsi[..., 1, 1]
        psipsi = test_function.phipsi(test_function)
        nt, nx, _, _ = psipsi.shape
        D = D.reshape(nt * nx, -1)
        psipsi = psipsi.reshape(nt * nx, -1)
        return D, psipsi
        
    def visualize(self, X, ut, save_path, n_samples=5):
        n = min(n_samples, ut.shape[0])
        if n == 1:
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            ax.imshow(ut[0], extent=[self.domain[1][0], self.domain[1][1], self.domain[0][0], self.domain[0][1]], origin='lower', aspect='auto', cmap='hot')
            ax.set_title("Periodic Heat Equation")
        else:
            fig, axes = plt.subplots(1, n, figsize=(8 * n, 6))
            extent = [self.domain[1][0], self.domain[1][1], self.domain[0][0], self.domain[0][1]]
            for i in range(n):
                axes[i].imshow(ut[i], extent=extent, origin='lower', aspect='auto', cmap='hot')
                axes[i].set_title(f"Periodic Heat Equation - {i}")
        plt.savefig(save_path, dpi=300)
    
class FDMPeriodicHeatEquation(LinearDE):
    def __init__(self,
        t_min: float = 0.,
        t_max: float = 1.0,
        L: float = 1,
        K: float = 0.1,
        init_mean: List[float] = [0.0, 0.0],
        init_std: List[float] = [1.0, 1.0],
    ):
        domain = np.array([[t_min, t_max], [-L, L]])
        params = dict(t_min=t_min, t_max=t_max, L=L, K=K)
        super().__init__(domain, params)
        
        self.L = L
        self.K = K
        self.init_mean = init_mean
        self.init_std = init_std
        
        #alpha = self.K * self.dt / self.dx**2
        #if alpha > 0.5:
        #    raise ValueError("Stability condition is not satisfied")
    
    def __repr__(self) -> str:
        return f"FDMPeriodicHeatEquation-K:{self.K}"
    
    def __call__(self, u: NDArray, t: float, dx: float) -> NDArray:
        return self.ddu_dx(u, t, dx)
    
    def is_stable(self, dt, dx):
        alpha = self.K * dt / (dx**2)
        return alpha <= 0.5
    
    def dphi_dt(self, phi_t: NDArray, t: float, dt: float) -> NDArray:
        phi_t_minus_1 = roll(phi_t, shift=1, axis=-2)
        if isinstance(phi_t, jax.experimental.sparse.BCOO):
            mask = phi_t_minus_1.indices[:, -2] != 0
            phi_t_minus_1 = BCOO((phi_t_minus_1.data[mask], phi_t_minus_1.indices[mask]), shape=phi_t_minus_1.shape)
        else:
            phi_t_minus_1 = phi_t_minus_1.at[..., 0, :].set(jnp.zeros_like(phi_t[..., 0, :]))
        return (phi_t_minus_1 - phi_t) / dt
    
    def ddphi_dx(self, phi_t: NDArray, t: float, dx: float) -> NDArray:
        return self.ddu_dx(phi_t, t, dx)
    
    def du_dx(self, u_t: NDArray, t: float, dx: float) -> NDArray:    
        u_t_minus_1 = roll(u_t, shift=1, axis=-1)
        u_t_plus_1 = roll(u_t, shift=-1, axis=-1)
        du = u_t_plus_1 - u_t_minus_1
        return du / (2 * dx)
    
    def ddu_dx(self, u_t: NDArray, t: float, dx: float) -> NDArray:
        u_t_minus_1 = roll(u_t, shift=1, axis=-1)
        u_t_plus_1 = roll(u_t, shift=-1, axis=-1)
        ddu = u_t_plus_1 - 2 * u_t + u_t_minus_1
        return ddu / (dx ** 2)
    
    def init_dist(self, numbers:int, key: jax.random.PRNGKey, xs: int, max_k: int=5):
        key1, key2 = jax.random.split(key)
        A_n = jax.random.normal(key1, shape=(numbers, max_k + 1, )) * self.init_std[0] + self.init_mean[0]
        B_n = jax.random.normal(key2, shape=(numbers, max_k,)) * self.init_std[1] + self.init_mean[1]
        k = jnp.arange(1, max_k+1)
        sin_term = jnp.sin(k[None, :] * jnp.pi * xs[:, None] / self.L) # (nx, max_k + 1)
        cos_term = jnp.cos(k[None, :] * jnp.pi * xs[:, None] / self.L)
        u_0 = A_n[:, None, 0] + jnp.sum(A_n[:, None, 1:] * cos_term[None] + B_n[:, None] * sin_term[None], axis=-1) # (n, nx)
        return u_0

    def simulate(self, u0: NDArray, ts: NDArray, xs: NDArray):
        nt = len(ts)
        dt, dx = ts[1] - ts[0], xs[1] - xs[0]
        u = jnp.zeros((nt, len(xs)))
        u = u.at[0, :].set(u0)
        def step(u_t, _):
            ddu = self.ddu_dx(u_t, 0., dx)
            u_next = u_t + self.K * ddu * dt
            return u_next, u_next
        us = jax.lax.scan(step, u0, None, length=nt - 1)[1]
        u = u.at[1:].set(us)
        return u
    
    def get_matrix(self, basis_function, test_function, sparse=False):
        dt, dx = basis_function.h[0], basis_function.h[1]
        phipsi = basis_function.phipsi(test_function, sparse=sparse) # (nx, basis)
        phipsi = phipsi[:-1, :]
        lhs = self.dphi_dt(phipsi, 0., dt)
        rhs = self.K * self.ddphi_dx(phipsi, 0., dx)
        D = lhs - rhs
        D = remove_zero_sparse(D)
        psipsi = test_function.phipsi(test_function, sparse=sparse)[:-1, :, :-1, :]
        
        nt, nx, _, _ = D.shape
        D = D.reshape(nt * nx, -1)
        psipsi = remove_zero_sparse(psipsi.reshape(nt * nx, -1))
        return D, psipsi
    
    def visualize(self, X, ut, save_path, n_samples=5):
        n = min(n_samples, ut.shape[0])
        if n == 1:
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            ax.imshow(ut[0], extent=[self.domain[1][0], self.domain[1][1], self.domain[0][0], self.domain[0][1]], origin='lower', aspect='auto', cmap='hot')
            ax.set_title("FDM Periodic Heat Equation")
        else:
            fig, axes = plt.subplots(1, n, figsize=(8 * n, 6))
            extent = [self.domain[1][0], self.domain[1][1], self.domain[0][0], self.domain[0][1]]
            for i in range(n):
                axes[i].imshow(ut[i], extent=extent, origin='lower', aspect='auto', cmap='hot')
                axes[i].set_title(f"FDM Periodic Heat Equation - {i}")
        plt.savefig(save_path, dpi=300)
        
        
class FDMPeriodicNonlHeatEquation(NonlinearDE):
    def __init__(self,
        t_min: float = 0.,
        t_max: float = 1.0,
        L: float = 1,
        init_mean: List[float] = [0.0, 0.0],
        init_std: List[float] = [1.0, 1.0],
    ):
        domain = np.array([[t_min, t_max], [-L, L]])
        params = dict(t_min=t_min, t_max=t_max, L=L)
        super().__init__(domain, params)
        
        self.L = L
        self.init_mean = init_mean
        self.init_std = init_std
        
        #alpha = self.K * self.dt / self.dx**2
        #if alpha > 0.5:
        #    raise ValueError("Stability condition is not satisfied")
    
    def __repr__(self) -> str:
        return f"FDMPeriodicNonlHeatEquation"
    
    def __call__(self, u: NDArray, t: float) -> NDArray:
        return self.velocity(u, t)

    def is_stable(self, dt, dx):
        alpha = self.K * dt / (dx**2)
        return alpha <= 0.5
    
    def du_dt(self, u_t: NDArray, t: float, dt: float) -> NDArray:
        u_t_plus_1 = roll(u_t, shift=-1, axis=-2)
        if isinstance(u_t, jax.experimental.sparse.BCOO):
            mask = u_t_plus_1.indices[:, -2] != u_t.shape[-2] - 1
            u_t_plus_1 = BCOO((u_t_plus_1.data[mask], u_t_plus_1.indices[mask]), shape=u_t_plus_1.shape)
        else:
            u_t_plus_1 = u_t_plus_1.at[..., -1, :].set(jnp.zeros_like(u_t[..., -1, :]))
        return (u_t_plus_1 - u_t) / dt
    
    def du_dx(self, u_t: NDArray, t: float, dx: float) -> NDArray:    
        u_t_minus_1 = roll(u_t, shift=1, axis=-1)
        u_t_plus_1 = roll(u_t, shift=-1, axis=-1)
        du = u_t_minus_1 - u_t_plus_1
        return du / (2 * dx)
    
    def ddu_dx(self, u_t: NDArray, t: float, dx: float) -> NDArray:
        u_t_minus_1 = roll(u_t, shift=1, axis=-1)
        u_t_plus_1 = roll(u_t, shift=-1, axis=-1)
        ddu = u_t_minus_1 - 2 * u_t + u_t_plus_1
        return ddu / (dx ** 2)
    
    def K(self, u_t: NDArray, t: float, dx: float) -> NDArray:
        return 0.1 / (1 + (u_t ** 2))
        
    def dK_half_dw(self, u_t: NDArray, phipsi: NDArray, t: float, dx: float, n_roll: int, axis_x: int) -> NDArray:
        # K_lf = jnp.sqrt(K(u_t) * K(jnp.roll(u_t, n_roll, axis=axis_x)))
        u_t_roll = jnp.roll(u_t, n_roll, axis=axis_x)
        K_ut = self.K(u_t, t, dx)
        K_ut_roll = self.K(u_t_roll, t, dx)
        
        coef_ut = (- K_ut / ((1 + K_ut ** 2) ** (3/2))) / jnp.sqrt(1 + K_ut_roll ** 2)
        coef_ut_roll = (- K_ut_roll / ((1 + K_ut_roll ** 2) ** (3/2))) / jnp.sqrt(1 + K_ut ** 2)
        ret = coef_ut * phipsi + coef_ut_roll * jnp.roll(phipsi, n_roll, axis=axis_x)
        return ret
    
    def velocity(self, u_t: NDArray, t: float, dx: float, axis_x: int) -> NDArray:
        u_t_xL = jnp.roll(u_t, 1, axis=axis_x)
        u_t_xR = jnp.roll(u_t, -1, axis=axis_x)
        K_ut = self.K(u_t, t, dx)
        K_halfR = jnp.sqrt(K_ut * self.K(u_t_xR, t ,dx))
        K_halfL = jnp.sqrt(K_ut * self.K(u_t_xL, t, dx))
        return K_halfR * (u_t_xR - u_t) / (dx ** 2) - K_halfL * (u_t - u_t_xL) / (dx ** 2)
    
    def init_dist(self, numbers:int, key: jax.random.PRNGKey, xs: int, max_k: int=1):
        key1, key2 = jax.random.split(key)
        A_n = jax.random.normal(key1, shape=(numbers, max_k + 1, )) * self.init_std[0] + self.init_mean[0]
        B_n = jax.random.normal(key2, shape=(numbers, max_k,)) * self.init_std[1] + self.init_mean[1]
        k = jnp.arange(1, max_k+1)
        sin_term = jnp.sin(k[None, :] * jnp.pi * xs[:, None] / self.L) # (nx, max_k + 1)
        cos_term = jnp.cos(k[None, :] * jnp.pi * xs[:, None] / self.L)
        u_0 = A_n[:, None, 0] + jnp.sum(A_n[:, None, 1:] * cos_term[None] + B_n[:, None] * sin_term[None], axis=-1) # (n, nx)
        return u_0

    def simulate(self, u0: NDArray, ts: NDArray, xs: NDArray):
        nt = len(ts)
        dt, dx = ts[1] - ts[0], xs[1] - xs[0]
        u = jnp.zeros((nt, len(xs)))
        u = u.at[0, :].set(u0)
        def step(u_t, _):
            u_next = u_t + self.velocity(u_t, 0., dx, axis_x=-1) * dt
            return u_next, u_next
        us = jax.lax.scan(step, u0, None, length=nt - 1)[1]
        u = u.at[1:].set(us)
        return u
    
    def pi_loss(self, basis_function, test_function, weights):
        dt, dx = basis_function.h[0], basis_function.h[1]
        nt, nx = basis_function.dim_out[0], basis_function.dim_out[1]
        u = weights.reshape(nt, nx) # = phispi @ weights
        lhs = self.du_dt(u, 0., dt)[:-1]
        rhs = self.velocity(u[:-1], 0, dx, axis_x=-1)
        loss = jnp.sqrt(jnp.mean((lhs - rhs) ** 2))
        return loss
    
    def get_matrix(self, basis_function, test_function, weight, sparse=False):
        dt, dx = basis_function.h[0], basis_function.h[1]
        phipsi = basis_function.phipsi(test_function, sparse=sparse)
        _nt, _nx, nt, nx = phipsi.shape
        
        def p(beta):
            u = phipsi.reshape(_nt, _nx, nt*nx) @ beta.reshape(nt*nx)
            lhs = (u[1:, :] - u[:-1, :]) / dt
            rhs = self.velocity(u[:-1, :], 0., dx, axis_x=-1)
            return lhs - rhs
        
        Jac_p = jax.vmap(jax.jit(jacrev(lambda w: p(w))))
        D = Jac_p(weight) # (n, nt, nx, nt, nx)
        psipsi = test_function.phipsi(test_function, sparse=sparse)[:-1, :, :-1, :]
        
        n, _nt, _nx, _, _ = D.shape
        D = D.reshape(n, _nt * _nx, -1)
        print(D.shape)
        psipsi = remove_zero_sparse(psipsi.reshape(_nt * _nx, -1))
        return D, psipsi
            
    def visualize(self, X, ut, save_path, n_samples=5):
        n = min(n_samples, ut.shape[0])
        if n == 1:
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            ax.imshow(ut[0], extent=[self.domain[1][0], self.domain[1][1], self.domain[0][0], self.domain[0][1]], origin='lower', aspect='auto', cmap='hot')
            ax.set_title("FDM Periodic Nonl Heat Equation")
        else:
            fig, axes = plt.subplots(1, n, figsize=(8 * n, 6))
            extent = [self.domain[1][0], self.domain[1][1], self.domain[0][0], self.domain[0][1]]
            for i in range(n):
                pos = axes[i].imshow(ut[i], extent=extent, origin='lower', aspect='auto', cmap='hot')
                axes[i].set_title(f"FDM Periodic Nonl Heat Equation - {i}")
                fig.colorbar(pos, ax=axes[i])
        plt.savefig(save_path, dpi=300)


class LinearBernoulli(LinearDE):
    def __init__(self, 
        t_min: float = 0.,
        t_max: float = 1.0,
        P: float = 1.0,
        init_mean: List[float] = [0.0],
        init_std: List[float] = [1.0],
    ):
        params = dict(P=P, init_mean=init_mean, init_std=init_std)
        domain = np.array([[t_min, t_max]])
        super().__init__(domain, params)
        self.P = P
  
    def __call__(self, y: NDArray, t: float) -> NDArray:
        return -self.P * y
    
    def __repr__(self) -> str:
        return f"Bernoulli-P:{self.P}"
    
    def init_dist(self, numbers:int, key: jax.random.PRNGKey, xs: Optional[NDArray] = None):
        mean = jnp.asarray(self.params['init_mean'])
        std = jnp.asarray(self.params['init_std'])
        y0 = jax.random.normal(key, shape=(numbers,)) * std + mean
        return y0
        
    def simulate(self, y0: NDArray, ts: NDArray, xs: NDArray):
        nt = len(ts)
        dt = ts[1] - ts[0]
        y = jnp.zeros((nt, 1))
        y = y.at[0].set(y0)
        def step(y_t, _):
            y_next = y_t + self(y_t, 0.) * dt
            return y_next, y_next
        ys = jax.lax.scan(step, y0, None, length=nt - 1)[1]
        y = y.at[1:].set(ys)
        return y
    
    def dphi_dt(self, phi_t: NDArray, t: float, dt: float) -> NDArray:
        phi_t_plus_1 = roll(phi_t, shift=1, axis=-1)
        if isinstance(phi_t, jax.experimental.sparse.BCOO):
            mask = phi_t_plus_1.indices[:, -1] != 0
            phi_t_plus_1 = BCOO((phi_t_plus_1.data[mask], phi_t_plus_1.indices[mask]), shape=phi_t_plus_1.shape)
        else:
            phi_t_plus_1 = phi_t_plus_1.at[:, 0].set(jnp.zeros_like(phi_t[:, 0]))
        return (phi_t_plus_1 - phi_t) / dt
    
    def get_matrix(self, basis_function, test_function, sparse=False):
        sparse = False
        dt = basis_function.h[0]
        phipsi = basis_function.phipsi(test_function, sparse=sparse) # (nx, basis)
        phipsi = phipsi[:-1, :]
        lhs = self.dphi_dt(phipsi, 0., dt)
  
        if isinstance(phipsi, jax.experimental.sparse.BCOO):
            mask = phipsi.indices[:, 1] != (phipsi.shape[-1] - 1)
            rhs = BCOO((phipsi.data[mask], phipsi.indices[mask]), shape=phipsi.shape)
            rhs = - self.P * rhs
        else:
            rhs = - self.P * phipsi.at[:, -1].set(jnp.zeros_like(phipsi[:, -1]))
        
        # rhs = - jnp.concatenate((phipsi[:, :-1], jnp.zeros_like(phipsi[:, -1:])), axis=-1)
        D = lhs - rhs
        D = remove_zero_sparse(D)
        psipsi = test_function.phipsi(test_function, sparse=sparse)[:-1, :-1]
        nt, _ = D.shape
        D = D.reshape(nt, -1)
        psipsi = remove_zero_sparse(psipsi.reshape(nt, -1))
        return D, psipsi
    
    def visualize(self, X, y, save_path, n_samples=5):
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        n = min(n_samples, y.shape[0])
        for i in range(n):
            ax.plot(X, y[i])
        ax.set_xlabel("t")
        ax.set_ylabel("y(t)")
        ax.set_title("Linear Bernoulli")
        plt.savefig(save_path, dpi=300)
        
class NonlBernoulli(NonlinearDE):
    def __init__(self, 
        t_min: float = 0.,
        t_max: float = 1.0,
        P: float = 1.0,
        Q: float = 0.1,
        init_mean: List[float] = [0.0],
        init_std: List[float] = [1.0],
    ):
        params = dict(P=P, Q=Q, init_mean=init_mean, init_std=init_std)
        domain = np.array([[t_min, t_max]])
        super().__init__(domain, params)
        self.P = P
        self.Q = Q
  
    def __call__(self, y: NDArray, t: float) -> NDArray:
        return -self.P * y + self.Q * (y ** 2)
    
    def __repr__(self) -> str:
        return f"NonlinearBernoulli-P:{self.P}-Q:{self.Q}"
    
    def init_dist(self, numbers:int, key: jax.random.PRNGKey, xs: Optional[NDArray] = None):
        mean = jnp.asarray(self.params['init_mean'])
        std = jnp.asarray(self.params['init_std'])
        y0 = jax.random.normal(key, shape=(numbers,)) * std + mean
        return y0
        
    def simulate(self, y0: NDArray, ts: NDArray, xs: NDArray):
        nt = len(ts)
        dt = ts[1] - ts[0]
        y = jnp.zeros((nt, 1))
        y = y.at[0].set(y0)
        def step(y_t, _):
            y_next = y_t + self(y_t, 0.) * dt
            return y_next, y_next
        ys = jax.lax.scan(step, y0, None, length=nt - 1)[1]
        y = y.at[1:].set(ys)
        return y
    
    def du_dt(self, y_t: NDArray, t: float, dt: float) -> NDArray:
        y_t_plus_1 = roll(y_t, shift=-1, axis=-1)
        if isinstance(y_t, jax.experimental.sparse.BCOO):
            mask = y_t_plus_1.indices[:, -1] != y_t.shape[-1] - 1
            y_t_plus_1 = BCOO((y_t_plus_1.data[mask], y_t_plus_1.indices[mask]), shape=y_t_plus_1.shape)
        else:
            y_t_plus_1 = y_t_plus_1.at[..., -1].set(jnp.zeros_like(y_t[..., -1]))
        return (y_t_plus_1 - y_t) / dt
    
    def pi_loss(self, basis_function, test_function, weights):
        dt = basis_function.h[0]
        nt = basis_function.dim_out[0]
        u = weights.reshape(nt) # = phispi @ weights
        lhs = self.du_dt(u, 0., dt)[:-1]
        rhs = - self.P * u[:-1] + self.Q * (u[:-1] ** 2)
        loss = jnp.sqrt(jnp.mean((lhs - rhs) ** 2))
        return loss

    def get_matrix(self, basis_function, test_function, weight, sparse=False):
        dt = basis_function.h[0]
        phipsi = basis_function.phipsi(test_function, sparse=sparse)
        _nt, nt = phipsi.shape
        
        def p(beta):
            u = phipsi.reshape(_nt, nt) @ beta
            lhs = (u[1:] - u[:-1]) / dt
            rhs = - self.P * u[:-1] + self.Q * (u[:-1] ** 2)
            return lhs - rhs
        
        Jac_p = jax.vmap(jax.jit(jacrev(lambda w: p(w))))
        D = Jac_p(weight) # (n, nt, nx, nt, nx)

        print(D.shape)
        psipsi = test_function.phipsi(test_function, sparse=sparse)[:-1, :-1]
        n, _nt, _nx, nt, _ = D.shape
        D = D.reshape(n, _nt, -1)
        psipsi = remove_zero_sparse(psipsi.reshape(_nt, -1))
        return D, psipsi
    
    def visualize(self, X, y, save_path, n_samples=5):
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        n = min(n_samples, y.shape[0])
        for i in range(n):
            ax.plot(X, y[i])
        ax.set_xlabel("t")
        ax.set_ylabel("y(t)")
        ax.set_title("Nonl Linear Bernoulli")
        plt.savefig(save_path, dpi=300)
