#!/usr/bin/env python
"""
Loss functions for ARQ optimization
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('..')
from quant_utils import ActQuantizer


class UniformityRegularizer:
    """
    Regularizer to encourage uniform distribution across quantization bins
    """
    
    def __init__(self, num_bins: int, gamma_uni: float = 0.05, 
                 curriculum: bool = False, warmup_steps: int = 50, 
                 total_steps: int = 200, target_type: str = 'uniform'):
        self.num_bins = num_bins
        self.gamma_uni = gamma_uni
        self.curriculum = curriculum
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.target_type = target_type
        
        # Target distribution (uniform)
        self.target_dist = torch.ones(num_bins) / num_bins
        
    def compute_loss(self, activations: torch.Tensor, quantized: torch.Tensor, step: int = 0) -> torch.Tensor:
        """
        Compute uniformity loss based on bin distribution
        
        Args:
            activations: Original activations
            quantized: Quantized activations
            step: Current training step (for curriculum learning)
            
        Returns:
            Uniformity loss
        """
        # Get bin assignments
        # For symmetric quantization, bins go from -(2^(b-1)) to 2^(b-1)-1
        max_val = self.num_bins // 2
        bins = torch.round(quantized * (max_val - 0.5)).long() + max_val
        bins = torch.clamp(bins, 0, self.num_bins - 1)
        
        # Compute empirical distribution
        bins_flat = bins.flatten()
        bin_counts = torch.bincount(bins_flat, minlength=self.num_bins).float()
        empirical_dist = bin_counts / bins_flat.numel()
        
        # Move target distribution to same device
        target = self.target_dist.to(empirical_dist.device)
        
        # Compute KL divergence
        # Add small epsilon to avoid log(0)
        eps = 1e-10
        empirical_dist = empirical_dist + eps
        target = target + eps
        
        # Normalize
        empirical_dist = empirical_dist / empirical_dist.sum()
        target = target / target.sum()
        
        # KL(P||Q) = sum(P * log(P/Q))
        kl_div = torch.sum(empirical_dist * torch.log(empirical_dist / target))
        
        return kl_div
    
    def __call__(self, x_rot: torch.Tensor, step: int = 0) -> dict:
        """
        Main call interface for uniformity regularization
        
        Args:
            x_rot: Rotated activations
            step: Current training step (for curriculum learning)
            
        Returns:
            Dict with uniformity losses
        """
        # Apply curriculum learning if enabled
        if self.curriculum and step < self.warmup_steps:
            weight = self.gamma_uni * (step / self.warmup_steps)
        else:
            weight = self.gamma_uni
        
        # Quantize the rotated activations for bin assignment
        # Simple symmetric quantization for uniformity calculation
        max_val = torch.abs(x_rot).max()
        x_normalized = x_rot / (max_val + 1e-8)
        x_quantized = torch.round(x_normalized * (self.num_bins // 2 - 0.5))
        
        # Compute uniformity loss
        uniformity_loss = self.compute_loss(x_rot, x_quantized, step)
        
        return {
            'total': weight * uniformity_loss,
            'uniformity': uniformity_loss
        }


class MultiObjectiveLoss:
    """
    Combined loss for ARQ optimization
    
    Loss = λ_q * L_quant + λ_o * L_ortho + λ_e * L_entropy + λ_s * L_sparsity + γ_uni * L_uniformity
    """
    
    def __init__(self, 
                 lambda_quant: float = 1.0,
                 lambda_ortho: float = 0.1, 
                 lambda_entropy: float = 0.01,
                 lambda_sparsity: float = 0.0,
                 bits: int = 4,
                 sym: bool = True,
                 sparsity_type: str = 'l1',
                 # Uniformity regularization parameters
                 gamma_uni: float = 0.0,
                 curriculum: bool = False,
                 warmup_steps: int = 50,
                 total_steps: int = 200):
        self.lambda_quant = lambda_quant
        self.lambda_ortho = lambda_ortho
        self.lambda_entropy = lambda_entropy
        self.lambda_sparsity = lambda_sparsity
        self.bits = bits
        self.sym = sym
        self.sparsity_type = sparsity_type
        
        # Setup quantizer for loss computation
        self.quantizer = ActQuantizer()
        self.quantizer.configure(bits=bits, sym=sym, clip_ratio=1.0)
        
        # Setup uniformity regularizer if weight is non-zero
        if gamma_uni > 0:
            self.uniformity_regularizer = UniformityRegularizer(
                num_bins=2**bits,
                gamma_uni=gamma_uni,
                curriculum=curriculum,
                warmup_steps=warmup_steps,
                total_steps=total_steps
            )
        else:
            self.uniformity_regularizer = None
    
    def quantization_loss(self, x_orig: torch.Tensor, x_rot: torch.Tensor) -> torch.Tensor:
        """
        Compute quantization reconstruction error
        x_orig: Original activations before rotation
        x_rot: Rotated activations
        """
        # Debug: check if bits=16 (no quantization)
        if self.bits == 16:
            # For 16-bit, use a proxy loss: minimize activation magnitude change
            # This encourages the rotation to preserve activation statistics
            orig_norm = torch.norm(x_orig, dim=-1)
            rot_norm = torch.norm(x_rot, dim=-1)
            # Convert to float32 for MSE loss computation (CPU doesn't support half MSE)
            return F.mse_loss(rot_norm.float(), orig_norm.float()) * 0.1
        
        # Quantize rotated activations
        self.quantizer.find_params(x_rot)
        x_rot_quant = self.quantizer(x_rot)
        
        # Reconstruction error after inverse rotation
        # Since we're learning Q, x_rot = x_orig @ Q
        # So reconstruction is x_rot_quant @ Q^T
        # For now, we'll compute error in rotated space (simpler)
        # Convert to float32 for MSE loss computation (CPU doesn't support half MSE)
        quant_error = F.mse_loss(x_rot_quant.float(), x_rot.float())
        
        return quant_error
    
    def orthogonality_loss(self, rotation_matrix: torch.Tensor) -> torch.Tensor:
        """
        Soft orthogonality constraint: ||Q^T Q - I||_F^2
        """
        Q = rotation_matrix
        n = Q.shape[0]
        I = torch.eye(n, device=Q.device)
        QTQ = Q.t() @ Q
        
        return torch.mean((QTQ - I) ** 2)
    
    def entropy_loss(self, x_rot: torch.Tensor) -> torch.Tensor:
        """
        Entropy regularization to encourage uniform distribution
        This helps with quantization by spreading values across the range
        """
        # Compute histogram in quantization bins
        self.quantizer.find_params(x_rot)
        
        # Normalize to [-1, 1] range for histogram
        if self.sym:
            x_norm = x_rot / (x_rot.abs().max(dim=-1, keepdim=True)[0] + 1e-8)
        else:
            x_min = x_rot.min(dim=-1, keepdim=True)[0]
            x_max = x_rot.max(dim=-1, keepdim=True)[0]
            x_norm = 2 * (x_rot - x_min) / (x_max - x_min + 1e-8) - 1
        
        # Compute soft histogram using softmax-based binning
        num_bins = 2**self.bits
        bin_edges = torch.linspace(-1, 1, num_bins + 1, device=x_rot.device)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        # For efficiency, sample a subset if tensor is too large
        if x_norm.numel() > 10000:
            # Random sample 10000 elements
            indices = torch.randperm(x_norm.numel(), device=x_norm.device)[:10000]
            x_sample = x_norm.view(-1)[indices].unsqueeze(1)
        else:
            x_sample = x_norm.reshape(-1, 1)
        
        # Compute distances to bin centers
        distances = -torch.abs(x_sample - bin_centers.unsqueeze(0))
        
        # Soft assignment to bins
        soft_hist = F.softmax(distances * 10, dim=1).mean(dim=0)  # Temperature = 10
        
        # Compute entropy
        entropy = -(soft_hist * torch.log(soft_hist + 1e-10)).sum()
        
        # We want to maximize entropy, so return negative
        return -entropy
    
    def sparsity_loss(self, transform) -> torch.Tensor:
        """
        Sparsity loss to encourage angles to stay close to zero (identity)
        """
        if self.lambda_sparsity == 0:
            return torch.tensor(0.0)
        
        # Check if transform has layer_angles (butterfly transform)
        if hasattr(transform, 'layer_angles'):
            # Collect all angles from butterfly layers
            all_angles = torch.cat([angles for angles in transform.layer_angles])
        else:
            # For other transforms, return 0
            return torch.tensor(0.0)
        
        if self.sparsity_type == 'l1':
            # L1 penalty: encourages exact zeros
            loss = torch.abs(all_angles).mean()
        elif self.sparsity_type == 'l2':
            # L2 penalty: smoother penalty
            loss = (all_angles ** 2).mean()
        elif self.sparsity_type == 'huber':
            # Huber loss: robust to outliers
            loss = F.smooth_l1_loss(all_angles, torch.zeros_like(all_angles))
        else:
            loss = torch.abs(all_angles).mean()  # Default to L1
        
        return loss
    
    def __call__(self, x_orig: torch.Tensor, x_rot: torch.Tensor, 
                 rotation_matrix: torch.Tensor, transform=None, step: int = None) -> dict:
        """
        Compute all loss components
        Returns dict with individual losses and total loss
        
        Args:
            x_orig: Original activations
            x_rot: Rotated activations
            rotation_matrix: The rotation matrix Q
            transform: Optional transform object for sparsity loss
            step: Current training step for curriculum learning
        """
        l_quant = self.quantization_loss(x_orig, x_rot)
        l_ortho = self.orthogonality_loss(rotation_matrix)
        l_entropy = self.entropy_loss(x_rot)
        
        # Compute sparsity loss if transform is provided
        if transform is not None and self.lambda_sparsity > 0:
            l_sparsity = self.sparsity_loss(transform)
        else:
            l_sparsity = torch.tensor(0.0)
        
        # Compute uniformity regularization if enabled
        if self.uniformity_regularizer is not None:
            uni_losses = self.uniformity_regularizer(x_rot, step=step)
            l_uniformity_total = uni_losses['total']
            l_uniformity = uni_losses['uniformity']
        else:
            l_uniformity_total = torch.tensor(0.0)
            l_uniformity = torch.tensor(0.0)
        
        total_loss = (self.lambda_quant * l_quant + 
                     self.lambda_ortho * l_ortho + 
                     self.lambda_entropy * l_entropy +
                     self.lambda_sparsity * l_sparsity +
                     l_uniformity_total)  # Uniformity regularization already weighted
        
        return {
            'total': total_loss,
            'quantization': l_quant,
            'orthogonality': l_ortho,
            'entropy': l_entropy,
            'sparsity': l_sparsity,
            'uniformity': l_uniformity
        }


class HardOrthogonalityConstraint:
    """
    Hard orthogonality constraint using Givens rotations or SVD projection
    """
    
    @staticmethod
    def project_to_orthogonal_svd(matrix: torch.Tensor) -> torch.Tensor:
        """
        Project matrix to nearest orthogonal matrix using SVD
        """
        U, _, Vt = torch.linalg.svd(matrix, full_matrices=False)
        return U @ Vt
    
    @staticmethod
    def project_to_orthogonal_givens(angles: list[torch.Tensor], dim: int) -> torch.Tensor:
        """
        Construct orthogonal matrix from Givens rotation angles
        This ensures exact orthogonality by construction
        """
        Q = torch.eye(dim, device=angles[0].device)
        
        # Apply Givens rotations in sequence
        idx = 0
        for i in range(dim):
            for j in range(i + 1, dim):
                if idx < len(angles):
                    c = torch.cos(angles[idx])
                    s = torch.sin(angles[idx])
                    
                    # Apply Givens rotation
                    G = torch.eye(dim, device=angles[0].device)
                    G[i, i] = c
                    G[j, j] = c
                    G[i, j] = -s
                    G[j, i] = s
                    
                    Q = Q @ G
                    idx += 1
        
        return Q


# Test the losses
if __name__ == "__main__":
    # Test setup
    batch_size = 32
    dim = 256
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create test data
    x_orig = torch.randn(batch_size, dim, device=device)
    Q = torch.eye(dim, device=device) + 0.1 * torch.randn(dim, dim, device=device)
    
    # Make Q approximately orthogonal
    Q = HardOrthogonalityConstraint.project_to_orthogonal_svd(Q)
    x_rot = x_orig @ Q
    
    # Test losses
    loss_fn = MultiObjectiveLoss(lambda_ortho=0.1, lambda_entropy=0.01)
    losses = loss_fn(x_orig, x_rot, Q)
    
    print("Loss components:")
    for name, value in losses.items():
        print(f"  {name}: {value.item():.6f}")
    
    # Test hard constraint
    print(f"\nOrthogonality before projection: {loss_fn.orthogonality_loss(Q + 0.1 * torch.randn_like(Q)).item():.6f}")
    Q_proj = HardOrthogonalityConstraint.project_to_orthogonal_svd(Q + 0.1 * torch.randn_like(Q))
    print(f"Orthogonality after projection: {loss_fn.orthogonality_loss(Q_proj).item():.6f}")