import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import Data
from sklearn.decomposition import PCA
from pytorch3d.ops import sample_farthest_points

from tmp1.tmp_graph_visualize import visualize_superpoint_graph_with_labels
from model.GNN import SimpleGNN, GATNet
from model.utils import JS2Weight
from model.partgeoze import PartGeoZe
from dataset.PartnetEpc import get_is_seen
from model.ImageEncoder import ImageEncoder
from model.SuperPointAggre import SPAttentionAggregation, SPAttentionDownPropagation
from loss.contrast_loss import proxy_contrastive_loss, construct_positive_negative_samples,ContrastiveLoss


def value_to_color(tensor):
    
    red = tensor.unsqueeze(-1)
    blue = torch.zeros_like(red)
    green = 1 - tensor.unsqueeze(-1)
    
    # 将通道拼接成 BGR 图像 (W, H, 3)
    color = torch.cat([blue, green, red], dim=-1)
    return color

def get_bbox(img_mask):
    nonzero_indices = torch.nonzero(img_mask, as_tuple=True)
    
    if len(nonzero_indices[0]) == 0:
        return None
    
    y_coords = nonzero_indices[0]  
    x_coords = nonzero_indices[1]  
    
    x_min, x_max = x_coords.min(), x_coords.max()
    y_min, y_max = y_coords.min(), y_coords.max()
    
    return torch.tensor([x_min, y_min, x_max, y_max])

def knn_points(
    query: torch.Tensor,
    key: torch.Tensor,
    k: int,
    sorted: bool = False,
    transpose: bool = False,
):
    """Compute k nearest neighbors.

    Args:
        query: [B, N1, D], query points. [B, D, N1] if @transpose is True.
        key:  [B, N2, D], key points. [B, D, N2] if @transpose is True.
        k: the number of nearest neighbors.
        sorted: whether to sort the results
        transpose: whether to transpose the last two dimensions.

    Returns:
        torch.Tensor: [B, N1, K], distances to the k nearest neighbors in the key.
        torch.Tensor: [B, N1, K], indices of the k nearest neighbors in the key.
    """
    if transpose:
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
    # Compute pairwise distances, [B, N1, N2]
    distance = torch.cdist(query, key)
    if k == 1:
        knn_dist, knn_ind = torch.min(distance, dim=2, keepdim=True)
    else:
        knn_dist, knn_ind = torch.topk(distance, k, dim=2, largest=False, sorted=sorted)
    return knn_dist, knn_ind

def compute_class_weights(labels, num_classes, eps=1.02):
    hist = torch.bincount(labels.flatten(), minlength=num_classes)
    freq = hist.float() / hist.sum()
    weights = 1.0 / torch.log(freq + eps)
    return weights

class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(multires, input_dims, i=0):
    if i == -1:
        return torch.nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : input_dims,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim


class CN_layer(torch.nn.Module):
    def __init__(
        self,
        in_channel=256,
        out_channel=1,
        num_block=3,
        he_init=False,
    ):
        super().__init__()

        self.layer_in = torch.nn.Linear(in_channel, in_channel)
        self.relu = torch.nn.ReLU()
    
    def forward(self,feature):
        feature_in = self.layer_in(feature)
        feature_mean = torch.mean(feature_in, dim=0, keepdim=True) # [1, in_channel]
        feature_std = torch.std(feature_in, dim=0, keepdim=True, correction=0) #[1, in_channel]
        cn_feature = (feature_in - feature_mean)/feature_std
        cn_feature = self.relu(cn_feature)

        return cn_feature


