import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

        
class CA(nn.Module):

    def __init__(self):
        super(CA, self).__init__()
        
    def ca_loss(self, centers, classifier_weight):

        cw_norm=torch.norm(classifier_weight, p='fro')
        cp_norm=torch.norm(centers, p='fro')
        
        cw=classifier_weight/cw_norm
        cp=centers/cp_norm
        
        sub=cw-cp
        loss=torch.norm(sub, p='fro')**2

        return loss        
    
    
    def forward(self, proto, cls_a, cls_v):
        loss_a=self.ca_loss(proto, cls_a)
        loss_a=self.ca_loss(proto, cls_v)
        
        loss= (loss_a+ loss_a)/2
    
        return loss
    
class CA_tri(nn.Module):

    def __init__(self):
        super(CA_tri, self).__init__()
        
    def ca_loss(self, centers, classifier_weight):

        cw_norm=torch.norm(classifier_weight, p='fro')
        cp_norm=torch.norm(centers, p='fro')
        
        cw=classifier_weight/cw_norm
        cp=centers/cp_norm
        
        sub=cw-cp
        loss=torch.norm(sub, p='fro')**2

        return loss        
    
    
    def forward(self, proto, cls_r, cls_o, cls_d):
        produal_r=self.ca_loss(proto, cls_r)
        produal_o=self.ca_loss(proto, cls_o)
        produal_d=self.ca_loss(proto, cls_d)
        
        loss= (produal_r+ produal_o+ produal_d)/3
    
        return loss
    


