from copy import deepcopy
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch_geometric.nn.pool import voxel_grid
from torch_scatter import segment_csr

import einops
from timm.layers import DropPath
import pointops

from pointcept.models.builder import MODELS
from pointcept.models.utils import offset2batch, batch2offset

def generate_simplex_vectors_with_projection(dimension: int, normalize_rows: bool = True):
    if dimension == 1:
        return torch.tensor([[1.0]], dtype=torch.float32)
    points = torch.eye(dimension + 1, dtype=torch.float32)
    points -= points.mean(dim=0, keepdim=True)
    U, _, _ = torch.linalg.svd(points.T, full_matrices=False)
    reduced = points @ U[:, :-1]
    if normalize_rows:
        reduced = reduced / (reduced.norm(dim=1, keepdim=True) + 1e-12)
    return reduced

def init_nd_freqs(reduced_vectors: torch.Tensor, num_heads: int, rotate: bool = True,
                  device=None, dtype=torch.float32):
    reduced_vectors = reduced_vectors.to(device=device, dtype=dtype)
    M, d = reduced_vectors.shape
    freqs_all = []
    for _ in range(num_heads):
        if rotate:
            Q, _ = torch.linalg.qr(torch.randn(d, d, device=device, dtype=dtype))
            if torch.linalg.det(Q) < 0:
                Q[:, 0] = -Q[:, 0]
            freqs_all.append(reduced_vectors @ Q.T)
        else:
            freqs_all.append(reduced_vectors)
    return torch.stack(freqs_all, dim=0)   # (H, M, d)

def compute_ndrope_cis_3d(freqs, position, head_dim, theta=100.0, phase_offset=None):
    """
    position: (N, S, 3)  relative pos
    freqs:    (H, M, 3)
    return freqs_cis: complex (N, H, S, D/2)
    """
    device, dtype = position.device, position.dtype
    N, S, d = position.shape
    H, M, d2 = freqs.shape
    assert d == d2 == 3

    dim_per_scale = 2 * M
    assert head_dim % dim_per_scale == 0, f"head_dim must be divisible by 2*M. got {head_dim} vs {2*M}"
    Scales = head_dim // dim_per_scale

    mag = 1.0 / (theta ** (torch.arange(Scales, device=device, dtype=dtype) / max(Scales, 1)))

    proj = torch.einsum("nsd,hmd->nhsm", position.to(dtype), freqs.to(device=device, dtype=dtype))  # (N,H,S,M)
    angles = proj.unsqueeze(-1) * mag.view(1, 1, 1, 1, Scales)                                      # (N,H,S,M,Scales)
    angles = angles.reshape(N, H, S, M * Scales)                                                     # (N,H,S,D/2)

    if phase_offset is not None:
        angles = angles + phase_offset.to(device=device, dtype=dtype)  # (1,H,1,D/2) broadcast

    return torch.polar(torch.ones_like(angles), angles)  # complex

# def apply_rotary_emb_relation(x: torch.Tensor, freqs_cis: torch.Tensor):
#     """
#     x: (N, H, S, D) real
#     freqs_cis: (N, H, S, D/2) complex
#     """
#     N, H, S, D = x.shape
#     assert D % 2 == 0
#     D2 = D // 2
#     x_c = torch.view_as_complex(x.float().reshape(N, H, S, D2, 2))
#     x_rot = x_c * freqs_cis
#     return torch.view_as_real(x_rot).reshape(N, H, S, D).type_as(x)
def apply_rotary_emb_relation(x: torch.Tensor, freqs_cis: torch.Tensor):
    """
    x:        (N, H, S, D)   real
    freqs_cis:(N, H, S, D/2) complex
    """
    N, H, S, D = x.shape
    assert D % 2 == 0
    D2 = D // 2

    # ★ 核心修复：contiguous
    x = x.contiguous()

    x_c = torch.view_as_complex(
        x.float().reshape(N, H, S, D2, 2)
    )
    x_rot = x_c * freqs_cis
    out = torch.view_as_real(x_rot).reshape(N, H, S, D)
    return out.type_as(x)

