import torch
from collections import defaultdict
import numpy as np

def get_row_norm(W):
    """Calculate average row norm of matrix W."""
    # Use float32 for norm calculation to avoid overflow
    return torch.norm(W.float(), dim=1).mean().item()

def calculate_alignment_metrics(model, batch, random_baseline=True):
    """
    Calculate normalized ||Wz||/||W||_o metrics for all weight matrices.
    For the random baseline, uses random unit vectors instead of the actual normalized input.
    """
    print("Setting up metrics calculation...")
    # Add device check
    print(f"Input device: {batch['input_ids'].device}")
    
    # Move batch to GPU if needed
    if next(model.parameters()).device != batch['input_ids'].device:
        batch = {k: v.to(next(model.parameters()).device) for k, v in batch.items()}
    
    metrics = defaultdict(dict)
    
    def hook_fn(name):
        def hook(module, input, output):
            print(f"Processing layer: {name}")
            if hasattr(module, 'weight') and module.weight is not None:
                #import ipdb; ipdb.set_trace()
                z = input[0] if isinstance(input, tuple) else input  # input activation
                W = module.weight
                
                # Print shapes for debugging
                print(f"Input shape: {z.shape}, Weight shape: {W.shape}")
                
                # Convert to float32 for better numerical stability
                z = z.float()
                W = W.float()
                
                # Reshape z to 2D: (batch_size * seq_len, hidden_dim)
                if len(z.shape) > 2:
                    batch_size, seq_len, hidden_dim = z.shape
                    z = z.reshape(-1, hidden_dim)
                
                # For linear layers, we want to compute ||W·z|| / ||W||
                try:
                    # Normalize matrices
                    W_norm = (W**2).mean().sqrt()  # Compute average weight magnitude
                    z_norm = torch.norm(z, dim=1, keepdim=True)  # Compute row-wise norms
                    
                    W_normalized = W / (W_norm + 1e-8)
                    z_normalized = z / (z_norm + 1e-8)
                    
                    # Compute alignment metric - transpose z first for correct multiplication
                    Wz = torch.mm(z_normalized, W_normalized.t())  # Changed order of multiplication
                    metrics[name]['actual'] = torch.norm(Wz, dim=1).mean().item()/hidden_dim
                    
                    # Calculate metric for random baseline using random unit vectors
                    if random_baseline:
                        # Create random vectors with same shape as z_normalized
                        z_random = torch.randn_like(z_normalized)
                        # Normalize to unit vectors
                        z_random_norm = torch.norm(z_random, dim=1, keepdim=True)
                        z_random_normalized = z_random / (z_random_norm + 1e-8)
                        
                        # Use the same normalized W with random unit vectors
                        Wz_random = torch.mm(z_random_normalized, W_normalized.t())
                        metrics[name]['random'] = torch.norm(Wz_random, dim=1).mean().item()/hidden_dim

                except RuntimeError as e:
                    print(f"Error in layer {name}:")
                    print(f"Input shape: {z.shape}")
                    print(f"Weight shape: {W.shape}")
                    raise e
                
        return hook

    # Register hooks
    print("Registering hooks...")
    hooks = []
    weight_count = 0
    
    for name, module in model.named_modules():
        embedding_filter = isinstance(module, torch.nn.Embedding)
        ln_filter = isinstance(module, torch.nn.LayerNorm) or 'norm' in name.lower()
        lm_head_filter = 'lm_head' in name
        if hasattr(module, 'weight') and (module.weight is not None) and (not embedding_filter) and (not ln_filter) and (not lm_head_filter):
            weight_count += 1
            hooks.append(module.register_forward_hook(hook_fn(name)))
            print(f"Registered hook for weight matrix in: {name}")
    
    print(f"Registered hooks for {weight_count} weight matrices")
    
    # Forward pass
    print("Performing forward pass...")
    with torch.no_grad():
        outputs = model(**batch)
    print("Forward pass completed")
    
    # Remove hooks
    print("Cleaning up hooks...")
    for hook in hooks:
        hook.remove()
    
    return metrics 