class WeightPredNetworkCNe(torch.nn.Module):
    def __init__(
        self,
        in_channel=256,
        out_channel=1,
        num_cn_layer=1,
        he_init=False,
        skip_connection=True,
    ):
        super().__init__()
        self.skip_connection = skip_connection

        self.CN_layers = torch.nn.ModuleList([CN_layer(in_channel) for i in range(num_cn_layer)])
        
        self.layer_out = torch.nn.Linear(in_channel,out_channel)

        if he_init:
            self.apply(self._init_weights_he)
        else:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.0001)
            module.bias.data.zero_()

    def _init_weights_he(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.kaiming_normal_(module.weight.data)
            
            if module.bias is not None:
                module.bias.data.zero_()


    def forward(self, feature):
        """
            feature : [N,in_channel]
        """

        feature_in = feature
        
        for layer in self.CN_layers:
            cn_feature = layer(feature_in)
            if self.skip_connection:
                feature_in = feature_in + cn_feature

        out = self.layer_out(feature_in)

        return out 

def check(mask_weight, mask_label, nview, pc_idx, shape_idx):
    tmp_weight = mask_weight.clone()
    mn = torch.min(tmp_weight)
    mx = torch.max(tmp_weight)
    tmp_weight = (tmp_weight-mn)/(mx-mn)
    print(f"mx = {mx}, mn = {mn}, mean = {tmp_weight.mean()}")
    iidd=0
    save_dir = f"./output/mask_weight2/check{shape_idx}"
    os.makedirs(save_dir, exist_ok=True)
    for view in range(nview):
        mask_label_tmp = torch.zeros_like(mask_label[0]).float()
        for i in range(mask_label.max()+1):
            img_ind = mask_label[view]==i
            if img_ind.sum()!=0:
                mask_label_tmp[img_ind] = tmp_weight[iidd].item()
                iidd += 1
        mask_label_tmp[pc_idx[view]==-1] = 0
        rgb = value_to_color(mask_label_tmp).detach().cpu().numpy()
        rgb = (rgb*255).astype(np.uint8)
        cv2.imwrite(f"{save_dir}/{view}.png", rgb)
    print("good")
                    
def pairwise_js_divergence(A, eps=1e-10):
    """
    A: torch.Tensor of shape (N, M), each row is a probability distribution
    Returns: torch.Tensor of shape (N, N), pairwise JS divergence matrix
    """
    A = A + eps  # 避免 log(0)
    A = A / A.sum(dim=1, keepdim=True)  # 归一化为概率分布

    logA = torch.log(A)

    # 扩展维度用于广播：A_i -> (N, 1, M), A_j -> (1, N, M)
    P = A.unsqueeze(1)  # (N, 1, M)
    Q = A.unsqueeze(0)  # (1, N, M)
    M = 0.5 * (P + Q)   # (N, N, M)

    # log(P), log(Q), log(M)
    logP = logA.unsqueeze(1)  # (N, 1, M)
    logQ = logA.unsqueeze(0)  # (1, N, M)
    logM = torch.log(M)

    # KL(P || M)
    kl_pm = torch.sum(P * (logP - logM), dim=2)
    kl_qm = torch.sum(Q * (logQ - logM), dim=2)

    # JS(P || Q)
    js = 0.5 * (kl_pm + kl_qm)  # shape (N, N)

    return js
    
def graph_visualize(mask_num, mask_label, id2mask, edges):
    label_new = torch.zeros_like(mask_label)
    min_lb = (torch.ones(10)*10000).long()
    max_lb = torch.zeros(10).long()
    for i in range(mask_num):
        view, label = id2mask[i]
        view, label = view.item(), label.item()
        ind = mask_label[view]==label
        label_new[view][ind]=i
        min_lb[view]=min(min_lb[view], i)
        max_lb[view]=max(max_lb[view], i)
    os.makedirs("./output/graph", exist_ok=True)
    for i in range(10):
        visualize_superpoint_graph_with_labels(label_new[i], edges, min_lb[i], max_lb[i], save_path=f"./output/graph/graph_visualize_{i}.png")
        
    

class SegmentorNew(torch.nn.Module):
    def __init__(self, num_labels, args):
        super().__init__()
        
        self.num_labels = num_labels
        self.use_attn_map = args.use_attn_map   
        self.use_front_attn_map = args.use_front_attn_map
        self.ave_per_mask = args.ave_per_mask
        self.use_2d_feat = args.use_2d_feat      
        self.use_3d_feat = args.use_3d_feat
        self.sam_mask_weight = args.sam_mask_weight
        self.weight_ave_per_mask = args.weight_ave_per_mask
        self.use_propagate = args.use_propagate
        self.use_contrast_loss2 = args.use_contrast_loss2
        self.use_proxy_contrast_loss = args.use_proxy_contrast_loss
        self.use_mask_consist_loss = args.use_mask_consist_loss
        self.eliminate_sparseness = args.eliminate_sparseness
        self.use_attn_ave = args.use_attn_ave
        self.use_slow_start = args.use_slow_start
        self.use_gnn = args.use_gnn
        self.use_new_classifier = args.use_new_classifier
        self.use_js2weight = args.use_js2weight
        self.random_mask = args.random_mask
        
        self.use_fpfh = args.use_fpfh
        self.use_3d_mask = args.use_3d_mask
        self.use_ref_loss = args.use_ref_loss
        
        self.cam_pos = np.load("view.npy")
         
        self.pc_feat_dim = 0
        if self.use_2d_feat:
            self.img_encoder = ImageEncoder(args.img_encoder, args.use_cache)
            self.pc_feat_dim += self.img_encoder.out_dim
        if self.use_3d_feat:
            # self.pc_encoder = Point_M2AE_ReductionD()
            self.pc_feat_dim += self.pc_encoder.out_dim
        if self.use_fpfh:
            self.pc_feat_dim += 33
        if self.use_contrast_loss2:
            self.contrastive_loss = ContrastiveLoss(temperature=0.5)
        if self.use_attn_ave:
            self.sp_aggre = SPAttentionAggregation(self.pc_feat_dim)
            self.sp_down = SPAttentionDownPropagation(self.pc_feat_dim)
            self.self_attn = torch.nn.MultiheadAttention(embed_dim=self.pc_feat_dim, num_heads=1, batch_first=True)
        if self.use_gnn:
            # self.gnn = SimpleGNN(self.pc_feat_dim, self.pc_feat_dim, self.pc_feat_dim)
            # self.calc_weight = torch.nn.Linear(2 * self.pc_feat_dim, 1)
            self.gnn = GATNet(self.pc_feat_dim, self.pc_feat_dim*2, self.pc_feat_dim, heads=8)
        # embed_fn, embed_dim = get_embedder(multires=10, input_dims=6)
        # self.embed_fn = embed_fn
        # self.embed_dim = embed_dim
        # self.weight_pred = WeightPredNetworkCNe(in_channel=self.img_encoder.out_dim*2+self.embed_dim,num_cn_layer=1, he_init=True, skip_connection=True)
        # self.weight_pred = WeightPredNetworkCNe(in_channel=self.embed_dim+96,num_cn_layer=1, he_init=True, skip_connection=True)
        
        # self.propagater = PartGeoZe(sigma_d=10.0, sigma_a=0.01, sigma_e=0.001, angle_k=10, n_pts=256*4)
        
        self.js2weight = JS2Weight(in_dim=num_labels, hidden_dim=128, layers_num=2)
        
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.classifier = torch.nn.Linear(self.pc_feat_dim, num_labels)
        self.classifier_copy = torch.nn.Linear(self.pc_feat_dim, num_labels)
        self.classifier_pseudo = torch.nn.Linear(self.pc_feat_dim, num_labels)
        
        # self.my_weight = torch.nn.Parameter(torch.ones(10))
        # self.my_weight = torch.nn.Parameter(0.5 + 0.7 * torch.rand(10))
# Bottle Cart Refrigerator Laptop

        
    def get_weighted_imgfeat(self, nview, device, dtype, mask_label, img_feat):
        masks = []
        mask_feats = []
        for view in range(nview):
            cam_pos = torch.from_numpy(self.cam_pos[view]).to(device).to(dtype)
            for i in range(mask_label.max()+1):
                img_ind = mask_label[view]==i
                if img_ind.sum()!=0:
                    bbox = get_bbox(img_ind).to(device).to(dtype)/800
                    pos = torch.cat([bbox, cam_pos])
                    pos_embed  = self.embed_fn(pos)
                    img_feat_mask = img_feat[view,:, img_ind]
                    ave_feat = img_feat_mask.mean(dim=1)
                    max_feat = img_feat_mask.max(dim=1)[0]
                    mask_feat = torch.cat([ave_feat, max_feat, pos_embed])
                    mask_feats.append(mask_feat)
                    masks.append((view, i))
                    
        mask_feats = torch.stack(mask_feats, dim=0)
        mask_weight = self.relu(10 + self.weight_pred(mask_feats))
        for id,(view,i) in enumerate(masks):
            img_feat[view,:, mask_label[view]==i] *= mask_weight[id]
    
    def get_weighted_imgfeat2(self, nview, device, dtype, mask_label, img_feat):
        masks = []
        mask_feats = []
        cam_pos = torch.from_numpy(self.cam_pos).to(device, dtype)
        unique_labels = torch.arange(mask_label.max() + 1, device=device)
        
        for view in range(nview):
            img_label_view = mask_label[view]  
            img_feat_view = img_feat[view]

            for i in unique_labels:
                img_ind = img_label_view == i  

                if img_ind.any(): 
                    bbox = get_bbox(img_ind).to(device, dtype) / 800
                    pos = torch.cat([bbox, cam_pos[view]]) 
                    pos_embed = self.embed_fn(pos)

                    img_feat_mask = img_feat_view[:, img_ind]
                    ave_feat = img_feat_mask.mean(dim=1)
                    max_feat = img_feat_mask.max(dim=1)[0]

                    mask_feat = torch.cat([ave_feat, max_feat, pos_embed])
                    mask_feats.append(mask_feat)
                    masks.append((view, i))

        mask_feats = torch.stack(mask_feats, dim=0)
        mask_weight = self.relu(10 + self.weight_pred(mask_feats))

        masks = torch.tensor(masks, dtype=torch.long, device=device) 
        view_ids, label_ids = masks[:, 0], masks[:, 1]
        for idx in range(len(masks)):
            img_feat[view_ids[idx], :, mask_label[view_ids[idx]] == label_ids[idx]] *= mask_weight[idx]
        return img_feat
    
    def get_weighted_mask(self, nview, device, dtype, mask_label, img_feat, pc_idx, pc_norm):
        masks = []
        mask_feats = []
        cam_pos = torch.from_numpy(self.cam_pos).to(device, dtype)
        unique_labels = torch.arange(mask_label.max() + 1, device=device)
        
        for view in range(nview):
            img_label_view = mask_label[view]  
            img_feat_view = img_feat[view]

            for i in unique_labels:
                img_ind = img_label_view == i  

                if img_ind.any(): 
                    bbox = get_bbox(img_ind).to(device, dtype) / 800
                    pc_ind = pc_idx[view, img_ind]
                    pc_norm_ave = pc_norm[pc_ind].mean(dim=0)
                    pos = torch.cat([pc_norm_ave, cam_pos[view]]) 
                    pos_embed = self.embed_fn(pos)

                    img_feat_mask = img_feat_view[:, img_ind]
                    ave_feat = img_feat_mask.mean(dim=1)
                    max_feat = img_feat_mask.max(dim=1)[0]

                    mask_feat = torch.cat([ave_feat, max_feat, pos_embed])
                    mask_feats.append(mask_feat)
                    masks.append((view, i))

        mask_feats = torch.stack(mask_feats, dim=0)
        mask_weight = self.weight_pred(mask_feats)
        mask_weight = self.relu(10 + mask_weight)
        # mask_weight = self.sigmoid(mask_weight)
        masks = torch.tensor(masks, dtype=torch.long, device=device)
        return mask_weight, masks
             
    def aggregate(self, npoint, img_feat, pc_idx, coords, mask_label, epoch):
        nview=pc_idx.shape[0]
        device = img_feat.device
        dtype = img_feat.dtype
        
        if self.sam_mask_weight:
            img_feat = self.get_weighted_imgfeat2(nview, device, dtype, mask_label, img_feat)
        
        nbatch = torch.repeat_interleave(torch.arange(0, nview)[:, None], npoint).view(-1, ).long()
        point_loc = coords.reshape(nview, -1, 2)
        xx, yy = point_loc[:, :, 0].long().reshape(-1), point_loc[:, :, 1].long().reshape(-1)
        point_feats = img_feat[nbatch, :, yy, xx].view(nview, npoint, -1)
        is_seen = get_is_seen(pc_idx, npoint).to(device)
        
        if self.random_mask:
            seen_ind = is_seen==1
            is_seen[seen_ind] = is_seen[seen_ind] * (torch.rand(seen_ind.sum()).to(device)>0.9)
            pass
        
        point_feats = torch.sum(point_feats * is_seen[:,:,None], dim=0)/(torch.sum(is_seen, dim=0)[:,None]+1e-6)
        
        return point_feats
    
    def propagate_ori(self, pc, pc_feat, pc_norm, pc_fpfh, text_feat=None):
        """_summary_
            处理单个点云, 没有batch
        Args:
            pc (_type_): [N,3]
        
        return:
        """
        input_pc_feat = pc_feat / torch.norm(pc_feat, dim=-1, keepdim=True)
        feats, node_feats, idx = self.propagater(pc[None,:], pc_norm[None,:], pc_fpfh[None, :], input_pc_feat[None,:], text_feat)
        # N, D = pc_feat.shape
        # num_groups = 2048
        # pc = pc.unsqueeze(dim=0)
        # centroids, fps_ids = sample_farthest_points(pc, K=num_groups)
        # _, knn_idx = knn_points(centroids, pc, N//num_groups)
        # knn_feat = pc_feat[knn_idx[0]]
        # knn_mean_feat = knn_feat.mean(dim=1)
          
        # _,knn_idx = knn_points(pc, centroids, 5)
        # knn_feat = knn_mean_feat[knn_idx[0]]
        # pc_feat = knn_feat.mean(dim=1)
        
        return feats[0]

    def get_pc_ind(self, view, label, mask_label, pc_idx, coords, nearest_index, grouped_indices):
            n_point = pc_idx.max()+3
            device = mask_label.device
            img_ind = mask_label[view]==label
            pc_ind = pc_idx[view, img_ind]
            pc_ind = pc_ind[pc_ind!=-1]
            pc_seen = torch.zeros(n_point).long().to(device)
            pc_seen[pc_ind]=1
            if pc_ind.numel() ==0:
                return torch.empty((0))
            centers = nearest_index[pc_ind]
            centers = torch.unique(centers)
            ret_ind = []
            for i in range(centers.shape[0]):
                g_ind = grouped_indices[i]
                cnt_seen = pc_seen[g_ind].sum()
                if cnt_seen/g_ind.shape[0]<0.4:
                    ret_ind.append(g_ind[pc_seen[g_ind]==1])
                else:
                    ret_coords = coords[view, g_ind].long()
                    in_mask = img_ind[ret_coords[:,1], ret_coords[:,0]]
                    ret_ind.append(g_ind[in_mask])
            # ret_ind = torch.cat(ret_ind)
            return torch.cat(ret_ind)

    def propagate_ori(self, pc_feat,mask_label,pc_idx, graph):
        edges = graph["edges"].squeeze(dim=0).to(pc_feat.dtype).to(pc_feat.device)
        if self.ave_per_mask:
            edges = torch.eye(edges.shape[0]).to(edges.dtype).to(pc_feat.device)
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        # w_ind = (edges!=0)&(edges!=1)
        # edges[w_ind] = 1/(edges[w_ind]+1e-6)
        # edges[(edges!=0)&(edges!=1)] *= 4 
        n_point = pc_feat.shape[0]
        
        ave_feat = []
        mask_num = max(mask2id.values())+1
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            img_ind = mask_label[view]==label
            pc_ind = pc_idx[view, img_ind]
            pc_ind = pc_ind[pc_ind!=-1]
            if pc_ind.numel()>0: 
                ave_feat.append(pc_feat[pc_ind].mean(dim=0))
            else:
                ave_feat.append(torch.zeros_like(pc_feat[0]))
        ave_feat = torch.stack(ave_feat, dim=0)
        
        prop_feat = (ave_feat.T@edges).T/edges.sum(dim=0)[:,None]
        
        # prop_feat = (ave_feat.T@torch.linalg.matrix_power(edges, 5)).T
        
        cnt = torch.ones(n_point).float().to(pc_feat.device)
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            img_ind = mask_label[view]==label
            pc_ind = pc_idx[view, img_ind]
            pc_ind = pc_ind[pc_ind!=-1]
            if pc_ind.numel()>0:
                pc_feat[pc_ind]+=prop_feat[i]
                cnt[pc_ind]+=1
        pc_feat /= cnt[:,None]
        return pc_feat

    def propagate_eli(self, pc_feat, graph):
        device = pc_feat.device
        edges = graph["edges"].squeeze(dim=0).to(pc_feat.dtype).to(device)
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        mask_pc_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_pc_ind"]]
        edge_index = graph["edge_index"].squeeze(dim=0).to(device)
        n_point = pc_feat.shape[0]
        
        ave_feat = []
        mask_num = max(mask2id.values())+1
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0: 
                ave_feat.append(pc_feat[pc_ind].mean(dim=0))
            else:
                ave_feat.append(torch.zeros_like(pc_feat[0]))
        ave_feat = torch.stack(ave_feat, dim=0)
        
        if self.ave_per_mask:
            prop_feat = ave_feat
        elif self.use_gnn:
            # print("useGNNNNNNNNN")
            prop_feat = self.gnn(ave_feat, edge_index)+ave_feat
        else:
            prop_feat = (ave_feat.T@edges).T/(edges.sum(dim=0)[:,None]+1e-6)
        
        # prop_feat = (ave_feat.T@torch.linalg.matrix_power(edges, 5)).T
        
        cnt = torch.ones(n_point).float().to(pc_feat.device)
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0:
                pc_feat[pc_ind]+=prop_feat[i]
                cnt[pc_ind]+=1
        pc_feat /= cnt[:,None]
        return pc_feat
    
    def propagate(self, pc_feat, graph):
        device = pc_feat.device
        edges = graph["edges"].squeeze(dim=0).to(pc_feat.dtype).to(device)
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        mask_pc_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_pc_ind"]]
        n_point = pc_feat.shape[0]
        
        # if self.ave_per_mask:
        #     edges = torch.eye(edges.shape[0]).to(edges.dtype).to(device)
            
        ave_feat = []
        mask_num = max(mask2id.values()).item()+1
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0: 
                if not self.use_attn_ave:
                    tmp_feat = pc_feat[pc_ind].mean(dim=0)
                else:
                    num_limit = 2048*4
                    if pc_ind.numel()<=num_limit:
                        tmp_feat = pc_feat[pc_ind].unsqueeze(dim=0)
                    else:
                        pc_ind = pc_ind[torch.randperm(pc_ind.numel())[:num_limit]]
                        tmp_feat = pc_feat[pc_ind].unsqueeze(dim=0)
                    attn_feat, _ = self.self_attn(tmp_feat,tmp_feat,tmp_feat)
                    # tmp_ave_feat = pc_feat[pc_ind].mean(dim=0)
                    # tmp_feat = self.sp_aggre(pc_feat[pc_ind], tmp_ave_feat)
                    tmp_feat = attn_feat.squeeze(dim=0).mean(dim=0)
                ave_feat.append(tmp_feat)
            else:
                ave_feat.append(torch.zeros_like(pc_feat[0]))
        ave_feat = torch.stack(ave_feat, dim=0)
        
        if self.ave_per_mask:
            prop_feat = ave_feat
        elif self.use_gnn:
            edge_index = graph["edge_index"].squeeze(dim=0).to(device)
            # new_edges = torch.zeros_like(edges)
            # for i in range(edge_index.shape[1]):
            #     u,v = edge_index[:,i]
            #     new_edges[u,v] = self.calc_weight(torch.cat([ave_feat[u], ave_feat[v]], dim=-1))
            # for i in range(mask_num):
            #     new_edges[i,i]=1
                
            u, v = edge_index
            edge_features = torch.cat([ave_feat[u], ave_feat[v]], dim=-1)  # 直接拼接特征
            weights = self.calc_weight(edge_features)  # 计算权重

            # 构建 new_edges 矩阵
            new_edges = torch.zeros((ave_feat.shape[0], ave_feat.shape[0]), device=device)
            new_edges[u, v] = weights.squeeze(dim=-1)  # 直接填充权重

            # 对角线赋值 1
            new_edges[:mask_num, :mask_num] += torch.eye(mask_num, device=device)
            prop_feat = (ave_feat.T@new_edges).T/new_edges.sum(dim=0)[:,None]
                
            # prop_feat = self.gnn(ave_feat, edge_index)
        else:
            prop_feat = (ave_feat.T@edges).T/edges.sum(dim=0)[:,None]
            
        cnt = torch.ones(n_point).float().to(pc_feat.device)
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0:
                # if not self.use_attn_ave:
                pc_feat[pc_ind]+=prop_feat[i]
                cnt[pc_ind]+=1
                # else:
                #     _, tmp_feat = self.sp_down(pc_feat[pc_ind], prop_feat[i])
                #     pc_feat[pc_ind]+=tmp_feat
                #     cnt[pc_ind]+=1
                    
        pc_feat /= cnt[:,None]
        return pc_feat

    def propagate_pseudolabel(self, pc_feat, graph, pred, mask_label):
        device = pc_feat.device
        edges = graph["edges"].squeeze(dim=0).to(pc_feat.dtype).to(device)
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        mask_pc_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_pc_ind"]]
        n_point = pc_feat.shape[0]
        mask_num = max(mask2id.values()).item()+1
        edge_index = graph["edge_index"].squeeze(dim=0).to(device)
        
        p_pred = torch.zeros((mask_num, self.num_labels)).to(device).to(pc_feat.dtype)
        for i in range(mask_num):
            pc_ind = mask_pc_ind[i]
            pc_cnt = pc_ind.numel()
            if pc_cnt>0:
                pred_mask = pred[pc_ind]
                for j in range(self.num_labels):
                    p_pred[i,j] = (pred_mask==j).sum()
        p_pred = p_pred  / (p_pred.sum(dim=-1, keepdim=True)+1e-6)
        p_pred_ori = p_pred.clone()
        # p_pred = torch.softmax(p_pred, dim=-1)
        

        ave_feat = []
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0: 
                tmp_feat = pc_feat[pc_ind].mean(dim=0)
                ave_feat.append(tmp_feat.squeeze(dim=0))
            else:
                ave_feat.append(torch.zeros_like(pc_feat[0]))
        ave_feat = torch.stack(ave_feat, dim=0)
        
        if not self.use_js2weight:
            js = pairwise_js_divergence(p_pred)
            js_upbd = torch.log(torch.tensor(2))
            js = js/js_upbd
            edges = edges * (1-js) 
            graph_visualize(mask_num, mask_label, id2mask, edges)
            prop_feat = (ave_feat.T@edges).T/edges.sum(dim=0)[:,None]
        else:
            u, v = edge_index
            pred_distribution_a = p_pred[u]
            pred_distribution_b = p_pred[v]
            weights = self.js2weight(pred_distribution_a, pred_distribution_b)
            new_edges = torch.zeros((ave_feat.shape[0], ave_feat.shape[0]), device=device)
            new_edges[u, v] = weights.squeeze(dim=-1)
            prop_feat = (ave_feat.T@new_edges).T/new_edges.sum(dim=0)[:,None]
            
        
        cnt = torch.ones(n_point).float().to(pc_feat.device)
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0:
                pc_feat[pc_ind]+=prop_feat[i]
                cnt[pc_ind]+=1
                    
        pc_feat /= cnt[:,None]
        return pc_feat
        

    def encode_pc(self, pc, pc_label,img, mask_label, pc_idx, coords, pc_fpfh, pc_sampled_idx, args, epoch):
        num_view = pc_idx.shape[0]
        n_point = pc.shape[0]
        img_feat = self.img_encoder(img)
        pc_feat = self.aggregate(n_point, img_feat, pc_idx, coords, mask_label, epoch)
        return pc_feat
        
    def forward(self, pc_id, pc, pc_label,img, mask_label, pc_idx, coords, pc_fpfh, pc_norm, args, epoch, shape_idx=0):
        num_view = pc_idx.shape[0]
        n_point = pc.shape[0]
        img_feat, loss_ref = self.img_encoder(pc_id, img)
        pc_feat = self.aggregate(n_point, img_feat, pc_idx, coords, mask_label, epoch)
        
        if self.ave_per_mask and (not self.use_propagate):
            if not self.use_3d_mask:
                ave_cnt = torch.ones(n_point).float().to(pc_feat.device)
                pc_feat_ave = pc_feat.clone()
                for view in range(num_view):
                    for i in range(mask_label.max()+1):
                        img_ind = mask_label[view]==i
                        pc_ind = pc_idx[view,img_ind]
                        pc_ind = pc_ind[pc_ind!=-1]
                        if pc_ind.numel() > 0:
                            ave_feat = pc_feat[pc_ind].mean(dim=0)
                            pc_feat_ave[pc_ind] += ave_feat
                            ave_cnt[pc_ind] += 1
                pc_feat = pc_feat_ave/ ave_cnt[:, None]
            else:
                ave_cnt = torch.ones(n_point).float().to(pc_feat.device)
                pc_feat_ave = pc_feat.clone()
                for mask in mask_label:
                    ave_feat = pc_feat[mask].mean(dim=0)
                    pc_feat_ave[mask] += ave_feat
                    ave_cnt[mask] += 1
                pc_feat = pc_feat_ave/ ave_cnt[:, None]
                
        if self.weight_ave_per_mask: # and (epoch==-1 or epoch> 4):
            pc_feat_ave = pc_feat.clone()
            mask_weight, masks = self.get_weighted_mask(num_view, pc_feat.device, pc_feat.dtype, mask_label, img_feat, pc_idx, pc_norm)
            # check(mask_weight, mask_label, 10, pc_idx, shape_idx)
            ave_cnt = torch.ones(n_point).float().to(pc_feat.device)
            for idx in range(len(masks)):
                view = masks[idx,0]
                i = masks[idx,1]
                img_ind = mask_label[view]==i
                pc_ind = pc_idx[view,img_ind]
                pc_ind = pc_ind[pc_ind!=-1]
                if pc_ind.numel() > 0:
                    ave_feat = pc_feat[pc_ind].mean(dim=0)
                    # pc_feat_ave[pc_ind] = pc_feat[pc_ind]+ave_feat*mask_weight[idx]
                    pc_feat_ave[pc_ind] += ave_feat*mask_weight[idx]
                    
                    ave_cnt[pc_ind] += mask_weight[idx]
                    
            pc_feat = pc_feat_ave/ ave_cnt[:, None]
        
        if self.use_propagate == 1:
            pc_feat = self.propagate_eli(pc_feat, pc_fpfh)
            logits = self.classifier(pc_feat)
        elif self.use_propagate == 2:
            if epoch > self.use_slow_start:
                
                if epoch == self.use_slow_start+1:
                    self.classifier_copy.load_state_dict(self.classifier.state_dict())
                    
                with torch.no_grad():
                    pc_feat_tmp = self.propagate(pc_feat.clone(), pc_fpfh)
                    logits = self.classifier_copy(pc_feat_tmp)
                    pred = logits.argmax(dim=1)
                    
                pc_feat = self.propagate_pseudolabel(pc_feat,pc_fpfh, pred, mask_label)
                
                if self.use_new_classifier:
                    logits = self.classifier_pseudo(pc_feat)
                else:
                    logits = self.classifier(pc_feat)    
            else:
                pc_feat = self.propagate(pc_feat, pc_fpfh)
                logits = self.classifier(pc_feat)
        elif self.use_propagate == 0:
            logits = self.classifier(pc_feat)
        else:
            print("args.propagate valuse error")
            
            exit(0)
            
        if self.use_fpfh:
            pc_feat = torch.cat([pc_feat, pc_fpfh], dim=-1)
        
        n_label = pc_label
        n_logits = logits
        loss = None
        pc_label_pseudo = pc_norm
        if pc_label_pseudo is not None:
            valid_ind = pc_label_pseudo!=-1
            pc_label_pseudo = pc_label_pseudo[valid_ind]
            logits = logits[valid_ind]
            # print("111111111111111", pc_label_pseudo.shape)
            weight = compute_class_weights(pc_label_pseudo, self.num_labels)
            loss_ce = torch.nn.functional.cross_entropy(logits, pc_label_pseudo.long(), weight=weight, reduction="none")

            loss_ce = loss_ce.mean()
            loss = loss_ce
            
            if self.use_proxy_contrast_loss:
                loss_contrast = proxy_contrastive_loss(pc_feat, pc_label, self.classifier)
                loss += loss_contrast
            if self.use_contrast_loss2:
                positive_pairs, negative_pairs = construct_positive_negative_samples(pc_feat.unsqueeze(dim=0), pc_label.unsqueeze(dim=0), sample_num=1024*16)
                if positive_pairs.shape[0] > 0 and negative_pairs.shape[0] > 0:
                    loss_contrast = self.contrastive_loss(positive_pairs, negative_pairs)
                    loss += loss_contrast
                    # print("contrastive loss  :", loss_contrast)
            if self.use_mask_consist_loss:
                loss_mask_consist = []
                for view in range(num_view):
                    for i in range(mask_label.max()+1):
                        img_ind = mask_label[view]==i
                        pc_ind = pc_idx[view,img_ind]
                        pc_ind = pc_ind[pc_ind!=-1]
                        if pc_ind.numel() > 0:
                            ave_feat = pc_feat[pc_ind].mean(dim=0)
                            loss_tmp = torch.norm(pc_feat[pc_ind]-ave_feat[None,:], dim=-1, p=2)
                            loss_mask_consist.append(loss_tmp.mean())
                loss += sum(loss_mask_consist)/len(loss_mask_consist)
            if self.use_ref_loss:
                loss += loss_ref
        return n_logits, loss, n_label
    
    def freeze(self):
        for name, param in self.named_parameters():
            if not name.startswith("weight_pred"):
                param.requires_grad = False
            