class PointBatchNorm(nn.Module):
    """
    Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C]
    """

    def __init__(self, embed_channels):
        super().__init__()
        self.norm = nn.BatchNorm1d(embed_channels)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if input.dim() == 3:
            return (
                self.norm(input.transpose(1, 2).contiguous())
                .transpose(1, 2)
                .contiguous()
            )
        elif input.dim() == 2:
            return self.norm(input)
        else:
            raise NotImplementedError


class GroupedVectorAttention(nn.Module):
    def __init__(
        self,
        embed_channels,
        groups,
        attn_drop_rate=0.0,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        use_ndrope=True,
        rope_theta=100.0,
        rope_rotate_freqs=True,
        rope_phase_offset=True,
    ):
        super().__init__()
        self.embed_channels = embed_channels
        self.groups = groups
        assert embed_channels % groups == 0
        self.head_dim = embed_channels // groups

        self.use_ndrope = use_ndrope
        self.rope_theta = rope_theta
        self.pe_multiplier = pe_multiplier
        self.pe_bias = pe_bias

        # ---- qkv projection ----
        self.linear_q = nn.Sequential(
            nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True),
        )
        self.linear_k = nn.Sequential(
            nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True),
        )
        self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)

        # ---- optional PE terms (same as original PT-v2) ----
        if self.pe_multiplier:
            self.linear_p_multiplier = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )
        if self.pe_bias:
            self.linear_p_bias = nn.Sequential(
                nn.Linear(3, embed_channels),
                PointBatchNorm(embed_channels),
                nn.ReLU(inplace=True),
                nn.Linear(embed_channels, embed_channels),
            )

        self.attn_drop = nn.Dropout(attn_drop_rate)

        # ---- nD-RoPE setup ----
        if self.use_ndrope:
            reduced = generate_simplex_vectors_with_projection(3, normalize_rows=True)  # (4,3)
            freqs = init_nd_freqs(
                reduced,
                num_heads=groups,
                rotate=rope_rotate_freqs,
            )  # (H, M, 3)
            self.register_buffer("ndrope_freqs", freqs, persistent=True)

            if rope_phase_offset:
                self.phase_offset = nn.Parameter(
                    torch.zeros(1, groups, 1, self.head_dim // 2)
                )
            else:
                self.phase_offset = None

    def forward(self, feat, coord, reference_index):
        """
        feat: (N, C)
        coord: (N, 3)
        reference_index: (N, S)
        """
        # ---- project qkv ----
        q = self.linear_q(feat)    # (N, C)
        k = self.linear_k(feat)
        v = self.linear_v(feat)

        # ---- group neighbors ----
        k = pointops.grouping(reference_index, k, coord, with_xyz=True)
        v = pointops.grouping(reference_index, v, coord, with_xyz=False)

        pos, k = k[:, :, :3], k[:, :, 3:]  # pos: (N,S,3), k: (N,S,C)

        # ---- reshape to multi-head ----
        N, S, C = k.shape
        H, D = self.groups, self.head_dim

        q = q.view(N, H, D)                          # (N,H,D)
        k = k.view(N, S, H, D).permute(0, 2, 1, 3)   # (N,H,S,D)
        v = v.view(N, S, H, D).permute(0, 2, 1, 3)

        # ---- nD-RoPE (rotate q & k using RELATIVE position) ----
        if self.use_ndrope:
            freqs_cis = compute_ndrope_cis_3d(
                self.ndrope_freqs,   # (H,M,3)
                pos,                 # (N,S,3)
                head_dim=D,
                theta=self.rope_theta,
                phase_offset=self.phase_offset,
            )  # (N,H,S,D/2)

            k = apply_rotary_emb_relation(k, freqs_cis)
            q = apply_rotary_emb_relation(
                q.unsqueeze(2), freqs_cis[:, :, :1]
            ).squeeze(2)

        # ---- standard dot-product attention ----
        attn = torch.einsum("nhd,nhsd->nhs", q, k) / math.sqrt(D)
        attn = torch.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        # ---- mask invalid neighbors ----
        mask = torch.sign(reference_index + 1)       # (N,S)
        attn = attn * mask.unsqueeze(1)

        # ---- aggregate ----
        out = torch.einsum("nhs,nhsd->nhd", attn, v)
        out = out.reshape(N, C)

        # ---- optional PE bias (same behavior as PT-v2) ----
        if self.pe_bias:
            out = out + self.linear_p_bias(coord)

        return out


class Block(nn.Module):
    def __init__(
        self,
        embed_channels,
        groups,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        use_ndrope=True,
        rope_theta=100.0,
    ):
        super(Block, self).__init__()
        self.attn = GroupedVectorAttention(
            embed_channels=embed_channels,
            groups=groups,
            qkv_bias=qkv_bias,
            attn_drop_rate=attn_drop_rate,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            use_ndrope=use_ndrope,
            rope_theta = rope_theta,
        )
        self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False)
        self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False)
        self.norm1 = PointBatchNorm(embed_channels)
        self.norm2 = PointBatchNorm(embed_channels)
        self.norm3 = PointBatchNorm(embed_channels)
        self.act = nn.ReLU(inplace=True)
        self.enable_checkpoint = enable_checkpoint
        self.drop_path = (
            DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
        )

    def forward(self, points, reference_index):
        coord, feat, offset = points
        identity = feat
        feat = self.act(self.norm1(self.fc1(feat)))
        feat = (
            self.attn(feat, coord, reference_index)
            if not self.enable_checkpoint
            else checkpoint(self.attn, feat, coord, reference_index)
        )
        feat = self.act(self.norm2(feat))
        feat = self.norm3(self.fc3(feat))
        feat = identity + self.drop_path(feat)
        feat = self.act(feat)
        return [coord, feat, offset]


