# some general notes:
# 	performing inverse with double precision is best
# 	matrix inverse is best performed on CPU due to memory and precision reasons

import scipy
from scipy.linalg import eigh, pinvh, lstsq, solve
import numpy as np


def mean_projector_matrix(n):
    return np.eye(n) - np.ones((n, 1)) @ np.ones((1, n)) / n


def kernel_matrix_VICREG(K, A, beta=1.0, dimension=None):
    # might be more stable if we perform matrix-vector multiplication throughout rather than this below
    n = K.shape[0]

    Kinv = pinvh(K)

    L = np.diag(np.sum(A, axis=1)) - A
    M = mean_projector_matrix(n) - beta / 2 * L
    # M = np.eye(n) - beta/2*L
    if dimension is not None:
        e, C = eigh(M, subset_by_index=[n - dimension, n - 1])
        e = np.maximum(e, 0)
        M = (e * C) @ (
            C.T
        )  # would be faster to do this multiplication without the outer product
    else:
        e, C = eigh(M)
        e = np.maximum(e, 0)
        M = (e * C) @ (
            C.T
        )  # would be faster to do this multiplication without the outer product
    return Kinv @ M @ Kinv


def kernel_matrix_contrastive(K, A, dimension=None, is_SPD=False):
    # might be more stable if we perform matrix-vector multiplication throughout rather than this below
    n = K.shape[0]

    Kinv = pinvh(K)

    # D = np.sqrt(np.sum(A,axis = 1))
    # L = np.eye(n) + (1/D).reshape(-1,1)*A*(1/D).reshape(1,-1)
    M = np.eye(n) + A
    if dimension is not None:
        e, C = eigh(M, subset_by_index=[n - dimension, n - 1])
        e = np.maximum(e, 0)
        M = (e * C) @ (
            C.T
        )  # would be faster to do this multiplication without the outer product
    elif dimension is None and (not is_SPD):
        e, C = eigh(M)
        e = np.maximum(e, 0)
        M = (e * C) @ (
            C.T
        )  # would be faster to do this multiplication without the outer product

    return Kinv @ M @ Kinv


def calc_supervised_kernel_matrix(K_xs, K_sx, K_inv):
    return K_xs @ K_inv @ K_sx
