import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
import random
import types
import copy

from models.Gaussian_MAE import Gaussian_MAE
from core.lie_neurons_layers import LNLinear, HatLayer

# ==============================================================================
# 1. Geometric Utils (Physics & Rotation)
# ==============================================================================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def get_random_so3():
    q, r = torch.linalg.qr(torch.randn(3, 3))
    if torch.det(q) < 0: q[:, 0] *= -1
    return q.cuda()

def rotate_quaternion_simple(q, R):
    trace = R.trace()
    if trace > 0.0:
        S = torch.sqrt(trace + 1.0) * 2
        qw, qx, qy, qz = 0.25*S, (R[2,1]-R[1,2])/S, (R[0,2]-R[2,0])/S, (R[1,0]-R[0,1])/S
    else:
        S = torch.sqrt(trace + 1.0 + 1e-8) * 2
        qw, qx, qy, qz = 0.25*S, (R[2,1]-R[1,2])/S, (R[0,2]-R[2,0])/S, (R[1,0]-R[0,1])/S
        
    q_r = torch.tensor([qw, qx, qy, qz], device=q.device)
    w1, x1, y1, z1 = q_r
    w2, x2, y2, z2 = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    return torch.stack([
        w1*w2 - x1*x2 - y1*y2 - z1*z2,
        w1*x2 + x1*w2 + y1*z2 - z1*y2,
        w1*y2 - x1*z2 + y1*w2 + z1*x2,
        w1*z2 + x1*y2 - y1*x2 + z1*w2
    ], dim=-1)

def build_covariance_flatten(scale, rotation):
    norm = torch.sqrt(rotation.square().sum(-1, keepdim=True) + 1e-8)
    q = rotation / norm
    w, x, y, z = q.unbind(-1)
    row0 = torch.stack([1 - 2*y**2 - 2*z**2, 2*x*y - 2*w*z, 2*x*z + 2*w*y], dim=-1)
    row1 = torch.stack([2*x*y + 2*w*z, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*w*x], dim=-1)
    row2 = torch.stack([2*x*z - 2*w*y, 2*y*z + 2*w*x, 1 - 2*x**2 - 2*y**2], dim=-1)
    R = torch.stack([row0, row1, row2], dim=-2)
    S = torch.zeros_like(R)
    S[..., 0, 0] = scale[..., 0]; S[..., 1, 1] = scale[..., 1]; S[..., 2, 2] = scale[..., 2]
    M = torch.matmul(R, S)
    Sigma = torch.matmul(M, M.transpose(-1, -2)) 
    return rearrange(Sigma, '... i j -> ... (i j)')