class BlockSequence(nn.Module):
    def __init__(
        self,
        depth,
        embed_channels,
        groups,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        use_ndrope=True, rope_theta=100.0
    ):
        super(BlockSequence, self).__init__()

        if isinstance(drop_path_rate, list):
            drop_path_rates = drop_path_rate
            assert len(drop_path_rates) == depth
        elif isinstance(drop_path_rate, float):
            drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
        else:
            drop_path_rates = [0.0 for _ in range(depth)]

        self.neighbours = neighbours
        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = Block(
                embed_channels=embed_channels,
                groups=groups,
                qkv_bias=qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=drop_path_rates[i],
                enable_checkpoint=enable_checkpoint,
                use_ndrope=use_ndrope,
                rope_theta=rope_theta,
            )
            self.blocks.append(block)

    def forward(self, points):
        coord, feat, offset = points
        # reference index query of neighbourhood attention
        # for windows attention, modify reference index query method
        reference_index, _ = pointops.knn_query(self.neighbours, coord, offset)
        for block in self.blocks:
            points = block(points, reference_index)
        return points


class GridPool(nn.Module):
    """
    Partition-based Pooling (Grid Pooling)
    """

    def __init__(self, in_channels, out_channels, grid_size, bias=False):
        super(GridPool, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_size = grid_size

        self.fc = nn.Linear(in_channels, out_channels, bias=bias)
        self.norm = PointBatchNorm(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, points, start=None):
        coord, feat, offset = points
        batch = offset2batch(offset)
        feat = self.act(self.norm(self.fc(feat)))
        start = (
            segment_csr(
                coord,
                torch.cat([batch.new_zeros(1), torch.cumsum(batch.bincount(), dim=0)]),
                reduce="min",
            )
            if start is None
            else start
        )
        cluster = voxel_grid(
            pos=coord - start[batch], size=self.grid_size, batch=batch, start=0
        )
        unique, cluster, counts = torch.unique(
            cluster, sorted=True, return_inverse=True, return_counts=True
        )
        _, sorted_cluster_indices = torch.sort(cluster)
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        coord = segment_csr(coord[sorted_cluster_indices], idx_ptr, reduce="mean")
        feat = segment_csr(feat[sorted_cluster_indices], idx_ptr, reduce="max")
        batch = batch[idx_ptr[:-1]]
        offset = batch2offset(batch)
        return [coord, feat, offset], cluster


class UnpoolWithSkip(nn.Module):
    """
    Map Unpooling with skip connection
    """

    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        bias=True,
        skip=True,
        backend="map",
    ):
        super(UnpoolWithSkip, self).__init__()
        self.in_channels = in_channels
        self.skip_channels = skip_channels
        self.out_channels = out_channels
        self.skip = skip
        self.backend = backend
        assert self.backend in ["map", "interp"]

        self.proj = nn.Sequential(
            nn.Linear(in_channels, out_channels, bias=bias),
            PointBatchNorm(out_channels),
            nn.ReLU(inplace=True),
        )
        self.proj_skip = nn.Sequential(
            nn.Linear(skip_channels, out_channels, bias=bias),
            PointBatchNorm(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, points, skip_points, cluster=None):
        coord, feat, offset = points
        skip_coord, skip_feat, skip_offset = skip_points
        if self.backend == "map" and cluster is not None:
            feat = self.proj(feat)[cluster]
        else:
            feat = pointops.interpolation(
                coord, skip_coord, self.proj(feat), offset, skip_offset
            )
        if self.skip:
            feat = feat + self.proj_skip(skip_feat)
        return [skip_coord, feat, skip_offset]


class Encoder(nn.Module):
    def __init__(
        self,
        depth,
        in_channels,
        embed_channels,
        groups,
        grid_size=None,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=None,
        drop_path_rate=None,
        enable_checkpoint=False,
        use_ndrope=True, 
        rope_theta=100.0
    ):
        super(Encoder, self).__init__()

        self.down = GridPool(
            in_channels=in_channels,
            out_channels=embed_channels,
            grid_size=grid_size,
        )

        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate if attn_drop_rate is not None else 0.0,
            drop_path_rate=drop_path_rate if drop_path_rate is not None else 0.0,
            enable_checkpoint=enable_checkpoint,
            use_ndrope=use_ndrope,
            rope_theta=rope_theta,
        )

    def forward(self, points):
        points, cluster = self.down(points)
        return self.blocks(points), cluster


