import jax
import jax.numpy as jnp
from jax.scipy.signal import convolve
from jax import jit
from functools import partial


def causal_convolution(u, K, nofft=False):
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]]
    else:
        assert K.shape[0] == u.shape[0]
        ud = jnp.fft.fft(jnp.pad(u, (0, K.shape[0])))
        Kd = jnp.fft.fft(jnp.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return jnp.fft.ifft(out)[: u.shape[0]]

# Compute the predicted Koopman measurements per dimension in time
def compute_measurement_block(Ker, action_emb, init_state):
    return causal_convolution(action_emb, Ker[:-1]) + Ker[1:] * init_state


@jit
def compute_measurement(Ker, action_emb, init_state):
    return jax.vmap(
            compute_measurement_block,
            in_axes=(0, 1, 0),
            out_axes=1
        )(Ker, action_emb, init_state)


@partial(jit, static_argnames=('mode'))
def discretize(K, L, step, mode="zoh"):
    if mode == "bilinear":
        num, denom = 1 + .5 * step*K, 1 - .5 * step*K
        return num / denom, step * L / denom
    elif mode == "zoh":
        return jnp.exp(step*K), (jnp.exp(step*K)-1)/K * L



def log_step_init(dt_min=0.001, dt_max=0.1):
    def init(key, shape):
        return jax.random.uniform(key, shape) * (
            jnp.log(dt_max) - jnp.log(dt_min)
        ) + jnp.log(dt_min)
    return init

def increasing_im_init():
    def init(key, shape):
        return jnp.pi * jnp.arange(shape[0])
    return init


def random_im_init():
    def init(key, shape):
        return jax.random.uniform(key, shape)
    return init

def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

def vandermonde(v, L, alpha):
    """
    Computes v @ Vandermonde(alpha, L)
    v, alpha: shape (N,)
    Returns: shape (L,)
    """
    V = alpha[:, jnp.newaxis] ** jnp.arange(L)  # Vandermonde matrix
    return (v[jnp.newaxis, :] @ V)[0]

@partial(jax.jit, static_argnums=2)
def s4d_kernel_zoh(C, A, L, step):
    """ A version of the kernel for B=1 and ZOH """
    kernel_l = lambda l: (C * (jnp.exp(step * A) - 1) / A * jnp.exp(l * step * A)).sum()
    return jax.vmap(kernel_l)(jnp.arange(L)).ravel().real

