#!/usr/bin/env python
"""
Fake Butterfly Transform implementation for ARQ
Uses butterfly structure with O(n log n) parameters but dense matrix multiplication for inference
No SVD needed - orthogonality is guaranteed by parameterization
"""

import torch
import torch.nn as nn
import numpy as np
import math


class FakeButterflyTransform(nn.Module):
    """
    Fake Butterfly Transform that simulates butterfly structure
    - Training: Uses butterfly parameterization with O(n log n) parameters
    - Inference: Combines into dense matrix for O(n^2) computation
    - Orthogonality: Guaranteed by Givens rotation parameterization (no SVD needed)
    """
    
    def __init__(self, dim: int, device='cuda', init_mode='identity'):
        super().__init__()
        self.original_dim = dim
        self.device = device
        
        # Handle non-power-of-2 dimensions by padding
        if dim & (dim - 1) != 0:
            # Find next power of 2
            self.padded_dim = 2 ** math.ceil(math.log2(dim))
            print(f"Warning: Dimension {dim} is not power of 2, padding to {self.padded_dim}")
            self.dim = self.padded_dim
        else:
            self.dim = dim
            self.padded_dim = dim
        
        self.n_layers = int(np.log2(self.dim))
        
        # Initialize butterfly layers with rotation angles
        # Each layer has dim/2 rotation angles for Givens rotations
        self.layer_angles = nn.ParameterList()
        
        # Pre-compute and cache butterfly indices for all layers
        self.butterfly_indices = []
        for layer_idx in range(self.n_layers):
            indices = self._compute_butterfly_indices(layer_idx)
            self.butterfly_indices.append(indices)
        
        for layer_idx in range(self.n_layers):
            if init_mode == 'identity':
                # Initialize near identity
                angles = torch.zeros(self.dim // 2, device=device)
            elif init_mode == 'random':
                # Random initialization
                angles = torch.randn(self.dim // 2, device=device) * 0.1
            elif init_mode == 'hadamard':
                # Initialize to approximate Hadamard
                # This is a heuristic - true Hadamard angles depend on the layer
                angles = torch.ones(self.dim // 2, device=device) * (math.pi / 4)
            else:
                raise ValueError(f"Unknown init_mode: {init_mode}")
            
            self.layer_angles.append(nn.Parameter(angles))
    
    def _compute_butterfly_indices(self, layer_idx: int) -> tuple:
        """
        Compute the butterfly connection pattern for a given layer
        Returns pairs of indices that are connected in this layer
        """
        n = self.dim
        stride = 2 ** (self.n_layers - layer_idx - 1)
        
        indices1 = []
        indices2 = []
        
        for i in range(0, n, 2 * stride):
            for j in range(stride):
                idx1 = i + j
                idx2 = i + j + stride
                indices1.append(idx1)
                indices2.append(idx2)
        
        return torch.tensor(indices1, device=self.device), torch.tensor(indices2, device=self.device)
    
    def get_butterfly_indices(self, layer_idx: int) -> tuple:
        """
        Get cached butterfly indices for a given layer
        """
        return self.butterfly_indices[layer_idx]
    
    def apply_butterfly_layer(self, x: torch.Tensor, angles: torch.Tensor, layer_idx: int) -> torch.Tensor:
        """
        Apply one butterfly layer with Givens rotations
        This maintains orthogonality by construction
        """
        # Get connection pattern for this layer
        idx1, idx2 = self.get_butterfly_indices(layer_idx)
        
        # Apply Givens rotations - ensure correct dtype
        cos_angles = torch.cos(angles).to(x.dtype)
        sin_angles = torch.sin(angles).to(x.dtype)
        
        # Extract the values to rotate
        x1 = x[..., idx1].clone()  # Need clone here to avoid in-place issues
        x2 = x[..., idx2].clone()
        
        # Apply 2x2 rotations directly
        # [x1'] = [cos  -sin] [x1]
        # [x2']   [sin   cos] [x2]
        x[..., idx1] = cos_angles * x1 - sin_angles * x2
        x[..., idx2] = sin_angles * x1 + cos_angles * x2
        
        return x
    
    def construct_layer_matrix(self, angles: torch.Tensor, layer_idx: int) -> torch.Tensor:
        """
        Construct the dense matrix representation of one butterfly layer
        """
        n = self.dim
        layer_matrix = torch.eye(n, device=self.device, dtype=torch.float32)
        
        # Get indices
        idx1, idx2 = self.get_butterfly_indices(layer_idx)
        
        # Fill in the 2x2 rotation blocks
        cos_angles = torch.cos(angles).to(layer_matrix.dtype)
        sin_angles = torch.sin(angles).to(layer_matrix.dtype)
        
        for i, (i1, i2) in enumerate(zip(idx1, idx2)):
            layer_matrix[i1, i1] = cos_angles[i]
            layer_matrix[i1, i2] = -sin_angles[i]
            layer_matrix[i2, i1] = sin_angles[i]
            layer_matrix[i2, i2] = cos_angles[i]
        
        return layer_matrix
    
    def get_matrix(self) -> torch.Tensor:
        """
        Get the full transformation matrix by composing all butterfly layers
        This is used for inference (fake butterfly mode)
        """
        Q = torch.eye(self.dim, device=self.device, dtype=torch.float64)
        
        for layer_idx, angles in enumerate(self.layer_angles):
            layer_matrix = self.construct_layer_matrix(angles, layer_idx).to(torch.float64)
            Q = Q @ layer_matrix
        
        # If we padded, extract only the original dimensions
        if self.original_dim != self.dim:
            Q = Q[:self.original_dim, :self.original_dim]
        
        return Q
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply butterfly transform
        Training: Apply layers sequentially (simulating O(n log n))
        Inference: Use dense matrix multiplication
        """
        original_shape = x.shape
        
        # Pad input if necessary
        if self.original_dim != self.dim:
            padding_size = self.dim - self.original_dim
            pad_shape = list(x.shape)
            pad_shape[-1] = padding_size
            padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
            x = torch.cat([x, padding], dim=-1)
        
        if self.training:
            # Training mode: apply butterfly layers sequentially
            # This simulates the O(n log n) structure
            x = x.clone()  # Clone once at the beginning to avoid modifying input
            for layer_idx, angles in enumerate(self.layer_angles):
                x = self.apply_butterfly_layer(x, angles, layer_idx)
        else:
            # Inference mode: use dense matrix (fake butterfly)
            Q_full = torch.eye(self.dim, device=self.device, dtype=torch.float64)
            for layer_idx, angles in enumerate(self.layer_angles):
                layer_matrix = self.construct_layer_matrix(angles, layer_idx).to(torch.float64)
                Q_full = Q_full @ layer_matrix
            Q_full = Q_full.to(x.dtype)
            x = x @ Q_full.t()
        
        # Remove padding if we added it
        if self.original_dim != self.dim:
            x = x[..., :self.original_dim]
        
        return x
    
    def count_parameters(self) -> int:
        """Count total number of parameters"""
        return sum(p.numel() for p in self.parameters())
    
    def check_orthogonality(self) -> float:
        """Check how orthogonal the transformation is"""
        Q = self.get_matrix()
        QTQ = Q.t() @ Q
        I = torch.eye(self.dim, device=self.device, dtype=Q.dtype)
        return torch.norm(QTQ - I).item()


def test_fake_butterfly():
    """Test the fake butterfly implementation"""
    print("Testing Fake Butterfly Transform")
    print("="*60)
    
    dim = 256
    batch_size = 10
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Test different initialization modes
    for init_mode in ['identity', 'random', 'hadamard']:
        print(f"\nTesting with {init_mode} initialization:")
        
        # Create transform
        transform = FakeButterflyTransform(dim, device, init_mode)
        
        # Check parameter count
        n_params = transform.count_parameters()
        expected_params = dim // 2 * int(np.log2(dim))
        print(f"  Parameters: {n_params} (expected: {expected_params})")
        print(f"  Compression ratio vs dense: {dim*dim / n_params:.1f}x")
        
        # Test orthogonality
        ortho_error = transform.check_orthogonality()
        print(f"  Orthogonality error: {ortho_error:.8f}")
        
        # Test forward pass
        x = torch.randn(batch_size, dim, device=device)
        
        # Training mode
        transform.train()
        y_train = transform(x)
        
        # Inference mode
        transform.eval()
        y_eval = transform(x)
        
        # Check norm preservation
        norm_ratio_train = torch.norm(y_train) / torch.norm(x)
        norm_ratio_eval = torch.norm(y_eval) / torch.norm(x)
        print(f"  Norm preservation (train): {norm_ratio_train:.6f}")
        print(f"  Norm preservation (eval): {norm_ratio_eval:.6f}")
        
        # Check consistency between train and eval
        consistency_error = torch.norm(y_train - y_eval) / torch.norm(y_train)
        print(f"  Train/eval consistency error: {consistency_error:.8f}")
    
    print("\n✓ Fake butterfly transform working correctly!")
    print("✓ Orthogonality maintained through parameterization (no SVD needed)!")
    print(f"✓ Parameter compression: {dim*dim / n_params:.1f}x vs dense matrix!")


if __name__ == "__main__":
    test_fake_butterfly()