class Decoder(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        embed_channels,
        groups,
        depth,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=None,
        drop_path_rate=None,
        enable_checkpoint=False,
        unpool_backend="map",
        use_ndrope=True,
        rope_theta = 100.0
    ):
        super(Decoder, self).__init__()

        self.up = UnpoolWithSkip(
            in_channels=in_channels,
            out_channels=embed_channels,
            skip_channels=skip_channels,
            backend=unpool_backend,
        )

        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate if attn_drop_rate is not None else 0.0,
            drop_path_rate=drop_path_rate if drop_path_rate is not None else 0.0,
            enable_checkpoint=enable_checkpoint,
            use_ndrope=use_ndrope,
            rope_theta=rope_theta,
        )

    def forward(self, points, skip_points, cluster):
        points = self.up(points, skip_points, cluster)
        return self.blocks(points)


class GVAPatchEmbed(nn.Module):
    def __init__(
        self,
        depth,
        in_channels,
        embed_channels,
        groups,
        neighbours=16,
        qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        enable_checkpoint=False,
        use_ndrope=True,
        rope_theta=100.0,
    ):
        super(GVAPatchEmbed, self).__init__()
        self.in_channels = in_channels
        self.embed_channels = embed_channels
        self.proj = nn.Sequential(
            nn.Linear(in_channels, embed_channels, bias=False),
            PointBatchNorm(embed_channels),
            nn.ReLU(inplace=True),
        )
        self.blocks = BlockSequence(
            depth=depth,
            embed_channels=embed_channels,
            groups=groups,
            neighbours=neighbours,
            qkv_bias=qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            enable_checkpoint=enable_checkpoint,
            use_ndrope=use_ndrope,
            rope_theta=rope_theta,
        )

    def forward(self, points):
        coord, feat, offset = points
        feat = self.proj(feat)
        return self.blocks([coord, feat, offset])


