import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
import numpy as np
from resnet_features import resnet18_features, resnet34_features, resnet50_features, resnet101_features, resnet152_features
from densenet_features import densenet121_features, densenet161_features, densenet169_features, densenet201_features
from vgg_features import vgg11_features, vgg11_bn_features, vgg13_features, vgg13_bn_features, vgg16_features, vgg16_bn_features,\
                         vgg19_features, vgg19_bn_features

from receptive_field import compute_proto_layer_rf_info_v2

base_architecture_to_features = {'resnet18': resnet18_features,
                                 'resnet34': resnet34_features,
                                 'resnet50': resnet50_features,
                                 'resnet101': resnet101_features,
                                 'resnet152': resnet152_features,
                                 'densenet121': densenet121_features,
                                 'densenet161': densenet161_features,
                                 'densenet169': densenet169_features,
                                 'densenet201': densenet201_features,
                                 'vgg11': vgg11_features,
                                 'vgg11_bn': vgg11_bn_features,
                                 'vgg13': vgg13_features,
                                 'vgg13_bn': vgg13_bn_features,
                                 'vgg16': vgg16_features,
                                 'vgg16_bn': vgg16_bn_features,
                                 'vgg19': vgg19_features,
                                 'vgg19_bn': vgg19_bn_features}
attribute_path = "/data/fengyi/wangjiaqi/code/mm-proto-pnet/preprocess_dataset/gpt_data/CUB_atrribute_GPT3_4k.pt"
class PPNet(nn.Module):

    def __init__(self, features, img_size, prototype_shape,
                 proto_layer_rf_info, num_classes, init_weights=True,
                 prototype_activation_function='log',
                 add_on_layers_type='bottleneck'):

        super(PPNet, self).__init__()
        self.img_size = img_size
        self.prototype_shape = prototype_shape #（2000，512,1,1） 每个类10个
        self.num_prototypes = prototype_shape[0] #2000个原型
        self.num_classes = num_classes #200个类
        self.epsilon = 1e-4
        
        # prototype_activation_function could be 'log', 'linear',
        # or a generic function that converts distance to similarity score
        self.prototype_activation_function = prototype_activation_function #log

        '''
        Here we are initializing the class identities of the prototypes
        Without domain specific knowledge we allocate the same number of
        prototypes for each class
        在这里，我们正在初始化原型的类身份，如果没有特定领域的知识，我们将分配相同数量的，每个类的原型。
        '''
        assert(self.num_prototypes % self.num_classes == 0)
        # a onehot indication matrix for each prototype's class identity
        self.prototype_class_identity = torch.zeros(self.num_prototypes,
                                                    self.num_classes)

        num_prototypes_per_class = self.num_prototypes // self.num_classes
        self.num_prototypes_per_class = num_prototypes_per_class
        for j in range(self.num_prototypes):
            self.prototype_class_identity[j, j // num_prototypes_per_class] = 1

        self.proto_layer_rf_info = proto_layer_rf_info #感受野信息

        # this has to be named features to allow the precise loading
        self.features = features #

        features_name = str(self.features).upper()
        if features_name.startswith('VGG') or features_name.startswith('RES'):
            first_add_on_layer_in_channels = \
                [i for i in features.modules() if isinstance(i, nn.Conv2d)][-1].out_channels
        elif features_name.startswith('DENSE'):
            first_add_on_layer_in_channels = \
                [i for i in features.modules() if isinstance(i, nn.BatchNorm2d)][-1].num_features
        else:
            raise Exception('other base base_architecture NOT implemented')

        if add_on_layers_type == 'bottleneck':
            add_on_layers = []
            current_in_channels = first_add_on_layer_in_channels
            while (current_in_channels > self.prototype_shape[1]) or (len(add_on_layers) == 0):
                current_out_channels = max(self.prototype_shape[1], (current_in_channels // 2))
                add_on_layers.append(nn.Conv2d(in_channels=current_in_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                add_on_layers.append(nn.ReLU())
                add_on_layers.append(nn.Conv2d(in_channels=current_out_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                if current_out_channels > self.prototype_shape[1]:
                    add_on_layers.append(nn.ReLU())
                else:
                    assert(current_out_channels == self.prototype_shape[1])
                    add_on_layers.append(nn.Sigmoid())
                current_in_channels = current_in_channels // 2
            self.add_on_layers = nn.Sequential(*add_on_layers)
        else:
            self.add_on_layers = nn.Sequential(
                nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
                )
        
        self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape),
                                              requires_grad=True)
        #self.smooth_factor = nn.Parameter(torch.FloatTensor(self.num_prototypes)) #10
        # do not make this just a tensor,
        # since it will not be moved automatically to gpu
        self.ones = nn.Parameter(torch.ones(self.prototype_shape),
                                 requires_grad=False)

        self.last_layer = nn.Linear(self.num_prototypes, self.num_classes,
                                    bias=False) # do not use bias

        if init_weights:
            self._initialize_weights()

    def conv_features(self, x):
        '''
        the feature input to prototype layer
        '''
        x = self.features(x)
        x = self.add_on_layers(x)
        return x

    @staticmethod
    def _weighted_l2_convolution(input, filter, weights):
        '''
        input of shape N * c * h * w
        filter of shape P * c * h1 * w1
        weight of shape P * c * h1 * w1
        '''
        input2 = input ** 2
        input_patch_weighted_norm2 = F.conv2d(input=input2, weight=weights)

        filter2 = filter ** 2
        weighted_filter2 = filter2 * weights
        filter_weighted_norm2 = torch.sum(weighted_filter2, dim=(1, 2, 3))
        filter_weighted_norm2_reshape = filter_weighted_norm2.view(-1, 1, 1)

        weighted_filter = filter * weights
        weighted_inner_product = F.conv2d(input=input, weight=weighted_filter)

        # use broadcast
        intermediate_result = \
            - 2 * weighted_inner_product + filter_weighted_norm2_reshape
        # x2_patch_sum and intermediate_result are of the same shape
        distances = F.relu(input_patch_weighted_norm2 + intermediate_result)

        return distances
    
    def find_high_activation_crop(activation_map, percentile=95):
        threshold = np.percentile(activation_map, percentile)
        mask = np.ones(activation_map.shape)
        mask[activation_map < threshold] = 0
        lower_y, upper_y, lower_x, upper_x = 0, 0, 0, 0
        for i in range(mask.shape[0]):
            if np.amax(mask[i]) > 0.5:
                lower_y = i
                break
        for i in reversed(range(mask.shape[0])):
            if np.amax(mask[i]) > 0.5:
                upper_y = i
                break
        for j in range(mask.shape[1]):
            if np.amax(mask[:,j]) > 0.5:
                lower_x = j
                break
        for j in reversed(range(mask.shape[1])):
            if np.amax(mask[:,j]) > 0.5:
                upper_x = j
                break
        return lower_y, upper_y+1, lower_x, upper_x+1

    def _l2_convolution(self, x):
        '''
        apply self.prototype_vectors as l2-convolution filters on input x
        '''
        x2 = x ** 2
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)# weight 是 (out_channal,in_channels/group,kH,kW)（2000,512,1,1）尺寸的 (2000,)
        #x2_patch_sum   (batchsize,2000,14,14)
        p2 = self.prototype_vectors ** 2#(2000,512,1,1)
        p2 = torch.sum(p2, dim=(1, 2, 3)) #
        # p2 is a vector of shape (num_prototypes,)
        # then we reshape it to (num_prototypes, 1, 1)
        p2_reshape = p2.view(-1, 1, 1)  # (2000,1,1)原型的平方

        xp = F.conv2d(input=x, weight=self.prototype_vectors) # [bathsize,512,14,14]
        intermediate_result = - 2 * xp + p2_reshape  # use broadcast
        # x2_patch_sum and intermediate_result are of the same shape
        #||x-x_p||_2^2
        distances = F.relu(x2_patch_sum + intermediate_result)

        return distances #(batchsize,2000,14,14) 算的是每个样本中每个Patch 距离每个原型的距离
    
    def prototype_embedding(self, conv_features, distances, y_hat):
        #生成预测类别的原型下图片的表示
        #con_features [32,512,7,7]
        #distances [32,2000,7,7]
        batch_size = conv_features.size(0)
        num_parts = self.num_prototypes_per_class
        in_channels = conv_features.size(1)
        input_h = conv_features.size(2)
        input_w = conv_features.size(3)
        conv_features = conv_features #[32,512,7,7]
        start_indices = torch.tensor([idx * 10 for idx in y_hat])
        end_indices = start_indices + 10
        #selected_values = distances[torch.arange(32)[:, None], slice(start_indices[None, :], end_indices[None, :]), :, :]
        selected_values = torch.stack([distances[i, start:end, :, :] for i, (start, end) in enumerate(zip(start_indices, end_indices))])
        # 希望的尺寸是[4,10,7,7]
        
        assign = nn.functional.softmax(selected_values, dim=1)
        
        #assign = distances #[32,2000,7,7]
        
        # compute weighted feats N * K * C
        # expand the smooth term
        #beta = torch.sigmoid(self.smooth_factor)
        #beta_batch = beta.unsqueeze(0).unsqueeze(2).unsqueeze(3)
        #beta_batch = beta_batch.expand(batch_size, -1, input_h, input_w)

        x = conv_features.contiguous().view(batch_size, in_channels, -1) #[b,512,49]
        # permute the inputs -> N * HW * C
        x = x.permute(0, 2, 1) #[b,49,512]  
        assign = assign.contiguous().view(batch_size, num_parts, -1) #[b,2000,49]
        qx = torch.bmm(assign, x) #[b,10,49] [b,49,512] -> [b,2000,512] 相当于是特征加权

        # repeat the graph_weights (K * C) -> (N * K * C)
        grouping_centers = self.prototype_vectors #[2000,512]
        grouping_centers = torch.stack([grouping_centers[start:end, :, :, :] for i, (start, end) in enumerate(zip(start_indices, end_indices))])
        c = torch.squeeze(grouping_centers)  #[b,10,512]
        #pdb.set_trace()
        # sum of assignment (N * K * 1) -> (N * K * K)
        sum_ass = torch.sum(assign, dim=2, keepdim=True) #7*7的每个区域的概率求和 [b,2000,1]
        
        # residual coding N * K * C
        sum_ass = sum_ass.expand(-1, -1, in_channels).clamp(min=1e-5) #[b,10,512]
        #sigma = (beta / 2).sqrt()
        #out = ((qx / sum_ass) - c) / sigma.unsqueeze(0).unsqueeze(2) #[b,10,512]
        out = ((qx / sum_ass) - c)  #[b,10,512]
        #pdb.set_trace()
        # 4. prepare outputs
        # we need to memorize the assignment (N * K * H * W)
        assign = assign.contiguous().view(
            batch_size, num_parts, input_h, input_w) #[32,2000,7,7]

        # output features has the size of N * K * C 
        outputs = nn.functional.normalize(out, dim=2)
        outputs_t = outputs.permute(0, 2, 1) #[B,512,10]
        #pdb.set_trace() #assign [4,10,7,7]
        return outputs_t, assign #[B,512,2000]  [B,2000,7,7]
    def prototype_distances(self, x):
        '''
        x is the raw input
        '''
        conv_features = self.conv_features(x)#[batchsize,512,7,7]
        distances = self._l2_convolution(conv_features)
        return conv_features, distances

    def distance_2_similarity(self, distances):
        if self.prototype_activation_function == 'log':
            return torch.log((distances + 1) / (distances + self.epsilon))
        elif self.prototype_activation_function == 'linear':
            return -distances
        else:
            return self.prototype_activation_function(distances)

    def forward(self, x):
        conv_features, distances = self.prototype_distances(x)
        # distances [batchsize,512,14,14]
        '''
        we cannot refactor the lines below for similarity scores
        because we need to return min_distances
        '''
        # global min pooling
        min_distances = -F.max_pool2d(-distances,
                                      kernel_size=(distances.size()[2],
                                                   distances.size()[3])) # 14 14 ->[batchsize,2000,1,1] 算出距离原型最近的那个距离值
        min_distances = min_distances.view(-1, self.num_prototypes)
        prototype_activations = self.distance_2_similarity(min_distances)
        logits = self.last_layer(prototype_activations)
        max_indices = torch.argmax(logits, dim=1) #[32] #预测的类别
        outputs_t, assign = self.prototype_embedding(conv_features, distances, max_indices)
        #pdb.set_trace()
        return logits, min_distances, outputs_t, assign, max_indices 

    def push_forward(self, x):
        '''this method is needed for the pushing operation'''
        conv_output = self.conv_features(x)
        distances = self._l2_convolution(conv_output)
        return conv_output, distances

    def prune_prototypes(self, prototypes_to_prune):
        '''
        prototypes_to_prune: a list of indices each in
        [0, current number of prototypes - 1] that indicates the prototypes to
        be removed
        '''
        prototypes_to_keep = list(set(range(self.num_prototypes)) - set(prototypes_to_prune))

        self.prototype_vectors = nn.Parameter(self.prototype_vectors.data[prototypes_to_keep, ...],
                                              requires_grad=True)

        self.prototype_shape = list(self.prototype_vectors.size())
        self.num_prototypes = self.prototype_shape[0]

        # changing self.last_layer in place
        # changing in_features and out_features make sure the numbers are consistent
        self.last_layer.in_features = self.num_prototypes
        self.last_layer.out_features = self.num_classes
        self.last_layer.weight.data = self.last_layer.weight.data[:, prototypes_to_keep]

        # self.ones is nn.Parameter
        self.ones = nn.Parameter(self.ones.data[prototypes_to_keep, ...],
                                 requires_grad=False)
        # self.prototype_class_identity is torch tensor
        # so it does not need .data access for value update
        self.prototype_class_identity = self.prototype_class_identity[prototypes_to_keep, :]

    # def __repr__(self):
    #     # PPNet(self, features, img_size, prototype_shape,
    #     # proto_layer_rf_info, num_classes, init_weights=True):
    #     rep = (
    #         'PPNet(\n'
    #         '\tfeatures: {},\n'
    #         '\timg_size: {},\n'
    #         '\tprototype_shape: {},\n'
    #         '\tproto_layer_rf_info: {},\n'
    #         '\tnum_classes: {},\n'
    #         '\tepsilon: {}\n'
    #         ')'
    #     )
    #
    #     return rep.format(self.features,
    #                       self.img_size,
    #                       self.prototype_shape,
    #                       self.proto_layer_rf_info,
    #                       self.num_classes,
    #                       self.epsilon)

    def set_last_layer_incorrect_connection(self, incorrect_strength):
        '''
        the incorrect strength will be actual strength if -0.5 then input -0.5
        '''
        positive_one_weights_locations = torch.t(self.prototype_class_identity)
        negative_one_weights_locations = 1 - positive_one_weights_locations

        correct_class_connection = 1
        incorrect_class_connection = incorrect_strength
        self.last_layer.weight.data.copy_(
            correct_class_connection * positive_one_weights_locations
            + incorrect_class_connection * negative_one_weights_locations)

    def _initialize_weights(self):
        for m in self.add_on_layers.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self.set_last_layer_incorrect_connection(incorrect_strength=-0.5)



def construct_PPNet(base_architecture, pretrained=True, img_size=224,
                    prototype_shape=(2000, 512, 1, 1), num_classes=200,
                    prototype_activation_function='log',
                    add_on_layers_type='bottleneck'):
    features = base_architecture_to_features[base_architecture](pretrained=pretrained)
    layer_filter_sizes, layer_strides, layer_paddings = features.conv_info()
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(img_size=img_size,#224
                                                         layer_filter_sizes=layer_filter_sizes,#
                                                         layer_strides=layer_strides,
                                                         layer_paddings=layer_paddings,
                                                         prototype_kernel_size=prototype_shape[2])
    return PPNet(features=features,
                 img_size=img_size,
                 prototype_shape=prototype_shape,
                 proto_layer_rf_info=proto_layer_rf_info,
                 num_classes=num_classes,
                 init_weights=True,
                 prototype_activation_function=prototype_activation_function,
                 add_on_layers_type=add_on_layers_type)
import pdb
class Clip_ProtoPNet(nn.Module):
    # CLIP和ProtoPNet联合一起的模型
    def __init__(self, protopnet, num_classes, mode = -1, momentum=0.1):
        super(Clip_ProtoPNet, self).__init__()
        self.protopnet = protopnet #原型网络 # clip编码后的原型[]
        #视觉的特征向VL的特征中间去映射
        proj_size = 512
        #self.projector_v = nn.Linear(proj_size, proj_size) #512
        #self.clip_model, self.clip_preprocess = clip.load("ViT-B/32")
        #self.concept_bottleneck = torch.load("/data/fengyi/wangjiaqi/code/mm-proto-pnet/preprocess_dataset/selected_concepts/CUB_atrribute_4k.pt") #[1w,512]
        self.concept_bottleneck = torch.load(attribute_path)
        self.concept_bottleneck = self.concept_bottleneck.float()
        self.logit_scale = nn.Parameter(torch.Tensor([4.6052]))
        #背景分类器
        #self.background = nn.Parameter(torch.rand(1, self.concept_bottleneck.size(1))).cuda()
        #self.concept_bottleneck = torch.cat([self.background,self.concept_bottleneck]) #[10001,512]
        #pdb.set_trace()
        #基于概念瓶颈得分的分类器
        self.classifier_c = nn.Linear(self.concept_bottleneck.size(0), num_classes, bias=False) #[10000,200]
        #self.classifier_c = nn.Linear(self.protopnet.num_prototypes, num_classes,bias=False) #[2000,200]
        self.concept_class_identity = torch.zeros(self.concept_bottleneck.size(0),
                                                     num_classes)
        num_concept_per_class = 20
        #self.num_prototypes_per_class = num_concept_per_class
        # num_prototypes_per_class = self.protopnet.num_prototypes //self.protopnet.num_classes
        # self.num_prototypes_per_class = num_prototypes_per_class
        # for j in range(self.protopnet.num_prototypes):
        #     self.prototype_class_identity[j, j // num_prototypes_per_class] = 1
        for j in range(self.concept_bottleneck.size(0)):
            self.concept_class_identity [j, j // num_concept_per_class] = 1
        #流行对齐部分
        self.momentum = momentum
        #self.alpha = alpha #融合系数
        self.num_channels = proj_size
        #self.mode = mode
        #optim.step每次更新值更新nn.param中的值  所以register_buffer不会被更新 但是会被保存
        # running rotation matrix Projector 旋转矩阵 正交的
        self.register_buffer('projector_vision', torch.eye(self.num_channels)) #或者初始化为一个正交矩阵nn.init.orthogonal_()
        # sum Gradient, need to take average later
        self.register_buffer('sum_G_v', torch.zeros(self.num_channels, self.num_channels)) #G矩阵
        # counter, number of gradient for each concept
        self.register_buffer("counter", torch.ones(self.num_channels)*0.001)
        
        self.mode = mode
        self.set_last_layer_incorrect_connection(incorrect_strength=-0.5)
    #def prototype_image(self, images, ):
    
    def set_last_layer_incorrect_connection(self, incorrect_strength):
        '''
        the incorrect strength will be actual strength if -0.5 then input -0.5
        '''
        positive_one_weights_locations = torch.t(self.concept_class_identity)
        negative_one_weights_locations = 1 - positive_one_weights_locations

        correct_class_connection = 1
        incorrect_class_connection = incorrect_strength
        self.classifier_c.weight.data.copy_(
            correct_class_connection * positive_one_weights_locations
            + incorrect_class_connection * negative_one_weights_locations)
        
    # def set_last_layer_incorrect_connection(self, incorrect_strength):
    #     '''
    #     the incorrect strength will be actual strength if -0.5 then input -0.5
    #     '''
    #     positive_one_weights_locations = torch.t(self.prototype_class_identity)
    #     negative_one_weights_locations = 1 - positive_one_weights_locations

    #     correct_class_connection = 1
    #     incorrect_class_connection = incorrect_strength
    #     self.classifier_c.weight.data.copy_(
    #         correct_class_connection * positive_one_weights_locations
    #         + incorrect_class_connection * negative_one_weights_locations)

    def sim(self, A, B):
        """
        计算两向量的余弦相似度(cosine simarility)
        # 假设 A 和 B 已经定义，形状分别为 [10000, 512] 和 [B, 2000, 512]
        """
        A_norm = A / (A.norm(dim=1, keepdim=True) + 1e-8)
        B_norm = B / (B.norm(dim=2, keepdim=True) + 1e-8)
        A_expanded = A_norm.unsqueeze(0).expand(B.size(0), -1, -1)
        cos_sim = torch.einsum('bij,bkj->bik', B_norm, A_expanded)
        #值越小
        return cos_sim
    def change_mode(self, mode):
        """
        Change the training mode
        mode = -1, no update for gradient matrix G, 也不计算对齐矩阵A  为了加速运算
             = 1 , 需要流型对齐  计算对齐矩阵A 以及梯度矩阵G
        """
        self.projector.mode = mode
        self.mode = mode
    def cal_affinity(self, x, y, X_p, X_ep):
        # encoded_prototype[B,10,512]
        # outputs_t [B,512,10]
        # 构建的对齐矩阵 [10,10] 原型网络概念总数 和clip抽取的特征数量
        batch_size = x.size(0)
        p_size =   X_p.size(-1) #2000
        ep_size =  X_ep.size(-1) #10
        eye_matrix = torch.eye(p_size).cuda()
        A = eye_matrix.unsqueeze(0).repeat(batch_size, 1, 1)
        #A = torch.zeros(batch_size, p_size , ep_size).type_as(x).cuda()  # construct affinity matrix A "(h*w)*(L)"
        U = torch.zeros(batch_size, p_size , ep_size).type_as(x).cuda()  # [B, 2000,10]
        #class_prototype_indices = np.nonzero(self.protopnet.prototype_class_identity.detach().cpu().numpy()[:, y])[0]
        #start_indices = y * self.protopnet.num_prototypes_per_class

        #indices = (y * self.protopnet.num_prototypes_per_class).unsqueeze(-1) + torch.arange(self.protopnet.num_prototypes_per_class).cuda()
        #A.scatter_(1, indices.unsqueeze(1).expand(-1, self.protopnet.num_prototypes_per_class, -1), 1)
        n_cs = torch.sum(A,(1,2))
        n_cs = n_cs.unsqueeze(1).unsqueeze(1).expand(batch_size, p_size, ep_size)
        U = A / n_cs
        U = U.cuda()
        return A,U 
    # def cal_affinity(self, x, y, X_p, X_ep):
    #     # X_p [B,512,2000]
    #     # X_ep [B,512, 10]
    #     # 构建的对齐矩阵 [2000,10] 原型网络概念总数 和clip抽取的特征数量
    #     batch_size = x.size(0)
    #     p_size =   X_p.size(-1) #2000
    #     ep_size =  X_ep.size(-1) #10
    #     A = torch.zeros(batch_size, p_size , ep_size).type_as(x).cuda()  # construct affinity matrix A "(h*w)*(L)"
    #     U = torch.zeros(batch_size, p_size , ep_size).type_as(x).cuda()  # [B, 2000,10]
    #     #class_prototype_indices = np.nonzero(self.protopnet.prototype_class_identity.detach().cpu().numpy()[:, y])[0]
    #     #start_indices = y * self.protopnet.num_prototypes_per_class
    #     X_ep = X_ep / X_ep.norm(dim=1, keepdim=True)
    #     cosine_sim = torch.bmm(X_ep.transpose(1, 2), X_ep)  # [batch_size, 10, 10]
    #     threshold = 0.9
    #     mask = (cosine_sim > threshold).float() #相似的区域     

    #     num_concepts = cosine_sim.shape[1]  # 通常是10

    #     #pdb.set_trace()
    #     # 根据y创建索引矩阵，用于scatter_
    #     #indices = (y.unsqueeze(1) * num_concepts + torch.arange(num_concepts).unsqueeze(0).cuda()).unsqueeze(-1).expand(-1, -1, ep_size).to(A.device)
    #     #pdb.set_trace()
    #     # 创建一个扩展的B矩阵，以便能够使用scatter_填充到A中
    #     #B_expanded = mask.unsqueeze(2).expand(-1, -1, ep_size, -1).reshape(batch_size, num_concepts * ep_size, ep_size)

    #     # 使用scatter_将B的值填充到A的指定位置
    #     #A.scatter_(1, indices, B_expanded)
        
    #     # 生成每个样本对应的索引偏移量
    #     offsets = (y * num_concepts).unsqueeze(-1).expand(-1, num_concepts)  # [batch_size, 10]

    #     # 生成基本索引
    #     base_indices = torch.arange(0, num_concepts).unsqueeze(0).expand(batch_size, -1).to(y.device)  # [batch_size, 10]

    #     # 将偏移量添加到基本索引上以获得最终索引
    #     final_indices = (offsets + base_indices).unsqueeze(-1).expand(-1, -1, ep_size)  # [batch_size, 10, 10]

    #     # 将mask扩展以匹配A的形状
    #     #mask_expanded = mask.unsqueeze(-1)  # [batch_size, 10, 10, 1]

    #     # 使用scatter_填充A
    #     A.scatter_(1, final_indices, mask)
    #     #pdb.set_trace()
        
    #     n_cs = torch.sum(A,(1,2))
    #     n_cs = n_cs.unsqueeze(1).unsqueeze(1).expand(batch_size, p_size, ep_size)
    #     U = A / n_cs
    #     U = U.cuda()
    #     return A,U 
    
    def update_rotation_matrix(self):
        """
        Update the rotation matrix R using the accumulated gradient G.
        The update uses Cayley transform to make sure R is always orthonormal.
        """
        size_R = self.projector_vision.size()
        #size_projector_language = self.projector_language.size()
        with torch.no_grad():
            G = self.sum_G_v/self.counter.reshape(-1,1) #损失函数的梯度 除以行上的统计值
            R = self.projector_vision.clone() #R是上一步的Q(t)的值  Q是下一步Q(t+1)的值
            for i in range(2):
                tau = 1000 # learning rate in Cayley transform
                alpha = 0
                beta = 100000000
                c1 = 1e-4 #两个参数来寻找满足的学习率 满足Armijo-wolfo conditions
                c2 = 0.9
                #pdb.set_trace()
                A = torch.einsum('in,jn->ij', G, R) - torch.einsum('in,jn->ij', R, G) # GR^T - RG^T
                I = torch.eye(size_R[1]).expand(*size_R).cuda()
                dF_0 = -0.5 * (A ** 2).sum()
                # binary search for appropriate learning rate
                cnt = 0
                while True:
                    Q = torch.mm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A)
                    Y_tau = torch.mm(Q, R) #曲线Y_tao = QR
                    F_X = (G[:,:] * R[:,:]).sum() #内积
                    F_Y_tau = (G[:,:] * Y_tau[:,:]).sum()
                    dF_tau = -torch.mm(torch.einsum('ni,nj->ij', G, (I + 0.5 * tau * A).inverse()), torch.mm(A,0.5*(R+Y_tau)))[:,:].trace()
                    #pdb.set_trace()
                    if F_Y_tau > F_X + c1*tau*dF_0 + 1e-18:
                        beta = tau
                        tau = (beta+alpha)/2
                    elif dF_tau  + 1e-18 < c2*dF_0:
                        alpha = tau
                        tau = (beta+alpha)/2
                    else:
                        break
                    cnt += 1
                    if cnt > 500:
                        print("--------------------update fail------------------------")
                        print(F_Y_tau, F_X + c1*tau*dF_0)
                        print(dF_tau, c2*dF_0)
                        print("-------------------------------------------------------")
                        break
                print(tau, F_Y_tau)
                Q = torch.mm((I + 0.5 * tau * A).inverse(), I - 0.5 * tau * A) #计算出新的tau 来更新Q
                R = torch.mm(Q, R)
            
            self.projector_vision = R
            self.counter = (torch.ones(size_R[-1]) * 0.001).cuda()
            
    def forward(self, x, y=None, encoded_prototype=None):
        #pdb.set_trace()
        #self.protopnet.eval()
        logits_ppnet, min_distances, outputs_t, assign, max_indices  = self.protopnet(x)
        X_p = outputs_t #[B,512,10]
        X_p_t = torch.einsum("bnd->bdn", X_p) #[B,10,512]
        if self.training and encoded_prototype is not None:
            #流行对齐 匹配原型表示和视觉编码之间的关系
            #U_matrix [B, 2000,10]
            encoded_prototype =  encoded_prototype.float()
            #outputs_t [B,512,2000]
            X_ep_t = encoded_prototype #[B,10,512]
            X_ep = encoded_prototype.permute(0,2,1) #[B,512,10] L
            affinity_matrix, U_matrix  = self.cal_affinity(x, y, outputs_t, X_ep)
            with torch.no_grad():#前向也累积了梯度
                if self.mode == 1:
                    #原始问题:tr(2*P_v^T*X_v*U_*X_l^T)
                    #grad_v = -2*X_v*U*X_l^T
                    #print("累积梯度中！")
                    #pdb.set_trace()
                    grad_v_temp1 = torch.einsum("bdv,bvl->bdl", X_p, U_matrix) ##[batchsize,512,2000] x [B,2000,100] -> [B,512,10]
                    grad_v_temp2 = torch.einsum("bdl,blq ->bdq", grad_v_temp1, X_ep_t) # [B,512,10] x [B,10,512] -> [B,512,512]
                    #p_l_t = torch.einsum("ij->ji", self.projector_language)  
                    grad_v = -2 * (grad_v_temp2).mean((0,))  # [B,512,512]
                    #pdb.set_trace()
                    self.sum_G_v = self.momentum * grad_v + (1. - self.momentum) * self.sum_G_v #累计梯度 动量
                    self.counter += 1
        #正常模式就是原型表示和Clip text encoder的表征乘了以后 加分类器 送入到CE中
        #pdb.set_trace()
        Z_v = torch.einsum("dd,bdn->bdn",self.projector_vision.T,X_p) #[32,512,10]
        Z_v = Z_v.permute(0,2,1)
        cos_sim = self.sim(self.concept_bottleneck, Z_v)  #[32, 10, 4K]
        #logit_scale = self.logit_scale.exp()
        #cos_sim = logit_scale * cos_sim
        cos_sim = cos_sim.sum(dim=1)
        #cos_sim = cos_sim.sum(dim=2)
        logits_c = self.classifier_c(cos_sim)  #[32,200]
        return [logits_ppnet,logits_c], min_distances
        
if __name__ == "__main__":

    ppnet = construct_PPNet(base_architecture="vgg19")
    ppnet.cuda()
    print(ppnet)
    # class Config:
    #     def _init__(self):
    #         self.proj_size = 512
    # config = Config()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    input = torch.rand(32,3,224,224).cuda()
    y =  torch.randint(low=0, high=200, size=(32,)).cuda()
    cppnet = Clip_ProtoPNet(ppnet,200)
    cppnet.cuda()
    cppnet.mode = 1
    #encoded_prototype = torch.tensor(torch.load("/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_vgg16.pt")).cuda()
    encoded_prototype = torch.rand(32,10,512).cuda()
    output = cppnet(input, y,encoded_prototype)
    print(output.size())