import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import Data
from sklearn.decomposition import PCA
from pytorch3d.ops import sample_farthest_points
from torch_sparse import SparseTensor
from pytorch3d.ops import ball_query

from tmp1.tmp_graph_visualize import visualize_superpoint_graph_with_labels
from model.GNN import SimpleGNN, GATNet, DualEdgeGATv2

from model.utils import JS2Weight
from model.partgeoze import PartGeoZe
from dataset.PartnetEpc import get_is_seen
from model.ImageEncoder import ImageEncoder
from model.SuperPointAggre import SPAttentionAggregation, SPAttentionDownPropagation
from loss.contrast_loss import proxy_contrastive_loss, construct_positive_negative_samples,ContrastiveLoss,triplet_loss


def value_to_color(tensor):
    
    red = tensor.unsqueeze(-1)
    blue = torch.zeros_like(red)
    green = 1 - tensor.unsqueeze(-1)
    
    # 将通道拼接成 BGR 图像 (W, H, 3)
    color = torch.cat([blue, green, red], dim=-1)
    return color

def get_bbox(img_mask):
    nonzero_indices = torch.nonzero(img_mask, as_tuple=True)
    
    if len(nonzero_indices[0]) == 0:
        return None
    
    y_coords = nonzero_indices[0]  
    x_coords = nonzero_indices[1]  
    
    x_min, x_max = x_coords.min(), x_coords.max()
    y_min, y_max = y_coords.min(), y_coords.max()
    
    return torch.tensor([x_min, y_min, x_max, y_max])

def knn_points(
    query: torch.Tensor,
    key: torch.Tensor,
    k: int,
    sorted: bool = False,
    transpose: bool = False,
):
    """Compute k nearest neighbors.

    Args:
        query: [B, N1, D], query points. [B, D, N1] if @transpose is True.
        key:  [B, N2, D], key points. [B, D, N2] if @transpose is True.
        k: the number of nearest neighbors.
        sorted: whether to sort the results
        transpose: whether to transpose the last two dimensions.

    Returns:
        torch.Tensor: [B, N1, K], distances to the k nearest neighbors in the key.
        torch.Tensor: [B, N1, K], indices of the k nearest neighbors in the key.
    """
    if transpose:
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
    # Compute pairwise distances, [B, N1, N2]
    distance = torch.cdist(query, key)
    if k == 1:
        knn_dist, knn_ind = torch.min(distance, dim=2, keepdim=True)
    else:
        knn_dist, knn_ind = torch.topk(distance, k, dim=2, largest=False, sorted=sorted)
    return knn_dist, knn_ind

def compute_class_weights(labels, num_classes, eps=1.02):
    hist = torch.bincount(labels.flatten(), minlength=num_classes)
    freq = hist.float() / hist.sum()
    weights = 1.0 / torch.log(freq + eps)
    return weights

class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(multires, input_dims, i=0):
    if i == -1:
        return torch.nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : input_dims,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim


class CN_layer(torch.nn.Module):
    def __init__(
        self,
        in_channel=256,
        out_channel=1,
        num_block=3,
        he_init=False,
    ):
        super().__init__()

        self.layer_in = torch.nn.Linear(in_channel, in_channel)
        self.relu = torch.nn.ReLU()
    
    def forward(self,feature):
        feature_in = self.layer_in(feature)
        feature_mean = torch.mean(feature_in, dim=0, keepdim=True) # [1, in_channel]
        feature_std = torch.std(feature_in, dim=0, keepdim=True, correction=0) #[1, in_channel]
        cn_feature = (feature_in - feature_mean)/feature_std
        cn_feature = self.relu(cn_feature)

        return cn_feature


class WeightPredNetworkCNe(torch.nn.Module):
    def __init__(
        self,
        in_channel=256,
        out_channel=1,
        num_cn_layer=1,
        he_init=False,
        skip_connection=True,
    ):
        super().__init__()
        self.skip_connection = skip_connection

        self.CN_layers = torch.nn.ModuleList([CN_layer(in_channel) for i in range(num_cn_layer)])
        
        self.layer_out = torch.nn.Linear(in_channel,out_channel)

        if he_init:
            self.apply(self._init_weights_he)
        else:
            self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.0001)
            module.bias.data.zero_()

    def _init_weights_he(self, module):
        if isinstance(module, torch.nn.Linear):
            torch.nn.init.kaiming_normal_(module.weight.data)
            
            if module.bias is not None:
                module.bias.data.zero_()


    def forward(self, feature):
        """
            feature : [N,in_channel]
        """

        feature_in = feature
        
        for layer in self.CN_layers:
            cn_feature = layer(feature_in)
            if self.skip_connection:
                feature_in = feature_in + cn_feature

        out = self.layer_out(feature_in)

        return out 