# ==============================================================================
# 2. ReLN Patch: Full-Spec Lifting Logic
# ==============================================================================
def forward_full_spec_lifting(self, point_groups):
    """
    [User's Insight Implementation]
    Replaces ReLNEncoder.forward()
    Lifting: Pos(Skew), Cov(Sym), SH0(Id), SH1(Skew), SH2(Sym)
    """
    B, G, M, D = point_groups.shape
    p = rearrange(point_groups, 'b g m d -> (b g) m d') 
    device = p.device
    
    # 1. Parsing
    xyz = p[..., :3]
    opac = p[..., 3:4]
    scale = p[..., 4:7]
    rot = p[..., 7:11]
    
    if D > 14:
        sh0 = p[..., 11:14]
        sh1 = p[..., 14:23]
        sh2 = p[..., 23:38]
    else:
        sh0 = torch.zeros_like(xyz)
        sh1 = torch.zeros(B*G, M, 9, device=device)
        sh2 = torch.zeros(B*G, M, 15, device=device)

    # 2. Lifting
    # (A) Position -> Skew
    xyz_hat_mat = self.hat(xyz) 
    xyz_flat = rearrange(xyz_hat_mat, 'bg m i j -> bg m (i j)')
    
    # (B) Covariance -> Sym
    cov_flat = build_covariance_flatten(scale, rot)
    
    # (C) Opacity & SH0 -> Identity
    I = torch.eye(3, device=device).view(1, 1, 1, 3, 3)
    opac_mat = opac.view(-1, M, 1, 1, 1) * I
    opac_flat = rearrange(opac_mat, 'bg m c i j -> bg m c (i j)')
    
    sh0_mat = sh0.view(-1, M, 3, 1, 1) * I
    sh0_flat = rearrange(sh0_mat, 'bg m c i j -> bg m c (i j)')
    
    # (D) SH 1 -> Skew
    sh1_vecs = sh1.view(-1, M, 3, 3)
    sh1_flat_vec = rearrange(sh1_vecs, 'bg m c v -> (bg c) m v')
    sh1_hat = self.hat(sh1_flat_vec)
    sh1_flat = rearrange(sh1_hat, '(bg c) m i j -> bg m c (i j)', c=3)
    
    # (E) SH 2 -> Sym
    c = sh2.view(-1, M, 3, 5)
    c0, c1, c2, c3, c4 = [c[..., i] for i in range(5)]
    row0 = torch.stack([c3 + c4, c0, c2], dim=-1)
    row1 = torch.stack([c0, -c3 + c4, c1], dim=-1)
    row2 = torch.stack([c2, c1, -2*c4], dim=-1)
    sh2_mat = torch.stack([row0, row1, row2], dim=-2)
    sh2_flat = rearrange(sh2_mat, 'bg m c i j -> bg m c (i j)')

    # 3. Stacking
    # Channels: Pos(1) + Cov(1) + Opac(1) + SH0(3) + SH1(3) + SH2(3) = 12
    all_feats = torch.cat([
        xyz_flat.unsqueeze(2), 
        cov_flat.unsqueeze(2),
        opac_flat, sh0_flat, sh1_flat, sh2_flat
    ], dim=2) # [BG, M, 12, 9]
    
    all_feats = rearrange(all_feats, 'bg m c k -> bg c k m')
    
    # 4. Processing
    x = self.geo_embed(all_feats) # [BG, 128, 9, M]
    
    # Skip invariant weights for pure equivariance check
    feature = self.first_conv_rest(x)
    
    def pool_vectors(x):
        norms = x.norm(dim=2)
        idx = norms.argmax(dim=-1, keepdim=True)
        idx_exp = idx.unsqueeze(2).expand(-1, -1, 9, -1)
        return torch.gather(x, -1, idx_exp).squeeze(-1)

    feature_global = pool_vectors(feature)
    feature_global_exp = feature_global.unsqueeze(-1).expand(-1, -1, -1, M)
    feature_cat = torch.cat([feature_global_exp, feature], dim=1)
    
    feature = self.second_conv(feature_cat)
    feature_group = pool_vectors(feature)
    
    return rearrange(feature_group, '(b g) c k -> b c k g', b=B, g=G)

# ==============================================================================
# 3. Manual Run Helpers (User's Preferred Style)
# ==============================================================================
def run_reln_manual(model, neighborhood, center, noaug=False):
    # Encoder
    x_vis, mask = model.MAE_encoder(neighborhood, center, noaug=noaug)
    # ReLN Feature: [B, C, 9, G] (Matrix Feature)
    
    # XYZ Extraction from Feature
    # (Since Decoder is complex to manually run without full weights, 
    #  extract equivariance property directly from Encoder Output)
    feat = x_vis.mean(dim=1) # [B, 9, G]
    feat_mat = rearrange(feat, 'b (i j) g -> b g i j', i=3, j=3)
    
    # Skew-Symmetric Part -> Vector
    skew = 0.5 * (feat_mat - feat_mat.transpose(-1, -2))
    vx = skew[..., 2, 1]; vy = skew[..., 0, 2]; vz = skew[..., 1, 0]
    return torch.stack([vx, vy, vz], dim=-1) # [B, G, 3]

def run_original_manual(model, p, c, noaug=False):
    x_vis, mask = model.MAE_encoder(p, c, noaug=noaug)
    # Fix: Provide Pos Embed to Decoder
    center_pos = c[..., model.pos_feature_dim]
    pos = model.decoder_pos_embed(center_pos)
    x_rec = model.MAE_decoder(x_vis, pos, 0)
    pred = model.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2)
    
    B, G, _ = pred.shape
    return pred.reshape(B, G, -1, 3).mean(dim=2)

