import numpy as np
import torch
from torch.nn.functional import binary_cross_entropy, cross_entropy


def cosine_sim(u, v) -> torch.Tensor:
    """Computes cosine similarity between input vectors u, v.
    If u, v are matrices, computes the cosine similarity between any two column vectors of u and v.
    """
    if len(u.shape) == 1:
        u = u.reshape(*u.shape, 1)
    if len(v.shape) == 1:
        v = v.reshape(*v.shape, 1)
    u /= torch.norm(u, dim=0)
    v /= torch.norm(v, dim=0)

    cosine_similarities = torch.einsum("ij,jk->ik", u.T, v)
    return cosine_similarities


def eigen_dissim(u, v, su=None, sv=None, metric="l2", reduce=True) -> torch.Tensor:
    """ Give the dissimilarity of eigenvectors u, v, weighted by their eigenvalue magnitueds s, s_"""
    if len(u.shape) == 1:
        u = u.reshape(*u.shape, 1)
    if len(v.shape) == 1:
        v = v.reshape(*v.shape, 1)
    N = u.shape[-1]
    if su is not None and sv is not None:
        assert len(su.shape) <= 1 and len(sv.shape) <= 1
        weights = torch.einsum("i,j->ij", su, sv)
    else:
        weights = 1

    cosine_similarities = torch.abs(torch.einsum("ij,jk->ik", u.T, v))
    kronecker_delta = torch.eye(N, device=u.device)
    if metric == "bce":
        y, t = cosine_similarities, kronecker_delta
        if reduce:
            dissim = binary_cross_entropy(y, t, reduction="sum").item()
        else:
            dissim = binary_cross_entropy(y, t, reduction="none")
    elif metric == "ce":
        y = cosine_similarities
        t = torch.argmax(kronecker_delta, dim=1)
        if reduce:
            dissim = cross_entropy(y, t, reduction="sum").item()
        else:
            dissim = cross_entropy(y, t, reduction="sum")
    elif metric == "l1":
        deviations = (kronecker_delta - cosine_similarities).abs()
        dissim = weights * deviations
        if reduce:
            dissim = torch.sum(dissim)
    elif metric == "l2":
        deviations = kronecker_delta - cosine_similarities
        dissim = weights * deviations ** 2
        if reduce:
            dissim = torch.sum(dissim)
    if reduce:
        dissim /= N ** 2
    return dissim


def subspace_sim(V, W):
    """Computes subspace similarity of V,W. 
    We suppose that V, W contain an orthogonal basis and have the same dimension d. We calculate:
    sim(V,W)=Tr(P_V*P_W)/Sqrt(Tr(P_V)Tr(P_W))
    Now: Tr(P_V)=Tr(P_W)=d and
    Tr(P_V*P_W)=Tr((V*V.T)*(W*W.T))=Tr((W.T*V)*(V.T*W))=||flatten(V.T*W)||_2^2
    """
    assert V.shape[0] == W.shape[0],"V, W must be subspaces of the same vector space"
    V /= torch.linalg.norm(V, dim=0)
    W /= torch.linalg.norm(W, dim=0)
    d = min([V.shape[1], W.shape[1]])
    return ((V.T @ W).reshape(-1) ** 2).sum() / d

def random_baseline_subsim(V, W):
    """Compute expected overlap of random vectorspaces V',W' with 
    dim(V') = dim(V) and dim(W') = dim(W)"""
    
    assert V.shape[0] == W.shape[0],"V, W must be subspaces of the same vector space"
    return max([V.shape[1], W.shape[1]])/V.shape[0]