def check(mask_weight, mask_label, nview, pc_idx, shape_idx):
    tmp_weight = mask_weight.clone()
    mn = torch.min(tmp_weight)
    mx = torch.max(tmp_weight)
    tmp_weight = (tmp_weight-mn)/(mx-mn)
    print(f"mx = {mx}, mn = {mn}, mean = {tmp_weight.mean()}")
    iidd=0
    save_dir = f"./output/mask_weight2/check{shape_idx}"
    os.makedirs(save_dir, exist_ok=True)
    for view in range(nview):
        mask_label_tmp = torch.zeros_like(mask_label[0]).float()
        for i in range(mask_label.max()+1):
            img_ind = mask_label[view]==i
            if img_ind.sum()!=0:
                mask_label_tmp[img_ind] = tmp_weight[iidd].item()
                iidd += 1
        mask_label_tmp[pc_idx[view]==-1] = 0
        rgb = value_to_color(mask_label_tmp).detach().cpu().numpy()
        rgb = (rgb*255).astype(np.uint8)
        cv2.imwrite(f"{save_dir}/{view}.png", rgb)
    print("good")
                    
def pairwise_js_divergence(A, eps=1e-10):
    """
    A: torch.Tensor of shape (N, M), each row is a probability distribution
    Returns: torch.Tensor of shape (N, N), pairwise JS divergence matrix
    """
    A = A + eps  # 避免 log(0)
    A = A / A.sum(dim=1, keepdim=True)  # 归一化为概率分布

    logA = torch.log(A)

    # 扩展维度用于广播：A_i -> (N, 1, M), A_j -> (1, N, M)
    P = A.unsqueeze(1)  # (N, 1, M)
    Q = A.unsqueeze(0)  # (1, N, M)
    M = 0.5 * (P + Q)   # (N, N, M)

    # log(P), log(Q), log(M)
    logP = logA.unsqueeze(1)  # (N, 1, M)
    logQ = logA.unsqueeze(0)  # (1, N, M)
    logM = torch.log(M)

    # KL(P || M)
    kl_pm = torch.sum(P * (logP - logM), dim=2)
    kl_qm = torch.sum(Q * (logQ - logM), dim=2)

    # JS(P || Q)
    js = 0.5 * (kl_pm + kl_qm)  # shape (N, N)

    return js
    
def graph_visualize(mask_num, mask_label, id2mask, edges):
    label_new = torch.zeros_like(mask_label)
    min_lb = (torch.ones(10)*10000).long()
    max_lb = torch.zeros(10).long()
    for i in range(mask_num):
        view, label = id2mask[i]
        view, label = view.item(), label.item()
        ind = mask_label[view]==label
        label_new[view][ind]=i
        min_lb[view]=min(min_lb[view], i)
        max_lb[view]=max(max_lb[view], i)
    os.makedirs("./output/graph", exist_ok=True)
    for i in range(10):
        visualize_superpoint_graph_with_labels(label_new[i], edges, min_lb[i], max_lb[i], save_path=f"./output/graph/graph_visualize_{i}.png")

def get_ball_index(pc):
    pc = pc.unsqueeze(dim=0)
    idx = ball_query(pc, pc, K=20, radius=0.01)
    return idx[1].squeeze(dim=0)

