import torch
import gc
import torch.nn as nn
import torch.nn.functional as F
from vgg_features import vgg11_features, vgg11_bn_features, vgg13_features, vgg13_bn_features, vgg16_features, vgg16_bn_features,\
                         vgg19_features, vgg19_bn_features
from resnet_features import resnet18_features,resnet34_features,resnet50_features, resnet101_features,resnet152_features
from densenet_features import densenet121_features,densenet161_features,densenet169_features,densenet201_features

from receptive_field import compute_proto_layer_rf_info_v2
import pdb
base_architecture_to_features = {
                                'resnet18': resnet18_features,
                                 'resnet34': resnet34_features,
                                 'resnet50': resnet50_features,
                                 'resnet101': resnet101_features,
                                 'resnet152': resnet152_features,
                                 'densenet121': densenet121_features,
                                 'densenet161': densenet161_features,
                                 'densenet169': densenet169_features,
                                 'densenet201': densenet201_features,
                                 'vgg11': vgg11_features,
                                 'vgg11_bn': vgg11_bn_features,
                                 'vgg13': vgg13_features,
                                 'vgg13_bn': vgg13_bn_features,
                                 'vgg16': vgg16_features,
                                 'vgg16_bn': vgg16_bn_features,
                                 'vgg19': vgg19_features,
                                 'vgg19_bn': vgg19_bn_features}