@MODELS.register_module("PT-v2m2-ndrope-original")
class PointTransformerV2(nn.Module):
    def __init__(
        self,
        in_channels,
        num_classes,
        patch_embed_depth=1,
        patch_embed_channels=48,
        patch_embed_groups=6,
        patch_embed_neighbours=8,
        enc_depths=(2, 2, 6, 2),
        enc_channels=(96, 192, 384, 512),
        enc_groups=(12, 24, 48, 64),
        enc_neighbours=(16, 16, 16, 16),
        dec_depths=(1, 1, 1, 1),
        dec_channels=(48, 96, 192, 384),
        dec_groups=(6, 12, 24, 48),
        dec_neighbours=(16, 16, 16, 16),
        grid_sizes=(0.06, 0.12, 0.24, 0.48),
        attn_qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0,
        enable_checkpoint=False,
        unpool_backend="map",
        use_ndrope=True,
        rope_theta=100.0,
    ):
        super(PointTransformerV2, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_stages = len(enc_depths)
        assert self.num_stages == len(dec_depths)
        assert self.num_stages == len(enc_channels)
        assert self.num_stages == len(dec_channels)
        assert self.num_stages == len(enc_groups)
        assert self.num_stages == len(dec_groups)
        assert self.num_stages == len(enc_neighbours)
        assert self.num_stages == len(dec_neighbours)
        assert self.num_stages == len(grid_sizes)
        self.patch_embed = GVAPatchEmbed(
            in_channels=in_channels,
            embed_channels=patch_embed_channels,
            groups=patch_embed_groups,
            depth=patch_embed_depth,
            neighbours=patch_embed_neighbours,
            qkv_bias=attn_qkv_bias,
            pe_multiplier=pe_multiplier,
            pe_bias=pe_bias,
            attn_drop_rate=attn_drop_rate,
            enable_checkpoint=enable_checkpoint,
            use_ndrope=use_ndrope,
            rope_theta=rope_theta,
        )

        enc_dp_rates = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(enc_depths))
        ]
        dec_dp_rates = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(dec_depths))
        ]
        enc_channels = [patch_embed_channels] + list(enc_channels)
        dec_channels = list(dec_channels) + [enc_channels[-1]]
        self.enc_stages = nn.ModuleList()
        self.dec_stages = nn.ModuleList()
        for i in range(self.num_stages):
            enc = Encoder(
                depth=enc_depths[i],
                in_channels=enc_channels[i],
                embed_channels=enc_channels[i + 1],
                groups=enc_groups[i],
                grid_size=grid_sizes[i],
                neighbours=enc_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=enc_dp_rates[
                    sum(enc_depths[:i]) : sum(enc_depths[: i + 1])
                ],
                enable_checkpoint=enable_checkpoint,
                use_ndrope=use_ndrope,
                rope_theta=rope_theta,
            )
            dec = Decoder(
                depth=dec_depths[i],
                in_channels=dec_channels[i + 1],
                skip_channels=enc_channels[i],
                embed_channels=dec_channels[i],
                groups=dec_groups[i],
                neighbours=dec_neighbours[i],
                qkv_bias=attn_qkv_bias,
                pe_multiplier=pe_multiplier,
                pe_bias=pe_bias,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=dec_dp_rates[
                    sum(dec_depths[:i]) : sum(dec_depths[: i + 1])
                ],
                enable_checkpoint=enable_checkpoint,
                unpool_backend=unpool_backend,
                use_ndrope=use_ndrope,
                rope_theta=rope_theta,
            )
            self.enc_stages.append(enc)
            self.dec_stages.append(dec)
        self.seg_head = (
            nn.Sequential(
                nn.Linear(dec_channels[0], dec_channels[0]),
                PointBatchNorm(dec_channels[0]),
                nn.ReLU(inplace=True),
                nn.Linear(dec_channels[0], num_classes),
            )
            if num_classes > 0
            else nn.Identity()
        )

    def forward(self, data_dict):
        coord = data_dict["coord"]
        feat = data_dict["feat"]
        offset = data_dict["offset"].int()

        # a batch of point cloud is a list of coord, feat and offset
        points = [coord, feat, offset]
        points = self.patch_embed(points)
        skips = [[points]]
        for i in range(self.num_stages):
            points, cluster = self.enc_stages[i](points)
            skips[-1].append(cluster)  # record grid cluster of pooling
            skips.append([points])  # record points info of current stage

        points = skips.pop(-1)[0]  # unpooling points info in the last enc stage
        for i in reversed(range(self.num_stages)):
            skip_points, cluster = skips.pop(-1)
            points = self.dec_stages[i](points, skip_points, cluster)
        coord, feat, offset = points
        seg_logits = self.seg_head(feat)
        return seg_logits
