import numpy as np
import jax.numpy as jnp
from typing import List
from functools import reduce
from pi_lr.basis.delta_basis import DeltaBasis
from jax.scipy.integrate import trapezoid
from jax.experimental.sparse import BCOO
from jax import config
config.update("jax_enable_x64", True)

class PiecewiseConstantBasis:
    def __init__(self, 
        dim_in: int, 
        dim_out: List[int], 
        domain: List[int]
    ):
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dim_out_all = reduce(lambda x, y: x*y, dim_out)
        self.domain = jnp.asarray(domain)
        ht = (self.domain[0:1, 1] - self.domain[0:1, 0]) / (jnp.asarray(dim_out)[0:1] - 1)
        hx = (self.domain[1:, 1] - self.domain[1:, 0]) / jnp.asarray(dim_out[1:])
        self.h = jnp.concatenate((ht, hx), axis=0)
        
    @property
    def dim(self):
        return self.dim_out_all
    
    def __call__(self, x, sparse=False):
        return self.phi(x, sparse, keepshape=False)

    def phi(self, x, sparse, keepshape):
        inp_shape = x.shape
        x = x.reshape(-1, self.dim_in)
        phi = self._phi_sparse(x) # (n, *dim_out)
        
        if not sparse:
            phi = phi.todense()
            
        if keepshape:
            phi = phi.reshape(inp_shape[:-1] + self.dim_out)
        else:
            phi = phi.reshape((-1, self.dim_out_all))
        
        return phi
        
    def _phi_sparse(self, x):
        assert x.ndim == 2
        n = x.shape[0]
        idx = ((x - self.domain[:, 0]) + self.h / 2) // self.h # (idx_t, idx_x)
        flattened_idx = idx.reshape(-1, self.dim_in)
        indicies = jnp.arange(len(x)).reshape(-1, 1)
        
        indicies = jnp.concatenate([indicies, flattened_idx], axis=-1).astype(jnp.int32)
        data = jnp.ones(len(x))
        
        nt = self.dim_out[0]
        if len(self.dim_out) == 1:
            nx = 0
        else:
            nx = self.dim_out[1]
        m = [nt] + [nx] * (self.dim_in-1)
        shape = tuple([n] + m)
        values = BCOO((data, indicies), shape=shape)
        return values
    
    def phipsi(self, test_function, sparse=False):
        # <bases, test_function>
        if isinstance(test_function, DeltaBasis):
            return self.phi(test_function.nodes, sparse, keepshape=True)
        else:
            raise NotImplementedError("Not implemented yet")
        
    def projection(self, y, test_function):
        if isinstance(test_function, DeltaBasis):
            return y
        else:
            raise NotImplementedError("Not implemented yet")
    