import torch
import torch.nn as nn
from torch import distributed as dist
import torch_scatter
from openpoints.cpp.pointops.functions import pointops
from ..build import MODELS
from .ptnet import TransitionDown, TransitionUp, PointTransformerBlock
from .ptnetv2 import GVAPatchEmbed, Encoder, Decoder, PointBatchNorm 
from ..layers import concat_all_gather_diff
from hefm import HEFM
from typing import Dict, Optional
from PointTransformerV3.model import Point
import os

@MODELS.register_module()
class PTSeg_Balance_Prior(nn.Module):
    def __init__(self,
                 block=PointTransformerBlock,
                 blocks=[2, 3, 4, 6, 3],    # depth, default: blocks=[2, 3, 4, 6, 3]
                 width=32,
                 nsample=[8, 16, 16, 16, 16],
                 in_channels=6,
                 num_classes=13,
                 mid_res=False,
                 beta=0.999,
                 adp_pointnetv2=False,
                 **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.beta = beta
        if in_channels<6:
            in_channels = in_channels + 3 # modify in_channels from rgb(z) to xyz+rgb(z)
        self.in_planes, planes = in_channels, [width * 2**i for i in range(len(blocks))]
        share_planes = 8
        stride, nsample = [1, 4, 4, 4, 4], nsample

        if isinstance(block, str):
            block = eval(block)
        self.mid_res = mid_res

        # prior model
        self.enc1_prior = self._make_enc(block, planes[0], blocks[0], share_planes, stride=stride[0],
                                   nsample=nsample[0])  # N/1
        self.enc2_prior = self._make_enc(block, planes[1], blocks[1], share_planes, stride=stride[1],
                                   nsample=nsample[1])  # N/4
        self.enc3_prior = self._make_enc(block, planes[2], blocks[2], share_planes, stride=stride[2],
                                   nsample=nsample[2])  # N/16
        self.enc4_prior = self._make_enc(block, planes[3], blocks[3], share_planes, stride=stride[3],
                                   nsample=nsample[3])  # N/64
        self.enc5_prior = self._make_enc(block, planes[4], blocks[4], share_planes, stride=stride[4],
                                   nsample=nsample[4])  # N/256
        if not adp_pointnetv2:
            self.projection = nn.Sequential(nn.Linear(planes[4], planes[3]), nn.BatchNorm1d(planes[3]), nn.ReLU(inplace=True),
                                            nn.Linear(planes[3], planes[2]), nn.BatchNorm1d(planes[2]), nn.ReLU(inplace=True),
                                            nn.Linear(planes[2], planes[0]), nn.BatchNorm1d(planes[0]), nn.ReLU(inplace=True))
            self.register_buffer('prior_ema', torch.rand(num_classes, planes[0]))
        else:
            self.projection = nn.Sequential(nn.Linear(planes[4], planes[3]), nn.ReLU(inplace=True),
                                            nn.Linear(planes[3], planes[2]), nn.ReLU(inplace=True))
            self.register_buffer('prior_ema', torch.rand(num_classes, planes[2]))
        self.prior_ema = nn.functional.normalize(self.prior_ema, dim=1)

    def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16):
        layers = []
        layers.append(TransitionDown(self.in_planes, planes * block.expansion, stride, nsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample, mid_res=self.mid_res))
        return nn.Sequential(*layers)

    @torch.no_grad()
    def _ema(self, prior):
        """prior: n*(dim+1), feature dim + label"""
        if dist.is_initialized():
            # gather prior before updating self.prior_ema
            prior = concat_all_gather_diff(prior)
        cur_status = self.prior_ema.clone()
        for label in range(self.num_classes):
            mask_c = prior[:, -1] == label
            if mask_c.nonzero().numel() > 0:
                cur_status[label, :] = prior[mask_c, :-1].mean(0)
        self.prior_ema = self.beta * self.prior_ema + (1 - self.beta) * cur_status

    def forward(self, p0, x0=None, o0=None, is_train=False, mask=None, ignore_index=None):
        # p, x, o: points, features, batches 
        # The dataloader input here is different from PointTransformer source code; 
        # it's need to modify input data{'pos':, 'x':, 'label':, } to the following form.
        if x0 is None:  # this means p0 is a dict.
            p0, x0, o0, labels = p0['pos'], p0.get('x', None), p0.get('offset', None), p0.get('y', None)
            if x0 is None:
                x0 = p0
            if o0 == None:
                o0, count = [], 0
                for _ in range(p0.size()[0]):
                    count += p0.size()[1]
                    o0.append(count)
                o0 = torch.IntTensor(o0).cuda(device=p0.device)
            # ---- robust flatten ------------------------------------------------
            # 1) If features are (B, C, N), transpose to (B, N, C) then flatten to (B*N, C)
            # 2) Only flatten coordinates when `p0` still has a batch dimension (B, N, 3).
            # 3) Labels follow the same rule: only flatten during training when they are 2-D.
            if x0.dim() == 3:  # x0: (B, C, N)
                B, C_in, N_pts = x0.size()
                x0 = x0.transpose(1, 2).contiguous().view(-1, C_in)  # (B*N, C)
                if p0.dim() == 3:
                    p0 = p0.view(-1, p0.size(-1))                    # (B*N, 3)
                if is_train and (labels is not None) and labels.dim() > 1:
                    labels = labels.reshape(-1)
        if x0.size(1)<6:
            x0 = torch.cat((p0,x0),1) # x0(n, c=in_channels+3)

        if mask is not None:
            p0, x0, labels = p0[mask], x0[mask], labels[mask]
            new_o0 = []
            for i in range(o0.size(0)):
                new_o0.append(mask[:(o0[i])].sum())
            o0 = torch.tensor(new_o0, dtype=o0.dtype, device=o0.device)
            if ignore_index == 0:
                labels = labels - 1

        # NOTE: prior model data prepare
        # # method 1: aggregate the same class feature in a batch 
        # feat, coord, offset = [], [], []
        # point_set = [] # record the number of point set for each class in a batch
        # batch = o0.size(0)
        # for c in range(self.num_classes):
        #     point_set_num = 0
        #     index = (label==c).nonzero().squeeze(1)
        #     if index.numel() != 0:
        #         for b in range(batch):
        #             if b == 0:
        #                 point_mask = index<o0[b]
        #             else:
        #                 point_mask = (o0[b-1]<=index) & (index<o0[b])
        #             point_num = point_mask.sum()
        #             # NOTE: When the number of points in the point cloud set is less than the sampling rate, 
        #             # fill it up to the sampling rate (256)
        #             if point_num >= 256:
        #                 feat.append(x0[index[point_mask], :])
        #                 coord.append(p0[index[point_mask], :])
        #                 offset.append(point_num)
        #                 point_set_num += 1
        #             elif (point_num > 0) and (point_num < 256):
        #                 select_index = torch.randint(0, point_num, (256,))
        #                 select_index = index[point_mask][select_index]
        #                 feat.append(x0[select_index, :])
        #                 coord.append(p0[select_index, :])
        #                 offset.append(torch.tensor(256))
        #                 point_set_num += 1
        #     point_set.append(point_set_num)            
        # feat = torch.cat(feat, dim=0)
        # coord = torch.cat(coord, dim=0)
        # offset = torch.cumsum(torch.IntTensor(offset), dim=0, dtype=torch.int32).cuda(device=p0.device)
        # point_set = torch.tensor(point_set)

        # method 2: aggregate the same class feature in a example of a batch
        feat, coord, offset = [], [], []
        class_index_batch = [] # record the class index in a example of a batch
        for b in range(o0.size(0)):
            class_index = [] # record the class index in a example
            for c in range(self.num_classes):
                index = (labels==c).nonzero().squeeze(1)
                if b == 0:
                    point_mask = index<o0[b]
                else:
                    point_mask = (o0[b-1]<=index) & (index<o0[b])
                point_num = point_mask.sum().item()
                if point_num != 0:
                    # NOTE: When the number of points in the point cloud set is less than the sampling rate, 
                    # fill it up to the sampling rate (256)
                    if point_num >= 256:
                        feat.append(x0[index[point_mask], :])
                        coord.append(p0[index[point_mask], :])
                        offset.append(point_num)
                    elif (point_num > 0) and (point_num < 256):
                        select_index = torch.randint(0, point_num, (256,))
                        select_index = index[point_mask][select_index]
                        feat.append(x0[select_index, :])
                        coord.append(p0[select_index, :])
                        offset.append(torch.tensor(256))
                    class_index.append(c)     
            class_index_batch.append(class_index)       
        feat = torch.cat(feat, dim=0)
        coord = torch.cat(coord, dim=0)
        offset = torch.cumsum(torch.IntTensor(offset), dim=0, dtype=torch.int32).cuda(device=p0.device)
        
        # prior information process
        p1_prior, x1_prior, o1_prior = self.enc1_prior([coord, feat, offset])
        p2_prior, x2_prior, o2_prior = self.enc2_prior([p1_prior, x1_prior, o1_prior])
        p3_prior, x3_prior, o3_prior = self.enc3_prior([p2_prior, x2_prior, o2_prior])
        p4_prior, x4_prior, o4_prior = self.enc4_prior([p3_prior, x3_prior, o3_prior])
        p5_prior, feat, offset = self.enc5_prior([p4_prior, x4_prior, o4_prior])
        feat = self.projection(feat)
        # # method 1: aggregate the same class feature in a batch 
        # prior = []
        # for i, num in enumerate(point_set):
        #     if num == 0:
        #         prior.append(memory_prior[i, :].unsqueeze(0))
        #     else:
        #         if i == 0:
        #             begin = 0
        #         elif point_set[:i].sum() == 0: 
        #             begin = 0
        #         else:
        #             begin = offset[point_set[:i].sum()]
        #         end = offset[point_set[:i+1].sum()]
        #         prior.append(feat[begin:end, :].mean(dim=0, keepdim=True))
        # prior = torch.cat(prior, dim=0)
        # prior = torch.div(prior, torch.norm(prior,dim=1, keepdim=True) + 1e-9)

        # method 2: aggregate the same class feature in a example of a batch
        feat = nn.functional.normalize(feat, dim=1)
        current_prior = torch.cat([feat, torch.zeros([feat.size(0),1], device=feat.device)], dim=1)
        class_index_len = 0
        for i in range(o0.size(0)):
            for j, c_index in enumerate(class_index_batch[i]):
                begin = offset[class_index_len-1] if class_index_len-1 >= 0 else 0
                end = offset[class_index_len]
                current_prior[begin:end, -1] = c_index
                class_index_len += 1
        self._ema(current_prior.detach())
        self.prior_ema = nn.functional.normalize(self.prior_ema, dim=1)
        return current_prior, self.prior_ema


