
import jax.numpy as jnp
from typing import List
from functools import reduce
import jax

class DeltaBasis:
    def __init__(self, 
        dim_in: int, 
        dim_out: List[int], 
        domain: List[int]
    ):
        self.dim_in = dim_in
        self.domain = domain
        self.dim_out = dim_out
        self.dim_out_all = reduce(lambda x, y: x*y, dim_out)
        
        linspaces = []
        linspaces.append(jnp.linspace(domain[0][0], domain[0][-1], dim_out[0]))
        for i in range(1, len(dim_out)):
            linspaces.append(jnp.linspace(domain[i][0], domain[i][-1], dim_out[i], endpoint=False))
        self.nodes = jnp.stack(jnp.meshgrid(*linspaces, indexing='ij'), axis=-1)
    
    def __call__(self, x):
        raise NotImplementedError("Dirac delta function is a superfunction")
        
    @property
    def dim(self):
        return self.dim_out_all
    
    def phipsi(self, test_function, sparse=False):
        if isinstance(test_function, DeltaBasis) and jnp.all(self.nodes == test_function.nodes):
            # delta_{ij}
            if sparse:
                mat = jax.experimental.sparse.eye(self.dim_out_all)
            else:
                mat = jnp.eye(self.dim_out_all)
            return mat.reshape(*(self.dim_out + self.dim_out))
        else:
            raise NotImplementedError("Not implemented yet")
    