import jax.numpy as jnp
from pi_lr.basis.delta_basis import DeltaBasis
from jax.scipy.integrate import trapezoid
from typing import List

class FourierBasis:
    def __init__(self, 
        dim_in: int, 
        dim_out: List[int], 
        domain: List[int]
    ):
        dim_out = dim_out[0]
        base_w = jnp.arange(0, dim_out+1)
        W_cos = jnp.meshgrid(*[base_w]*dim_in, indexing='ij')
        self.W_cos = jnp.concatenate([ w[..., None] for w in W_cos ], axis=-1).reshape(len(base_w)**dim_in, dim_in) # (dim_in, dim_out**dim_in)
        W_sin = jnp.meshgrid(*[base_w[1:]]*dim_in, indexing='ij')
        self.W_sin = jnp.concatenate([ w[..., None] for w in W_sin ], axis=-1).reshape((len(base_w)-1)**dim_in, dim_in)
        
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.domain = domain
        self.T = domain[0][-1] - domain[0][0]
        
    @property
    def dim(self):
        return len(self.W_cos) + len(self.W_sin)

    def __call__(self, x, sparse):
        if sparse:
            raise UserWarning("Sparse Fourier basis is not supported")
        return self.phi(x)
    
    def phi(self, x):
        x = x.reshape(-1, self.dim_in)
        # phi = (4*self.L) ** (-self.dim_in/2) * jnp.exp(jnp.pi/self.L * (self.W @ x.T)*1j / 2) # (dim_out**dim_in, n)
        phi_cos = jnp.sqrt(self.T / 2) * jnp.cos(2*jnp.pi/self.T * (self.W_cos @ x.T)) # (dim_out**dim_in, n)
        phi_sin = jnp.sqrt(self.T / 2) * jnp.sin(2*jnp.pi/self.T * (self.W_sin @ x.T)) # (dim_out**dim_in, n)
        phi = jnp.concatenate([phi_cos, phi_sin], axis=0)
        return phi.transpose(1, 0) # (n, 2 * (dim_out**dim_in))

    def dphi(self, x):
        x = x.reshape(-1, self.dim_in)
        # dphi: (n, dim_out, dim_in), W: (dim_in, dim_out**dim_in)
        # (dim_out**dim_in, dim_in) x (dim_in, n)
        phi_cos = jnp.sqrt(self.T / 2) * jnp.cos(2*jnp.pi/self.T * (self.W_cos @ x.T))
        phi_sin = jnp.sqrt(self.T / 2) * jnp.sin(2*jnp.pi/self.T * (self.W_sin @ x.T))
        dphi_cos = -2*jnp.pi/self.T * self.W_cos[..., None] * phi_sin[:, None]
        dphi_sin = 2*jnp.pi/self.T * self.W_sin[..., None] * phi_cos[:, None]
        dphi = jnp.concatenate([dphi_cos, dphi_sin], axis=0)
        return dphi.transpose(2, 0, 1) # (n, dim_out**dim_in, dim_in)

    def ddphi(self, x):
        x = x.reshape(-1, self.dim_in)
        # dphi: (n, dim_out, dim_in), W: (dim_in, dim_out**dim_in)
        phi_cos = jnp.sqrt(self.T / 2) * jnp.cos(2*jnp.pi/self.T * (self.W_cos @ x.T))
        phi_sin = jnp.sqrt(self.T / 2) * jnp.sin(2*jnp.pi/self.T * (self.W_sin @ x.T))
        ddphi_cos = -(2*jnp.pi/self.T) * (2*jnp.pi/self.T) * (self.W_cos[..., None] @ self.W_cos[..., None, :])[..., None] * phi_cos[:, None, None, :]
        ddphi_sin = -(2*jnp.pi/self.T) * (2*jnp.pi/self.T) * (self.W_sin[..., None] @ self.W_sin[..., None, :])[..., None] * phi_sin[:, None, None, :]
        ddphi = jnp.concatenate([ddphi_cos, ddphi_sin], axis=0)
        return ddphi.transpose(3, 0, 1, 2)
    
    def phipsi(self, test_function):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.phi(test_function.nodes)
        else:
            Z = jnp.linspace(self.domain[0][0], self.domain[0][1], 1000)
            return trapezoid(lambda x: self.phi(x) * test_function(x), x=Z)
        
    def dphipsi(self, test_function):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.dphi(test_function.nodes)[..., 0]
        else:
            Z = jnp.linspace(self.domain[0][0], self.domain[0][1], 1000)
            return trapezoid(lambda x: self.dphi(x) * test_function(x), x=Z)
    
    def ddphipsi(self, test_function):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.ddphi(test_function.nodes)[..., 0, 0]
        else:
            Z = jnp.linspace(self.domain[0][0], self.domain[0][1], 1000)
            return trapezoid(lambda x: self.ddphi(x) * test_function(x), x=Z)
        
    def projection(self, y, test_funcion):
        phipsi = self.phipsi(test_funcion)
        weight, resid, _, _ = jnp.linalg.lstsq(phipsi, y.squeeze().transpose((1, 0)))
        print(f"projection residual: {jnp.mean(resid)}")
        return weight.squeeze()
    
    