class Manifold_Projector(torch.nn.Module):
    """
    projector matrix module
    """
    def __init__(self, num_channels, mode = -1, momentum=0.05):
        super(Manifold_Projector, self).__init__()
        self.momentum = momentum
        #self.alpha = alpha #融合系数
        self.num_channels = num_channels
        self.mode = mode

        # running rotation matrix Projector 旋转矩阵 正交的
        self.register_buffer('projector_vision', torch.eye(num_channels))
        self.register_buffer('projector_language', torch.eye(num_channels))
        #self.register_buffer('projector_vision', torch.nn.init.orthogonal_(torch.Tensor(num_channels,num_channels))) #或者初始化为一个正交矩阵nn.init.orthogonal_()
        #self.register_buffer('projector_language', torch.nn.init.orthogonal_(torch.Tensor(num_channels,num_channels)))
        # sum Gradient, need to take average later
        self.register_buffer('sum_G_v', torch.zeros(num_channels, num_channels)) #G矩阵
        self.register_buffer('sum_G_l', torch.zeros(num_channels, num_channels)) #G矩阵
        # counter, number of gradient for each concept
        self.register_buffer("counter", torch.ones(num_channels)*0.001)

    def update_rotation_matrix(self):

        size_R = self.projector_vision.size()
        P_v_before = self.projector_vision
        P_l_before = self.projector_language
        self.update_rotation_matrix_v()
        self.update_rotation_matrix_l()
        self.counter = (torch.ones(size_R[-1]) * 0.001).cuda()




    def update_rotation_matrix_v(self):
        """
        Update the rotation matrix R using the accumulated gradient G.
        The update uses Cayley transform to make sure R is always orthonormal.
        """
        size_R = self.projector_vision.size()
        #size_projector_language = self.projector_language.size()
        with torch.no_grad():
            G = self.sum_G_v/self.counter.reshape(-1,1)
            R = self.projector_vision.clone()
            for i in range(2):
                tau = 1000 # learning rate in Cayley transform
                alpha = 0
                beta = 100000000
                c1 = 1e-4
                c2 = 0.9
                #pdb.set_trace()
                A = torch.einsum('in,jn->ij', G, R) - torch.einsum('in,jn->ij', R, G) # GR^T - RG^T
                I = torch.eye(size_R[1]).expand(*size_R).cuda()
                dF_0 = -0.5 * (A ** 2).sum()
                # binary search for appropriate learning rate
                cnt = 0
                while True:
                    Q = torch.mm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A)
                    Y_tau = torch.mm(Q, R) #曲线Y_tao = QR
                    F_X = (G[:,:] * R[:,:]).sum() #inter product
                    F_Y_tau = (G[:,:] * Y_tau[:,:]).sum()
                    dF_tau = -torch.mm(torch.einsum('ni,nj->ij', G, (I + 0.5 * tau * A).inverse()), torch.mm(A,0.5*(R+Y_tau)))[:,:].trace()
                    #pdb.set_trace()
                    if F_Y_tau > F_X + c1*tau*dF_0 + 1e-18:
                        beta = tau
                        tau = (beta+alpha)/2
                    elif dF_tau  + 1e-18 < c2*dF_0:
                        alpha = tau
                        tau = (beta+alpha)/2
                    else:
                        break
                    cnt += 1
                    if cnt > 500:
                        print("--------------------update fail------------------------")
                        print(F_Y_tau, F_X + c1*tau*dF_0)
                        print(dF_tau, c2*dF_0)
                        print("-------------------------------------------------------")
                        break
                print(tau, F_Y_tau)
                Q = torch.mm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A) #计算出新的tau 来更新Q
                R = torch.mm(Q, R)
            
            self.projector_vision = R
            #self.counter = (torch.ones(size_R[-1]) * 0.001).cuda()
    
    def update_rotation_matrix_l(self):
        """
        Update the rotation matrix R using the accumulated gradient G.
        The update uses Cayley transform to make sure R is always orthonormal.
        """
        #size_R = self.projector_vision.size()
        size_R = self.projector_language.size()
        with torch.no_grad():
            G = self.sum_G_l/self.counter.reshape(-1,1)
            R = self.projector_language.clone()
            for i in range(2):
                tau = 1000 # learning rate in Cayley transform
                alpha = 0
                beta = 100000000
                c1 = 1e-4
                c2 = 0.9
                
                A = torch.einsum('in,jn->ij', G, R) - torch.einsum('in,jn->ij', R, G) # GR^T - RG^T
                I = torch.eye(size_R[1]).expand(*size_R).cuda()
                dF_0 = -0.5 * (A ** 2).sum()
                # binary search for appropriate learning rate
                cnt = 0
                while True:
                    Q = torch.mm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A)
                    Y_tau = torch.mm(Q, R)
                    F_X = (G[:,:] * R[:,:]).sum()
                    F_Y_tau = (G[:,:] * Y_tau[:,:]).sum()
                    dF_tau = -torch.mm(torch.einsum('ni,nj->ij', G, (I + 0.5 * tau * A).inverse()), torch.mm(A,0.5*(R+Y_tau)))[:,:].trace()
                    if F_Y_tau > F_X + c1*tau*dF_0 + 1e-18:
                        beta = tau
                        tau = (beta+alpha)/2
                    elif dF_tau  + 1e-18 < c2*dF_0:
                        alpha = tau
                        tau = (beta+alpha)/2
                    else:
                        break
                    cnt += 1
                    if cnt > 500:
                        print("--------------------update fail------------------------")
                        print(F_Y_tau, F_X + c1*tau*dF_0)
                        print(dF_tau, c2*dF_0)
                        print("-------------------------------------------------------")
                        break
                print(tau, F_Y_tau) # (1000,0)
                Q = torch.mm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A)
                R = torch.mm(Q, R)
            
            self.projector_language = R
            #self.counter = (torch.ones(size_R[-1]) * 0.001).cuda()
    def forward(self, X_v_t, X_l_t, U):
        """
        X_v^T [B,49,128]
        X_l^T [B,32,128]
        A U : [B,49,32]
        """

        # updating the gradient matrix, using the concept dataset
        # the gradient is accumulated with momentum to stablize the training
        X_v = torch.einsum("bnd->bdn",X_v_t) #[batchsize,128,49]
        X_l = torch.einsum("bnd->bdn",X_l_t) #[batchsize,128,32]
        U_t = torch.einsum("bvl->blv",U) #[B,49,32]->[B,32,49]
        #pdb.set_trace()
        with torch.no_grad():
            
            if self.mode == 1:
                #tr(2*P_v^T*X_v*U_*X_l^T*P_l)
                #grad_v = -2*X_v*U*X_l^T*P_l
                #pdb.set_trace()
                grad_v_temp1 = torch.einsum("bdv,bvl->bdl",X_v,U) ##[batchsize,128,49] x [B,49,32] -> [B,128,32]
                grad_v_temp2 = torch.einsum("bdl,blq ->bdq",grad_v_temp1,X_l_t) # [B,128,32] x [B,32,128] -> [B,128,128]
                #p_l_t = torch.einsum("ij->ji", self.projector_language)  
                grad_v = -2 * torch.einsum("bdq,qs->bds",grad_v_temp2, self.projector_language).mean((0,)) # [B,128,128] x [B,128,128]-> [B,128,128]
                #pdb.set_trace()
                #grad_l = -2*X_l*U^T*X_v^T*P_v
                grad_l_temp1 = torch.einsum("bdl,blv->bdv",X_l, U_t) #[B,128,32] x [B,32,49] -> [B,128,49]
                grad_l_temp2 = torch.einsum("bdv,bvq->bdq",grad_l_temp1,X_v_t) #[B,128,49] x [B,49,128] -> [B,128,128]
                grad_l = -2 * torch.einsum("bdq,qs->bds",grad_l_temp2, self.projector_vision).mean((0,)) # [B,128,128] x [128,128]->[B,128,128]
                self.sum_G_v = self.momentum * grad_v + (1. - self.momentum) * self.sum_G_v
                self.sum_G_l = self.momentum * grad_l + (1. - self.momentum) * self.sum_G_l
                self.counter += 1
                #print("grad_v equal grad_l: {}".format(grad_v.equal(grad_l)))
                #print("sum_G_v  equal sum_G_l: {}".format(self.sum_G_v.equal(self.sum_G_l)))
                #print(torch.sum(self.sum_G_v),torch.sum(self.sum_G_l))
        Z_v = torch.einsum("dd,bdn->bdn",self.projector_vision.T,X_v) # Z=P^TX
        #Z_v = torch.einsum("bnd->bdn",Z_v)#[B,128,49]

        Z_l = torch.einsum("dd,bdn->bdn",self.projector_language.T, X_l) #[B,128,32]
        #Z_l = torch.einsum("bnd->bdn",Z_l)##[B,128,32]
        #A = torch.einsum("bnl->bln",affinity_matrix)#[B,32,49]
        #Z_l = torch.einsum("bdl,bln->bdn", Z_l, A) # [B,128,49]

        return Z_v, Z_l