@MODELS.register_module()
class PTSeg_Balance_Main(nn.Module):
    def __init__(self,
                 block=PointTransformerBlock,
                 blocks=[2, 3, 4, 6, 3],    # depth
                 width=32,
                 nsample=[8, 16, 16, 16, 16],
                 in_channels=6,
                 num_classes=13,
                 dec_local_aggr=True,
                 mid_res=False,
                 beta=0.999,
                 **kwargs):
        super().__init__()
        self.num_classes = num_classes
        if in_channels<6:
            in_channels = in_channels + 3 # modify in_channels from rgb(z) to xyz+rgb(z)
        self.in_planes, planes = in_channels, [width * 2**i for i in range(len(blocks))]
        share_planes = 8
        stride, nsample = [1, 4, 4, 4, 4], nsample
        self.beta = beta

        if isinstance(block, str):
            block = eval(block)
        self.mid_res = mid_res
        self.dec_local_aggr = dec_local_aggr

        # main model enc 5
        self.in_planes = in_channels
        self.enc1 = self._make_enc(block, planes[0], blocks[0], share_planes, stride=stride[0],
                                   nsample=nsample[0])  # N/1
        self.enc2 = self._make_enc(block, planes[1], blocks[1], share_planes, stride=stride[1],
                                   nsample=nsample[1])  # N/4
        self.enc3 = self._make_enc(block, planes[2], blocks[2], share_planes, stride=stride[2],
                                   nsample=nsample[2])  # N/16
        self.enc4 = self._make_enc(block, planes[3], blocks[3], share_planes, stride=stride[3],
                                   nsample=nsample[3])  # N/64
        self.enc5 = self._make_enc(block, planes[4], blocks[4], share_planes, stride=stride[4],
                                   nsample=nsample[4])  # N/256
        # main model dec 5, no interpolation
        self.dec5 = self._make_dec(block, planes[4], 2, share_planes, nsample[4], True)  # transform p5
        self.dec4 = self._make_dec(block, planes[3], 2, share_planes, nsample[3])  # fusion p5 and p4
        self.dec3 = self._make_dec(block, planes[2], 2, share_planes, nsample[2])  # fusion p4 and p3
        self.dec2 = self._make_dec(block, planes[1], 2, share_planes, nsample[1])  # fusion p3 and p2
        self.dec1 = self._make_dec(block, planes[0], 2, share_planes, nsample[0])  # fusion p2 and p1
        self.cls = nn.Sequential(nn.Linear(planes[0], planes[0]), nn.BatchNorm1d(planes[0]), nn.ReLU(inplace=True),
                                nn.Linear(planes[0], self.num_classes))

        self.register_buffer('prior_ema', torch.rand(num_classes, planes[0]))
        self.prior_ema = nn.functional.normalize(self.prior_ema, dim=1)

    def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16):
        layers = []
        layers.append(TransitionDown(self.in_planes, planes * block.expansion, stride, nsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample, mid_res=self.mid_res))
        return nn.Sequential(*layers)

    def _make_dec(self, block, planes, blocks, share_planes=8, nsample=16, is_head=False):
        layers = []
        layers.append(TransitionUp(self.in_planes, None if is_head else planes * block.expansion))
        self.in_planes = planes * block.expansion

        if self.dec_local_aggr:
            for _ in range(1, blocks):
                layers.append(block(self.in_planes, self.in_planes, share_planes,
                              nsample=nsample, mid_res=self.mid_res))
        return nn.Sequential(*layers)

    @torch.no_grad()
    def _ema(self, prior):
        """prior: n*(dim+1), feature dim + label"""
        if dist.is_initialized():
            # gather prior before updating self.prior_ema
            prior = concat_all_gather_diff(prior)
        cur_status = self.prior_ema.clone()
        for label in range(self.num_classes):
            mask_c = prior[:, -1] == label
            if mask_c.nonzero().numel() > 0:
                cur_status[label, :] = prior[mask_c, :-1].mean(0)
        self.prior_ema = self.beta * self.prior_ema + (1 - self.beta) * cur_status

    def forward(self, p0, x0=None, o0=None, is_train=False, mask=None, ignore_index=None):
        # p, x, o: points, features, batches 
        # The dataloader input here is different from PointTransformer source code; 
        # it's need to modify input data{'pos':, 'x':, 'label':, } to the following form.
        if x0 is None:  # this means p0 is a dict.
            p0, x0, o0, labels = p0['pos'], p0.get('x', None), p0.get('offset', None), p0.get('y', None)
            if x0 is None: x0 = p0
            if o0 == None:
                o0, count = [], 0
                for _ in range(p0.size()[0]):
                    count += p0.size()[1]
                    o0.append(count)
                o0 = torch.IntTensor(o0).cuda(device=p0.device)
            # ---- robust flatten ------------------------------------------------
            # 1) If features are (B, C, N), transpose to (B, N, C) then flatten to (B*N, C)
            # 2) Only flatten coordinates when `p0` still has a batch dimension (B, N, 3).
            # 3) Labels follow the same rule: only flatten during training when they are 2-D.
            if x0.dim() == 3:  # x0: (B, C, N)
                B, C_in, N_pts = x0.size()
                x0 = x0.transpose(1, 2).contiguous().view(-1, C_in)  # (B*N, C)
                if p0.dim() == 3:
                    p0 = p0.view(-1, p0.size(-1))                    # (B*N, 3)
                if is_train and (labels is not None) and labels.dim() > 1:
                    labels = labels.reshape(-1)
        if x0.size(1)<6:
            x0 = torch.cat((p0, x0), 1)  # x0(n, c=in_channels+3)           
        
        # main model encoder and decoder
        p1, x1, o1 = self.enc1([p0, x0, o0])
        p2, x2, o2 = self.enc2([p1, x1, o1])
        p3, x3, o3 = self.enc3([p2, x2, o2])
        p4, x4, o4 = self.enc4([p3, x3, o3])
        p5, x5, o5 = self.enc5([p4, x4, o4])
        x5 = self.dec5[1:]([p5, self.dec5[0]([p5, x5, o5]), o5])[1]
        x4 = self.dec4[1:]([p4, self.dec4[0]([p4, x4, o4], [p5, x5, o5]), o4])[1]
        x3 = self.dec3[1:]([p3, self.dec3[0]([p3, x3, o3], [p4, x4, o4]), o3])[1]
        x2 = self.dec2[1:]([p2, self.dec2[0]([p2, x2, o2], [p3, x3, o3]), o2])[1]
        x1 = self.dec1[1:]([p1, self.dec1[0]([p1, x1, o1], [p2, x2, o2]), o1])[1]
        logits = self.cls(x1)
        if not is_train: return logits

        feas_norm = nn.functional.normalize(x1, dim=1)
        if mask is not None:
            labels = labels[mask]
            if ignore_index == 0:
                labels = labels - 1
            seg_logits, feas_norm = logits[mask, :], feas_norm[mask, :]
        else: 
            seg_logits = logits
        logits_softmax = nn.functional.softmax(seg_logits, dim=1)
        preds = logits_softmax.argmax(dim=1)
        mask_true = (preds==labels)
        # # method1: current scene prototype
        # feas_true_mean = []
        # for c in range(self.num_classes):
        #     mask_true_c = mask_true & (labels==c)
        #     if mask_true_c.sum() > 0:
        #         feas_true_mean.append(torch.cat([feas_norm[mask_true_c, :].mean(0), \
        #             torch.tensor([c], device=feas_norm.device)]).unsqueeze(0))
        # feas_true_mean = torch.cat(feas_true_mean, dim=0)
        # method2: all scenes prototype
        self._ema(torch.cat([feas_norm.detach()[mask_true, :], labels[mask_true].unsqueeze(1)], dim=1))
        self.prior_ema = nn.functional.normalize(self.prior_ema, dim=1)
        return logits, feas_norm, self.prior_ema