def calculate_alignment_metrics_seq(model, batch, random_baseline=True):
    """
    Calculate sequence-level alignment metrics: (1/T)∑||W_normalized·z||/∑||z||
    For the random baseline, uses random unit vectors instead of the actual normalized input.
    """
    print("Setting up sequence-level metrics calculation...")
    # Add device check
    print(f"Input device: {batch['input_ids'].device}")
    
    # Move batch to GPU if needed
    if next(model.parameters()).device != batch['input_ids'].device:
        batch = {k: v.to(next(model.parameters()).device) for k, v in batch.items()}
    
    metrics = defaultdict(dict)
    
    def hook_fn(name):
        def hook(module, input, output):
            print(f"Processing layer: {name}")
            if hasattr(module, 'weight') and module.weight is not None:
                z = input[0] if isinstance(input, tuple) else input  # input activation
                W = module.weight
                
                # Print shapes for debugging
                print(f"Input shape: {z.shape}, Weight shape: {W.shape}")
                
                # Convert to float32 for better numerical stability
                z = z.float()
                W = W.float()
                
                # For sequence-level normalization, we need to preserve the batch and sequence dimensions
                if len(z.shape) == 3:  # [batch, seq_len, hidden_dim]
                    batch_size, seq_len, hidden_dim = z.shape
                    
                    try:
                        # Normalize only the weight matrix
                        W_norm = (W**2).mean().sqrt()
                        W_normalized = W / (W_norm + 1e-8)
                        
                        # Compute alignment for each sequence separately, preserving batch dimension
                        batch_metrics = []
                        batch_random_metrics = []
                        
                        for b in range(batch_size):
                            # Extract sequence for this batch item
                            z_seq = z[b]  # [seq_len, hidden_dim]
                            
                            # Compute ||z|| for each token in the sequence
                            z_norms = torch.norm(z_seq, dim=1)  # [seq_len]
                            
                            # Compute sum of ||z|| across the sequence
                            z_norm_mean = z_norms.mean() + 1e-8
                            
                            # Normalize z_seq
                            z_seq_normalized = z_seq / (torch.norm(z_seq, dim=1, keepdim=True) + 1e-8)
                            
                            # Compute W_normalized·z for each token 
                            Wz_seq = torch.mm(z_seq_normalized, W_normalized.t())  # [seq_len, output_dim]
                            
                            # Compute ||W_normalized·z|| for each token
                            Wz_norms = torch.norm(Wz_seq, dim=1)  # [seq_len]
                            
                            # Compute average ||W_normalized·z|| across the sequence
                            Wz_norm_avg = Wz_norms.mean()
                            
                            # Final metric: average ||W_normalized·z|| / average ||z|| for this sequence
                            seq_metric = Wz_norm_avg / z_norm_mean / hidden_dim  # Scale by hidden_dim for normalization
                            
                            batch_metrics.append(seq_metric.item())
                            
                            # Calculate random baseline with random unit vectors
                            if random_baseline:
                                # Create random vectors with same shape as z_seq
                                z_random = torch.randn_like(z_seq)
                                # Normalize to unit vectors
                                z_random_normalized = z_random / (torch.norm(z_random, dim=1, keepdim=True) + 1e-8)
                                
                                # Use the same normalized W with random unit vectors
                                Wz_random_seq = torch.mm(z_random_normalized, W_normalized.t())
                                Wz_random_norms = torch.norm(Wz_random_seq, dim=1)
                                Wz_random_norm_avg = Wz_random_norms.mean()
                                
                                seq_random_metric = Wz_random_norm_avg / z_norm_mean / hidden_dim
                                
                                batch_random_metrics.append(seq_random_metric.item())
                        
                        # Average across the batch
                        metrics[name]['actual'] = np.mean(batch_metrics)
                        
                        # Average random metrics across the batch
                        if random_baseline:
                            metrics[name]['random'] = np.mean(batch_random_metrics)
                    
                    except RuntimeError as e:
                        print(f"Error in layer {name}:")
                        print(f"Input shape: {z.shape}")
                        print(f"Weight shape: {W.shape}")
                        raise e
                else:
                    print(f"Skipping layer {name}: Input shape {z.shape} is not 3D [batch, seq, hidden]")
                
        return hook

    # Register hooks
    print("Registering hooks...")
    hooks = []
    weight_count = 0
    
    for name, module in model.named_modules():
        embedding_filter = isinstance(module, torch.nn.Embedding)
        ln_filter = isinstance(module, torch.nn.LayerNorm) or 'norm' in name.lower()
        if hasattr(module, 'weight') and (module.weight is not None) and (not embedding_filter) and (not ln_filter):
            weight_count += 1
            hooks.append(module.register_forward_hook(hook_fn(name)))
            print(f"Registered hook for weight matrix in: {name}")
    
    print(f"Registered hooks for {weight_count} weight matrices")
    
    # Forward pass
    print("Performing forward pass...")
    with torch.no_grad():
        outputs = model(**batch)
    print("Forward pass completed")
    
    # Remove hooks
    print("Cleaning up hooks...")
    for hook in hooks:
        hook.remove()
    
    return metrics

