import numpy as np
import sympy as sp


def observability_check(solver):
    op = solver.op
    mask = solver.mask
    print(f'Operator rank: {np.linalg.matrix_rank(op)}')

    is_observable(op, mask, solver.mask_dimension)
    is_observable_alt(op, mask, 1e-2)


def is_observable(operator, mask, c):
    c_factor = c ** 2
    q = mask.shape[0]
    L = operator.shape[0]
    assert q * c_factor == L
    matrices = [mask]
    for i in range(c_factor - 1):
        new = np.matmul(matrices[i], operator)
        matrices.append(new)
        print(f'Current iteration {i + 1}')
    stacked = np.vstack(matrices)
    rank = np.linalg.matrix_rank(stacked)
    U, S, Vh = np.linalg.svd(stacked)
    print(f'Shape: {stacked.shape}, Max: {S.max()}, Dtype: {stacked.dtype}')
    print(f'Rank: {rank}')
    return rank == L


def is_observable_alt(operator, mask, epsilon):
    vals, vecs = np.linalg.eig(operator)
    result = np.matmul(mask, vecs)
    check = np.where(np.all(np.abs(result) < epsilon, axis=0))
    vec_check = np.where(np.all(np.abs(vecs) < epsilon, axis=0))
    print(f'Max eigenvalue: {np.abs(vals).max()}, Min eigenvalue: {np.abs(vals).min()}')
    print(f'Result: {len(check[0])} errant eigenvectors')
    return check

def kse_lie_check():
    # Parameters
    L = 1.0  # Domain length
    N = 100  # Number of spatial grid points
    c = 2   # compression factor
    dx = L / N  # Spatial resolution
    q = N // c  # compression ratio

    def h(u):  # Define the masked output
        D = np.zeros((q, N))
        for i in range(q):
            D[i, i * c:(i + 1) * c] = 1
        return D @ u

    def central_difference(u, t, n):
        if n <= 0:
            return u  # Zeroth derivative is the function itself

        dt = np.diff(t)  # Time step differences
        if not np.allclose(dt, dt[0]):
            raise ValueError("Time steps must be uniform.")

        dt = dt[0]  # Assume uniform time step
        deriv = u.copy()

        for _ in range(n):
             deriv = np.diff(deriv, axis=0) / dt  # Finite difference derivative

        return deriv[0]

    def numerical_observability(uu, tt):
        n = c * N
        O = np.zeros((n, q))
        O[0, :] = h(uu[0])

        for i in range(1, n):
            O[i, :] = h(central_difference(uu, tt, i))  # h linear so this is fine

        rank_O = np.linalg.matrix_rank(O)
        return O, rank_O

    [x, tt, uu] = calculate_ks_equation(N=N, tmax=100)

    O_matrix, rank = numerical_observability(uu, tt)
    print("Numerical Observability Matrix:\n", O_matrix)
    print("Rank:", rank, " (System is observable)" if rank == len(x) else " (System is NOT observable)")


# Quick and dirty implementation of the Kuramoto-Sivashinsky equation in 1D. See the careful 2D derivation in KSSolver
def calculate_ks_equation(N=1024,
                          h=0.25,
                          M=16,
                          tt=0, tmax=150):
    x = np.transpose(np.conj(np.arange(1, N + 1))) / N
    u = np.cos(x) * (1 + np.sin(x))
    v = np.fft.fft(u)

    # scalars for ETDRK4
    k = np.transpose(np.conj(np.concatenate((np.arange(0, N / 2), np.array([0]), np.arange(-N / 2 + 1, 0))))) / 16
    L = k ** 2 - k ** 4
    E = np.exp(h * L)
    E_2 = np.exp(h * L / 2)
    r = np.exp(1j * np.pi * (np.arange(1, M + 1) - 0.5) / M)
    LR = h * np.transpose(np.repeat([L], M, axis=0)) + np.repeat([r], N, axis=0)
    Q = h * np.real(np.mean((np.exp(LR / 2) - 1) / LR, axis=1))
    f1 = h * np.real(np.mean((-4 - LR + np.exp(LR) * (4 - 3 * LR + LR ** 2)) / LR ** 3, axis=1))
    f2 = h * np.real(np.mean((2 + LR + np.exp(LR) * (-2 + LR)) / LR ** 3, axis=1))
    f3 = h * np.real(np.mean((-4 - 3 * LR - LR ** 2 + np.exp(LR) * (4 - LR)) / LR ** 3, axis=1))

    # main loop
    uu = np.array([u])
    nmax = round(tmax / h)
    # nplt = int((tmax/100)/h)
    g = -0.5j * k
    for n in range(1, nmax + 1):
        t = n * h
        Nv = g * np.fft.fft(np.real(np.fft.ifft(v)) ** 2)
        a = E_2 * v + Q * Nv
        Na = g * np.fft.fft(np.real(np.fft.ifft(a)) ** 2)
        b = E_2 * v + Q * Na
        Nb = g * np.fft.fft(np.real(np.fft.ifft(b)) ** 2)
        c = E_2 * a + Q * (2 * Nb - Nv)
        Nc = g * np.fft.fft(np.real(np.fft.ifft(c)) ** 2)
        v = E * v + Nv * f1 + 2 * (Na + Nb) * f2 + Nc * f3
        u = np.real(np.fft.ifft(v))
        uu = np.append(uu, np.array([u]), axis=0)
        tt = np.hstack((tt, t))
    return [x, tt, uu]

