import torch

epsilon=1e-6


def thetas(a,b):
    dots = a@b.T
    norm_prods = norm_products(a,b)
    angles = torch.acos((dots/norm_prods).clamp(min=-1+epsilon,max=1-epsilon))
#     angles = torch.acos((dots/norm_prods).clamp(min=-1,max=1))
    return angles

def norm_products(a,b):
    a_norms = torch.linalg.norm(a,dim=1)
    b_norms = torch.linalg.norm(b,dim=1)
    norm_prods = torch.outer(a_norms, b_norms)
    return norm_prods

def f(a,b):
    norm_prods = norm_products(a,b)
    angles = thetas(a,b)
    return torch.sum((torch.sin(angles)+(torch.pi-angles)*torch.cos(angles))*norm_prods/(2*torch.pi))

def f_single(a,b):
    norm_prods = torch.linalg.norm(a)*torch.linalg.norm(b)
    angles = torch.acos((a@b/norm_prods).clamp(min=-1+epsilon,max=1-epsilon))
    return torch.sum((torch.sin(angles)+(torch.pi-angles)*torch.cos(angles))*norm_prods/(2*torch.pi))

def F(Ws,Vs):
    return f(Ws,Ws)/2-f(Ws,Vs)+f(Vs,Vs)/2

def g(a,b):
    angles = thetas(a,b)
    b_norms = torch.linalg.norm(b,dim=1,keepdim=True)
    a_normed = a/torch.linalg.norm(a,dim=1,keepdim=True)
    return ((torch.sin(angles)@b_norms)*a_normed+(torch.pi-angles)@b)/(2*torch.pi)

def grad_F(a,b):
    return g(a,a)-g(a,b)

def n(a,b):
    angles = thetas(a,b)
    angles = angles[:,:,None]
    a_normed = a/torch.linalg.norm(a,dim=1,keepdim=True)
    b_normed = b/torch.linalg.norm(b,dim=1,keepdim=True)
    b_normed = b_normed[None,:,:]
    a_normed = a_normed[:,None,:]
    return a_normed-torch.cos(angles)*b_normed

def n_bar(a,b):
    ns = n(a,b)
    ns /= torch.linalg.norm(ns,dim=2,keepdim=True)
    return ns

def n_bar_same(a):
    ns = n(a,a)
    norms = torch.linalg.norm(ns,dim=2)
    norms.fill_diagonal_(torch.inf)
    return ns/norms[:,:,None]

def h1(a,b):
    angles = thetas(a,b)[:,:,None,None]
    n_bars = n_bar(b,a).transpose(0,1)
#     n_bars = n_bar(a,b)
    a_normed = (a/torch.linalg.norm(a,dim=1,keepdim=True))
    a_norms = torch.linalg.norm(a,dim=1)
    b_norms = torch.linalg.norm(b,dim=1)
    
    a_normed_outer = a_normed[:,None,:,None]*a_normed[:,None,None,:]
    n_bar_outer = n_bars[:,:,:,None]*n_bars[:,:,None,:]
    identities = torch.eye(a.shape[1])[None,None,:,:]
    paren_expr = identities-a_normed_outer+n_bar_outer
    
    prefactor = torch.sin(angles)*b_norms[None,:,None,None]/(a_norms[:,None,None,None]*2*torch.pi)
    
    h1_values = torch.sum(prefactor*paren_expr,dim=1)[:,None,:,:]
    
    return (h1_values*torch.eye(a.shape[0])[:,:,None,None]).transpose(1,2)

def h2(a,b):
    angles = thetas(a,b)[:,:,None,None]
    n_bars1 = n_bar_same(a)[:,:,:,None]
    n_bars2 = n_bar_same(a)[:,:,:,None].transpose(0,1)
    a_normed = (a/torch.linalg.norm(a,dim=1,keepdim=True))[:,None,None,:]
    b_normed = (b/torch.linalg.norm(b,dim=1,keepdim=True))[None,:,None,:]
    identities = torch.eye(a.shape[1])[None,None,:,:]
    return (((torch.pi-angles)*identities+n_bars1*b_normed+n_bars2*a_normed)/(2*torch.pi)).transpose(1,2)

def hessian_F(Ws,Vs):
    return h1(Ws,Ws)-h1(Ws,Vs)+h2(Ws,Ws)