def calculate_alignment_metrics_seq_spectral(model, batch, random_baseline=True):
    """
    Calculate sequence-level alignment metrics using spectral norm: (1/T)∑||W_normalized·z||/∑||z||
    where W is normalized by its spectral norm (largest singular value).
    For the random baseline, uses random unit vectors instead of the actual normalized input.
    """
    print("Setting up sequence-level metrics calculation with spectral norm...")
    # Add device check
    print(f"Input device: {batch['input_ids'].device}")
    
    # Move batch to GPU if needed
    if next(model.parameters()).device != batch['input_ids'].device:
        batch = {k: v.to(next(model.parameters()).device) for k, v in batch.items()}
    
    metrics = defaultdict(dict)
    
    def hook_fn(name):
        def hook(module, input, output):
            print(f"Processing layer: {name}")
            if hasattr(module, 'weight') and module.weight is not None:
                z = input[0] if isinstance(input, tuple) else input  # input activation
                W = module.weight
                
                # Print shapes for debugging
                print(f"Input shape: {z.shape}, Weight shape: {W.shape}")
                
                # Convert to float32 for better numerical stability
                z = z.float()
                W = W.float()
                
                # For sequence-level normalization, we need to preserve the batch and sequence dimensions
                if len(z.shape) == 3:  # [batch, seq_len, hidden_dim]
                    batch_size, seq_len, hidden_dim = z.shape
                    
                    try:
                        # Compute spectral norm (largest singular value) for weight matrix
                        with torch.no_grad():
                            # Use appropriate method based on matrix size
                            if max(W.shape) > 500:  # Use power iteration for large matrices
                                u = torch.randn(W.shape[0], 1, device=W.device)
                                v = torch.randn(W.shape[1], 1, device=W.device)
                                
                                # Power iteration for 10 steps (usually converges quickly)
                                for _ in range(10):
                                    v = torch.matmul(W.t(), u)
                                    v = v / (torch.norm(v) + 1e-8)
                                    u = torch.matmul(W, v)
                                    u = u / (torch.norm(u) + 1e-8)
                                    
                                W_spectral_norm = torch.matmul(torch.matmul(u.t(), W), v).item()
                            else:
                                # For smaller matrices, directly compute SVD
                                _, s, _ = torch.linalg.svd(W, full_matrices=False)
                                W_spectral_norm = s[0].item()  # Largest singular value
                        
                        # Normalize weight matrix by spectral norm
                        W_normalized = W / (W_spectral_norm + 1e-8)
                        
                        # Compute alignment for each sequence separately, preserving batch dimension
                        batch_metrics = []
                        batch_random_metrics = []
                        
                        for b in range(batch_size):
                            # Extract sequence for this batch item
                            z_seq = z[b]  # [seq_len, hidden_dim]
                            
                            # Compute ||z|| for each token in the sequence
                            z_norms = torch.norm(z_seq, dim=1)  # [seq_len]
                            
                            # Compute sum of ||z|| across the sequence
                            z_norm_mean = z_norms.mean() + 1e-8
                            
                            # Normalize z_seq
                            z_seq_normalized = z_seq / (torch.norm(z_seq, dim=1, keepdim=True) + 1e-8)
                            
                            # Compute W_normalized·z for each token 
                            Wz_seq = torch.mm(z_seq_normalized, W_normalized.t())  # [seq_len, output_dim]
                            
                            # Compute ||W_normalized·z|| for each token
                            Wz_norms = torch.norm(Wz_seq, dim=1)  # [seq_len]
                            
                            # Compute average ||W_normalized·z|| across the sequence
                            Wz_norm_avg = Wz_norms.mean()
                            
                            # Final metric: average ||W_normalized·z|| / average ||z|| for this sequence
                            seq_metric = Wz_norm_avg / z_norm_mean / hidden_dim  # Scale by hidden_dim for normalization
                            
                            batch_metrics.append(seq_metric.item())
                            
                            # Calculate random baseline with random unit vectors
                            if random_baseline:
                                # Create random vectors with same shape as z_seq
                                z_random = torch.randn_like(z_seq)
                                # Normalize to unit vectors
                                z_random_normalized = z_random / (torch.norm(z_random, dim=1, keepdim=True) + 1e-8)
                                
                                # Use the same normalized W with random unit vectors
                                Wz_random_seq = torch.mm(z_random_normalized, W_normalized.t())
                                Wz_random_norms = torch.norm(Wz_random_seq, dim=1)
                                Wz_random_norm_avg = Wz_random_norms.mean()
                                
                                seq_random_metric = Wz_random_norm_avg / z_norm_mean / hidden_dim
                                
                                batch_random_metrics.append(seq_random_metric.item())
                        
                        # Average across the batch
                        metrics[name]['actual'] = np.mean(batch_metrics)
                        
                        # Average random metrics across the batch
                        if random_baseline:
                            metrics[name]['random'] = np.mean(batch_random_metrics)
                    
                    except RuntimeError as e:
                        print(f"Error in layer {name}:")
                        print(f"Input shape: {z.shape}")
                        print(f"Weight shape: {W.shape}")
                        raise e
                else:
                    print(f"Skipping layer {name}: Input shape {z.shape} is not 3D [batch, seq, hidden]")
                
        return hook

    # Register hooks
    print("Registering hooks...")
    hooks = []
    weight_count = 0
    
    for name, module in model.named_modules():
        embedding_filter = isinstance(module, torch.nn.Embedding)
        ln_filter = isinstance(module, torch.nn.LayerNorm) or 'norm' in name.lower()
        if hasattr(module, 'weight') and (module.weight is not None) and (not embedding_filter) and (not ln_filter):
            weight_count += 1
            hooks.append(module.register_forward_hook(hook_fn(name)))
            print(f"Registered hook for weight matrix in: {name}")
    
    print(f"Registered hooks for {weight_count} weight matrices")
    
    # Forward pass
    print("Performing forward pass...")
    with torch.no_grad():
        outputs = model(**batch)
    print("Forward pass completed")
    
    # Remove hooks
    print("Cleaning up hooks...")
    for hook in hooks:
        hook.remove()
    
    return metrics