class VL_Protopnet(nn.Module):
    '''
    Vision-language protopnet
    '''
    def __init__(self, args, image_model, language_model, img_size, prototype_shape,
                 proto_layer_rf_info, num_classes, init_weights=True,
                 prototype_activation_function='log',
                 add_on_layers_type='bottleneck',  mode = -1):
        super(VL_Protopnet, self).__init__()
        #model
        self.args = args
        self.mode = -1 #
        self.k_cross = 5 


        self.image_model = image_model
        self.language_model = language_model

        embedding_dim = 768
        self.language_projection_head = nn.Sequential(
            nn.Linear(embedding_dim,prototype_shape[1]))
        #prototype
        self.img_size = img_size
        self.prototype_shape = prototype_shape
        self.num_prototypes = prototype_shape[0]
        self.num_classes = num_classes
        self.epsilon = 1e-4
        self.prototype_activation_function = prototype_activation_function #log
        assert(self.num_prototypes % self.num_classes == 0)
        # a onehot indication matrix for each prototype's class identity
        self.prototype_class_identity = torch.zeros(self.num_prototypes,
                                                    self.num_classes)

        self.num_prototypes_per_class = self.num_prototypes // self.num_classes
        for j in range(self.num_prototypes):
            self.prototype_class_identity[j, j // self.num_prototypes_per_class] = 1

        self.proto_layer_rf_info = proto_layer_rf_info

        features = image_model
        features_name = str(self.image_model).upper()
        if features_name.startswith('VGG') or features_name.startswith('RES'):
            first_add_on_layer_in_channels = \
                [i for i in features.modules() if isinstance(i, nn.Conv2d)][-1].out_channels
        elif features_name.startswith('DENSE'):
            first_add_on_layer_in_channels = \
                [i for i in features.modules() if isinstance(i, nn.BatchNorm2d)][-1].num_features
        else:
            raise Exception('other base base_architecture NOT implemented')

        if add_on_layers_type == 'bottleneck':
            add_on_layers = []
            current_in_channels = first_add_on_layer_in_channels
            while (current_in_channels > self.prototype_shape[1]) or (len(add_on_layers) == 0):
                current_out_channels = max(self.prototype_shape[1], (current_in_channels // 2))
                add_on_layers.append(nn.Conv2d(in_channels=current_in_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                add_on_layers.append(nn.ReLU())
                add_on_layers.append(nn.Conv2d(in_channels=current_out_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                if current_out_channels > self.prototype_shape[1]:
                    add_on_layers.append(nn.ReLU())
                else:
                    assert(current_out_channels == self.prototype_shape[1])
                    add_on_layers.append(nn.Sigmoid())
                current_in_channels = current_in_channels // 2
            self.add_on_layers = nn.Sequential(*add_on_layers)
        else:
            self.add_on_layers = nn.Sequential(
                nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
                )
        
        self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape),
                                              requires_grad=True)
        #self.P_v = nn.Parameter(torch.rand((self.prototype_shape[1],self.prototype_shape[1])),requires_grad=True)
        #self.P_l = nn.Parameter(torch.rand((self.prototype_shape[1],self.prototype_shape[1])),requires_grad=True)
        self.projector = Manifold_Projector(num_channels = self.prototype_shape[1],mode = mode, momentum = args.momentum_G)
        # do not make this just a tensor,
        # since it will not be moved automatically to gpu
        self.ones = nn.Parameter(torch.ones(self.prototype_shape),
                                 requires_grad=False)

        self.last_layer = nn.Linear(self.num_prototypes, self.num_classes,
                                    bias=False) # do not use bias
        

        if init_weights:
            self._initialize_weights()

    def conv_features(self, x):
        '''
        the feature input to prototype layer
        '''
        x = self.features(x)
        x = self.add_on_layers(x)
        return x
    def _l2_convolution(self, x):
        '''
        apply self.prototype_vectors as l2-convolution filters on input x
        '''
        x2 = x ** 2
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)
        #x2_patch_sum   (batchsize,2000,14,14)
        p2 = self.prototype_vectors ** 2#(2000,512,1,1)
        p2 = torch.sum(p2, dim=(1, 2, 3)) #
        # p2 is a vector of shape (num_prototypes,)
        # then we reshape it to (num_prototypes, 1, 1)
        p2_reshape = p2.view(-1, 1, 1)

        xp = F.conv2d(input=x, weight=self.prototype_vectors) # [bathsize,512,14,14]
        intermediate_result = - 2 * xp + p2_reshape  # use broadcast
        # x2_patch_sum and intermediate_result are of the same shape
        #||x-x_p||_2^2
        distances = F.relu(x2_patch_sum + intermediate_result)

        return distances #(batchsize,2000,14,14)

    def _l2_convolution_2(self, x):
        '''
        apply self.prototype_vectors as l2-convolution filters on input x
        '''
        x2 = x ** 2
        #x的尺寸为[32,128,32,1]
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)
        #x2_patch_sum(batchsize,2000,32,1)
        p2 = self.prototype_vectors ** 2#(2000,512,1,1)
        p2 = torch.sum(p2, dim=(1, 2, 3)) #(2000)
        # p2 is a vector of shape (num_prototypes,)
        # then we reshape it to (num_prototypes, 1, 1)
        p2_reshape = p2.view(-1, 1, 1)  # (2000,1,1)
        
        xp = F.conv2d(input=x, weight=self.prototype_vectors) # [bathsize,512,14,14]
        #xp尺寸为[B,2000,32,1]
        intermediate_result = - 2 * xp + p2_reshape  # use broadcast
        # x2_patch_sum and intermediate_result are of the same shape
        #||x-x_p||_2^2
        distances = F.relu(x2_patch_sum + intermediate_result)

        return distances #(batchsize,2000,14,14)

    def _cosine_convolution(self, x):

        x = F.normalize(x,p=2,dim=1)
        now_prototype_vectors = F.normalize(self.prototype_vectors,p=2,dim=1)

        distances = F.conv2d(input=x, weight=now_prototype_vectors) # [bathsize,2000,14,14]
        #distances = -distances
        #distances = F.relu(distances)

        return distances
    
    def prototype_distances(self, x):

        conv_features = self.conv_features(x)#[batchsize,512,7,7]
        #L2_distances = self._l2_convolution(conv_features)
        cosine_distances = self._cosine_convolution(conv_features)
        project_distances = self._project2basis(conv_features)

        return project_distances,cosine_distances

    def distance_2_similarity(self, distances):
        if self.prototype_activation_function == 'log':
            return torch.log((distances + 1) / (distances + self.epsilon))
        elif self.prototype_activation_function == 'linear':
            return -distances
        else:
            return self.prototype_activation_function(distances)
    def global_min_pooling(self,distances):
        # global min pooling
        min_distances = -F.max_pool2d(-distances,
                                      kernel_size=(distances.size()[2],
                                                   distances.size()[3]))  # 14 14 ->[batchsize,2000,1,1]
        min_distances = min_distances.view(-1, self.num_prototypes)

        return min_distances

    def global_max_pooling(self,distances):
        # global min pooling
        max_distances = F.max_pool2d(distances,
                                      kernel_size=(distances.size()[2],
                                                   distances.size()[3]))  # 14 14 ->[batchsize,2000,1,1]
        max_distances = max_distances.view(-1, self.num_prototypes)

        return max_distances

    def cal_affinity(self,X_v, X_l):
        """
        X_v :[B,49,128]
        X_l :[B,32,128]
        """
         
        batch_size = X_v.size(0) 
        v_size = X_v.size(1) #49
        l_size = X_l.size(1) #32

        X_v = F.normalize(X_v,2,2)
        X_l= F.normalize(X_l,2,2)

        dist = torch.bmm(X_v, X_l.permute(0,2,1))

        A = torch.zeros(batch_size, v_size, l_size).type_as(X_v).cuda()  # construct affinity matrix A "(h*w)*(L)"
        U = torch.zeros(batch_size, v_size, l_size).type_as(X_v).cuda()  # construct affinity matrix U "(h*w)*(L)"


        index = torch.topk(dist, self.k_cross, 1)[1]  # find indices k nearest neighbors along row dimension
        value = torch.ones(batch_size, self.k_cross, l_size).type_as(X_v).cuda() # "KCross*(h*w)"
        #print(index.size(),value.size())
        A.scatter_(1, index, value)  # set weight matrix
        del index
        del value


        index = torch.topk(dist, self.k_cross, 2)[1]  # find indices k nearest neighbors along col dimension
        value = torch.ones(batch_size,v_size, self.k_cross).type_as(X_v).cuda()
        A.scatter_(2, index, value)  # set weight matrix
        del index
        del value

        n_cs = torch.sum(A,(1,2))
        n_cs = n_cs.unsqueeze(1).unsqueeze(1).expand(batch_size, v_size, l_size)
        #print(A.size(),n_cs)
        U = A / n_cs

        return A,U
    def change_mode(self, mode):

        self.projector.mode = mode
        self.mode = mode
    def update_rotation_matrix(self):
        """
        update the rotation R using accumulated gradient G
        """
        self.projector.update_rotation_matrix()
    def forward_vision(self,image_input):
        #vision stream
        image_output = self.image_model(image_input) #image_output [batchsize,512,7,7]
        image_output = self.add_on_layers(image_output) #image_output2 [batchsize,128,7,7]
        B = image_output.size(0)
        dim = image_output.size(1)
        height = image_output.size(2)
        weight = image_output.size(3)
        image_output = image_output.view(image_output.size(0),image_output.size(1),-1)
        image_output = torch.einsum("bdn->bnd",image_output) #[batchsize,49,128]
        #normalize
        image_output = F.normalize(image_output,p=2,dim=2)

        return image_output, (B,dim,height,weight)

    def forward_language(self,language_input, token_type, input_mask, nwords):
        language_output = self.language_model(language_input, token_type, input_mask)  # batch x n_object x emb_size language_output [32,32,768]
        language_output = self.language_projection_head(language_output) # language_output2[32,32,128]
        #normalize
        B = language_output.size(0)
        dim = language_output.size(1)
        length = language_output.size(2)

        language_output = F.normalize(language_output,p=2,dim=2)

        return language_output, (B,dim,length)
    def calculate_prototype_activation(self,output):

        distances = self._l2_convolution(output)#[batchsize,2000,14,14]
        min_distances = -F.max_pool2d(-distances,
                                      kernel_size=(distances.size()[2],
                                                   distances.size()[3])) # 14 14 ->[batchsize,2000,1,1]
        min_distances = min_distances.view(-1, self.num_prototypes)
        prototype_activations = self.distance_2_similarity(min_distances)

        return prototype_activations, min_distances

    def forward_joint_manifold(self, image_output, language_output, affinity_matrix):
        """
        image output:[B,49,128]
        language output:[B,32,128]
        affinity_matrix:[B,49,32]
        """
        Z_v = torch.einsum("bnd,dd->bnd",image_output,self.P_v)
        Z_v = torch.einsum("bnd->bdn",Z_v)#[B,128,49]

        Z_l = torch.einsum("bnd,dd->bnd",language_output,self.P_v) #[B,32,128]
        Z_l = torch.einsum("bnd->bdn",Z_l)##[B,128,32]
        A = torch.einsum("bnl->bln",affinity_matrix)#[B,32,49]
        Z_l = torch.einsum("bdl,bln->bdn", Z_l, A) # [B,128,49]

        Z_vl = self.alpha * Z_v + (1 - self.alpha) * Z_l #[B,128,49]
        return Z_vl

    def forward(self, image_input, language_input, token_type, input_mask, nwords):  # add two input terms: relation mask and relation class

        image_output, (B, dim, height_v, weight_v) = self.forward_vision(image_input) #[B,49,128]

        language_output, (_, _, length) = self.forward_language(language_input, token_type, input_mask, nwords) #[B,32,128]
        affinity_matrix, U_matrix = self.cal_affinity(image_output, language_output)
        Z_v,Z_l = self.projector(image_output,language_output,U_matrix)

        A = torch.einsum("bnl->bln",affinity_matrix)#[B,32,49]
        Z_l = torch.einsum("bdl,bln->bdn", Z_l, A) # [B,128,49]
        Z_vl = self.args.alpha * Z_v + (1 - self.args.alpha) * Z_l #[B,128,49]
        Z_vl = Z_vl.view(B,dim,height_v,weight_v) #[B,128,7,7]

        vl_distances = self._l2_convolution(Z_vl)#[batchsize,2000,7,7]
        min_vl_distances = -F.max_pool2d(-vl_distances,
                                       kernel_size=(vl_distances.size()[2],
                                                    vl_distances.size()[3]))
        min_vl_distances = min_vl_distances.view(-1, self.num_prototypes)
        prototype_activations = self.distance_2_similarity(min_vl_distances)
        logits = self.last_layer(prototype_activations)

        return logits, min_vl_distances
        
    def push_forward_vision(self, image_input, language_input, token_type, input_mask, nwords):  # add two input terms: relation mask and relation class
        '''this method is needed for the pushing operation'''
        image_output, (B, dim, height_v, weight_v) = self.forward_vision(image_input) #[B,49,128]

        language_output, (_, _, length) = self.forward_language(language_input, token_type, input_mask, nwords) #[B,32,128]
        affinity_matrix, U_matrix = self.cal_affinity(image_output, language_output)
        Z_v,Z_l = self.projector(image_output,language_output,U_matrix)

        A = torch.einsum("bnl->bln",affinity_matrix)#[B,32,49]
        Z_l = torch.einsum("bdl,bln->bdn", Z_l, A) # [B,128,49]
        Z_vl =  Z_v 
        Z_vl = Z_vl.view(B,dim,height_v,weight_v) #[B,128,7,7]

        vl_distances = self._l2_convolution(Z_vl)#[batchsize,2000,7,7]
        
        return Z_vl, vl_distances

    def push_forward(self, image_input, language_input, token_type, input_mask, nwords):  # add two input terms: relation mask and relation class
        '''this method is needed for the pushing operation'''
        image_output, (B, dim, height_v, weight_v) = self.forward_vision(image_input) #[B,49,128]

        language_output, (_, _, length) = self.forward_language(language_input, token_type, input_mask, nwords) #[B,32,128]
        affinity_matrix, U_matrix = self.cal_affinity(image_output, language_output)
        Z_v,Z_l = self.projector(image_output,language_output,U_matrix)

        A = torch.einsum("bnl->bln",affinity_matrix)#[B,32,49]
        Z_l = torch.einsum("bdl,bln->bdn", Z_l, A) # [B,128,49]
        Z_vl = self.args.alpha * Z_v + (1 - self.args.alpha) * Z_l #[B,128,49]
        Z_vl = Z_vl.view(B,dim,height_v,weight_v) #[B,128,7,7]

        vl_distances = self._l2_convolution(Z_vl)#[batchsize,2000,7,7]
        
        return Z_vl, vl_distances
    def set_last_layer_incorrect_connection(self, incorrect_strength):
        '''
        the incorrect strength will be actual strength if -0.5 then input -0.5
        '''
        positive_one_weights_locations = torch.t(self.prototype_class_identity)
        negative_one_weights_locations = 1 - positive_one_weights_locations

        correct_class_connection = 1
        incorrect_class_connection = incorrect_strength
        self.last_layer.weight.data.copy_(
            correct_class_connection * positive_one_weights_locations
            + incorrect_class_connection * negative_one_weights_locations)


    def _initialize_weights(self):
        for m in self.add_on_layers.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self.set_last_layer_incorrect_connection(incorrect_strength=-0.5)


def construct_VLNet(args,image_model, language_model,  img_size=224,
                    prototype_shape=(2000, 256, 1, 1), num_classes=200,
                    prototype_activation_function='log',
                    add_on_layers_type='bottleneck',mode = 1):
    layer_filter_sizes, layer_strides, layer_paddings = image_model.conv_info()
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(img_size=224,#224
                                                         layer_filter_sizes=layer_filter_sizes,#
                                                         layer_strides=layer_strides,
                                                         layer_paddings=layer_paddings,
                                                         prototype_kernel_size=prototype_shape[2])
    
    return VL_Protopnet(args=args, image_model=image_model,
                        language_model=language_model,
                        img_size = 224,
                        prototype_shape = prototype_shape,
                        proto_layer_rf_info = proto_layer_rf_info,
                        num_classes = num_classes,
                        init_weights = True,
                        prototype_activation_function = prototype_activation_function,
                        add_on_layers_type = add_on_layers_type,
                        mode = 1)