# ==============================================================================
# 4. Main Test
# ==============================================================================
def check_comparison():
    # --- Data Setup ---
    B, G, M = 2, 16, 32
    device = torch.device('cuda')
    pts = torch.randn(B, G, M, 59).to(device)
    pts[..., 4:7] = pts[..., 4:7].abs() + 0.1
    pts[..., 7:11] /= (pts[..., 7:11].norm(dim=-1, keepdim=True) + 1e-8)
    center = pts.mean(dim=2)
    
    # --- Create Rotated Inputs ---
    R = get_random_so3()
    pts_rot = pts.clone()
    pts_rot[..., 0:3] = pts[..., 0:3] @ R.T
    pts_rot[..., 7:11] = rotate_quaternion_simple(pts[..., 7:11], R)
    
    # SH1 (Vector Rotation)
    sh1 = pts[..., 14:23].view(B, G, M, 3, 3)
    pts_rot[..., 14:23] = (sh1 @ R.T).view(B, G, M, 9)
    
    # SH2 (Tensor Rotation)
    def rotate_sh2_coeffs(coeffs, R):
        c = [coeffs[..., i] for i in range(5)]
        M = torch.stack([
            torch.stack([c[3]+c[4], c[0], c[2]], -1),
            torch.stack([c[0], -c[3]+c[4], c[1]], -1),
            torch.stack([c[2], c[1], -2*c[4]], -1)
        ], -2)
        R_bd = R.view(1, 1, 1, 1, 3, 3)
        M_rot = R_bd @ M @ R_bd.transpose(-1, -2)
        m00, m01, m02 = M_rot[...,0,0], M_rot[...,0,1], M_rot[...,0,2]
        m11, m12 = M_rot[...,1,1], M_rot[...,1,2]
        return torch.stack([m01, m12, m02, (m00-m11)/2, (m00+m11)/2], -1)
    
    sh2 = pts[..., 23:38].view(B, G, M, 3, 5)
    pts_rot[..., 23:38] = rotate_sh2_coeffs(sh2, R).view(B, G, M, 15)
    
    center_rot = pts_rot.mean(dim=2)

    # --------------------------------------------------------------------------
    # 1. Baseline Test
    # --------------------------------------------------------------------------
    print("\n============================================================")
    print(">>> Testing [Baseline Model] (use_reln=False)")
    print("============================================================")
    
    class BaselineConfig:
        group_size = 32; num_group = 64
        attribute = ['xyz', 'opacity', 'scale', 'rotation'] + [f'f_rest_{i}' for i in range(45)]
        group_attribute = attribute; norm_attribute = []
        class TConfig:
            mask_ratio = 0.0; mask_type = "rand"; trans_dim = 128; depth = 2
            num_heads = 4; encoder_dims = 128; decoder_depth = 2
            decoder_num_heads = 4; drop_path_rate = 0.0; use_reln = False
        transformer_config = TConfig()

    model_base = Gaussian_MAE(BaselineConfig()).cuda().eval()
    
    with torch.no_grad():
        set_seed(42)
        out_orig = run_original_manual(model_base, pts, center)
        set_seed(42)
        out_rot = run_original_manual(model_base, pts_rot, center_rot)
        
        out_aligned = out_orig @ R.T
        rel_err = (out_rot - out_aligned).abs().mean() / out_aligned.abs().mean()
        
        print(f"   - Relative Error: {rel_err:.6f}")
        print("   ❌ FAILURE (Expected)" if rel_err > 0.1 else "   ❓ PASS")

    # --------------------------------------------------------------------------
    # 2. ReLN Test
    # --------------------------------------------------------------------------
    print("\n============================================================")
    print(">>> Testing [ReLN Model] (Full-Spec Patch)")
    print("============================================================")
    
    class ReLNConfig(BaselineConfig):
        pass
    ReLNConfig.transformer_config.use_reln = True
    # Important: trans_dim 128 matches the ReLNEncoder internal dimension
    ReLNConfig.transformer_config.trans_dim = 128 
    ReLNConfig.transformer_config.encoder_dims = 128

    model_reln = Gaussian_MAE(ReLNConfig()).cuda().eval()
    
    # [MONKEY PATCH]
    # 1. Forward 교체
    model_reln.MAE_encoder.encoder.forward = types.MethodType(forward_full_spec_lifting, model_reln.MAE_encoder.encoder)
    
    # 2. Input Layer 교체 (12 channels -> 128 dim) **[FIXED HERE: 64 -> 128]**
    model_reln.MAE_encoder.encoder.geo_embed = LNLinear(12, 128).cuda()
    
    with torch.no_grad():
        set_seed(42)
        out_orig = run_reln_manual(model_reln, pts, center)
        set_seed(42)
        out_rot = run_reln_manual(model_reln, pts_rot, center_rot)
        
        out_aligned = out_orig @ R.T
        rel_err = (out_rot - out_aligned).abs().mean() / out_aligned.abs().mean()
        
        print(f"   - Relative Error: {rel_err:.8f}")
        print("   ✅ SUCCESS" if rel_err < 1e-5 else "   ❌ FAILURE")

if __name__ == "__main__":
    check_comparison()