"""GISA Decoder: Global Isotropy Semantic Aggregation"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .scan_utils import morton_sort, hilbert_sort, l2norm
from .gates import RMSNormGated


class GatedDeltaNetBlock(nn.Module):
    """Gated DeltaNet Block (Based on Qwen3-Next)"""

    def __init__(self, dim, hidden_dim=None, num_heads=4, expand_ratio=2):
        super().__init__()

        self.dim = dim
        self.hidden_dim = hidden_dim or dim * expand_ratio
        self.num_heads = num_heads
        self.head_dim = self.hidden_dim // num_heads

        assert self.hidden_dim % num_heads == 0

        # QKV + Z projection
        self.qkv_proj = nn.Linear(dim, self.hidden_dim * 3, bias=False)
        self.z_proj = nn.Linear(dim, self.hidden_dim, bias=False)

        # Beta and A projection (per-head)
        self.ba_proj = nn.Linear(dim, num_heads * 2, bias=False)

        # Time step bias
        self.dt_bias = nn.Parameter(torch.ones(num_heads))

        # A_log parameter
        self.A_log = nn.Parameter(torch.zeros(num_heads).uniform_(0, 16).log())

        # Output projection
        self.out_proj = nn.Linear(self.hidden_dim, dim, bias=False)

        # Normalization
        self.norm = RMSNormGated(self.head_dim)
        self.final_norm = nn.LayerNorm(dim)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.qkv_proj.weight, gain=0.1)
        nn.init.xavier_uniform_(self.z_proj.weight, gain=0.1)
        nn.init.xavier_uniform_(self.ba_proj.weight, gain=0.1)
        nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)

    def forward(self, x):
        N = x.shape[0]
        residual = x

        # Process in chunks to avoid OOM
        max_chunk_size = 30000

        if N <= max_chunk_size:
            out = self._forward_chunk(x)
        else:
            outputs = []
            for start in range(0, N, max_chunk_size):
                end = min(start + max_chunk_size, N)
                chunk_out = self._forward_chunk(x[start:end])
                outputs.append(chunk_out)
            out = torch.cat(outputs, dim=0)

        # Residual + LayerNorm
        out = self.final_norm(out + residual)
        return out

    def _forward_chunk(self, x):
        """Gated Delta Rule computation"""
        M = x.shape[0]
        device = x.device
        dtype = x.dtype

        # QKV projection
        qkv = self.qkv_proj(x)  # (M, hidden_dim * 3)
        q, k, v = qkv.chunk(3, dim=-1)  # Each (M, hidden_dim)

        # Z gate
        z = self.z_proj(x)  # (M, hidden_dim)

        # Beta and A projection
        ba = self.ba_proj(x)  # (M, num_heads * 2)
        b, a = ba.chunk(2, dim=-1)  # Each (M, num_heads)

        # Reshape to multi-head format
        q = q.view(M, self.num_heads, self.head_dim)  # (M, H, D)
        k = k.view(M, self.num_heads, self.head_dim)
        v = v.view(M, self.num_heads, self.head_dim)
        z = z.view(M, self.num_heads, self.head_dim)

        # L2 normalize Q, K
        q = l2norm(q, dim=-1)
        k = l2norm(k, dim=-1)

        # Compute beta and g
        beta = torch.sigmoid(b)  # (M, H) - update gate
        g = -self.A_log.exp() * F.softplus(a + self.dt_bias)  # (M, H) - decay factor

        # Scaling
        scale = 1.0 / math.sqrt(self.head_dim)
        q = q * scale

        # Gated Delta Rule recursive computation
        output = self._gated_delta_rule(q, k, v, beta, g)  # (M, H, D)

        # Apply RMSNorm + Gate
        output_flat = output.reshape(-1, self.head_dim)
        z_flat = z.reshape(-1, self.head_dim)
        output_flat = self.norm(output_flat, z_flat)
        output = output_flat.view(M, self.num_heads, self.head_dim)

        # Reshape and project output
        output = output.reshape(M, self.hidden_dim)
        out = self.out_proj(output)

        return out

    def _gated_delta_rule(self, q, k, v, beta, g):
        """Parallel implementation of Gated Delta Rule"""
        M, H, D = q.shape
        device = q.device
        dtype = q.dtype

        # Convert to float32
        q = q.float()
        k = k.float()
        v = v.float()
        beta = beta.float().unsqueeze(-1)
        g = g.float()

        # Cumulative decay
        g_cumsum = torch.cumsum(g, dim=0)
        decay = torch.exp(g_cumsum).unsqueeze(-1)
        decay = torch.clamp(decay, min=1e-8, max=1e8)

        # Weighted KV accumulation
        kv = k * v * beta
        kv_normalized = kv / decay
        kv_cumsum = torch.cumsum(kv_normalized, dim=0)
        state = decay * kv_cumsum

        # Output
        output = q * state

        return output.to(dtype)


class GISADecoder(nn.Module):
    """GISA: Global Isotropy Semantic Aggregation Decoder

    Heterogeneous bi-directional scanning: Morton + Identity
    - Morton: Spatial serialization (anisotropic)
    - Identity: Original order (isotropic)
    """

    def __init__(
        self,
        in_dim=256,
        hidden_dim=256,
        num_classes=20,
        dropout=0.55,
        use_residual=True,
        deltanet_hidden_dim=512,
        num_heads=4,
        scan_mode='gate',
    ):
        super().__init__()

        self.use_residual = use_residual
        self.hidden_dim = hidden_dim
        self.scan_mode = scan_mode
        feat_dim = hidden_dim // 2

        # MLP layers
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout1 = nn.Dropout(dropout)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.dropout2 = nn.Dropout(dropout)

        self.fc3 = nn.Linear(hidden_dim, feat_dim)
        self.bn3 = nn.BatchNorm1d(feat_dim)
        self.dropout3 = nn.Dropout(dropout)

        # DeltaNet
        self.deltanet = GatedDeltaNetBlock(
            dim=feat_dim,
            hidden_dim=deltanet_hidden_dim,
            num_heads=num_heads,
            expand_ratio=2
        )

        # Gated fusion
        self.gate_proj = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.Sigmoid()
        )

        # Fusion norm
        self.fusion_norm = nn.LayerNorm(feat_dim)

        # Classifier
        self.fc4 = nn.Linear(feat_dim, num_classes)
        self.act = nn.GELU()

    def forward(self, point_features, xyz=None, return_features=False):
        x = point_features

        # MLP preprocessing
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.dropout1(x)

        if self.use_residual and x.shape[-1] == self.hidden_dim:
            residual = x
            x = self.fc2(x)
            x = self.bn2(x)
            x = self.act(x)
            x = self.dropout2(x)
            x = x + residual
        else:
            x = self.fc2(x)
            x = self.bn2(x)
            x = self.act(x)
            x = self.dropout2(x)

        x = self.fc3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.dropout3(x)

        # Heterogeneous bi-scan + gated fusion
        if xyz is not None and self.scan_mode != 'identity':
            out_morton, out_identity = self._heterogeneous_biscan(x, xyz)

            if self.scan_mode == 'morton':
                x = out_morton
            elif self.scan_mode == 'hilbert':
                x = out_identity
            elif self.scan_mode == 'average':
                x = (out_morton + out_identity) / 2
            else:
                # gate mode
                gate = self.gate_proj(x)
                x_fused = gate * out_morton + (1 - gate) * out_identity
                x = self.fusion_norm(x_fused)
        else:
            # identity mode
            x = self.deltanet(x)

        # Classification
        logits = self.fc4(x)

        if return_features:
            return logits, x
        return logits

    def _heterogeneous_biscan(self, x, xyz):
        """Heterogeneous bi-directional scanning: Morton + Identity"""
        # Track A: Morton
        _, x_morton, inv_morton = morton_sort(xyz, x)
        out_morton = self.deltanet(x_morton)
        out_morton = out_morton[inv_morton]

        # Track B: Identity
        out_identity = self.deltanet(x)

        return out_morton, out_identity
