import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributed as dist

from ..layers import furthest_point_sample, concat_all_gather_diff
from ..build import MODELS
from ..build import build_model_from_cfg


@MODELS.register_module()
class BaseSeg_Balance_Prior(nn.Module):
    def __init__(self,
                 beta=0.999,
                 encoder_args=None,
                 cls_args=None,
                 num_classes=13,
                 **kwargs):
        super().__init__()
        self.beta = beta
        self.num_classes = num_classes
        self.encoder = build_model_from_cfg(encoder_args)
        self.projection = nn.Sequential(
            nn.Linear(encoder_args.mlps[-1][-1][-1], encoder_args.mlps[-2][-1][-1]), nn.ReLU(inplace=True),
            nn.Linear(encoder_args.mlps[-2][-1][-1], encoder_args.mlps[-3][-1][-1]), nn.ReLU(inplace=True)
            )
        # ema
        self.register_buffer('prior_ema', torch.rand(self.num_classes, encoder_args.mlps[-3][-1][-1]))
        F.normalize(self.prior_ema, dim=1, out=self.prior_ema)

    @torch.no_grad()
    def _ema(self, prior):
        """prior: n*(dim+1), feature dim + label"""
        if dist.is_initialized():
            prior = concat_all_gather_diff(prior)
        cur_status = self.prior_ema.clone()
        for label in range(self.num_classes):
            mask_c = prior[:, -1] == label
            if mask_c.sum() > 0:
                cur_status[label, :] = prior[mask_c, :-1].mean(0)
        
        updated_prior = self.prior_ema * self.beta + (1 - self.beta) * cur_status
        updated_prior = F.normalize(updated_prior, dim=1, eps=1e-6)
        self.prior_ema.copy_(updated_prior)
    
    def forward(self, data, is_train=False, mask=None, ignore_index=None):
        p0, f0 = data['pos'], data['x']
        if is_train:
            labels = data['y']
        
        f0 = f0.transpose(1, 2).contiguous()
        f0 = f0.reshape(-1, f0.shape[-1])
        p0 = p0.reshape(-1, 3)

        if labels.dim() > 1:
            # Dynamically select label level based on `num_classes` (supports multi-level datasets)
            lvl_num = labels.shape[-1]
            class_cnts = [(labels[..., i].max().item() + 1) for i in range(lvl_num)]
            candidates = [i for i, cnt in enumerate(class_cnts) if cnt >= self.num_classes]
            if candidates:
                sel_idx = min(candidates, key=lambda i: abs(class_cnts[i] - self.num_classes))
            else:
                sel_idx = min(range(lvl_num), key=lambda i: abs(class_cnts[i] - self.num_classes))
            labels = labels[..., sel_idx].reshape(-1)
        else:
            labels = labels.reshape(-1)
        
       
        prior_chunks = []
        for c in range(self.num_classes):
            idx = torch.where(labels == c)[0]
            if idx.numel() == 0:
                continue

            if idx.numel() < 1024:
                idx = idx.repeat(1024 // idx.numel() + 1)[:1024]
            else:
                idx = idx[torch.randperm(idx.numel(), device=idx.device)[:1024]]

            feat_c = f0[idx, :]
            coord_c = p0[idx, :]

            _, feat_enc_list = self.encoder.forward_all_features(
                coord_c.unsqueeze(0), feat_c.unsqueeze(0).transpose(1, 2))
            feat_enc = feat_enc_list[-1].squeeze(0).transpose(0, 1)  # (n_enc, C)

            feat_proj = self.projection(feat_enc)
            feat_norm = F.normalize(feat_proj, dim=1, eps=1e-6)

            lbl_c = torch.full((feat_norm.shape[0], 1), c, device=feat_norm.device, dtype=torch.float)
            prior_chunks.append(torch.cat([feat_norm, lbl_c], dim=1))

        if not prior_chunks:
            return torch.empty(0, self.projection[2].out_features + 1, device=p0.device), self.prior_ema

        current_prior = torch.cat(prior_chunks, dim=0)
        self._ema(current_prior.detach())
        
        return current_prior, self.prior_ema


@MODELS.register_module()
class BaseSeg_Balance_MainDual(nn.Module):
    """PointNet++ encoder + *L* independent decoders and heads for hierarchical segmentation.

    forward(is_train=False)  -> logits_list
    forward(is_train=True)   -> (logits_list_flat, feats_norm_list, prior_ema_list)
    """

    def __init__(self,
                 beta=0.999,
                 hier_levels: int = 2,
                 num_classes_per_level=None,
                 encoder_args=None,
                 decoder_args=None,
                 cls_head_args=None,
                 **kwargs):
        super().__init__()
        assert encoder_args is not None

        if num_classes_per_level is None:
            num_classes_per_level = [13, 4]
        assert len(num_classes_per_level) == hier_levels

        # 1. shared encoder
        self.encoder = build_model_from_cfg(encoder_args)

        # ensure lists
        if decoder_args is None:
            decoder_args = [{}] * hier_levels
        elif isinstance(decoder_args, dict):
            decoder_args = [decoder_args] * hier_levels

        if cls_head_args is None:
            cls_head_args = [{}] * hier_levels
        elif isinstance(cls_head_args, dict):
            cls_head_args = [cls_head_args] * hier_levels

        # 2. build per-level decoders & heads
        self.decoders = nn.ModuleList()
        self.heads = nn.ModuleList()
        self.priors = nn.ParameterList()
        self.hefms  = nn.ModuleList()
        self.beta = beta
        self.hier_levels = hier_levels
        self.num_classes_per_level = num_classes_per_level
        self.in_channels_per_level = []

        # softmax-guided conditioning config
        self.use_prev_softmax_guidance = kwargs.get('use_prev_softmax_guidance', True)
        self.cond_temperature = float(kwargs.get('cond_temperature', 1.0))
        self.cond_alpha = float(kwargs.get('cond_alpha', 1.0))
        self.detach_prev_guidance = bool(kwargs.get('detach_prev_guidance', True))

        # K for KNN in BottomGuidedAggregation
        self.k_neighbors = kwargs.get('k_neighbors', 40)
        # Efficient batched KNN layer
        self.knn_layer = KNN(self.k_neighbors, farthest=False, sorted=False)

        enc_out_c = getattr(self.encoder, 'out_channels', None)

        for lvl in range(hier_levels):
            # decoder
            dec_cfg = copy.deepcopy(encoder_args)
            dec_cfg.update(decoder_args[lvl])
            dec_cfg.encoder_channel_list = getattr(self.encoder, 'channel_list', None)
            decoder = build_model_from_cfg(dec_cfg)
            self.decoders.append(decoder)

            # head
            head_cfg = copy.deepcopy(cls_head_args[lvl])
            in_c = getattr(decoder, 'out_channels', None) or enc_out_c or head_cfg.get('in_channels')
            head_cfg['in_channels'] = in_c
            head_cfg['num_classes'] = num_classes_per_level[lvl]
            head = build_model_from_cfg(head_cfg)
            self.heads.append(head)
            self.in_channels_per_level.append(in_c)

            # EMA prototype for this level
            prior = torch.randn(num_classes_per_level[lvl], in_c)
            prior = nn.functional.normalize(prior, dim=1)
            self.register_buffer(f'prior_ema_{lvl}', prior)
            self.priors.append(getattr(self, f'prior_ema_{lvl}'))

            # keep track of feature dims for HEFM construction
            feat_dim_lvl = in_c
            if lvl == 0:
                feat_dim_prev = feat_dim_lvl
            else:
                # create HEFM between previous and current level
                self.hefms.append(HEFM(dim_top=feat_dim_prev,
                                        dim_bottom=feat_dim_lvl,
                                        alpha_init=0.01,
                                        tau=1.0))
                feat_dim_prev = feat_dim_lvl

        # conditioning blocks for levels > 0 (dynamic by hier_levels)
        self.cond_blocks = nn.ModuleList()
        for lvl in range(1, hier_levels):
            prev_nc = num_classes_per_level[lvl - 1]
            curr_c = self.in_channels_per_level[lvl]
            block = nn.ModuleDict({
                'proj': nn.Conv1d(prev_nc, curr_c, kernel_size=1, bias=True),
                'fuse': nn.Sequential(
                    nn.Conv1d(curr_c * 2, curr_c, kernel_size=1, bias=False),
                    nn.BatchNorm1d(curr_c),
                    nn.ReLU(inplace=True)
                )
            })
            self.cond_blocks.append(block)

    # ------------------------------------------------------------------
    @torch.no_grad()
    def _ema(self, prior_buffer: torch.Tensor, feats: torch.Tensor, labels: torch.Tensor):
        """Update ema buffer by class-wise mean."""
        if dist.is_initialized():
            feats = concat_all_gather_diff(torch.cat([feats, labels.unsqueeze(1).float()], dim=1))
            labels = feats[:, -1].long()
            feats = feats[:, :-1]
        cur = prior_buffer.clone()
        for c in range(prior_buffer.size(0)):
            idx = labels == c
            if idx.sum() > 0:
                cur[c] = feats[idx].mean(0)
        
        # Correctly update the buffer in-place
        updated_prior = prior_buffer * self.beta + cur * (1 - self.beta)
        updated_prior = nn.functional.normalize(updated_prior, dim=1)
        prior_buffer.copy_(updated_prior)

    # ------------------------------------------------------------------
    def forward(self, data, is_train=False, mask=None, ignore_index=None):
        p0, x0 = data['pos'], data['x']  # [B,N,3] , [B,C,N]
        labels = []
        if is_train:
            y = data['y']
            if y.dim() > 1 and y.shape[-1] > 1:
                for i in range(y.shape[-1]):
                    labels.append(y[..., i])
            else:
                labels.append(y)
        else:
            labels = None

        # encoder
        p_list, feats_enc = self.encoder.forward_all_features(p0, x0)

        feats_list = []
        logits_list = []
        # 1) sequentially run decoders and apply optional softmax-guided fusion from previous level
        for lvl, decoder in enumerate(self.decoders):
            feats = [f.clone() for f in feats_enc]
            feat_lvl = decoder(p_list, feats).squeeze(-1)    # (B,C,N)

            if lvl > 0 and self.use_prev_softmax_guidance:
                prev_logits = logits_list[lvl - 1]
                if self.detach_prev_guidance:
                    prev_logits = prev_logits.detach()
                probs = F.softmax(prev_logits / self.cond_temperature, dim=1)
                block = self.cond_blocks[lvl - 1]
                cond = block['proj'](probs)
                if self.cond_alpha != 1.0:
                    cond = cond * self.cond_alpha
                feat_lvl = block['fuse'](torch.cat([feat_lvl, cond], dim=1))

            feats_list.append(feat_lvl)
            # compute logits for this level now to be available for next level conditioning
            lg_lvl = self.heads[lvl](feat_lvl)
            logits_list.append(lg_lvl)

        # # 2) HEFM two-stage fusion
        # if len(self.hefms) > 0:
        #     neighbor_idx = self._batch_knn(p0, k=self.k_neighbors)  # (B,N,K)

        #     # -------- TopAwareFusion Only (alpha_init=0.2) --------
        #     for pair_idx, hefm in enumerate(self.hefms):
        #         top_orig    = feats_list[pair_idx]
        #         bottom_orig = feats_list[pair_idx + 1]

        #         bottom_ref = self._apply_top_fusion_batch(hefm, top_orig, bottom_orig)
        #         feats_list[pair_idx + 1] = bottom_ref  # update bottom; top remains original

        # 3) logits_list already computed in sequential loop above

        if not is_train:
            logits_flat = [lg.transpose(1,2).reshape(-1, lg.size(1)).contiguous() for lg in logits_list]
            return logits_flat

        logits_flat_list = []
        feats_norm_list  = []
        for lvl, (lg, ft) in enumerate(zip(logits_list, feats_list)):
            B, C, _ = ft.shape
            lg_flat = lg.transpose(1,2).reshape(-1, lg.size(1))
            feat_norm = nn.functional.normalize(ft.transpose(1,2).reshape(-1, C), dim=1, eps=1e-6)
            logits_flat_list.append(lg_flat)
            feats_norm_list.append(feat_norm)

            # EMA update if labels present and labels dim matches
            if is_train and labels is not None:
                label_use = labels[lvl].reshape(-1)
                self._ema(self.priors[lvl], feat_norm.detach(), label_use)

        return logits_flat_list, feats_norm_list, self.priors 

    # ------------------------------------------------------------------
    def _batch_knn(self, pos: torch.Tensor, k: int = 40):
        """Batched KNN using OpenPoints KNN layer."""
        _, idx = self.knn_layer(pos, pos)
        return idx

    def _apply_hefm_batch(self, hefm: HEFM, z_top: torch.Tensor, z_bottom: torch.Tensor, neighbor_idx: torch.Tensor):
        """Apply HEFM for a batch.
        Shapes:
            z_top    : (B,C_top,N)
            z_bottom : (B,C_bottom,N)
            neighbor_idx: (B,N,K)
        Returns:
            (z_top_refined, z_bottom_refined) with same shapes as inputs.
        """
        B, C_top, N = z_top.shape
        C_bottom = z_bottom.size(1)
        K = neighbor_idx.size(-1)

        z_top_out    = torch.zeros_like(z_top)
        z_bottom_out = torch.zeros_like(z_bottom)
        for b in range(B):
            idx_b = neighbor_idx[b]            # [N,K]
            top_b = z_top[b].transpose(0,1)    # [N,C_top]
            bot_b = z_bottom[b].transpose(0,1) # [N,C_bottom]
            top_ref_b, bot_ref_b = hefm(top_b, bot_b, idx_b)  # each [N,C]
            z_top_out[b]    = top_ref_b.transpose(0,1)
            z_bottom_out[b] = bot_ref_b.transpose(0,1)
        return z_top_out, z_bottom_out 

    # ------------------------------------------------------------------
    def _apply_top_fusion_batch(self, hefm: HEFM, z_top: torch.Tensor, z_bottom: torch.Tensor):
        """Batched TopAwareFusion only (no bottom aggregation)."""
        B = z_top.size(0)
        z_bottom_out = torch.zeros_like(z_bottom)
        for b in range(B):
            bot_ref = hefm.top_fusion(z_top[b].transpose(0,1),   # [N,C_top]
                                      z_bottom[b].transpose(0,1))  # [N,C_bot]
            z_bottom_out[b] = bot_ref.transpose(0,1)
        return z_bottom_out

    def _apply_bottom_agg_batch(self, hefm: HEFM, z_top: torch.Tensor, z_bottom: torch.Tensor, neighbor_idx: torch.Tensor):
        """Batched BottomGuidedAggregation only (no top fusion)."""
        B = z_top.size(0)
        z_top_out = torch.zeros_like(z_top)
        for b in range(B):
            top_ref = hefm.bottom_agg(z_top[b].transpose(0,1),   # [N,C_top]
                                      z_bottom[b].transpose(0,1),
                                      neighbor_idx[b])           # [N,K]
            z_top_out[b] = top_ref.transpose(0,1)
        return z_top_out 
