"""
Fixed Composite Butterfly Transform for dimension 5120
Mimics QuaRot's Had40 ⊗ Had128 approach but with learnable butterfly
"""

import torch
import torch.nn as nn
import numpy as np
import math
from typing import Optional

class CompositeButterflyFixed(nn.Module):
    """
    For 5120 = 40 × 128, use composite butterfly:
    - Small butterfly: 40×40 (not power of 2, use dense parameterization)
    - Large butterfly: 128×128 (power of 2, use butterfly layers)
    
    This avoids the 5×1024 factorization that failed before.
    """
    
    def __init__(self, n: int = 5120, device='cuda', init_mode='identity'):
        super().__init__()
        assert n == 5120, "This implementation is specifically for 5120"
        
        self.n = n
        self.device = device
        self.init_mode = init_mode
        
        # Factorization: 5120 = 40 × 128
        self.n_small = 40
        self.n_large = 128
        
        # Small matrix (40×40) - use dense parameterization with Cayley
        # Since 40 is small, we can afford dense parameterization
        if init_mode == 'identity':
            # Initialize near identity
            self.small_skew = nn.Parameter(torch.zeros(40, 40, device=device) * 0.01)
        else:
            # Initialize with small random perturbation
            self.small_skew = nn.Parameter(torch.randn(40, 40, device=device) * 0.1)
            
        # Large butterfly (128×128) - use efficient butterfly layers
        self.n_layers_large = int(np.log2(128))  # 7 layers
        
        if init_mode == 'identity':
            # Initialize angles near zero (identity)
            self.angles_large = nn.Parameter(
                torch.zeros(self.n_layers_large, 64, device=device) * 0.01
            )
        elif init_mode == 'hadamard':
            # Initialize to approximate Hadamard
            hadamard_angles = self._compute_hadamard_angles_128()
            self.angles_large = nn.Parameter(hadamard_angles.to(device))
        else:
            # Random initialization
            self.angles_large = nn.Parameter(
                torch.randn(self.n_layers_large, 64, device=device) * 0.5
            )
    
    def _compute_hadamard_angles_128(self):
        """Compute butterfly angles that approximate 128×128 Hadamard"""
        # For now, use small random perturbation
        # TODO: Implement proper Hadamard decomposition
        return torch.randn(self.n_layers_large, 64) * 0.1
    
    def get_small_matrix(self) -> torch.Tensor:
        """Get 40×40 orthogonal matrix via Cayley transform"""
        # Make skew-symmetric
        A = self.small_skew - self.small_skew.T
        
        # Cayley transform: Q = (I - A)(I + A)^{-1}
        I = torch.eye(40, device=self.device, dtype=torch.float32)
        Q = torch.linalg.solve(I + A, I - A)
        
        return Q
    
    def get_large_matrix(self) -> torch.Tensor:
        """Get 128×128 butterfly matrix"""
        n = 128
        Q = torch.eye(n, device=self.device, dtype=torch.float32)
        
        for layer_idx in range(self.n_layers_large):
            # Create layer matrix
            layer_matrix = torch.eye(n, device=self.device, dtype=torch.float32)
            
            # Apply Givens rotations
            stride = 2 ** (self.n_layers_large - layer_idx - 1)
            for i in range(0, n, 2 * stride):
                for j in range(stride):
                    if i + j < n and i + j + stride < n:
                        idx1 = i + j
                        idx2 = i + j + stride
                        
                        angle_idx = (i + j) // (2 * stride) * stride + j
                        if angle_idx < 64:  # 128/2 = 64 angles per layer
                            theta = self.angles_large[layer_idx, angle_idx]
                            c = torch.cos(theta)
                            s = torch.sin(theta)
                            
                            layer_matrix[idx1, idx1] = c
                            layer_matrix[idx1, idx2] = -s
                            layer_matrix[idx2, idx1] = s
                            layer_matrix[idx2, idx2] = c
            
            Q = Q @ layer_matrix
        
        return Q
    
    def get_matrix(self) -> torch.Tensor:
        """Construct the full 5120×5120 rotation matrix via Kronecker product"""
        Q_small = self.get_small_matrix().contiguous()
        Q_large = self.get_large_matrix().contiguous()
        
        # Kronecker product: Q = Q_small ⊗ Q_large
        # This gives us a 5120×5120 matrix
        Q = torch.kron(Q_small, Q_large)
        
        return Q.to(torch.float32)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply composite butterfly transform efficiently"""
        batch_shape = x.shape[:-1]
        x_flat = x.reshape(-1, self.n)
        
        # Convert to float32 for computation
        x_dtype = x_flat.dtype
        x_device = x_flat.device
        x_flat = x_flat.to(torch.float32)
        
        # Reshape to (batch, 40, 128)
        x_reshaped = x_flat.view(-1, self.n_small, self.n_large)
        
        # Apply large butterfly to each of 40 blocks
        Q_large = self.get_large_matrix().to(x_device).to(torch.float32)
        x_transformed = torch.zeros_like(x_reshaped)
        for i in range(self.n_small):
            x_transformed[:, i, :] = x_reshaped[:, i, :] @ Q_large.T
        
        # Apply small rotation across blocks
        Q_small = self.get_small_matrix().to(x_device).to(torch.float32)
        # Transpose for matrix multiplication: (batch, 128, 40) @ (40, 40).T
        x_transformed = x_transformed.transpose(1, 2)  # (batch, 128, 40)
        x_final = x_transformed @ Q_small.T.T  # Back to (batch, 128, 40)
        x_final = x_final.transpose(1, 2)  # Back to (batch, 40, 128)
        
        # Reshape back and convert to original dtype
        y = x_final.reshape(*batch_shape, self.n)
        return y.to(x_dtype)
    
    def count_parameters(self) -> int:
        """Count learnable parameters"""
        # Small matrix: 40×40 skew-symmetric has 40*39/2 = 780 parameters
        count = 40 * 39 // 2
        # Large butterfly: 7 layers × 64 angles = 448 parameters
        count += self.n_layers_large * 64
        return count


if __name__ == "__main__":
    # Test composite butterfly
    print("Testing Fixed Composite Butterfly Transform (40×128)")
    print("="*50)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    transform = CompositeButterflyFixed(5120, device=device, init_mode='identity')
    
    print(f"Parameters: {transform.count_parameters()}")
    print(f"  Small matrix (40×40): 780 params")
    print(f"  Large butterfly (128×128): 448 params") 
    print(f"  Total: 1228 params (vs 5120×5120 = 26,214,400)")
    print(f"  Compression: {5120*5120/transform.count_parameters():.1f}x")
    
    # Test forward pass
    x = torch.randn(2, 5120, device=device)
    y = transform(x)
    print(f"\nInput shape: {x.shape}")
    print(f"Output shape: {y.shape}")
    
    # Check orthogonality
    Q = transform.get_matrix()
    QQT = Q @ Q.T
    I = torch.eye(5120, device=device)
    ortho_error = torch.norm(QQT - I).item()
    print(f"\nOrthogonality check: ||Q^T Q - I|| = {ortho_error:.6f}")
    
    # Compare with QuaRot's approach
    print(f"\nQuaRot uses: Had_40 ⊗ Had_128 (composite Hadamard)")
    print(f"Our approach: Butterfly_40 ⊗ Butterfly_128 (learnable)")