import torch
import torch.nn as nn
from einops import rearrange, repeat
from core.lie_neurons_layers import (
    LNLinear, 
    LNKillingRelu, 
    LNLieBracket, 
    LNBatchNorm, 
    HatLayer, 
    LNMaxPool
)
from core.lie_alg_util import killingform, vee_gl3
from timm.models.layers import trunc_normal_

# ==============================================================================
# [Part 1] Geometric Utilities & Normalization
# ==============================================================================

def build_covariance_flatten(scale, rotation):
    """
    Computes Covariance matrix from Scale and Quaternion Rotation.
    scale: [..., 3]
    rotation: [..., 4] (Quaternion w, x, y, z)
    Returns: [..., 9] (Flattened 3x3 Covariance)
    """
    # 1. Quaternion to Rotation Matrix
    norm = torch.sqrt(rotation.square().sum(-1, keepdim=True) + 1e-8)
    q = rotation / norm
    w, x, y, z = q.unbind(-1)

    # R matrix construction [..., 3, 3]
    row0 = torch.stack([1 - 2*y**2 - 2*z**2, 2*x*y - 2*w*z, 2*x*z + 2*w*y], dim=-1)
    row1 = torch.stack([2*x*y + 2*w*z, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*w*x], dim=-1)
    row2 = torch.stack([2*x*z - 2*w*y, 2*y*z + 2*w*x, 1 - 2*x**2 - 2*y**2], dim=-1)
    R = torch.stack([row0, row1, row2], dim=-2)

    # 2. Scale Matrix [..., 3, 3]
    S = torch.zeros_like(R)
    S[..., 0, 0] = scale[..., 0]
    S[..., 1, 1] = scale[..., 1]
    S[..., 2, 2] = scale[..., 2]

    # 3. Covariance Sigma = R * S * S^T * R^T = (RS)(RS)^T
    M = torch.matmul(R, S)
    Sigma = torch.matmul(M, M.transpose(-1, -2)) 

    # Flatten to 9
    return rearrange(Sigma, '... i j -> ... (i j)')

def pool_vectors(x):
    """
    Manual Equivariant Max Pooling based on Norm.
    x: [B, C, 9, N] -> [B, C, 9]
    """
    # 1. Compute Norms (Invariant): [B, C, N]
    norms = x.norm(dim=2) 
    
    # 2. Find Max Index: [B, C, 1]
    idx = norms.argmax(dim=-1, keepdim=True) 
    
    # 3. Gather Vectors: [B, C, 9, 1]
    idx_exp = idx.unsqueeze(2).expand(-1, -1, 9, -1) 
    x_pooled = torch.gather(x, -1, idx_exp).squeeze(-1) 
    
    return x_pooled

class ReLNLayerNorm(nn.Module):
    """
    Equivariant Layer Normalization based on Killing Form (Magnitude).
    Replaces LNBatchNorm for better Transformer stability.
    """
    def __init__(self, num_features, algebra_type='gl3', eps=1e-6):
        super().__init__()
        self.eps = eps
        self.ln = nn.LayerNorm(num_features)
        # compute norms manually for efficiency, so HatLayer is optional here
        self.algebra_type = algebra_type

    def forward(self, x):
        """
        x: [B, F, K, N] (Batch, Feature, Lie_Dim, Sequence)
        """
        # 1. Compute Magnitude (Frobenius Norm of 3x3 matrix ~ L2 of 9-dim)
        # Norm over Lie dimension (K=9)
        norms = x.norm(dim=2) # [B, F, N]
        
        # 2. LayerNorm on Invariant Scalars
        # LayerNorm expects [..., Channels], so transpose
        norms_t = norms.transpose(1, 2) # [B, N, F]
        norms_out = self.ln(norms_t)    # [B, N, F]
        norms_out = norms_out.transpose(1, 2) # [B, F, N]
        
        # 3. Rescale Equivariant Vectors
        # x_new = x / old_norm * new_norm
        scale = (norms_out / norms.clamp(min=self.eps)).unsqueeze(2) # [B, F, 1, N]
        
        return x * scale