def get_group_metrics(metrics, groups=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj', 'down_proj']):
    """
    Calculate group metrics for all weight matrices.
    Returns averaged metrics for each group of weight matrices.
    """
    group_metrics = defaultdict(dict)
    
    # Initialize counters and accumulators for each group
    for group in groups:
        group_metrics[group] = {
            'count': 0,
            'actual_sum': 0.0,
            'random_sum': 0.0
        }
    
    # Accumulate metrics for each group
    for name, values in metrics.items():
        for group in groups:
            if group in name:
                group_metrics[group]['count'] += 1
                group_metrics[group]['actual_sum'] += values.get('actual', 0.0)
                group_metrics[group]['random_sum'] += values.get('random', 0.0)
    
    # Calculate averages and format final results
    results = {}
    for group, data in group_metrics.items():
        count = data['count']
        if count > 0:
            results[group] = {
                'actual': data['actual_sum'] / count,
                'random': data['random_sum'] / count if 'random_sum' in data else 0.0
            }
        else:
            results[group] = {'actual': 0.0, 'random': 0.0}
    
    return results

def calculate_alignment_metrics_spectral(model, batch, random_baseline=True):
    """
    Calculate normalized ||Wz||/||W||_spectral metrics for all weight matrices,
    where ||W||_spectral is the spectral norm (largest singular value) of W.
    For the random baseline, uses random unit vectors instead of the actual normalized input.
    """
    print("Setting up metrics calculation with spectral norm...")
    
    # Add device check
    print(f"Input device: {batch['input_ids'].device}")
    
    # Move batch to GPU if needed
    if next(model.parameters()).device != batch['input_ids'].device:
        batch = {k: v.to(next(model.parameters()).device) for k, v in batch.items()}
    
    metrics = defaultdict(dict)
    
    def hook_fn(name):
        def hook(module, input, output):
            print(f"Processing layer: {name}")
            if hasattr(module, 'weight') and module.weight is not None:
                z = input[0] if isinstance(input, tuple) else input  # input activation
                W = module.weight
                
                # Print shapes for debugging
                print(f"Input shape: {z.shape}, Weight shape: {W.shape}")
                
                # Convert to float32 for better numerical stability
                z = z.float()
                W = W.float()
                
                # Reshape z to 2D: (batch_size * seq_len, hidden_dim)
                if len(z.shape) > 2:
                    batch_size, seq_len, hidden_dim = z.shape
                    z = z.reshape(-1, hidden_dim)
                
                # For linear layers, we want to compute ||W·z|| / ||W||_spectral
                try:
                    # Compute spectral norm (largest singular value)
                    with torch.no_grad():
                        # Use torch.svd or torch.linalg.svd to get the largest singular value
                        # For efficiency, we can use power iteration for larger matrices
                        if max(W.shape) > 500:  # Use power iteration for large matrices
                            u = torch.randn(W.shape[0], 1, device=W.device)
                            v = torch.randn(W.shape[1], 1, device=W.device)
                            
                            # Power iteration for 10 steps (usually converges quickly)
                            for _ in range(10):
                                v = torch.matmul(W.t(), u)
                                v = v / (torch.norm(v) + 1e-8)
                                u = torch.matmul(W, v)
                                u = u / (torch.norm(u) + 1e-8)
                                
                            W_spectral_norm = torch.matmul(torch.matmul(u.t(), W), v).item()
                        else:
                            # For smaller matrices, directly compute SVD
                            _, s, _ = torch.linalg.svd(W, full_matrices=False)
                            W_spectral_norm = s[0].item()  # Largest singular value
                    
                    # Normalize matrices
                    z_norm = torch.norm(z, dim=1, keepdim=True)  # Compute row-wise norms
                    
                    W_normalized = W / (W_spectral_norm + 1e-8)
                    z_normalized = z / (z_norm + 1e-8)
                    
                    # Compute alignment metric - transpose z first for correct multiplication
                    Wz = torch.mm(z_normalized, W_normalized.t())
                    metrics[name]['actual'] = torch.norm(Wz, dim=1).mean().item()/hidden_dim
                    
                    # Calculate metric for random baseline using random unit vectors
                    if random_baseline:
                        # Create random vectors with same shape as z_normalized
                        z_random = torch.randn_like(z_normalized)
                        # Normalize to unit vectors
                        z_random_norm = torch.norm(z_random, dim=1, keepdim=True)
                        z_random_normalized = z_random / (z_random_norm + 1e-8)
                        
                        # Use the same normalized W with random unit vectors
                        Wz_random = torch.mm(z_random_normalized, W_normalized.t())
                        metrics[name]['random'] = torch.norm(Wz_random, dim=1).mean().item()/hidden_dim

                except RuntimeError as e:
                    print(f"Error in layer {name}:")
                    print(f"Input shape: {z.shape}")
                    print(f"Weight shape: {W.shape}")
                    raise e
                
        return hook

    # Register hooks
    print("Registering hooks...")
    hooks = []
    weight_count = 0
    
    for name, module in model.named_modules():
        embedding_filter = isinstance(module, torch.nn.Embedding)
        ln_filter = isinstance(module, torch.nn.LayerNorm) or 'norm' in name.lower()
        if hasattr(module, 'weight') and (module.weight is not None) and (not embedding_filter) and (not ln_filter):
            weight_count += 1
            hooks.append(module.register_forward_hook(hook_fn(name)))
            print(f"Registered hook for weight matrix in: {name}")
    
    print(f"Registered hooks for {weight_count} weight matrices")
    
    # Forward pass
    print("Performing forward pass...")
    with torch.no_grad():
        outputs = model(**batch)
    print("Forward pass completed")
    
    # Remove hooks
    print("Cleaning up hooks...")
    for hook in hooks:
        hook.remove()
    
    return metrics
    
    