import numpy as np
import numpy.typing as npt
from scipy import linalg


def time_based_sliding_window_agg(arr, timestamps, N, func, ind):
    i, j = 0, 1
    sum = arr[0]
    l = len(arr)
    ret = ind
    while j < l:
        if timestamps[j - 1] - timestamps[i] <= N:
            ret = func(ret, sum)
            sum += arr[j]
            j += 1
        elif timestamps[j - 1] - timestamps[i] >= N:
            sum -= arr[i]
            i += 1
        # else:
        #     ret = func(ret, sum)
        #     sum += arr[j]
        #     j += 1

    return ret


def sliding_window_agg(arr, N, func, ind):
    i, j = 0, 1
    sum = arr[0]
    l = len(arr)
    ret = ind
    while j < l:
        if j - 1 - i < N:
            sum += arr[j]
            j += 1
        elif j - 1 - i > N:
            sum -= arr[i]
            i += 1
        else:
            ret = func(ret, sum)
            sum += arr[j]
            j += 1

    return ret


def generate_random_subspace_matrix(m, d):
    # Generate a random d x m matrix with entries sampled from a normal distribution
    random_matrix = np.random.randn(d, m)

    # Perform QR decomposition
    q, r = np.linalg.qr(random_matrix, mode="reduced")

    # Transpose q to obtain a m x d matrix
    subspace_matrix = q.T

    return subspace_matrix


def generate_rankm_matrix(n, m, d):
    S = np.random.randn(n, m)
    D = np.diag([1 - (i - 1) / m for i in range(m)])
    # D = np.eye(m)
    U = generate_random_subspace_matrix(m, d)
    # print(S.shape, D.shape, U.shape)

    A = S @ D @ U

    return A


def generate_synthetic_matrix(n, m, d, U):
    S = np.random.randn(n, m)
    D = np.diag([1 - (i - 1) / m for i in range(m)])

    gaussian = np.random.randn(n, d)
    zeta = 10

    A = S @ D @ U + gaussian / zeta

    return A


def matrix_induced_norm(x: npt.NDArray, A: npt.NDArray) -> float:
    """Computes the matrix-induced norm ||x||_A = sqrt(x^T * A * x)

    Args:
        x (npt.ArrayLike): A vector
        A (npt.NDArray): A square matrix

    Returns:
        float: The matrix-induced norm of x with respect to A
    """

    # print(x.shape, A.shape)

    return np.sqrt(x.T @ A @ x)


def woodbury(A_inv: npt.NDArray, U: npt.NDArray, V: npt.NDArray) -> npt.NDArray:
    """Computes (A + UV^{\\top})^{-1} by Woodbury Formula

    Args:
        A_inv (npt.NDArray): matrix inverse of A
        U (npt.NDArray): left matrix
        V (npt.NDArray): right matrix

    Returns:
        npt.NDArray: (A + UV^{\\top})^{-1}
    """

    if U.ndim == 1:
        U = U[:, np.newaxis]
    if V.ndim == 1:
        V = V[:, np.newaxis]

    m = U.shape[1]

    # print(A_inv.shape, U.shape, linalg.inv(np.eye(m) + V.T @ A_inv @ U).shape)
    return A_inv - (A_inv @ U @ linalg.inv(np.eye(m) + V.T @ A_inv @ U) @ V.T @ A_inv)


# def generate_normalized_synthetic_matrix(n, m, d, U):
#     A = generate_synthetic_matrix(n, m, d, U)
#     epochs, d = A.shape

#     Rs = np.linalg.norm(A, axis=1)**2
#     r = np.min(Rs)
#     R = np.max(Rs)

#     A = A / np.sqrt(r)
