"""
Point Transformer V2 Mode 2 (recommend) + 4x RoPE + APE (for ablations)

RoPE modes:
- "none"  : baseline (original PTv2-m2 behavior)
- "axial" : 3D axial RoPE (split pairs across x/y/z)
- "mixed" : 3D mixed RoPE (learnable group-wise freqs)
- "proj"  : 3D projected RoPE (learnable group-wise direction, then 1D RoPE)

APE:
- use_ape=True injects absolute position embedding to Q/K input only (recommended).
"""

from copy import deepcopy
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import einops
from timm.layers import DropPath
from torch_geometric.nn.pool import voxel_grid
from torch_scatter import segment_csr

import pointops
from pointcept.models.builder import MODELS
from pointcept.models.utils import offset2batch, batch2offset


# =========================================================
# RoPE Utils (group-as-head)
# =========================================================

def _apply_rotary_grouped(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """
    x: (N, G, D) real
    freqs_cis: (N, G, D/2) complex
    """
    N, G, D = x.shape
    assert D % 2 == 0
    x_ = torch.view_as_complex(x.float().reshape(N, G, D // 2, 2))
    out = torch.view_as_real(x_ * freqs_cis).reshape(N, G, D)
    return out.type_as(x)


def _axial_cis_3d(coord: torch.Tensor, G: int, head_dim: int, theta: float = 100.0) -> torch.Tensor:
    """
    coord: (N, 3)
    return: (N, G, head_dim/2) complex
    """
    device = coord.device
    N = coord.shape[0]
    assert head_dim % 2 == 0
    P = head_dim // 2

    px = P // 3
    py = P // 3
    pz = P - px - py

    def make(t: torch.Tensor, n_pairs: int):
        if n_pairs == 0:
            return None
        freqs = 1.0 / (theta ** (torch.arange(n_pairs, device=device).float() / n_pairs))
        phase = t.unsqueeze(-1) * freqs  # (N, n_pairs)
        return torch.polar(torch.ones_like(phase), phase)  # complex

    cis_x = make(coord[:, 0], px)
    cis_y = make(coord[:, 1], py)
    cis_z = make(coord[:, 2], pz)

    cis = torch.cat([c for c in (cis_x, cis_y, cis_z) if c is not None], dim=-1)  # (N, P)
    return cis.unsqueeze(1).expand(-1, G, -1)  # (N, G, P)


def _init_mixed_freqs_3d(head_dim: int, G: int, theta: float = 10.0) -> torch.Tensor:
    """
    return freqs: (3, G, head_dim/2) real
    learnable parameter init
    """
    assert head_dim % 2 == 0
    P = head_dim // 2
    mag = 1.0 / (theta ** (torch.arange(P).float() / P))  # (P,)

    fx_list, fy_list, fz_list = [], [], []
    for _ in range(G):
        a = torch.rand(1) * 2 * torch.pi
        fx_list.append(mag * torch.cos(a))                 # (P,)
        fy_list.append(mag * torch.sin(a))                 # (P,)
        fz_list.append(mag * torch.cos(a + torch.pi / 4))  # (P,)

    fx = torch.stack(fx_list, dim=0)  # (G,P)
    fy = torch.stack(fy_list, dim=0)
    fz = torch.stack(fz_list, dim=0)
    return torch.stack([fx, fy, fz], dim=0)  # (3,G,P)


def _mixed_cis_3d(coord: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    """
    coord: (N, 3)
    freqs: (3, G, P)
    return: (N, G, P) complex
    """
    # (N,1) * (G,P) -> (N,G,P)
    px = coord[:, 0:1] * freqs[0]
    py = coord[:, 1:2] * freqs[1]
    pz = coord[:, 2:3] * freqs[2]
    phase = px + py + pz
    return torch.polar(torch.ones_like(phase), phase)


def _proj_cis_3d(coord: torch.Tensor, proj: torch.Tensor, head_dim: int, theta: float = 100.0) -> torch.Tensor:
    """
    coord: (N,3)
    proj: (G,3) learnable direction per group
    return: (N,G,head_dim/2) complex
    """
    device = coord.device
    assert head_dim % 2 == 0
    P = head_dim // 2
    freqs = 1.0 / (theta ** (torch.arange(P, device=device).float() / P))  # (P,)
    t = coord @ proj.t()  # (N,G)
    phase = t.unsqueeze(-1) * freqs.view(1, 1, P)  # (N,G,P)
    return torch.polar(torch.ones_like(phase), phase)


# =========================================================
# Norm
# =========================================================

class PointBatchNorm(nn.Module):
    """
    Batch Normalization for Point Clouds data in shape of [B*N, C] or [B*N, L, C]
    """
    def __init__(self, embed_channels: int):
        super().__init__()
        self.norm = nn.BatchNorm1d(embed_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 3:
            return self.norm(x.transpose(1, 2).contiguous()).transpose(1, 2).contiguous()
        elif x.dim() == 2:
            return self.norm(x)
        else:
            raise NotImplementedError


# =========================================================
# GVA + (RoPE + APE)
# =========================================================

class GroupedVectorAttention(nn.Module):
    def __init__(
        self,
        embed_channels: int,
        groups: int,
        attn_drop_rate: float = 0.0,
        qkv_bias: bool = True,
        pe_multiplier: bool = False,
        pe_bias: bool = True,
        # --- New ---
        rope_mode: str = "none",         # none|axial|mixed|proj
        rope_theta: float = 100.0,       # axial/proj theta
        rope_mixed_theta: float = 10.0,  # mixed init theta
        use_ape: bool = False,           # APE for Q/K input only
    ):
        super().__init__()
        self.embed_channels = embed_channels
        self.groups = groups
        assert embed_channels % groups == 0
        self.head_dim = embed_channels // groups
        assert self.head_dim % 2 == 0, "RoPE needs even head_dim per group."

        self.attn_drop_rate = attn_drop_rate
        self.qkv_bias = qkv_bias
        self.pe_multiplier = pe_multiplier
        self.pe_bias = pe_bias

        self.rope_mode = rope_mode
        self.rope_theta = rope_theta
        self.rope_mixed_theta = rope_mixed_theta
        self.use_ape = use_ape

        # Q K V
        self.linear_q = nn.Sequential(
            nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True),
        )
        self.linear_k = nn.Sequential(
            nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True),
        )
        self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)

        # APE for Q/K only (recommended)
        if self.use_ape:
            self.ape = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )

        # RoPE params
        if self.rope_mode == "mixed":
            freqs = _init_mixed_freqs_3d(self.head_dim, groups, theta=rope_mixed_theta)
            self.rope_freqs = nn.Parameter(freqs, requires_grad=True)  # (3,G,P)
        elif self.rope_mode == "proj":
            self.rope_proj = nn.Parameter(torch.randn(groups, 3) * 0.02, requires_grad=True)

        # Relative position enc (original PTv2)
        if self.pe_multiplier:
            self.linear_p_multiplier = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )
        if self.pe_bias:
            self.linear_p_bias = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )

        self.weight_encoding = nn.Sequential(
            nn.Linear(embed_channels, groups),
            PointBatchNorm(groups),
            nn.ReLU(inplace=True),
            nn.Linear(groups, groups),
        )
        self.softmax = nn.Softmax(dim=1)  # neighbor dim
        self.attn_drop = nn.Dropout(attn_drop_rate)

    def _apply_rope_qk(self, q: torch.Tensor, k: torch.Tensor, coord: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        q,k: (N,C), coord: (N,3)
        """
        N, C = q.shape
        G = self.groups
        D = C // G

        qg = q.view(N, G, D)
        kg = k.view(N, G, D)

        if self.rope_mode == "axial":
            cis = _axial_cis_3d(coord, G, D, theta=self.rope_theta)
        elif self.rope_mode == "mixed":
            cis = _mixed_cis_3d(coord, self.rope_freqs)
        elif self.rope_mode == "proj":
            cis = _proj_cis_3d(coord, self.rope_proj, D, theta=self.rope_theta)
        else:
            raise ValueError(f"Unknown rope_mode: {self.rope_mode}")

        q = _apply_rotary_grouped(qg, cis).reshape(N, C)
        k = _apply_rotary_grouped(kg, cis).reshape(N, C)
        return q, k

    def forward(self, feat: torch.Tensor, coord: torch.Tensor, reference_index: torch.Tensor) -> torch.Tensor:
        """
        feat: (N,C)
        coord: (N,3)
        reference_index: (N, nsamples) with -1 for invalid (pointops convention)
        """
        # ---- APE: only for Q/K input ----
        if self.use_ape:
            feat_qk = feat + self.ape(coord)
        else:
            feat_qk = feat

        # ---- Q K V ----
        query = self.linear_q(feat_qk)  # (N,C)
        key   = self.linear_k(feat_qk)  # (N,C)
        value = self.linear_v(feat)     # (N,C)  (V does NOT use APE)

        # ---- RoPE on Q/K (before grouping) ----
        if self.rope_mode != "none":
            query, key = self._apply_rope_qk(query, key, coord)

        # ---- Group neighbors ----
        # key with xyz: (N, ns, 3+C), value: (N, ns, C)
        key = pointops.grouping(reference_index, key, coord, with_xyz=True)
        value = pointops.grouping(reference_index, value, coord, with_xyz=False)

        pos, key = key[:, :, 0:3], key[:, :, 3:]
        relation_qk = key - query.unsqueeze(1)

        if self.pe_multiplier:
            pem = self.linear_p_multiplier(pos)
            relation_qk = relation_qk * pem

        if self.pe_bias:
            peb = self.linear_p_bias(pos)
            relation_qk = relation_qk + peb
            value = value + peb

        # ---- Weights ----
        weight = self.weight_encoding(relation_qk)      # (N, ns, G)
        weight = self.attn_drop(self.softmax(weight))   # softmax over ns dim=1

        # keep original invalid-neighbor mask behavior
        mask = torch.sign(reference_index + 1)          # valid->1, invalid(-1)->0
        weight = torch.einsum("n s g, n s -> n s g", weight, mask)

        # ---- Aggregate ----
        value = einops.rearrange(value, "n ns (g d) -> n ns g d", g=self.groups)
        out = torch.einsum("n s g d, n s g -> n g d", value, weight)
        out = einops.rearrange(out, "n g d -> n (g d)")
        return out


# =========================================================
# Block / Sequences / Pooling / Unpooling
# =========================================================

class Block(nn.Module):
    def __init__(
        self,
        embed_channels,
        groups,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        # --- new ---
        rope_mode="none",
        rope_theta=100.0,
        rope_mixed_theta=10.0,
        use_ape=False,
    ):
        super().__init__()
        self.attn = GroupedVectorAttention(
            embed_channels=embed_channels,
            groups=groups,
            qkv_bias=qkv_bias,
            attn_drop_rate=attn_drop_rate,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            rope_mode=rope_mode,
            rope_theta=rope_theta,
            rope_mixed_theta=rope_mixed_theta,
            use_ape=use_ape,
        )

        self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False)
        self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False)
        self.norm1 = PointBatchNorm(embed_channels)
        self.norm2 = PointBatchNorm(embed_channels)
        self.norm3 = PointBatchNorm(embed_channels)
        self.act = nn.ReLU(inplace=True)
        self.enable_checkpoint = enable_checkpoint
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()

    def forward(self, points, reference_index):
        coord, feat, offset = points
        identity = feat

        feat = self.act(self.norm1(self.fc1(feat)))
        feat = self.attn(feat, coord, reference_index) if not self.enable_checkpoint else checkpoint(self.attn, feat, coord, reference_index)
        feat = self.act(self.norm2(feat))
        feat = self.norm3(self.fc3(feat))

        feat = identity + self.drop_path(feat)
        feat = self.act(feat)
        return [coord, feat, offset]


class BlockSequence(nn.Module):
    def __init__(
        self,
        depth,
        embed_channels,
        groups,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        # --- new ---
        rope_mode="none",
        rope_theta=100.0,
        rope_mixed_theta=10.0,
        use_ape=False,
    ):
        super().__init__()

        if isinstance(drop_path_rate, list):
            drop_path_rates = drop_path_rate
            assert len(drop_path_rates) == depth
        elif isinstance(drop_path_rate, float):
            drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
        else:
            drop_path_rates = [0.0 for _ in range(depth)]

        self.neighbours = neighbours
        self.blocks = nn.ModuleList()
        for i in range(depth):
            self.blocks.append(
                Block(
                    embed_channels=embed_channels,
                    groups=groups,
                    qkv_bias=qkv_bias,
                    pe_multiplier=pe_multiplier,
                    pe_bias=pe_bias,
                    attn_drop_rate=attn_drop_rate,
                    drop_path_rate=drop_path_rates[i],
                    enable_checkpoint=enable_checkpoint,
                    rope_mode=rope_mode,
                    rope_theta=rope_theta,
                    rope_mixed_theta=rope_mixed_theta,
                    use_ape=use_ape,
                )
            )

    def forward(self, points):
        coord, feat, offset = points
        reference_index, _ = pointops.knn_query(self.neighbours, coord, offset)
        for blk in self.blocks:
            points = blk(points, reference_index)
        return points


class GridPool(nn.Module):
    def __init__(self, in_channels, out_channels, grid_size, bias=False):
        super().__init__()
        self.grid_size = grid_size
        self.fc = nn.Linear(in_channels, out_channels, bias=bias)
        self.norm = PointBatchNorm(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, points, start=None):
        coord, feat, offset = points
        batch = offset2batch(offset)
        feat = self.act(self.norm(self.fc(feat)))

        start = (
            segment_csr(
                coord,
                torch.cat([batch.new_zeros(1), torch.cumsum(batch.bincount(), dim=0)]),
                reduce="min",
            )
            if start is None
            else start
        )

        cluster = voxel_grid(pos=coord - start[batch], size=self.grid_size, batch=batch, start=0)
        unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True)

        _, sorted_cluster_indices = torch.sort(cluster)
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])

        coord = segment_csr(coord[sorted_cluster_indices], idx_ptr, reduce="mean")
        feat = segment_csr(feat[sorted_cluster_indices], idx_ptr, reduce="max")

        batch = batch[idx_ptr[:-1]]
        offset = batch2offset(batch)
        return [coord, feat, offset], cluster


class UnpoolWithSkip(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels, bias=True, skip=True, backend="map"):
        super().__init__()
        assert backend in ["map", "interp"]
        self.skip = skip
        self.backend = backend

        self.proj = nn.Sequential(
            nn.Linear(in_channels, out_channels, bias=bias),
            PointBatchNorm(out_channels),
            nn.ReLU(inplace=True),
        )
        self.proj_skip = nn.Sequential(
            nn.Linear(skip_channels, out_channels, bias=bias),
            PointBatchNorm(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, points, skip_points, cluster=None):
        coord, feat, offset = points
        skip_coord, skip_feat, skip_offset = skip_points

        if self.backend == "map" and cluster is not None:
            feat = self.proj(feat)[cluster]
        else:
            feat = pointops.interpolation(coord, skip_coord, self.proj(feat), offset, skip_offset)

        if self.skip:
            feat = feat + self.proj_skip(skip_feat)

        return [skip_coord, feat, skip_offset]


class Encoder(nn.Module):
    def __init__(
        self,
        depth,
        in_channels,
        embed_channels,
        groups,
        grid_size=None,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        # --- new ---
        rope_mode="none",
        rope_theta=100.0,
        rope_mixed_theta=10.0,
        use_ape=False,
    ):
        super().__init__()
        self.down = GridPool(in_channels=in_channels, out_channels=embed_channels, grid_size=grid_size)
        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            enable_checkpoint=enable_checkpoint,
            rope_mode=rope_mode,
            rope_theta=rope_theta,
            rope_mixed_theta=rope_mixed_theta,
            use_ape=use_ape,
        )

    def forward(self, points):
        points, cluster = self.down(points)
        return self.blocks(points), cluster


class Decoder(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        embed_channels,
        groups,
        depth,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        unpool_backend="map",
        # --- new ---
        rope_mode="none",
        rope_theta=100.0,
        rope_mixed_theta=10.0,
        use_ape=False,
    ):
        super().__init__()

        self.up = UnpoolWithSkip(
            in_channels=in_channels,
            out_channels=embed_channels,
            skip_channels=skip_channels,
            backend=unpool_backend,
        )

        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            enable_checkpoint=enable_checkpoint,
            rope_mode=rope_mode,
            rope_theta=rope_theta,
            rope_mixed_theta=rope_mixed_theta,
            use_ape=use_ape,
        )

    def forward(self, points, skip_points, cluster):
        points = self.up(points, skip_points, cluster)
        return self.blocks(points)


class GVAPatchEmbed(nn.Module):
    def __init__(
        self,
        depth,
        in_channels,
        embed_channels,
        groups,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        # --- new ---
        rope_mode="none",
        rope_theta=100.0,
        rope_mixed_theta=10.0,
        use_ape=False,
    ):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_channels, embed_channels, bias=False),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True),
        )
        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            enable_checkpoint=enable_checkpoint,
            rope_mode=rope_mode,
            rope_theta=rope_theta,
            rope_mixed_theta=rope_mixed_theta,
            use_ape=use_ape,
        )

    def forward(self, points):
        coord, feat, offset = points
        feat = self.proj(feat)
        return self.blocks([coord, feat, offset])


# =========================================================
# Model
# =========================================================

@MODELS.register_module("PT-v2m2-rope-axial")
class PointTransformerV2(nn.Module):
    def __init__(
        self,
        in_channels,
        num_classes,
        patch_embed_depth=1,
        patch_embed_channels=48,
        patch_embed_groups=6,
        patch_embed_neighbours=8,
        enc_depths=(2, 2, 6, 2),
        enc_channels=(96, 192, 384, 512),
        enc_groups=(12, 24, 48, 64),
        enc_neighbours=(16, 16, 16, 16),
        dec_depths=(1, 1, 1, 1),
        dec_channels=(48, 96, 192, 384),
        dec_groups=(6, 12, 24, 48),
        dec_neighbours=(16, 16, 16, 16),
        grid_sizes=(0.06, 0.12, 0.24, 0.48),
        attn_qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0,
        enable_checkpoint=False,
        unpool_backend="map",
        # --- new ---
        rope_mode="axial",
        rope_theta=100.0,
        rope_mixed_theta=10.0,
        use_ape=False,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_stages = len(enc_depths)

        assert self.num_stages == len(dec_depths)
        assert self.num_stages == len(enc_channels)
        assert self.num_stages == len(dec_channels)
        assert self.num_stages == len(enc_groups)
        assert self.num_stages == len(dec_groups)
        assert self.num_stages == len(enc_neighbours)
        assert self.num_stages == len(dec_neighbours)
        assert self.num_stages == len(grid_sizes)

        # patch embed
        self.patch_embed = GVAPatchEmbed(
            in_channels=in_channels,
            embed_channels=patch_embed_channels,
            groups=patch_embed_groups,
            depth=patch_embed_depth,
            neighbours=patch_embed_neighbours,
            qkv_bias=attn_qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            enable_checkpoint=enable_checkpoint,
            rope_mode=rope_mode,
            rope_theta=rope_theta,
            rope_mixed_theta=rope_mixed_theta,
            use_ape=use_ape,
        )

        enc_dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(enc_depths))]
        dec_dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(dec_depths))]

        enc_channels = [patch_embed_channels] + list(enc_channels)
        dec_channels = list(dec_channels) + [enc_channels[-1]]

        self.enc_stages = nn.ModuleList()
        self.dec_stages = nn.ModuleList()

        for i in range(self.num_stages):
            enc = Encoder(
                depth=enc_depths[i],
                in_channels=enc_channels[i],
                embed_channels=enc_channels[i + 1],
                groups=enc_groups[i],
                grid_size=grid_sizes[i],
                neighbours=enc_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=enc_dp_rates[sum(enc_depths[:i]): sum(enc_depths[: i + 1])],
                enable_checkpoint=enable_checkpoint,
                rope_mode=rope_mode,
                rope_theta=rope_theta,
                rope_mixed_theta=rope_mixed_theta,
                use_ape=use_ape,
            )
            dec = Decoder(
                depth=dec_depths[i],
                in_channels=dec_channels[i + 1],
                skip_channels=enc_channels[i],
                embed_channels=dec_channels[i],
                groups=dec_groups[i],
                neighbours=dec_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=dec_dp_rates[sum(dec_depths[:i]): sum(dec_depths[: i + 1])],
                enable_checkpoint=enable_checkpoint,
                unpool_backend=unpool_backend,
                rope_mode=rope_mode,
                rope_theta=rope_theta,
                rope_mixed_theta=rope_mixed_theta,
                use_ape=use_ape,
            )

            self.enc_stages.append(enc)
            self.dec_stages.append(dec)

        self.seg_head = (
            nn.Sequential(
                nn.Linear(dec_channels[0], dec_channels[0]),
                PointBatchNorm(dec_channels[0]),
                nn.ReLU(inplace=True),
                nn.Linear(dec_channels[0], num_classes),
            )
            if num_classes > 0
            else nn.Identity()
        )

    def forward(self, data_dict):
        coord = data_dict["coord"]
        feat = data_dict["feat"]
        offset = data_dict["offset"].int()

        points = [coord, feat, offset]
        points = self.patch_embed(points)

        skips = [[points]]
        for i in range(self.num_stages):
            points, cluster = self.enc_stages[i](points)
            skips[-1].append(cluster)
            skips.append([points])

        points = skips.pop(-1)[0]
        for i in reversed(range(self.num_stages)):
            skip_points, cluster = skips.pop(-1)
            points = self.dec_stages[i](points, skip_points, cluster)

        coord, feat, offset = points
        seg_logits = self.seg_head(feat)
        return seg_logits