@MODELS.register_module()
class PTSegV2_Balance_Prior(nn.Module):
    def __init__(self,
                 in_channels,
                 num_classes=13,
                 patch_embed_depth=2,
                 patch_embed_channels=48,
                 patch_embed_groups=6,
                 patch_embed_neighbours=16,
                 enc_depths=(2, 6, 2),
                 enc_channels=(96, 192, 384),
                 enc_groups=(12, 24, 48),
                 enc_neighbours=(16, 16, 16),
                 grid_sizes=(0.1, 0.2, 0.4),
                 attn_qkv_bias=True,
                 pe_multiplier=False,
                 pe_bias=True,
                 attn_drop_rate=0.,
                 drop_path_rate=0.3,
                 enable_checkpoint=False,
                 beta=0.999,
                 **kwargs):
        super(PTSegV2_Balance_Prior, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_stages = len(enc_depths)
        self.beta = beta
        assert self.num_stages == len(enc_channels)
        assert self.num_stages == len(enc_groups)
        assert self.num_stages == len(enc_neighbours)
        assert self.num_stages == len(grid_sizes)
        self.patch_embed = GVAPatchEmbed(
            in_channels=in_channels,
            embed_channels=patch_embed_channels,
            groups=patch_embed_groups,
            depth=patch_embed_depth,
            neighbours=patch_embed_neighbours,
            qkv_bias=attn_qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            enable_checkpoint=enable_checkpoint
        )

        enc_dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(enc_depths))]
        enc_channels = [patch_embed_channels] + list(enc_channels)
        self.enc_stages = nn.ModuleList()
        self.dec_stages = nn.ModuleList()
        for i in range(self.num_stages):
            enc = Encoder(
                depth=enc_depths[i],
                in_channels=enc_channels[i],
                embed_channels=enc_channels[i + 1],
                groups=enc_groups[i],
                grid_size=grid_sizes[i],
                neighbours=enc_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=enc_dp_rates[sum(enc_depths[:i]):sum(enc_depths[:i + 1])],
                enable_checkpoint=enable_checkpoint
            )
            self.enc_stages.append(enc)
        self.projection = nn.Sequential(nn.Linear(enc_channels[-1], enc_channels[-2]), nn.BatchNorm1d(enc_channels[-2]), nn.ReLU(inplace=True),
                                        nn.Linear(enc_channels[-2], enc_channels[-4]), nn.BatchNorm1d(enc_channels[-4]), nn.ReLU(inplace=True))
        # ema
        self.register_buffer('prior_ema', torch.rand(num_classes, 48))
        self.prior_ema = nn.functional.normalize(self.prior_ema, dim=1)

    @torch.no_grad()
    def _ema(self, prior):
        """prior: n*(dim+1), feature dim + label"""
        if dist.is_initialized():
            # gather prior before updating self.prior_ema
            prior = concat_all_gather_diff(prior)
        cur_status = self.prior_ema.clone()
        for label in range(self.num_classes):
            mask_c = prior[:, -1] == label
            if mask_c.nonzero().numel() > 0:
                cur_status[label, :] = prior[mask_c, :-1].mean(0)
        self.prior_ema = self.beta * self.prior_ema + (1 - self.beta) * cur_status

    def forward(self, p0, x0=None, o0=None, is_train=False, mask=None, ignore_index=None):
        # p, x, o: points, features, batches 
        # The dataloader input here is different from PointTransformer source code; 
        # it's need to modify input data{'pos':, 'x':, 'label':, } to the following form.
        labels = None  # default
        if x0 is None:  # this means p0 is a dict.
            p0, x0, o0, labels = p0['pos'], p0.get('x', None), p0.get('offset', None), p0.get('y', None)
            if o0 is not None:
                o0 = o0.to(torch.int32)

            # 动态层级选择：根据 self.num_classes 自动匹配最合适的层级
            if is_train and (labels is not None) and labels.dim() > 1 and labels.shape[-1] > 1:
                lvl_num = labels.shape[-1]
                class_cnts = [(labels[..., i].max().item() + 1) for i in range(lvl_num)]
                # 寻找类别数最接近 self.num_classes 的层级
                candidates = [i for i, cnt in enumerate(class_cnts) if cnt >= self.num_classes]
                if candidates:
                    # 优先选择类别数不少于 num_classes 的层级
                    sel_idx = min(candidates, key=lambda i: abs(class_cnts[i] - self.num_classes))
                else:
                    # 如果所有层级类别都少，则选择最接近的
                    sel_idx = min(range(lvl_num), key=lambda i: abs(class_cnts[i] - self.num_classes))
                labels = labels[..., sel_idx]

            if x0 is None:
                x0 = p0
            if o0 == None:
                o0, count = [], 0
                for _ in range(p0.size()[0]):
                    count += p0.size()[1]
                    o0.append(count)
                o0 = torch.IntTensor(o0).cuda(device=p0.device)
            # ---- robust flatten ------------------------------------------------
            if x0.dim() == 3:
                B, C_in, N_pts = x0.size()
                x0 = x0.transpose(1, 2).contiguous().view(-1, C_in)
                if p0.dim() == 3:
                    p0 = p0.view(-1, p0.size(-1))
                if is_train and (labels is not None) and labels.dim() > 1:
                    labels = labels.reshape(-1)
        if x0.size(1)<6:
            x0 = torch.cat((p0,x0),1) # x0(n, c=in_channels+3)
                
        if mask is not None:
            p0, x0, labels = p0[mask], x0[mask], labels[mask]
            new_o0 = []
            for i in range(o0.size(0)):
                new_o0.append(mask[:(o0[i])].sum())
            o0 = torch.tensor(new_o0, dtype=o0.dtype, device=o0.device)
            if ignore_index == 0:
                labels = labels - 1
        
        # aggregate the same class feature in a example of a batch
        feat, coord, offset = [], [], []
        class_index_batch = [] # record the class index in a example of a batch
        for b in range(o0.size(0)):
            class_index = [] # record the class index in a example
            for c in range(self.num_classes):
                index = (labels==c).nonzero().squeeze(1)
                if b == 0:
                    point_mask = index<o0[b]
                else:
                    point_mask = (o0[b-1]<=index) & (index<o0[b])
                point_num = point_mask.sum().item()
                if point_num != 0:
                    # NOTE: When the number of points in the point cloud set is less than the sampling rate, 
                    # fill it up to the sampling rate (256)
                    if point_num >= 256:
                        feat.append(x0[index[point_mask], :])
                        coord.append(p0[index[point_mask], :])
                        offset.append(point_num)
                    elif (point_num > 0) and (point_num < 256):
                        select_index = torch.randint(0, point_num, (256,))
                        select_index = index[point_mask][select_index]
                        feat.append(x0[select_index, :])
                        coord.append(p0[select_index, :])
                        offset.append(torch.tensor(256))
                    class_index.append(c)     
            class_index_batch.append(class_index)       
        feat = torch.cat(feat, dim=0)
        coord = torch.cat(coord, dim=0)
        offset = torch.cumsum(torch.IntTensor(offset), dim=0, dtype=torch.int32).cuda(device=p0.device)

        # a batch of point cloud is a list of coord, feat and offset
        points = [coord, feat, offset]
        points = self.patch_embed(points)
        skips = [[points]]
        for i in range(self.num_stages):
            points, cluster = self.enc_stages[i](points)
            skips[-1].append(cluster)  # record grid cluster of pooling
            skips.append([points])  # record points info of current stage
        coord, feat, offset = skips[-1][0][0], skips[-1][0][1], skips[-1][0][2]  # unpooling feature info in the last enc stage
        feat = self.projection(feat)

        feat = nn.functional.normalize(feat, dim=1)
        current_prior = torch.cat([feat, torch.zeros([feat.size(0),1], device=feat.device)], dim=1)
        class_index_len = 0
        for i in range(o0.size(0)):
            for j, c_index in enumerate(class_index_batch[i]):
                begin = offset[class_index_len-1] if class_index_len-1 >= 0 else 0
                end = offset[class_index_len]
                current_prior[begin:end, -1] = c_index
                class_index_len += 1
        self._ema(current_prior.detach())
        self.prior_ema = nn.functional.normalize(self.prior_ema, dim=1)
        return current_prior, self.prior_ema


