import torch
import torch.nn as nn


def get_beta(vec):
    beta = (.0 - vec.norm()).square() # * 100
    return beta


def get_alpha(vec):
    vec = vec.view(vec.size(0), -1) #HJING
    length = (vec * vec).sum(-1) / vec.shape[-1]
    alpha = (1 - length.sqrt()).square() # * 100
    return alpha.mean()
    # length = torch.norm(vec, dim=-1)/math.sqrt(vec.shape[-1])
    # # length = (vn*vn /vec.shape[-1])
    # alpha = torch.square(1 - length)
    # return alpha


def get_omega(mat):
    mat = mat.view(mat.size(0), -1) #HJING
    nxlen = (mat * mat).sum(-1)
    omega = (1 - nxlen.sqrt()).square()
    return omega.mean() 
    # mmt = torch.matmul(mat, mat.T) / mat.shape[-1]
    # I = torch.eye(mmt.shape[0]).to(device) / math.sqrt(mat.shape[-1])
    # omega = (I - mmt).square().mean()
    # return omega


def get_coef_vec(args, vec):
    vec = vec.detach()
    with torch.no_grad():
        alpha = get_alpha(vec)
        if args.alpha_mean: alpha = alpha.mean()
        alpha_b = torch.max(torch.abs(alpha),
                            torch.tensor(args.vec_min).to(args.device))
        alpha_bt = torch.min(torch.abs(alpha_b),
                            torch.tensor(args.vec_max).to(args.device))
        if not args.alpha_mean: alpha_bt = alpha_bt.unsqueeze(-1)
        return alpha_bt


def get_coef_mat(args, mat):
    mat = mat.detach()
    with torch.no_grad():
        return get_omega(mat, args.device)


def wb_norm(args, model):
    with torch.no_grad():
        for name, m in model.named_modules():
            if 'decs' in name: continue # to include ff along with pc
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                if args.w_norm:
                    w = m.weight
                    w_grad = w.grad.detach()
                    # omega = get_coef_mat(args, w) if args.w_orth else get_coef_vec(args, w)
                    # m.weight.grad = omega * w_grad if not args.w_orth or \
                    #         args.alpha_mean else torch.matmul(omega, w_grad)
                    omega = get_coef_vec(args, w)
                    m.weight.grad = omega * w_grad
