import jax.numpy as jnp
from pi_lr.basis.delta_basis import DeltaBasis
from jax.scipy.integrate import trapezoid
from typing import List

class FourierExpBasis:
    def __init__(self, 
        dim_in: int, 
        dim_out: List[int], 
        domain: List[int],
        alpha: float,
    ):
        '''
        dim_in: int : dimension of input
        dim_out: List[int] : dim_out[i] is the number of basis functions in the i-th dimension
        domain: List[int] : domain[i] is the domain of the i-th dimension
        '''

        d_x = dim_out[1]
        base_w = jnp.arange(1, d_x+1)
        W_cos = jnp.meshgrid(*[base_w]*(dim_in-1), indexing='ij')
        self.W_cos = jnp.concatenate([ w[..., None] for w in W_cos ], axis=-1).reshape(len(base_w)**(dim_in-1), (dim_in-1)) # (dim_in, dim_out**dim_in)
        W_sin = jnp.meshgrid(*[base_w]*(dim_in-1), indexing='ij')
        self.W_sin = jnp.concatenate([ w[..., None] for w in W_sin ], axis=-1).reshape(len(base_w)**(dim_in-1), (dim_in-1))
        
        d_t = dim_out[0]
        base_k = jnp.arange(1, d_t+1)
        self.k = base_k[:, None]
        
        self.dim_in = dim_in
        # (dim_in,)
        self.dim_out = dim_out
        # domain : [t, x_1, x_2, ..., x_d]
        self.domain = domain
        self.L = domain[1][-1] - domain[1][0]
        self.T = domain[0][-1] - domain[0][0]
        
        for i in range(1, dim_in-1):
            assert domain[i] == domain[i+1], "Only square domain is supported"
        
        self.alpha = alpha
        # theory: 1 + 2 * min(d_t, d_x)
        
    @property
    def dim(self):
        return 1 + (len(self.W_cos) + len(self.W_sin)) * len(self.k)

    def __call__(self, x, sparse):
        if sparse:
            raise UserWarning("Sparse FourierExp basis is not supported")
        return self.phi(x, keepshape=False)
    
    def _phi_cos(self, x):
        # (*, dim) -> (*, w)
        phi_cos = jnp.cos(jnp.pi/self.L * (self.W_cos @ x[:, 1:].T))
        return phi_cos.T
    
    def _phi_sin(self, x):
        # (*, dim) -> (*, w)
        phi_sin = jnp.sin(jnp.pi/self.L * (self.W_sin @ x[:, 1:].T))
        return phi_sin.T
    
    def _phi_exp(self, x):
        # (*, 1) -> (*, k)
        phi_exp = jnp.exp(-self.alpha * (jnp.pi * self.k / self.T)**2 * x[:, 0:1].T)
        return phi_exp.T
    
    def _dphi_cos(self, x):
        # (*, dim) -> (*, w, dim)
        phi_sin = self._phi_sin(x) # (*, w)
        return -jnp.pi/self.L * self.W_cos[None] * phi_sin[..., None]
    
    def _dphi_sin(self, x):
        # (*, dim) -> (x, w, dim)
        phi_cos = self._phi_cos(x)
        return jnp.pi/self.L * self.W_sin[None] * phi_cos[..., None]
    
    def _dphi_exp(self, x):
        # (*, dim) -> (x, k, 1)
        phi_exp = self._phi_exp(x) # (*, k)
        return - self.alpha * (jnp.pi * self.k[None] / self.T)**2 * phi_exp[..., None]
    
    def _ddphi_cos(self, x):
        # (*, dim) -> (*, w, dim, dim)
        phi_cos = self._phi_cos(x) # (*, w)
        # (1, w, dim, 1) * (1, w, 1, dim) -> (1, w, dim, dim)
        ddphi_cos = -((jnp.pi/self.L)**2) * (self.W_cos[None, :, :, None] * self.W_cos[None, :, None, :]) * phi_cos[..., None, None]
        return ddphi_cos

    def _ddphi_sin(self, x):
        # (*, dim) -> (x, w, dim, dim)
        phi_sin = self._phi_sin(x)
        ddphi_sin = -((jnp.pi/self.L)**2) * (self.W_sin[None, :, :, None] * self.W_sin[None, :, None, :]) * phi_sin[..., None, None]
        return ddphi_sin
    
    def _ddphi_exp(self, x):
        # (*, dim) -> (x, k, 1, 1)
        dphi_exp = self._dphi_exp(x) # (*, k, 1)
        return - self.alpha * (jnp.pi * self.k[None, :, :, None] / self.T)**2 * dphi_exp[..., None]
    
    def phi(self, x, keepshape):
        # x: (*, dim_in) -> (*, w, k)
        x = x.reshape(-1, self.dim_in)
        n = x.shape[0]
        
        phi_cos = self._phi_cos(x) # (n, w)
        phi_sin = self._phi_sin(x) # (n, w)
        phi_exp = self._phi_exp(x) # (n, k)
        
        
        phi_cos = phi_cos[..., None] * phi_exp[:, None]
        phi_sin = phi_sin[..., None] * phi_exp[:, None]
        
        phi = jnp.concatenate([phi_cos, phi_sin], axis=1).transpose((0, 2, 1))
        
        ones = jnp.ones((n, 1))
        if keepshape:
            return ones, phi
        else:
            return jnp.concatenate([ones, phi.reshape(n, -1)], axis=1)

    def dphi(self, x, keepshape):
        # x: (*, dim_in) -> (*, w, k, dim_in)
        x = x.reshape(-1, self.dim_in)
        n = x.shape[0]
        phi_cos = self._phi_cos(x)
        phi_sin = self._phi_sin(x)
        phi_exp = self._phi_exp(x)
        dphi_cos = self._dphi_cos(x)
        dphi_sin = self._dphi_sin(x)
        dphi_exp = self._dphi_exp(x)
        
        dphi_cos_x = dphi_cos[:, :, None, :] * phi_exp[:, None, :, None]
        dphi_sin_x = dphi_sin[:, :, None, :] * phi_exp[:, None, :, None]
        
        dphi_cos_t = phi_cos[:, :, None, None] * dphi_exp[:, None, :, :]
        dphi_sin_t = phi_sin[:, :, None, None] * dphi_exp[:, None, :, :]
        
        dphi_t = jnp.concatenate([dphi_cos_t, dphi_sin_t], axis=1) # (n, w, k, 1)
        dphi_x = jnp.concatenate([dphi_cos_x, dphi_sin_x], axis=1) # (n, w, k, dim_in-1)
        
        dphi = jnp.concatenate([dphi_t, dphi_x], axis=-1).transpose((0, 2, 1, 3))
        
        zeros = jnp.zeros((n, 1, self.dim_in))
        if keepshape:
            return zeros, dphi
        else:
            return jnp.concatenate([zeros, dphi.reshape(n, -1, self.dim_in)], axis=1)

    def ddphi(self, x, keepshape):
        # x: (*, dim_in) -> (*, w, k, dim_in, dim_in)
        x = x.reshape(-1, self.dim_in)
        n = x.shape[0]
        phi_cos = self._phi_cos(x)
        phi_sin = self._phi_sin(x)
        phi_exp = self._phi_exp(x)
        dphi_cos = self._dphi_cos(x)
        dphi_sin = self._dphi_sin(x)
        dphi_exp = self._dphi_exp(x)
        ddphi_cos = self._ddphi_cos(x)
        ddphi_sin = self._ddphi_sin(x)
        ddphi_exp = self._ddphi_exp(x)
        
        ddphi_cos_xx = ddphi_cos[:, :, None] * phi_exp[:, None, :, None, None]
        ddphi_sin_xx = ddphi_sin[:, :, None] * phi_exp[:, None, :, None, None]
        ddphi_xx = jnp.concatenate([ddphi_cos_xx, ddphi_sin_xx], axis=1) # (n, w, k, dim_in-1, dim_in-1)
        
        ddphi_cos_xt = dphi_cos[:, :, None, :, None] * dphi_exp[:, None, :, None, :]
        ddphi_sin_xt = dphi_sin[:, :, None, :, None] * dphi_exp[:, None, :, None, :]
        ddphi_xt = jnp.concatenate([ddphi_cos_xt, ddphi_sin_xt], axis=1) # (n, w, k, dim_in-1, 1)
        
        ddphi_cos_tt = phi_cos[:, :, None, None, None] * ddphi_exp[:, None, :]
        ddphi_sin_tt = phi_sin[:, :, None, None, None] * ddphi_exp[:, None, :]
        ddphi_tt = jnp.concatenate([ddphi_cos_tt, ddphi_sin_tt], axis=1) # (n, w, k, 1, 1)
        
        ddphi_u = jnp.concatenate([ddphi_tt, ddphi_xt.transpose((0, 1, 2, 4, 3))], axis=-1) # (n, w, k, 1, dim_in)
        ddphi_d = jnp.concatenate([ddphi_xt, ddphi_xx], axis=-1) # (n, w, k, dim_in-1, dim_in)
        ddphi = jnp.concatenate([ddphi_u, ddphi_d], axis=-2).transpose((0, 2, 1, 3, 4))

        zeros = jnp.zeros((n, 1, self.dim_in, self.dim_in))
        
        if keepshape:
            return zeros, ddphi
        else:
            return jnp.concatenate([zeros, ddphi.reshape(n, -1, self.dim_in, self.dim_in)], axis=1)
    
    def phipsi(self, test_function, keepshape=True):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.phi(test_function.nodes, keepshape)
        else:
            Z = jnp.linspace(self.domain[0][0], self.domain[0][1], 1000)
            return trapezoid(lambda x: self.phi(x) * test_function(x), x=Z)
        
    def dphipsi(self, test_function, keepshape=True):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.dphi(test_function.nodes, keepshape)
        else:
            Z = jnp.linspace(self.domain[0][0], self.domain[0][1], 1000)
            return trapezoid(lambda x: self.dphi(x) * test_function(x), x=Z)
    
    def ddphipsi(self, test_function, keepshape=True):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.ddphi(test_function.nodes, keepshape)
        else:
            Z = jnp.linspace(self.domain[0][0], self.domain[0][1], 1000)
            return trapezoid(lambda x: self.ddphi(x) * test_function(x), x=Z)
        
    def projection(self, y, Z):
        phi = self(Z) # (nx, dim)
        weight, resid, _, _ = jnp.linalg.lstsq(phi, y.squeeze().transpose((1, 0)))
        print(f"residual: {jnp.mean(resid)}")
        return weight.squeeze()