@MODELS.register_module()
class PTSegV2_Balance_Main(nn.Module):
    def __init__(self,
                 in_channels,
                 num_classes_per_level: Optional[list] = None,
                 # softmax-guided conditioning args
                 use_prev_softmax_guidance: bool = True,
                 cond_temperature: float = 1.0,
                 cond_alpha: float = 1.0,
                 detach_prev_guidance: bool = True,
                 use_projection: bool = False,
                 k_neighbors: int = 40,
                 patch_embed_depth=2,
                 patch_embed_channels=48,
                 patch_embed_groups=6,
                 patch_embed_neighbours=16,
                 enc_depths=(2, 6, 2),
                 enc_channels=(96, 192, 384),
                 enc_groups=(12, 24, 48),
                 enc_neighbours=(16, 16, 16),
                 dec_depths=(1, 1, 1),
                 dec_channels=(48, 96, 192),
                 dec_groups=(6, 12, 24),
                 dec_neighbours=(16, 16, 16),
                 grid_sizes=(0.1, 0.2, 0.4),
                 attn_qkv_bias=True,
                 pe_multiplier=False,
                 pe_bias=True,
                 attn_drop_rate=0.,
                 drop_path_rate=0.3,
                 enable_checkpoint=False,
                 unpool_backend="interp",
                 beta=0.999,
                 **kwargs):
        super(PTSegV2_Balance_Main, self).__init__()
        self.in_channels = in_channels
        if num_classes_per_level is None:
            num_classes_per_level = [13]
        self.num_classes_per_level = num_classes_per_level
        self.num_levels = len(num_classes_per_level)
        self.num_stages = len(enc_depths)
        self.beta = beta
        self.use_projection = use_projection
        self.k_neighbors = k_neighbors
        # conditioning config
        self.use_prev_softmax_guidance = use_prev_softmax_guidance
        self.cond_temperature = float(cond_temperature)
        self.cond_alpha = float(cond_alpha)
        self.detach_prev_guidance = bool(detach_prev_guidance)
        assert self.num_stages == len(dec_depths)
        assert self.num_stages == len(enc_channels)
        assert self.num_stages == len(dec_channels)
        assert self.num_stages == len(enc_groups)
        assert self.num_stages == len(dec_groups)
        assert self.num_stages == len(enc_neighbours)
        assert self.num_stages == len(dec_neighbours)
        assert self.num_stages == len(grid_sizes)
        self.patch_embed = GVAPatchEmbed(
            in_channels=in_channels,
            embed_channels=patch_embed_channels,
            groups=patch_embed_groups,
            depth=patch_embed_depth,
            neighbours=patch_embed_neighbours,
            qkv_bias=attn_qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            enable_checkpoint=enable_checkpoint
        )

        enc_dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(enc_depths))]
        dec_dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(dec_depths))]
        enc_channels = [patch_embed_channels] + list(enc_channels)
        dec_channels = list(dec_channels) + [enc_channels[-1]]
        self.enc_stages = nn.ModuleList()
        self.dec_stages = nn.ModuleList()
        for i in range(self.num_stages):
            enc = Encoder(
                depth=enc_depths[i],
                in_channels=enc_channels[i],
                embed_channels=enc_channels[i + 1],
                groups=enc_groups[i],
                grid_size=grid_sizes[i],
                neighbours=enc_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=enc_dp_rates[sum(enc_depths[:i]):sum(enc_depths[:i + 1])],
                enable_checkpoint=enable_checkpoint
            )
            dec = Decoder(
                depth=dec_depths[i],
                in_channels=dec_channels[i + 1],
                skip_channels=enc_channels[i],
                embed_channels=dec_channels[i],
                groups=dec_groups[i],
                neighbours=dec_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=dec_dp_rates[sum(dec_depths[:i]):sum(dec_depths[:i + 1])],
                enable_checkpoint=enable_checkpoint,
                unpool_backend=unpool_backend
            )
            self.enc_stages.append(enc)
            self.dec_stages.append(dec)

        self.feat_dim = dec_channels[0]
        # ---------------- Segmentation heads per level ----------------
        self.seg_heads = nn.ModuleList()
        for nc in self.num_classes_per_level:
            self.seg_heads.append(nn.Sequential(
                nn.Linear(self.feat_dim, self.feat_dim),
                PointBatchNorm(self.feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(self.feat_dim, nc)
            ))

        # ---------------- Per-level projection (optional) -------------
        self.proj_mlps = nn.ModuleList()
        for _ in range(self.num_levels):
            if self.use_projection:
                self.proj_mlps.append(nn.Sequential(
                    nn.Linear(self.feat_dim, self.feat_dim, bias=False),
                    PointBatchNorm(self.feat_dim),
                    nn.ReLU(inplace=True),
                ))
            else:
                self.proj_mlps.append(nn.Identity())

        # ---------------- Conditioning blocks for levels > 0 ----------
        # Each block: proj previous level probs -> channels, then fuse with current features
        self.cond_blocks = nn.ModuleList()
        for lvl in range(1, self.num_levels):
            prev_nc = self.num_classes_per_level[lvl - 1]
            block = nn.ModuleDict({
                'proj': nn.Linear(prev_nc, self.feat_dim, bias=True),
                'fuse': nn.Sequential(
                    nn.Linear(self.feat_dim * 2, self.feat_dim, bias=False),
                    PointBatchNorm(self.feat_dim),
                    nn.ReLU(inplace=True),
                )
            })
            self.cond_blocks.append(block)

        # ---------------- HEFM between levels (L-1 pairs) -------------
        self.hefms = nn.ModuleList()
        for _ in range(1, self.num_levels):
            self.hefms.append(HEFM(
                dim_top=self.feat_dim,
                dim_bottom=self.feat_dim,
                alpha_init=0.9,
                tau=1.0,
            ))

        # ---------------- EMA prototypes per level -------------------
        self.prior_ema_list = []
        for lvl, nc in enumerate(self.num_classes_per_level):
            buf = torch.rand(nc, self.feat_dim)
            self.register_buffer(f"prior_ema_{lvl}", F.normalize(buf, dim=1))
            self.prior_ema_list.append(getattr(self, f"prior_ema_{lvl}"))

    @torch.no_grad()
    def _ema_update(self, lvl: int, feats: torch.Tensor, labels: torch.Tensor):
        """Update EMA prototype for specific level."""
        buf = self.prior_ema_list[lvl]
        if dist.is_initialized():
            combined = torch.cat([feats, labels.unsqueeze(1).float()], dim=1)
            combined = concat_all_gather_diff(combined)
            labels = combined[:, -1].long()
            feats = combined[:, :-1]
        cur = buf.clone()
        for c in range(buf.size(0)):
            m = labels == c
            if m.any():
                cur[c] = feats[m].mean(0)
        buf.copy_(F.normalize(self.beta * buf + (1 - self.beta) * cur, dim=1))

    def forward(self, p0, x0=None, o0=None, is_train=False, mask=None, ignore_index=None):
        # p, x, o: points, features, batches
        # The dataloader input here is different from PointTransformer source code;
        # it's need to modify input data{'pos':, 'x':, 'label':, } to the following form.
        labels = None
        if x0 is None:  # this means p0 is a dict.
            p0, x0, o0, labels = p0['pos'], p0.get('x', None), p0.get('offset', None), p0.get('y', None)
            if o0 is not None:
                o0 = o0.to(torch.int32)
            if x0 is None: x0 = p0
            if o0 == None:
                o0, count = [], 0
                for _ in range(p0.size()[0]):
                    count += p0.size()[1]
                    o0.append(count)
                o0 = torch.IntTensor(o0).cuda(device=p0.device)
            # ---- robust flatten ------------------------------------------------
            if x0.dim() == 3:  # x0: (B, C, N)
                B, C_in, N_pts = x0.size()
                x0 = x0.transpose(1, 2).contiguous().view(-1, C_in)
                if p0.dim() == 3:
                    p0 = p0.view(-1, p0.size(-1))
                if is_train and (labels is not None) and labels.dim() > 1:
                    # this will be (N, L)
                    labels = labels.reshape(-1, labels.shape[-1])
        if x0.size(1)<6:
            x0 = torch.cat((p0,x0),1) # x0(n, c=in_channels+3)

        if mask is not None:
            p0, x0, labels = p0[mask], x0[mask], labels[mask] if labels is not None else None
            if ignore_index == 0 and labels is not None:
                labels = labels - 1
                if coarse_labels is not None:
                    coarse_labels = coarse_labels - 1
            new_o = []
            for i in range(o0.size(0)):
                new_o.append((mask[:o0[i]]).sum())
            o0 = torch.tensor(new_o, dtype=torch.int32, device=p0.device)

        coord, feat, offset = p0, x0, o0
        points = [coord, feat, offset]

        # ------------------------------------------------------------
        #  Profiling disabled
        # ------------------------------------------------------------
        profile_enable = False

        # ----- Patch Embedding -------------------------------------
        points = self.patch_embed(points)
        if profile_enable:
            import time
            torch.cuda.synchronize(); _t_pe = time.time();

        skips = [[points]]
        enc_times, dec_times = [], []
        for i in range(self.num_stages):
            if profile_enable:
                import time
                torch.cuda.synchronize(); _t_enc = time.time()
            points, cluster = self.enc_stages[i](points)
            if profile_enable:
                import time
                torch.cuda.synchronize(); enc_times.append(time.time() - _t_enc)
            skips[-1].append(cluster)  # record grid cluster of pooling
            skips.append([points])  # record points info of current stage

        points = skips.pop(-1)[0]  # unpooling points info in the last enc stage
        for i in reversed(range(self.num_stages)):
            if profile_enable:
                import time
                torch.cuda.synchronize(); _t_dec = time.time()
            skip_points, cluster = skips.pop(-1)
            points = self.dec_stages[i](points, skip_points, cluster)
            if profile_enable:
                import time
                torch.cuda.synchronize(); dec_times.append(time.time() - _t_dec)
        coord, feat, offset = points

        # ----- Per-level features with optional projection ----------
        feats_list = [proj(feat) for proj in self.proj_mlps]
        feats_orig = [f.clone() for f in feats_list]

        # ----- HEFM fusion (if configured) --------------------------
        # if len(self.hefms) > 0:
        #     num_pairs = len(self.hefms)
        #     refined_top_list = [[] for _ in range(num_pairs)]
        #     refined_bottom_list = [[] for _ in range(num_pairs)]
        #     idx_start = 0
        #     for b in range(offset.size(0)):
        #         idx_end = offset[b].item()
        #         neighbor_idx, _ = pointops.knnquery(
        #             self.k_neighbors,
        #             coord[idx_start:idx_end], coord[idx_start:idx_end],
        #             torch.tensor([idx_end - idx_start], dtype=torch.int32, device=coord.device),
        #             torch.tensor([idx_end - idx_start], dtype=torch.int32, device=coord.device)
        #         )
        #         for pair_idx, hefm in enumerate(self.hefms):
        #             z_top_orig = feats_orig[pair_idx][idx_start:idx_end]
        #             z_bottom_orig = feats_orig[pair_idx + 1][idx_start:idx_end]
        #             z_top_cur = feats_list[pair_idx][idx_start:idx_end]
        #             z_bottom_ref = hefm.top_fusion(z_top_orig, z_bottom_orig)
        #             z_top_ref = hefm.bottom_agg(z_top_cur, z_bottom_orig, neighbor_idx)
        #             refined_top_list[pair_idx].append(z_top_ref)
        #             refined_bottom_list[pair_idx].append(z_bottom_ref)
        #         idx_start = idx_end
        #     for pair_idx in range(num_pairs):
        #         feats_list[pair_idx] = torch.cat(refined_top_list[pair_idx], dim=0)
        #         feats_list[pair_idx + 1] = torch.cat(refined_bottom_list[pair_idx], dim=0)

        # ----- Segmentation Heads ----------------------------------
        if profile_enable:
            import time
            torch.cuda.synchronize(); _t_head = time.time()
        # sequential conditioning across levels
        logits_list = []
        conditioned_feats = []
        for lvl in range(self.num_levels):
            feat_lvl = feats_list[lvl]
            if lvl > 0 and self.use_prev_softmax_guidance:
                prev_logits = logits_list[lvl - 1]
                if self.detach_prev_guidance:
                    prev_logits = prev_logits.detach()
                probs = F.softmax(prev_logits / self.cond_temperature, dim=1)
                block = self.cond_blocks[lvl - 1]
                # feats are [N, C]; probs are [N, C_prev]
                cond = block['proj'](probs)
                if self.cond_alpha != 1.0:
                    cond = cond * self.cond_alpha
                feat_lvl = block['fuse'](torch.cat([feat_lvl, cond], dim=1))
            conditioned_feats.append(feat_lvl)
            logits_list.append(self.seg_heads[lvl](feat_lvl))
        if profile_enable:
            import time
            torch.cuda.synchronize(); head_time = time.time() - _t_head

        if profile_enable and (not getattr(self, "_printed_profile", False)):
            import logging, numpy as np
            self._printed_profile = True  # 只打印一次避免刷屏
            import time
            # Fallback if _t_start not defined due to earlier short-circuit
            _t_total = 0.0
            # compute total if available
            _t_total = 0.0
            msg = (f"[PTV2_PROFILE] patch_embed: {_t_pe:.3f}s | "
                   f"enc: {np.sum(enc_times):.3f}s ({[round(t,3) for t in enc_times]}) | "
                   f"dec: {np.sum(dec_times):.3f}s ({[round(t,3) for t in dec_times[::-1]]}) | "
                   f"heads: {head_time:.3f}s | total: {_t_total:.3f}s")
            logging.info(msg)

        if not is_train:
            return logits_list

        # --------------------------------------------------
        # 3) Build prototype & losses (delegated to train loop)
        # --------------------------------------------------
        feats_norm_list = [F.normalize(ft, dim=1) for ft in conditioned_feats]

        num_pts = feats_norm_list[0].size(0)
        if labels is not None:
            # Case: labels incorrectly reshaped to (B*N, N_pts) producing
            # huge mismatch; detect by total elements count.
            total_elems = labels.numel()
            if self.num_levels > 1:
                expected = num_pts * self.num_levels
                if total_elems == expected and labels.shape[0] != num_pts:
                    labels = labels.view(num_pts, self.num_levels)
            else:
                if total_elems == num_pts and labels.shape[0] != num_pts:
                    labels = labels.view(num_pts)

        # ---------------- EMA updates per level ----------------
        if labels is not None and labels.numel() > 0:
            # ensure labels shape (..., L)
            if labels.dim() == 1:
                labels = labels.unsqueeze(-1)
            for lvl in range(self.num_levels):
                lbl_lvl = labels[..., lvl] if labels.size(-1) > lvl else labels[..., -1]
                logits_lvl = logits_list[lvl]
                preds = logits_lvl.argmax(dim=1)
                mask_true = preds == lbl_lvl
                if mask_true.any():
                    self._ema_update(lvl, feats_norm_list[lvl].detach()[mask_true], lbl_lvl[mask_true])

        # ------------------------------------------------------
        prior_buffers = [buf for buf in self.prior_ema_list]

        return logits_list, feats_norm_list, prior_buffers

# =============================================================
#  PointTransformer-V3  Prior branch (encoder only)
# =============================================================
from PointTransformerV3.model import PointTransformerV3  # noqa: E402
import torch.nn.functional as F  # noqa: E402


@MODELS.register_module()
class PTSegV3_Balance_Prior(nn.Module):
    """Prior branch that reuses PointTransformer-V3 encoder while keeping the
    per-class prototype logic identical to PTSegV2_Balance_Prior.
    Only the feature extractor is replaced; all EMA prototype handling remains
    the same.
    """

    def __init__(
        self,
        in_channels: int,
        num_classes: int = 13,
        beta: float = 0.999,
        grid_size: float = 0.02,
        encoder_kwargs: Optional[Dict] = None,
    ) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.beta = beta
        self.grid_size = grid_size
        encoder_kwargs = encoder_kwargs or {}

        # 1) Point-Transformer-V3 encoder (decoder disabled via cls_mode=True)
        self.encoder = PointTransformerV3(
            in_channels=in_channels,
            cls_mode=True,
            **encoder_kwargs,
        )
        # Encoder produces features with dimension encoder.enc_channels[-1]
        enc_out_dim = self.encoder.enc[-1][-1].channels  # last block channels

        # Determine prototype dimension: follow main branch logic
        dec_chs = encoder_kwargs.get("dec_channels", None)
        if dec_chs is None:
            proto_dim = 64
        elif isinstance(dec_chs, (list, tuple)):
            proto_dim = dec_chs[0]
        else:
            proto_dim = int(dec_chs)

        mid_dim = enc_out_dim // 2  # projection hidden dim

        # 2) Projection to low-dim feature used by prototypes (match v2 design)
        self.projection = nn.Sequential(
            nn.Linear(enc_out_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(mid_dim, proto_dim),
            nn.BatchNorm1d(proto_dim),
            nn.ReLU(inplace=True),
        )

        # 3) EMA prototypes (buffer)
        self.register_buffer("prior_ema", torch.rand(num_classes, proto_dim))
        self.prior_ema = F.normalize(self.prior_ema, dim=1)

    # ---------------------------------------------------------------------
    # helpers
    # ---------------------------------------------------------------------
    @torch.no_grad()
    def _ema(self, prior: torch.Tensor):
        """Update EMA prototypes.
        prior shape: [N, D+1] where last dim stores class label."""
        if dist.is_initialized():
            prior = concat_all_gather_diff(prior)
        cur_status = self.prior_ema.clone()
        for cls in range(self.num_classes):
            mask = prior[:, -1] == cls
            if mask.any():
                cur_status[cls] = prior[mask, :-1].mean(0)
        self.prior_ema = self.beta * self.prior_ema + (1 - self.beta) * cur_status

    # ---------------------------------------------------------------------
    # forward
    # ---------------------------------------------------------------------
    def forward(
        self,
        p0,
        x0: Optional[torch.Tensor] = None,
        o0: Optional[torch.Tensor] = None,
        *,
        is_train: bool = False,
        mask: Optional[torch.Tensor] = None,
        ignore_index: Optional[int] = None,
    ):
        """Args mirror previous prior branch for drop-in replacement.
        Returns current prototypes (per batch) and global EMA prototypes."""
        # --------------------------------------------------
        # 0) Unpack dataloader dict variant
        # --------------------------------------------------
        labels = None  # default, may be filled later
        if x0 is None:  # p0 is dict
            p0, x0, o0, labels = (
                p0["pos"],
                p0.get("x", None),
                p0.get("offset", None),
                p0.get("y", None),
            )

            # 动态层级选择：根据 self.num_classes 自动匹配最合适的层级
            # Placed here to handle both (B,N,L) and (N,L) formats before any reshaping.
            if is_train and (labels is not None) and labels.dim() > 1 and labels.shape[-1] > 1:
                lvl_num = labels.shape[-1]
                class_cnts = [(labels[..., i].max().item() + 1) for i in range(lvl_num)]
                candidates = [i for i, cnt in enumerate(class_cnts) if cnt >= self.num_classes]
                if candidates:
                    sel_idx = min(candidates, key=lambda i: abs(class_cnts[i] - self.num_classes))
                else:
                    sel_idx = min(range(lvl_num), key=lambda i: abs(class_cnts[i] - self.num_classes))
                labels = labels[..., sel_idx]

            if x0 is None:
                x0 = p0
            if o0 is None:
                # build offset if missing  (uniform length per sample not assumed)
                o_list, cnt = [], 0
                for _ in range(p0.size(0)):
                    cnt += p0.size(1)
                    o_list.append(cnt)
                o0 = torch.tensor(o_list, dtype=torch.int32, device=p0.device)
            # flatten B×N to N
            if x0.dim() == 3:  # (B,C,N)
                B, C_in, N_pts = x0.size()
                x0 = x0.transpose(1, 2).contiguous().view(-1, C_in)
                p0 = p0.reshape(-1, 3)
                if is_train and (labels is not None) and labels.dim() > 1:
                    labels = labels.reshape(-1)
        if x0.size(1) < 6:
            x0 = torch.cat((p0, x0), 1)

        # --------------------------------------------------
        # 1) Optional masking
        # --------------------------------------------------
        if mask is not None:
            p0, x0, labels = p0[mask], x0[mask], labels[mask]
            if ignore_index == 0:
                labels = labels - 1
            new_o = []
            acc = 0
            for off in o0:
                new_o.append((mask[:off]).sum())
            o0 = torch.tensor(new_o, dtype=torch.int32, device=p0.device)

        # --------------------------------------------------
        # 2) Per-sample per-class aggregation (same as v2)
        # --------------------------------------------------
        feat_list, coord_list, off_list, class_idx_batches = [], [], [], []
        start = 0
        for b in range(o0.size(0)):
            class_indices = []
            end = o0[b].item()
            for cls in range(self.num_classes):
                sel = (labels[start:end] == cls).nonzero().squeeze(1)
                if sel.numel() == 0:
                    continue
                # upsample / downsample to 256 pts
                if sel.numel() < 1024:
                    sel = sel[torch.randint(0, sel.numel(), (1024,), device=p0.device)]
                elif sel.numel() > 1024:
                    sel = sel[:1024]
                # gather
                global_sel = sel + start
                feat_list.append(x0[global_sel])
                coord_list.append(p0[global_sel])
                off_list.append(1024)
                class_indices.append(cls)
            class_idx_batches.append(class_indices)
            start = end
        feat = torch.cat(feat_list, 0)
        coord = torch.cat(coord_list, 0)
        offset = torch.cumsum(torch.tensor(off_list, device=p0.device, dtype=torch.int32), 0)

        # --------------------------------------------------
        # 3) V3 encoder forward
        # --------------------------------------------------
        data_dict = {
            "coord": coord,
            "feat": feat,
            "offset": offset,
            "grid_size": self.grid_size,
        }
        point_out = self.encoder(data_dict)
        feat_enc = point_out.feat  # [N, C_enc]
        feat_proj = self.projection(feat_enc)
        feat_proj = F.normalize(feat_proj, dim=1)

        # --------------------------------------------------
        # 4) Build current prior tensor with labels
        # --------------------------------------------------
        current_prior = torch.cat(
            [feat_proj, torch.zeros(feat_proj.size(0), 1, device=feat_proj.device)],
            dim=1,
        )
        idx = 0
        for batch_classes in class_idx_batches:
            for cls in batch_classes:
                begin = 1024 * idx
                end = begin + 1024
                current_prior[begin:end, -1] = cls
                idx += 1
        # --------------------------------------------------
        # 5) EMA update & return
        # --------------------------------------------------
        self._ema(current_prior.detach())
        self.prior_ema = F.normalize(self.prior_ema, dim=1)
        return current_prior, self.prior_ema


# =============================================================
#  PointTransformer-V3  Main branch (encoder + decoder)
# =============================================================
from PointTransformerV3.model import PointTransformerV3  # noqa: E402

@MODELS.register_module()
class PTSegV3_Balance_Main(nn.Module):
    """Main segmentation branch based on PointTransformer-V3 encoder+decoder.
    The forward / loss logic mimics PTSegV2_Balance_Main so that
    `train_one_epoch()` can work unmodified (expects 3-tuple output).
    """

    def __init__(
        self,
        in_channels: int,
        num_classes_per_level: Optional[list] = None,
        # softmax-guided conditioning args
        use_prev_softmax_guidance: bool = True,
        cond_temperature: float = 1.0,
        cond_alpha: float = 1.0,
        detach_prev_guidance: bool = True,
        use_projection: bool = False,
        k_neighbors: int = 40,
        enable_hefm: bool = False,
        beta: float = 0.999,
        grid_size: float = 0.02,
        encoder_kwargs: Optional[Dict] = None,
    ) -> None:
        super().__init__()
        # ------------------------------------------------------------------
        # Handle hierarchy
        # ------------------------------------------------------------------
        if num_classes_per_level is None:
            num_classes_per_level = [13]
        self.num_levels = len(num_classes_per_level)
        self.num_classes_per_level = num_classes_per_level

        self.beta = beta
        self.grid_size = grid_size
        encoder_kwargs = encoder_kwargs or {}

        # ------------------------------------------------------------------
        # Backbone: full PT-V3 (single encoder+decoder)
        # ------------------------------------------------------------------
        self.backbone = PointTransformerV3(
            in_channels=in_channels,
            cls_mode=False,  # enable decoder for per-point features
            **encoder_kwargs,
        )

        # Determine feature dimension after decoder.
        # Heuristic: dec_channels[0] in kwargs; else default 48.
        dec_chs = encoder_kwargs.get("dec_channels", None)
        if dec_chs is None:
            feat_dim = 64  # PointTransformerV3 default first decoder channel
        elif isinstance(dec_chs, (list, tuple)):
            feat_dim = dec_chs[0]
        else:
            feat_dim = int(dec_chs)
        self.feat_dim = feat_dim

        # ---------------- Segmentation heads per level ----------------
        self.seg_heads = nn.ModuleList([
            nn.Linear(feat_dim, nc) for nc in num_classes_per_level
        ])

        # ---------------- Per-level projection (optional) -------------
        self.use_projection = use_projection
        self.k_neighbors = k_neighbors
        self.proj_mlps = nn.ModuleList()
        for _ in range(self.num_levels):
            if self.use_projection:
                self.proj_mlps.append(nn.Sequential(
                    nn.Linear(self.feat_dim, self.feat_dim, bias=False),
                    PointBatchNorm(self.feat_dim),
                    nn.ReLU(inplace=True),
                ))
            else:
                self.proj_mlps.append(nn.Identity())

        # ---------------- Conditioning blocks for levels > 0 ----------
        self.use_prev_softmax_guidance = use_prev_softmax_guidance
        self.cond_temperature = float(cond_temperature)
        self.cond_alpha = float(cond_alpha)
        self.detach_prev_guidance = bool(detach_prev_guidance)

        self.cond_blocks = nn.ModuleList()
        for lvl in range(1, self.num_levels):
            prev_nc = self.num_classes_per_level[lvl - 1]
            block = nn.ModuleDict({
                'proj': nn.Linear(prev_nc, self.feat_dim, bias=True),
                'fuse': nn.Sequential(
                    nn.Linear(self.feat_dim * 2, self.feat_dim, bias=False),
                    PointBatchNorm(self.feat_dim),
                    nn.ReLU(inplace=True),
                )
            })
            self.cond_blocks.append(block)

        # ---------------- HEFM between levels (L-1 pairs) -------------
        self.enable_hefm = bool(enable_hefm)
        self.hefms = nn.ModuleList()
        if self.enable_hefm:
            for _ in range(1, self.num_levels):
                self.hefms.append(HEFM(
                    dim_top=self.feat_dim,
                    dim_bottom=self.feat_dim,
                    alpha_init=0.9,
                    tau=1.0,
                ))

        # ---------------- EMA prototypes per level -------------------
        self.prior_ema_list = []
        for lvl, nc in enumerate(num_classes_per_level):
            buf = torch.rand(nc, feat_dim)
            self.register_buffer(f"prior_ema_{lvl}", F.normalize(buf, dim=1))
            self.prior_ema_list.append(getattr(self, f"prior_ema_{lvl}"))

    # ------------------------------------------------------------------
    # EMA update helper (generic level)
    # ------------------------------------------------------------------
    @torch.no_grad()
    def _ema_update(self, lvl: int, feats: torch.Tensor, labels: torch.Tensor):
        """Update EMA prototype for specific level."""
        buf = self.prior_ema_list[lvl]
        if dist.is_initialized():
            feats = concat_all_gather_diff(torch.cat([feats, labels.unsqueeze(1).float()], dim=1))
            labels = feats[:, -1].long()
            feats = feats[:, :-1]
        cur = buf.clone()
        for c in range(buf.size(0)):
            m = labels == c
            if m.any():
                cur[c] = feats[m].mean(0)
        buf.copy_(F.normalize(self.beta * buf + (1 - self.beta) * cur, dim=1))

    # ------------------------------------------------------------------
    # forward
    # ------------------------------------------------------------------
    def forward(
        self,
        p0,
        x0: Optional[torch.Tensor] = None,
        o0: Optional[torch.Tensor] = None,
        *,
        is_train: bool = False,
        mask: Optional[torch.Tensor] = None,
        ignore_index: Optional[int] = None,
    ):
        """Reusable signature with previous main branch.
        Returns (logits, feat_norm, prior_ema) when is_train=True,
        else just logits.
        """
        # --------------------------------------------------
        # 0) Unpack dataloader variant dict
        # --------------------------------------------------
        labels = None  # default, may be filled later
        coarse_labels = None  # for hierarchical
        if x0 is None:  # p0 is dict
            p0, x0, o0, labels = (
                p0["pos"],
                p0.get("x", None),
                p0.get("offset", None),
                p0.get("y", None),
            )
            if x0 is None:
                x0 = p0
            if o0 is None:
                # build offset lazily (uniform length not assumed)
                o_list, cnt = [], 0
                for _ in range(p0.size(0)):
                    cnt += p0.size(1)
                    o_list.append(cnt)
                o0 = torch.tensor(o_list, dtype=torch.int32, device=p0.device)
            # flatten B×N to N if needed
            if x0.dim() == 3:
                B, C_in, N_pts = x0.size()
                x0 = x0.transpose(1, 2).contiguous().view(-1, C_in)
                if p0.dim() == 3:
                    p0 = p0.reshape(-1, 3)
                # Handle hierarchical labels ([N,2]) – keep fine label only
                if labels is not None:
                    # Case (N,2)
                    if labels.dim() == 2 and labels.size(1) == 2:
                        coarse_labels = labels[:, 0].clone()
                        labels = labels[:, 1].clone()
                    # Case (B,N,2)
                    elif labels.dim() > 2 and labels.size(-1) == 2:
                        coarse_labels = labels[..., 0].reshape(-1).clone()
                        labels = labels[..., 1].reshape(-1).clone()
                # flatten if still has extra dims (e.g. B×N)
                if is_train and labels.dim() > 1:
                    labels = labels.reshape(-1)
        # concat xyz if feature dim <6
        if x0.size(1) < 6:
            x0 = torch.cat((p0, x0), 1)

         
        # --------------------------------------------------
        # 1) Optional masking
        # --------------------------------------------------
        if mask is not None:
            p0, x0, labels = p0[mask], x0[mask], labels[mask] if labels is not None else labels
            if coarse_labels is not None:
                coarse_labels = coarse_labels[mask]
            if ignore_index == 0 and labels is not None:
                labels = labels - 1
                if coarse_labels is not None:
                    coarse_labels = coarse_labels - 1
            new_o = []
            for i in range(o0.size(0)):
                new_o.append((mask[: o0[i]]).sum())
            o0 = torch.tensor(new_o, dtype=torch.int32, device=p0.device)

        # --------------------------------------------------
        # 2) Backbone forward
        # --------------------------------------------------
        data_dict = {
            "coord": p0,
            "feat": x0,
            "offset": o0,
            "grid_size": self.grid_size,
        }
        point_out = self.backbone(data_dict)
        if isinstance(point_out, Point):
            while "pooling_parent" in point_out.keys():
                assert "pooling_inverse" in point_out.keys()
                parent = point_out.pop("pooling_parent")
                inverse = point_out.pop("pooling_inverse")
                parent.feat = torch.cat([parent.feat, point_out.feat[inverse]], dim=-1)
                point_out = parent
            feat = point_out.feat
        else:
            print("no postprocess")
            feat = point_out.feat
        # ----- Per-level features with optional projection ----------
        feats_list = [proj(feat) for proj in self.proj_mlps]
        feats_orig = [f.clone() for f in feats_list]

        # ----- HEFM fusion (if configured) --------------------------
        if self.enable_hefm and len(self.hefms) > 0:
            # use point_out's coord/offset if available
            if isinstance(point_out, Point):
                coord = point_out.coord
                offset = point_out.offset
            else:
                coord = p0
                offset = o0
            num_pairs = len(self.hefms)
            refined_top_list = [[] for _ in range(num_pairs)]
            refined_bottom_list = [[] for _ in range(num_pairs)]
            idx_start = 0
            for b in range(offset.size(0)):
                idx_end = offset[b].item()
                neighbor_idx, _ = pointops.knnquery(
                    self.k_neighbors,
                    coord[idx_start:idx_end], coord[idx_start:idx_end],
                    torch.tensor([idx_end - idx_start], dtype=torch.int32, device=coord.device),
                    torch.tensor([idx_end - idx_start], dtype=torch.int32, device=coord.device)
                )
                for pair_idx, hefm in enumerate(self.hefms):
                    z_top_orig = feats_orig[pair_idx][idx_start:idx_end]
                    z_bottom_orig = feats_orig[pair_idx + 1][idx_start:idx_end]
                    z_top_cur = feats_list[pair_idx][idx_start:idx_end]
                    z_bottom_ref = hefm.top_fusion(z_top_orig, z_bottom_orig)
                    z_top_ref = hefm.bottom_agg(z_top_cur, z_bottom_orig, neighbor_idx)
                    refined_top_list[pair_idx].append(z_top_ref)
                    refined_bottom_list[pair_idx].append(z_bottom_ref)
                idx_start = idx_end
            for pair_idx in range(num_pairs):
                feats_list[pair_idx] = torch.cat(refined_top_list[pair_idx], dim=0)
                feats_list[pair_idx + 1] = torch.cat(refined_bottom_list[pair_idx], dim=0)

        # ----- Segmentation Heads with conditioning -----------------
        logits_list = []
        conditioned_feats = []
        for lvl in range(self.num_levels):
            feat_lvl = feats_list[lvl]
            if lvl > 0 and self.use_prev_softmax_guidance:
                prev_logits = logits_list[lvl - 1]
                if self.detach_prev_guidance:
                    prev_logits = prev_logits.detach()
                probs = F.softmax(prev_logits / self.cond_temperature, dim=1)
                block = self.cond_blocks[lvl - 1]
                cond = block['proj'](probs)
                if self.cond_alpha != 1.0:
                    cond = cond * self.cond_alpha
                feat_lvl = block['fuse'](torch.cat([feat_lvl, cond], dim=1))
            conditioned_feats.append(feat_lvl)
            logits_list.append(self.seg_heads[lvl](feat_lvl))

        if not is_train:
            return logits_list

        # --------------------------------------------------
        # 3) Build prototype & losses (delegated to train loop)
        # --------------------------------------------------
        feats_norm_list = [F.normalize(ft, dim=1) for ft in conditioned_feats]

        num_pts = feats_norm_list[0].size(0)
        if labels is not None:
            # Case: labels incorrectly reshaped to (B*N, N_pts) producing
            # huge mismatch; detect by total elements count.
            total_elems = labels.numel()
            if self.num_levels > 1:
                expected = num_pts * self.num_levels
                if total_elems == expected and labels.shape[0] != num_pts:
                    labels = labels.view(num_pts, self.num_levels)
            else:
                if total_elems == num_pts and labels.shape[0] != num_pts:
                    labels = labels.view(num_pts)

        # ---------------- EMA updates per level ----------------
        if labels is not None and labels.numel() > 0:
            # ensure labels shape (..., L)
            if labels.dim() == 1:
                labels = labels.unsqueeze(-1)
            for lvl in range(self.num_levels):
                lbl_lvl = labels[..., lvl] if labels.size(-1) > lvl else labels[..., -1]
                logits_lvl = logits_list[lvl]
                preds = logits_lvl.argmax(dim=1)
                mask_true = preds == lbl_lvl
                if mask_true.any():
                    self._ema_update(lvl, feats_norm_list[lvl].detach()[mask_true], lbl_lvl[mask_true])

        # ------------------------------------------------------
        prior_buffers = [buf for buf in self.prior_ema_list]

        return logits_list, feats_norm_list, prior_buffers