# ==============================================================================
# [Part 2] ReLN Encoder (Lifting & Embedding)
# ==============================================================================

class ReLNEncoder(nn.Module):
    def __init__(self, encoder_channel, attribute=["xyz"]):
        super().__init__()
        self.encoder_channel = encoder_channel
        self.attribute = attribute
        self.hat = HatLayer(algebra_type='so3')
        
        # S = 9  
        S = 1 
        
        # --- 1. Invariant Stream Setup ---
        inv_dim = 2 # Norms of Pos, Cov
        if "opacity" in attribute: inv_dim += 1
        if "sh" in attribute: inv_dim += 3
        if "scale" in attribute: inv_dim += 3
        
        self.inv_mlp = nn.Sequential(
            nn.Conv1d(inv_dim, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 128//S, 1),
            nn.Sigmoid()
        )
        
        # --- 2. First Block ---
        self.geo_embed = LNLinear(2, 128//S) 
        
        self.first_conv_rest = nn.Sequential(
            ReLNLayerNorm(128//S, algebra_type='gl3'),
            LNKillingRelu(128//S, algebra_type='gl3'),
            LNLinear(128//S, 256//S)
        )
        
        # --- 3. Second Block ---
        self.second_conv = nn.Sequential(
            LNLinear(512//S, 512//S),
            ReLNLayerNorm(512//S, algebra_type='gl3'),
            LNKillingRelu(512//S, algebra_type='gl3'),
            LNLinear(512//S, encoder_channel)
        )
        # Note: No self.pool layer needed anymore

    def forward(self, point_groups):
        B, G, M, D = point_groups.shape
        p = rearrange(point_groups, 'b g m d -> (b g) m d') # [BG, M, D]
        
        # [Lifting]
        xyz = p[..., :3]
        scale = p[..., 4:7]
        rot = p[..., 7:11]
        
        xyz_hat_mat = self.hat(xyz) 
        xyz_flat = rearrange(xyz_hat_mat, 'bg m i j -> bg m (i j)')
        cov_flat = build_covariance_flatten(scale, rot)
        
        # [Geo Input]: [BG, 2, 9, M]
        geo_input = torch.stack([xyz_flat, cov_flat], dim=2).permute(0, 2, 3, 1)
        
        # [Invariant Input]
        xyz_norm = xyz_flat.norm(dim=-1, keepdim=True)
        cov_norm = cov_flat.norm(dim=-1, keepdim=True)
        inv_list = [xyz_norm, cov_norm]
        
        if "opacity" in self.attribute: inv_list.append(p[..., 3:4])
        if "scale" in self.attribute: inv_list.append(scale)
        if "sh" in self.attribute: inv_list.append(p[..., 11:14])
        
        x_inv = torch.cat(inv_list, dim=-1).transpose(1, 2) # [BG, Inv_Dim, M]
        
        # --- First Conv ---
        x_geo = self.geo_embed(geo_input) # [BG, 128, 9, M]
        weights = self.inv_mlp(x_inv).unsqueeze(2) # [BG, 128, 1, M]
        x = x_geo * weights
        
        feature = self.first_conv_rest(x) # [BG, 256, 9, M]
        
        # --- Global Pooling 1 (Manual) ---
        feature_global = pool_vectors(feature) # [BG, 256, 9]
        
        # Concat
        feature_global_exp = feature_global.unsqueeze(-1).expand(-1, -1, -1, M)
        feature_cat = torch.cat([feature_global_exp, feature], dim=1) # [BG, 512, 9, M]
        
        # --- Second Conv ---
        feature = self.second_conv(feature_cat) # [BG, C, 9, M]
        
        # --- Global Pooling 2 (Manual) ---
        feature_group = pool_vectors(feature) # [BG, C, 9]
        
        return rearrange(feature_group, '(b g) c k -> b c k g', b=B, g=G)

# ==============================================================================
# [Part 3] ReLN Transformer Components (Same as before)
# ==============================================================================

class ReLNPosEmbed(nn.Module):
    def __init__(self, embed_dim, algebra_type='so3'):
        super().__init__()
        self.hat = HatLayer(algebra_type=algebra_type)
        self.proj = LNLinear(1, embed_dim)

    def forward(self, center):
        B, G, _ = center.shape
        xyz_hat = self.hat(center[..., :3]).permute(0, 2, 3, 1) 
        xyz_flat = rearrange(xyz_hat, 'b i j g -> b (i j) g')
        return self.proj(xyz_flat.unsqueeze(1))

class ReLNMlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, algebra_type='gl3'):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = LNLinear(in_features, hidden_features)
        # self.act = LNLieBracket(hidden_features, algebra_type=algebra_type)  #use bracket nonlinearity
        self.act = LNKillingRelu(hidden_features, algebra_type=algebra_type)  #use relu nonlinearity 
        self.fc2 = LNLinear(hidden_features, out_features)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class ReLNAttention(nn.Module):
    def __init__(self, dim, num_heads=8, algebra_type='gl3'):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.algebra_type = algebra_type
        self.qkv = LNLinear(dim, dim * 3)
        self.proj = LNLinear(dim, dim)

    def forward(self, x):
        B, C, K, N = x.shape
        H = self.num_heads
        
        qkv = self.qkv(x)
        q, k, v = rearrange(qkv, 'b (t c) k n -> t b c k n', t=3)
        q = rearrange(q, 'b (h c) k n -> b h c k n', h=H)
        k = rearrange(k, 'b (h c) k n -> b h c k n', h=H)
        v = rearrange(v, 'b (h c) k n -> b h c k n', h=H)
        
        # Invariant Score
        q_mat = rearrange(q, '... (i j) n -> ... n i j', i=3, j=3)
        k_mat = rearrange(k, '... (i j) n -> ... n i j', i=3, j=3)
        q_exp, k_exp = q_mat.unsqueeze(4), k_mat.unsqueeze(3)
        
        scores = killingform(q_exp, k_exp, algebra_type=self.algebra_type).squeeze(-1)
        attn = (scores.sum(dim=2) * self.scale).softmax(dim=-1)
        
        v_t = rearrange(v, 'b h c k n -> b h c n k')
        out = torch.einsum('b h i j, b h c j k -> b h c i k', attn, v_t)
        return self.proj(rearrange(out, 'b h c n k -> b (h c) k n'))

class ReLNBlock(nn.Module):
    def __init__(self, dim, num_heads, algebra_type='gl3', **kwargs):
        super().__init__()
        self.norm1 = ReLNLayerNorm(dim, algebra_type=algebra_type)
        self.attn = ReLNAttention(dim, num_heads=num_heads, algebra_type=algebra_type)
        self.norm2 = ReLNLayerNorm(dim, algebra_type=algebra_type)
        self.mlp = ReLNMlp(dim, int(dim * 4), algebra_type=algebra_type)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# ==============================================================================
# [Part 4] Main Encoder/Decoder Modules
# ==============================================================================

class ReLNTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, depth=4, num_heads=12, algebra_type='gl3', **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([
            ReLNBlock(embed_dim, num_heads, algebra_type=algebra_type) 
            for _ in range(depth)
        ])
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, pos=None):
        for block in self.blocks:
            input_x = x + pos if pos is not None else x
            x = block(input_x)
        return x

class ReLNTransformerDecoder(nn.Module):
    def __init__(self, embed_dim=384, depth=4, num_heads=6, algebra_type='gl3', **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([
            ReLNBlock(embed_dim, num_heads, algebra_type=algebra_type) 
            for _ in range(depth)
        ])
        self.norm = ReLNLayerNorm(embed_dim, algebra_type=algebra_type)
        self.head = nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, pos, return_token_num):
        for block in self.blocks:
            x = block(x + pos)
        x_out = x[..., -return_token_num:] 
        return self.head(self.norm(x_out))

# -----------------------------------------------------------------------------
# 5. Group (Import from existing)
# -----------------------------------------------------------------------------
from models.transformer import Group