import torch as t
from warnings import warn

def eye_like(A):
    assert A.shape[-1] == A.shape[-2]
    return t.eye(A.shape[-1], dtype=A.dtype, device=A.device)

def add_jitter(ii, eps=1e-4):
    """
    Following: 'Exploiting Lower Precision Arithmetic in Solving Symmetric Positive Definite Linear Systems and Least Squares Problems'
    we add jitter equal to a small multiple of the diagonal
    """
    return ii + eps * t.diag_embed(ii.diag())

def kernel_to_psd_param(K, jitter=0.):
    L = t.linalg.cholesky(K + jitter*eye_like(K))
    L = t.tril(L,diagonal=-1) + t.diag(L.diag().log())
    return L
def chol_to_psd_param(L):
    L = t.tril(L,diagonal=-1) + t.diag(L.diag().log())
    return L
def psd_param_to_chol(L):
    return L.tril(diagonal=-1) + t.diag(L.diag().exp())
def psd_param_to_chol_log_diag(L):
    return L.diag()
def psd_param_to_kernel(L):
    V = L.tril(diagonal=-1) + t.diag(L.diag().exp())
    return V @ V.t()

def kernel_to_ldl_param(K, jitter=0.):
    K = K + jitter*K.trace()*eye_like(K)
    cholK = t.linalg.cholesky(K)
    D = cholK.diag()
    cholK = cholK / D.unsqueeze(-1)
    return cholK.tril(diagonal=-1), D.log()
def ldl_param_to_chol(trilL, logD):
    return trilL.tril(diagonal=-1) + t.diag_embed(logD.exp())
def init_eye_ldl(Pi, lmbda=1.0):
    L = t.eye(Pi) * lmbda + (1.-lmbda) * t.randn(Pi, Pi)
    G = L @ L.t()
    trilL, logD = kernel_to_ldl_param(G)
    return trilL, logD

def retrying_cholesky(ii: t.tensor, init_jitter: float=1e-6, max_jitter=1e-2):
    """decompose `ii` into `cholesky(ii)` but retry with increasing jitter if it fails.
       gives up if jitter >= `max_jitter`"""
    jitter = init_jitter
    i = None
    while jitter < max_jitter:
        try:
            i = t.linalg.cholesky(add_jitter(ii, eps=jitter))
            if i.isnan().any():
                warn(f"retrying_cholesky: cholesky decomposition failed with jitter={jitter}")
                if jitter == 0.:
                    jitter = 1e-6
                else:
                    jitter = jitter * 10
                i = None
                continue
            else:
                return i
        except:
            jitter *= 10
    raise ValueError(f"retrying_cholesky: could not cholesky decompose ii, even with jitter={jitter}")

def sparse_eye_like(x):
    n = x.size(0)
    indices = t.arange(n, device=x.device, dtype=x.dtype).repeat(2, 1)
    values = t.ones(n, device=x.device, dtype=x.dtype)
    return t.sparse_coo_tensor(indices, values, (n, n))

def cka(K1, K2):
    """
    centered kernel alignment.

    1. center K1, K2, then
    2. CKA(K1, K2) = trace(K1 @ K2) / sqrt(trace(K1 @ K1) * trace(K2 @ K2))
    """
    K1_c = K1 - K1.mean(1, keepdims=True) - K1.mean(0, keepdims=True) + K1.mean()
    K2_c = K2 - K2.mean(1, keepdims=True) - K2.mean(0, keepdims=True) + K2.mean()
    res = t.sum(K1_c * K2_c) * t.rsqrt(t.sum(K1_c * K1_c) * t.sum(K2_c * K2_c))
    return res


if __name__ == '__main__':
    V = t.randn(3, 3)
    G1 = V @ V.T / 3
    U = t.randn(3, 3)
    G2 = U @ U.T / 3

    def cka_ineff(K1, K2):
        H = eye_like(K1) - 1./K1.size(0)
        K1 = H @ K1 @ H
        K2 = H @ K2 @ H
        num = t.trace(K1 @ K2 )
        denom = t.sqrt(t.trace(K1 @ K1) * t.trace(K2 @ K2))
        return num / denom

    cka1 =  cka(G1, G2)
    cka2 = cka_ineff(G1, G2)
    print(cka1, cka2)
    assert t.allclose(cka1, cka2)
    assert cka1 >= 0 and cka1 <= 1
