import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
def orth_loss(param,reg=1e-6):
    
    param_flat = param.view(param.shape[0], -1)
    sym = torch.mm(param_flat, torch.t(param_flat))
    if param.is_cuda: 
        sym -= torch.eye(param_flat.shape[0]).cuda()
    else:
        sym -= torch.eye(param_flat.shape[0])

    return reg * sym.abs().sum()

def cos_sim_loss_2d(w1,w2,reg=1e-6):
    shape_v = w1.shape
    
    return reg*F.cosine_similarity(w1.view(shape_v[0],shape_v[1],-1),
             w2.view(shape_v[0],shape_v[1],-1),dim=2).mean() 
             
             
def cos_sim_loss_1d(w1,w2,reg=1e-6):
#     shape_v = w1.shape
    
    return reg*F.cosine_similarity(w1,
             w2,dim=1).mean() 