import torch

epsilon = 1e-6

pi = 3.1415926535897932384626433832795028841971693993

def better_params(W,V,epsilon=epsilon):
    W_mag = torch.linalg.norm(W,dim=1)[:,None]
    dots = W@V.T+epsilon
    V_mag = torch.linalg.norm(V,dim=1)[None,:]
    V1s = dots/W_mag
    V2s_squared = V_mag**2-V1s**2
    return W_mag, V1s, V2s_squared

def f(W,V):
    W_mag, V1s, V2s_squared = better_params(W,V)
    root2 = torch.sqrt(torch.tensor(2))
    
    first_term = torch.arctan(torch.sqrt(1+2*V2s_squared)/(root2*V1s))
    
    second_term_numerator_first_part = root2+2*root2*(W_mag**2)+2*root2*V2s_squared*(1+2*W_mag**2)
    second_term_numerator_second_part = 2*W_mag*torch.sqrt((1+2*V2s_squared)*(2*V1s**2+(1+2*V2s_squared)*(1+2*W_mag**2)))
    second_term_denominator = 2*V1s*torch.sqrt(1+2*V2s_squared)
    second_term = torch.arctan((second_term_numerator_first_part-second_term_numerator_second_part)/second_term_denominator)
    
    return torch.sum(first_term-second_term)*2/pi

def f_single(w,v,epsilon=epsilon):
    dot = w@v+epsilon
    w = torch.linalg.norm(w)
    v1 = dot/w
    v2_sqrd = v@v-v1**2
    root2 = torch.sqrt(torch.tensor(2))
    
    first_term = torch.arctan(torch.sqrt(1+2*v2_sqrd)/(root2*v1))
    
    second_num_first = root2+2*root2*w**2+2*root2*v2_sqrd*(1+2*w**2)
    second_num_second = 2*w*torch.sqrt((1+2*v2_sqrd)*(2*v1**2+(1+2*v2_sqrd)*(1+2*w**2)))
    second_denom = 2*v1*torch.sqrt(1+2*v2_sqrd)
    second_term = torch.arctan((second_num_first-second_num_second)/second_denom)
        
    return (first_term-second_term)*2/pi

def f_naive(W,V):
    total = 0
    for w in W:
        for v in V:
            total+=f_single(w,v)
    return total

def F(Ws,Vs):
    return f(Ws,Ws)/2-f(Ws,Vs)+f(Vs,Vs)/2

def grad_F(Ws,Vs):
    return torch.autograd.grad(F(Ws,Vs),Ws)[0]

def hessian_F(Ws,Vs):
    return torch.autograd.functional.hessian(lambda x:F(x,Vs),Ws)