class SegmentorNew(torch.nn.Module):
    def __init__(self, num_labels, args):
        super().__init__()
        
        self.num_labels = num_labels
        self.use_2d_feat = args.use_2d_feat      
        self.use_3d_feat = args.use_3d_feat
        
        self.args = args
        self.use_propagate = args.use_propagate
        self.ave_per_mask = args.ave_per_mask
        self.eliminate_sparseness = args.eliminate_sparseness
        self.use_slow_start = args.use_slow_start
        self.use_new_classifier = args.use_new_classifier
        self.use_js2weight = args.use_js2weight
        self.use_attn_ave = args.use_attn_ave
        self.use_gnn = args.use_gnn
        
        self.use_contrast_loss2 = args.use_contrast_loss2
        self.use_proxy_contrast_loss = args.use_proxy_contrast_loss
        self.use_mask_consist_loss = args.use_mask_consist_loss
        self.use_ref_loss = args.use_ref_loss
        
        self.cam_pos = np.load("view.npy")
         
        self.pc_feat_dim = 0
        if self.use_2d_feat:
            self.img_encoder = ImageEncoder(args.img_encoder, args.use_cache)
            self.pc_feat_dim += self.img_encoder.out_dim
        if self.use_3d_feat:
            # self.pc_encoder = Point_M2AE_ReductionD()
            self.pc_feat_dim += self.pc_encoder.out_dim
        if self.use_contrast_loss2:
            self.contrastive_loss = ContrastiveLoss(temperature=0.5)
        if self.use_attn_ave:
            self.sp_aggre = SPAttentionAggregation(self.pc_feat_dim)
            self.sp_down = SPAttentionDownPropagation(self.pc_feat_dim)
            self.self_attn = torch.nn.MultiheadAttention(embed_dim=self.pc_feat_dim, num_heads=1, batch_first=True)
        if self.use_gnn:
            # self.gnn = SimpleGNN(self.pc_feat_dim, self.pc_feat_dim, self.pc_feat_dim)
            # self.calc_weight = torch.nn.Linear(2 * self.pc_feat_dim, 1)
            
            # self.gnn = GATNet(self.pc_feat_dim, self.pc_feat_dim*2, self.pc_feat_dim, heads=8)
            
            self.gnn_up = GATNet(self.pc_feat_dim, self.pc_feat_dim//8, self.pc_feat_dim, heads=8)
            if not self.args.graph4:
                self.gnn_prop = GATNet(self.pc_feat_dim, self.pc_feat_dim//8, self.pc_feat_dim, heads=8)
            else:
                self.gnn_prop = DualEdgeGATv2(self.pc_feat_dim, self.pc_feat_dim//8, self.pc_feat_dim, heads=8)
            self.gnn_down = GATNet(self.pc_feat_dim, self.pc_feat_dim//8, self.pc_feat_dim, heads=8)

        if self.use_js2weight:
            self.js2weight = JS2Weight(in_dim=num_labels, hidden_dim=128, layers_num=2)
        
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        # self.classifier = torch.nn.Linear(self.pc_feat_dim, num_labels)
        self.classifier = torch.nn.Sequential(
                            torch.nn.Linear(self.pc_feat_dim, self.pc_feat_dim//2),
                            torch.nn.ReLU(),
                            torch.nn.Dropout(0.1),
                            torch.nn.Linear(self.pc_feat_dim//2, num_labels),
                            # torch.nn.ReLU(),
                            # torch.nn.Dropout(0.1),
                            # torch.nn.Linear(self.pc_feat_dim//4, num_labels)
                            # torch.nn.ReLU(),
                            # torch.nn.Dropout(0.1),
                            # torch.nn.Linear(128, num_labels),
                            # torch.nn.ReLU(),
                            # # torch.nn.Dropout(0.1),
                            # torch.nn.Linear(64, num_labels)
                        )
        
    def encode_pc(self, pc, pc_label,img, mask_label, pc_idx, coords, pc_fpfh, pc_sampled_idx, args, epoch):
        num_view = pc_idx.shape[0]
        n_point = pc.shape[0]
        img_feat = self.img_encoder(img)
        pc_feat = self.aggregate(n_point, img_feat, pc_idx, coords, mask_label, epoch)
        return pc_feat
    
    def aggregate(self, npoint, img_feat, pc_idx, coords, mask_label, epoch):
        nview=pc_idx.shape[0]
        device = img_feat.device
        dtype = img_feat.dtype
        
        nbatch = torch.repeat_interleave(torch.arange(0, nview)[:, None], npoint).view(-1, ).long()
        point_loc = coords.reshape(nview, -1, 2)
        xx, yy = point_loc[:, :, 0].long().reshape(-1), point_loc[:, :, 1].long().reshape(-1)
        point_feats = img_feat[nbatch, :, yy, xx].view(nview, npoint, -1)
        is_seen = get_is_seen(pc_idx, npoint).to(device)
        
        point_feats = torch.sum(point_feats * is_seen[:,:,None], dim=0)/(torch.sum(is_seen, dim=0)[:,None]+1e-6)
        
        return point_feats
    
    def aggregate_ball(self, pc, img_feat, pc_idx, coords, mask_label, ball_index,  epoch):
        
        npoint = pc.shape[0]
        nview=pc_idx.shape[0]
        device = img_feat.device
        dtype = img_feat.dtype
        
        is_seen = get_is_seen(pc_idx, npoint).to(device)
        nbatch = torch.repeat_interleave(torch.arange(0, nview)[:, None], npoint).view(-1, ).long()
        point_loc = coords.reshape(nview, -1, 2)
        xx, yy = point_loc[:, :, 0].long().reshape(-1), point_loc[:, :, 1].long().reshape(-1)
        point_feats = img_feat[nbatch, :, yy, xx].view(nview, npoint, -1) * is_seen[:,:,None]
        
        # ball_index = get_ball_index(pc)
        zz = torch.zeros([10,1,96]).float().to(device)
        point_feats = torch.cat([point_feats, zz], dim=1)
        zz = torch.zeros([1, 20]).long().to(device)
        ball_index = torch.cat([ball_index, zz], dim=0)
        count = (ball_index != -1).sum(dim=1)
        for i in range(nview):
            point_feats[i] = point_feats[i][ball_index].sum(dim=1) / count[:,None] 
        
        point_feats = point_feats[:,:-1]
        point_feats = torch.sum(point_feats, dim=0)/(torch.sum(is_seen, dim=0)[:,None]+1e-6)
        
        return point_feats

    def propagate_eli(self, pc_feat, graph):
        device = pc_feat.device
        edges = graph["edges"].squeeze(dim=0).to(pc_feat.dtype).to(device)
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        mask_pc_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_pc_ind"]]
        edge_index = graph["edge_index"].squeeze(dim=0).to(device)
        n_point = pc_feat.shape[0]
        
        ave_feat = []
        mask_num = max(mask2id.values())+1
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0: 
                ave_feat.append(pc_feat[pc_ind].mean(dim=0))
            else:
                ave_feat.append(torch.zeros_like(pc_feat[0]))
        ave_feat = torch.stack(ave_feat, dim=0)
        
        if self.ave_per_mask:
            prop_feat = ave_feat
        elif self.use_gnn:
            # print("useGNNNNNNNNN")
            prop_feat = self.gnn(ave_feat, edge_index) +ave_feat
        elif self.args.ave_inter_mask:
            # print("<<<<<<<<<<<<<<<<<ave inter mask>>>>>>>>>>>>>>>>>")
            prop_feat = (ave_feat.T@edges).T/(edges.sum(dim=0)[:,None]+1e-6)
        else:
            prop_feat = (ave_feat.T@edges).T/(edges.sum(dim=0)[:,None]+1e-6)
        
        # prop_feat = (ave_feat.T@torch.linalg.matrix_power(edges, 5)).T
        
        cnt = torch.ones(n_point).float().to(pc_feat.device)
        for i in range(mask_num):
            view, label = id2mask[i]
            view, label = view.item(), label.item()
            pc_ind = mask_pc_ind[i]
            if pc_ind.numel()>0:
                pc_feat[pc_ind]+=prop_feat[i]
                cnt[pc_ind]+=1
        pc_feat /= cnt[:,None]
        return pc_feat        

    def compute_region_avg_features(self, features, nearest_index, num_regions):
        if not isinstance(features, torch.Tensor):
            features = torch.tensor(features)
        if not isinstance(nearest_index, torch.Tensor):
            nearest_index = torch.tensor(nearest_index, dtype=torch.long)

        avg_features = torch.zeros(num_regions, features.shape[1], device=features.device)
        counts = torch.zeros(num_regions, dtype=torch.float32, device=features.device)

        avg_features = avg_features.scatter_add(0, nearest_index.unsqueeze(1).expand(-1, features.shape[1]), features)
        counts = counts.scatter_add(0, nearest_index, torch.ones_like(nearest_index, dtype=torch.float32))
        counts = counts.clamp(min=1e-6)
        avg_features = avg_features / counts.unsqueeze(1)

        return avg_features
    def compute_avg_features_from_grouped(self, features, grouped_indices):

        K = len(grouped_indices)
        C = features.shape[1]
        avg_features = torch.zeros(K, C, dtype=features.dtype, device=features.device)

        for i, indices in enumerate(grouped_indices):
            if len(indices) > 0:
                avg_features[i] = features[indices].mean(dim=0)
        
        return avg_features


    
    def propagate_All_graph(self, img_feat, pc_feat, graph, mask_label):
        device = pc_feat.device
        edges = graph["edges"].squeeze(dim=0).to(pc_feat.dtype).to(device)
        mask2id = graph["mask2id"]
        id2mask = graph["id2mask"]
        centers = graph["centers"].squeeze(dim=0).long().to(device)
        mask_pc_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_pc_ind"]]
        mask_group_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_group_ind"]]
        nearest_index = graph["nearest_index"].to(device).squeeze(dim=0) 
        # grouped_indices = [x.squeeze(dim=0).to(device) for x in graph["grouped_indices"]] 
        
        edge_index = graph["edge_index"].squeeze(dim=0).to(device)
        edge_index_maskNode = graph["edge_index_maskNode"].squeeze(dim=0).to(device)
        if self.args.conf_label_edge:
            edge_index_maskNode_weak = graph["edge_index_maskNode_weak"].squeeze(dim=0).to(device)
        else:
            edge_index_maskNode_weak = None
        n_point = pc_feat.shape[0]
        mask_num = max(mask2id.values())+1
        group_num = centers.shape[0]
        # with torch.no_grad():
            # group_feat = self.compute_avg_features_from_grouped(pc_feat, grouped_indices)
        group_feat = pc_feat[centers]
        if not self.args.img_feat_on_mask:
            mask_feat = self.compute_avg_features_from_grouped(group_feat, mask_group_ind)
        else:
            # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            img_feat = img_feat.permute(0,2,3,1)
            mask_feat = []
            for i in range(mask_num):
                view, label = id2mask[i]
                view, label = view.item(), label.item()
                img_ind = mask_label[view]==label
                mask_feat.append(img_feat[view][img_ind].mean(dim=0))
            mask_feat = torch.stack(mask_feat, dim=0)
        
        input_feat = torch.cat([group_feat, mask_feat], dim=0)

        # gat_feat = self.gnn(input_feat, edge_index)
        
        
        
        aggre_feat = self.gnn_up(input_feat, edge_index)
        aggre_feat_prop = self.gnn_prop(aggre_feat[group_num:].clone(), edge_index_maskNode, edge_index_maskNode_weak)
        aggre_feat = torch.cat([aggre_feat[:group_num], aggre_feat_prop], dim=0)
        gat_feat = self.gnn_down(aggre_feat, edge_index)
        
        # cnt = torch.ones(group_num).float().to(pc_feat.device)
        # for i in range(mask_num):
        #     group_ind = mask_group_ind[i]
        #     if group_ind.numel()>0:
        #         group_feat[group_ind]+=mask_feat[i]
        #         cnt[group_ind]+=1
        # gat_feat2 = group_feat / cnt[:,None]
        
        pc_feat = pc_feat+gat_feat[:group_num][nearest_index] # +gat_feat2[:group_num][nearest_index]
        # pc_feat = torch.cat([pc_feat, gat_feat[:group_num][nearest_index]], dim=1)
        return pc_feat
        
        
    def forward(self, pc_id, pc, pc_label,img, mask_label, pc_idx, coords, graph, pc_norm, args, epoch, mode="train"):
        device = pc.device
        num_view = pc_idx.shape[0]
        n_point = pc.shape[0]
        img_feat, loss_ref = self.img_encoder(pc_id, img, pc_norm[0])
                
        if not self.args.use_ball_propagate:
            pc_feat = self.aggregate(n_point, img_feat, pc_idx, coords, mask_label, epoch)
        else:
            pc_feat = self.aggregate_ball(pc, img_feat, pc_idx, coords, mask_label, pc_norm, epoch)
            
        if self.use_propagate == 1:
            if not self.args.All_graph:
                pc_feat = self.propagate_eli(pc_feat, graph)
            else :
                pc_feat = self.propagate_All_graph(img_feat, pc_feat, graph, mask_label)
            logits = self.classifier(pc_feat)
        elif self.use_propagate == 0:
            logits = self.classifier(pc_feat)
        else:
            print("args.propagate valuse error")
            exit(0)
        
        n_label = pc_label
        if self.args.use_pseudo_label:
            pseudo_label = pc_norm.squeeze(dim=0).long().to(pc_label.device)
            
        pc_feat = pc_feat / torch.norm(pc_feat, dim=-1, keepdim=True)
        
        # ave_feat = torch.load("./output/tmp.pt")
        # ave_feat[3,:]=0
        # ave_feat /= (torch.norm(ave_feat, dim=-1, keepdim=True)+1e-6)
        # logits = pc_feat @ ave_feat.T
        loss = 0
        if pc_label is not None:
            logits_v = logits
            
            if self.args.use_pseudo_label:
                valid_ind = pseudo_label!=-1
                pc_label = pc_label[valid_ind]
                logits_v = logits[valid_ind]
            
            if not self.args.pretrain and mode!="self":
                weight = compute_class_weights(pc_label, self.num_labels)
                loss_ce = torch.nn.functional.cross_entropy(logits_v, pc_label.long(), weight=weight, reduction="none")

                loss_ce = loss_ce.mean()
                loss = loss_ce
            
            if self.use_proxy_contrast_loss:
                loss_contrast = proxy_contrastive_loss(pc_feat, pc_label, self.classifier)
                loss += loss_contrast
            if self.use_contrast_loss2:
                positive_pairs, negative_pairs = construct_positive_negative_samples(pc_feat.unsqueeze(dim=0), pc_label.unsqueeze(dim=0), sample_num=1024*16)
                if positive_pairs.shape[0] > 0 and negative_pairs.shape[0] > 0:
                    loss_contrast = self.contrastive_loss(positive_pairs, negative_pairs)
                    loss += loss_contrast
                    # print("contrastive loss  :", loss_contrast)
            losses = []
            if self.args.use_triplet_loss:
                PN_tri = graph["PN_tri"].squeeze(dim=0).long().to(device)
                loss_tri = triplet_loss(pc_feat, PN_tri)
                loss+=loss_tri
            if self.use_mask_consist_loss:
                mask2id = graph["mask2id"]
                mask_pc_ind = [x.squeeze(dim=0).to(device) for x in graph["mask_pc_ind"]]
                mask_num = max(mask2id.values())+1
                loss_mask_consist = []
                ave_feat = []
                for i in range(mask_num):
                    pc_ind = mask_pc_ind[i]
                    if pc_ind.numel()>0:
                        tmp_ave_feat = pc_feat[pc_ind].mean(dim=0)
                        ave_feat.append(tmp_ave_feat)
                        loss_tmp = torch.norm(pc_feat[pc_ind]-tmp_ave_feat[None,:], dim=-1, p=2)
                        loss_mask_consist.append(loss_tmp.mean())
                    else:
                        ave_feat.append(torch.zeros_like(pc_feat[0]))
                        
                ave_feat = torch.stack(ave_feat, dim=0)
                loss_ins = sum(loss_mask_consist)/len(loss_mask_consist)
                loss += loss_ins
                
                # flyd_edges = graph["flyd_edges"].squeeze(dim=0).to(pc_feat.dtype).to(device)
                # sim_matrix = ave_feat @ ave_feat.T  # [N, N]
                # N = sim_matrix.size(0)
                # mask = torch.eye(N, device=sim_matrix.device).bool()
                # sim_matrix = sim_matrix.masked_fill(mask, 0.0)
                # unlink = 1-flyd_edges
                # loss_inter = (sim_matrix*unlink).sum() / unlink.sum()
                # # loss_inter = sim_matrix.sum() / (N * (N - 1))
                # # loss += loss_inter
                # losses.append([loss_ins, loss_inter])
                
            if self.use_ref_loss:
                loss += loss_ref
        return logits, loss, n_label, losses
    
    def freeze(self):
        for name, param in self.named_parameters():
            if not name.startswith("weight_pred"):
                param.requires_grad = False

        